diff --git a/src/pkg/exp/ssh/channel.go b/src/pkg/exp/ssh/channel.go index 922584f631..f69b735fd4 100644 --- a/src/pkg/exp/ssh/channel.go +++ b/src/pkg/exp/ssh/channel.go @@ -68,7 +68,7 @@ type channel struct { weClosed bool dead bool - serverConn *ServerConnection + serverConn *ServerConn myId, theirId uint32 myWindow, theirWindow uint32 maxPacketSize uint32 diff --git a/src/pkg/exp/ssh/client.go b/src/pkg/exp/ssh/client.go index edb95eccc6..b3d7708a26 100644 --- a/src/pkg/exp/ssh/client.go +++ b/src/pkg/exp/ssh/client.go @@ -115,7 +115,7 @@ func (c *ClientConn) handshake() os.Error { dhGroup14Once.Do(initDHGroup14) H, K, err = c.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo) default: - fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo) + err = fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo) } if err != nil { return err diff --git a/src/pkg/exp/ssh/server.go b/src/pkg/exp/ssh/server.go index b5a5e017d3..3a640fc081 100644 --- a/src/pkg/exp/ssh/server.go +++ b/src/pkg/exp/ssh/server.go @@ -10,19 +10,23 @@ import ( "crypto" "crypto/rand" "crypto/rsa" - _ "crypto/sha1" "crypto/x509" "encoding/pem" + "io" "net" "os" "sync" ) -// Server represents an SSH server. A Server may have several ServerConnections. -type Server struct { +type ServerConfig struct { rsa *rsa.PrivateKey rsaSerialized []byte + // Rand provides the source of entropy for key exchange. If Rand is + // nil, the cryptographic random reader in package crypto/rand will + // be used. + Rand io.Reader + // NoClientAuth is true if clients are allowed to connect without // authenticating. NoClientAuth bool @@ -38,11 +42,18 @@ type Server struct { PubKeyCallback func(user, algo string, pubkey []byte) bool } +func (c *ServerConfig) rand() io.Reader { + if c.Rand == nil { + return rand.Reader + } + return c.Rand +} + // SetRSAPrivateKey sets the private key for a Server. A Server must have a // private key configured in order to accept connections. The private key must // be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa" // typically contains such a key. -func (s *Server) SetRSAPrivateKey(pemBytes []byte) os.Error { +func (s *ServerConfig) SetRSAPrivateKey(pemBytes []byte) os.Error { block, _ := pem.Decode(pemBytes) if block == nil { return os.NewError("ssh: no key found") @@ -109,7 +120,7 @@ func parseRSASig(in []byte) (sig []byte, ok bool) { } // cachedPubKey contains the results of querying whether a public key is -// acceptable for a user. The cache only applies to a single ServerConnection. +// acceptable for a user. The cache only applies to a single ServerConn. type cachedPubKey struct { user, algo string pubKey []byte @@ -118,11 +129,10 @@ type cachedPubKey struct { const maxCachedPubKeys = 16 -// ServerConnection represents an incomming connection to a Server. -type ServerConnection struct { - Server *Server - +// A ServerConn represents an incomming connection. +type ServerConn struct { *transport + config *ServerConfig channels map[uint32]*channel nextChanId uint32 @@ -139,9 +149,20 @@ type ServerConnection struct { cachedPubKeys []cachedPubKey } +// Server returns a new SSH server connection +// using c as the underlying transport. +func Server(c net.Conn, config *ServerConfig) *ServerConn { + conn := &ServerConn{ + transport: newTransport(c, config.rand()), + channels: make(map[uint32]*channel), + config: config, + } + return conn +} + // kexDH performs Diffie-Hellman key agreement on a ServerConnection. The // returned values are given the same names as in RFC 4253, section 8. -func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) { +func (s *ServerConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) { packet, err := s.readPacket() if err != nil { return @@ -155,7 +176,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h return nil, nil, os.NewError("client DH parameter out of bounds") } - y, err := rand.Int(rand.Reader, group.p) + y, err := rand.Int(s.config.rand(), group.p) if err != nil { return } @@ -166,7 +187,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h var serializedHostKey []byte switch hostKeyAlgo { case hostAlgoRSA: - serializedHostKey = s.Server.rsaSerialized + serializedHostKey = s.config.rsaSerialized default: return nil, nil, os.NewError("internal error") } @@ -192,7 +213,7 @@ func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *h var sig []byte switch hostKeyAlgo { case hostAlgoRSA: - sig, err = rsa.SignPKCS1v15(rand.Reader, s.Server.rsa, hashFunc, hh) + sig, err = rsa.SignPKCS1v15(s.config.rand(), s.config.rsa, hashFunc, hh) if err != nil { return } @@ -257,17 +278,18 @@ func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubK return ret } -// Handshake performs an SSH transport and client authentication on the given ServerConnection. -func (s *ServerConnection) Handshake(conn net.Conn) os.Error { +// Handshake performs an SSH transport and client authentication on the given ServerConn. +func (s *ServerConn) Handshake() os.Error { var magics handshakeMagics - s.transport = newTransport(conn, rand.Reader) - - if _, err := conn.Write(serverVersion); err != nil { + if _, err := s.Write(serverVersion); err != nil { + return err + } + if err := s.Flush(); err != nil { return err } magics.serverVersion = serverVersion[:len(serverVersion)-2] - version, err := readVersion(s.transport) + version, err := readVersion(s) if err != nil { return err } @@ -310,8 +332,7 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] { // The client sent a Kex message for the wrong algorithm, // which we have to ignore. - _, err := s.readPacket() - if err != nil { + if _, err := s.readPacket(); err != nil { return err } } @@ -324,32 +345,27 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { dhGroup14Once.Do(initDHGroup14) H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo) default: - err = os.NewError("ssh: internal error") + err = os.NewError("ssh: unexpected key exchange algorithm " + kexAlgo) } - if err != nil { return err } - packet = []byte{msgNewKeys} - if err = s.writePacket(packet); err != nil { + if err = s.writePacket([]byte{msgNewKeys}); err != nil { return err } if err = s.transport.writer.setupKeys(serverKeys, K, H, H, hashFunc); err != nil { return err } - if packet, err = s.readPacket(); err != nil { return err } + if packet[0] != msgNewKeys { return UnexpectedMessageError{msgNewKeys, packet[0]} } - s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc) - - packet, err = s.readPacket() - if err != nil { + if packet, err = s.readPacket(); err != nil { return err } @@ -360,20 +376,16 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error { if serviceRequest.Service != serviceUserAuth { return os.NewError("ssh: requested service '" + serviceRequest.Service + "' before authenticating") } - serviceAccept := serviceAcceptMsg{ Service: serviceUserAuth, } - packet = marshal(msgServiceAccept, serviceAccept) - if err = s.writePacket(packet); err != nil { + if err = s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil { return err } if err = s.authenticate(H); err != nil { return err } - - s.channels = make(map[uint32]*channel) return nil } @@ -382,8 +394,8 @@ func isAcceptableAlgo(algo string) bool { } // testPubKey returns true if the given public key is acceptable for the user. -func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool { - if s.Server.PubKeyCallback == nil || !isAcceptableAlgo(algo) { +func (s *ServerConn) testPubKey(user, algo string, pubKey []byte) bool { + if s.config.PubKeyCallback == nil || !isAcceptableAlgo(algo) { return false } @@ -393,7 +405,7 @@ func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool { } } - result := s.Server.PubKeyCallback(user, algo, pubKey) + result := s.config.PubKeyCallback(user, algo, pubKey) if len(s.cachedPubKeys) < maxCachedPubKeys { c := cachedPubKey{ user: user, @@ -408,7 +420,7 @@ func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool { return result } -func (s *ServerConnection) authenticate(H []byte) os.Error { +func (s *ServerConn) authenticate(H []byte) os.Error { var userAuthReq userAuthRequestMsg var err os.Error var packet []byte @@ -428,11 +440,11 @@ userAuthLoop: switch userAuthReq.Method { case "none": - if s.Server.NoClientAuth { + if s.config.NoClientAuth { break userAuthLoop } case "password": - if s.Server.PasswordCallback == nil { + if s.config.PasswordCallback == nil { break } payload := userAuthReq.Payload @@ -445,11 +457,11 @@ userAuthLoop: return ParseError{msgUserAuthRequest} } - if s.Server.PasswordCallback(userAuthReq.User, string(password)) { + if s.config.PasswordCallback(userAuthReq.User, string(password)) { break userAuthLoop } case "publickey": - if s.Server.PubKeyCallback == nil { + if s.config.PubKeyCallback == nil { break } payload := userAuthReq.Payload @@ -520,10 +532,10 @@ userAuthLoop: } var failureMsg userAuthFailureMsg - if s.Server.PasswordCallback != nil { + if s.config.PasswordCallback != nil { failureMsg.Methods = append(failureMsg.Methods, "password") } - if s.Server.PubKeyCallback != nil { + if s.config.PubKeyCallback != nil { failureMsg.Methods = append(failureMsg.Methods, "publickey") } @@ -546,9 +558,9 @@ userAuthLoop: const defaultWindowSize = 32768 -// Accept reads and processes messages on a ServerConnection. It must be called +// Accept reads and processes messages on a ServerConn. It must be called // in order to demultiplex messages to any resulting Channels. -func (s *ServerConnection) Accept() (Channel, os.Error) { +func (s *ServerConn) Accept() (Channel, os.Error) { if s.err != nil { return nil, s.err } @@ -643,3 +655,44 @@ func (s *ServerConnection) Accept() (Channel, os.Error) { panic("unreachable") } + +// A Listener implements a network listener (net.Listener) for SSH connections. +type Listener struct { + listener net.Listener + config *ServerConfig +} + +// Accept waits for and returns the next incoming SSH connection. +// The receiver should call Handshake() in another goroutine +// to avoid blocking the accepter. +func (l *Listener) Accept() (*ServerConn, os.Error) { + c, err := l.listener.Accept() + if err != nil { + return nil, err + } + conn := Server(c, l.config) + return conn, nil +} + +// Addr returns the listener's network address. +func (l *Listener) Addr() net.Addr { + return l.listener.Addr() +} + +// Close closes the listener. +func (l *Listener) Close() os.Error { + return l.listener.Close() +} + +// Listen creates an SSH listener accepting connections on +// the given network address using net.Listen. +func Listen(network, addr string, config *ServerConfig) (*Listener, os.Error) { + l, err := net.Listen(network, addr) + if err != nil { + return nil, err + } + return &Listener{ + l, + config, + }, nil +}