diff --git a/connection.go b/connection.go index 06b3286d..fe24d99a 100644 --- a/connection.go +++ b/connection.go @@ -10,32 +10,35 @@ import ( "net" ) +type ConnectionParameters struct { + socket string // path to unix domain socket (e.g. /private/tmp/.s.PGSQL.5432) + database string + user string + password string +} + type Connection struct { - conn net.Conn // the underlying TCP or unix domain socket connection - buf []byte // work buffer to avoid constant alloc and dealloc - pid int32 // backend pid - secretKey int32 // key to use to send a cancel query message to the server - runtimeParams map[string]string // parameters that have been reported by the server - options map[string]string // options used when establishing connection + conn net.Conn // the underlying TCP or unix domain socket connection + buf []byte // work buffer to avoid constant alloc and dealloc + pid int32 // backend pid + secretKey int32 // key to use to send a cancel query message to the server + runtimeParams map[string]string // parameters that have been reported by the server + parameters ConnectionParameters // parameters used when establishing this connection txStatus byte } // options: // socket: path to unix domain socket +// host: TCP address +// port: // database: name of database -func Connect(options map[string]string) (c *Connection, err error) { +func Connect(paramaters ConnectionParameters) (c *Connection, err error) { c = new(Connection) - c.options = make(map[string]string) - for k, v := range options { - c.options[k] = v - } + c.parameters = paramaters - var present bool - var socket string - - if socket, present = options["socket"]; present { - c.conn, err = net.Dial("unix", socket) + if c.parameters.socket != "" { + c.conn, err = net.Dial("unix", c.parameters.socket) if err != nil { return nil, err } @@ -46,12 +49,10 @@ func Connect(options map[string]string) (c *Connection, err error) { // conn, err := net.Dial("tcp", "localhost:5432") - var database string - msg := newStartupMessage() - msg.options["user"], _ = options["user"] - if database, present = options["database"]; present { - msg.options["database"] = database + msg.options["user"] = c.parameters.user + if c.parameters.database != "" { + msg.options["database"] = c.parameters.database } c.txStartupMessage(msg) @@ -240,10 +241,10 @@ func (c *Connection) rxAuthenticationX(r *messageReader) (err error) { switch code { case 0: // AuthenticationOk case 3: // AuthenticationCleartextPassword - c.txPasswordMessage(c.options["password"]) + c.txPasswordMessage(c.parameters.password) case 5: // AuthenticationMD5Password salt := r.readByteString(4) - digestedPassword := "md5" + hexMD5(hexMD5(c.options["password"]+c.options["user"])+salt) + digestedPassword := "md5" + hexMD5(hexMD5(c.parameters.password+c.parameters.user)+salt) c.txPasswordMessage(digestedPassword) default: err = errors.New("Received unknown authentication message") diff --git a/connection_pool.go b/connection_pool.go index 7a668c1b..256c72bb 100644 --- a/connection_pool.go +++ b/connection_pool.go @@ -2,25 +2,22 @@ package pgx type ConnectionPool struct { connectionChannel chan *Connection - options map[string]string // options used when establishing connection + parameters ConnectionParameters // options used when establishing connection MaxConnections int } // options: options used by Connect // MaxConnections: max simultaneous connections to use (currently all are immediately connected) -func NewConnectionPool(options map[string]string, MaxConnections int) (p *ConnectionPool, err error) { +func NewConnectionPool(parameters ConnectionParameters, MaxConnections int) (p *ConnectionPool, err error) { p = new(ConnectionPool) p.connectionChannel = make(chan *Connection, MaxConnections) p.MaxConnections = MaxConnections - p.options = make(map[string]string) - for k, v := range options { - p.options[k] = v - } + p.parameters = parameters for i := 0; i < p.MaxConnections; i++ { var c *Connection - c, err = Connect(options) + c, err = Connect(p.parameters) if err != nil { return } diff --git a/connection_pool_test.go b/connection_pool_test.go index d00a017e..d0be865d 100644 --- a/connection_pool_test.go +++ b/connection_pool_test.go @@ -6,7 +6,7 @@ import ( ) func createConnectionPool(maxConnections int) *ConnectionPool { - connectionOptions := map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"} + connectionOptions := ConnectionParameters{socket: "/private/tmp/.s.PGSQL.5432", user: "pgx_none", database: "pgx_test"} pool, err := NewConnectionPool(connectionOptions, maxConnections) if err != nil { panic("Unable to create connection pool") @@ -15,7 +15,7 @@ func createConnectionPool(maxConnections int) *ConnectionPool { } func TestNewConnectionPool(t *testing.T) { - connectionOptions := map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"} + connectionOptions := ConnectionParameters{socket: "/private/tmp/.s.PGSQL.5432", user: "pgx_none", database: "pgx_test"} pool, err := NewConnectionPool(connectionOptions, 5) if err != nil { t.Fatal("Unable to establish connection pool") diff --git a/connection_test.go b/connection_test.go index b32c7911..0e43ab7d 100644 --- a/connection_test.go +++ b/connection_test.go @@ -9,7 +9,7 @@ var SharedConnection *Connection func getSharedConnection() (c *Connection) { if SharedConnection == nil { var err error - SharedConnection, err = Connect(map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"}) + SharedConnection, err = Connect(ConnectionParameters{socket: "/private/tmp/.s.PGSQL.5432", user: "pgx_none", database: "pgx_test"}) if err != nil { panic("Unable to establish connection") } @@ -19,7 +19,7 @@ func getSharedConnection() (c *Connection) { } func TestConnect(t *testing.T) { - conn, err := Connect(map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_none", "database": "pgx_test"}) + conn, err := Connect(ConnectionParameters{socket: "/private/tmp/.s.PGSQL.5432", user: "pgx_none", database: "pgx_test"}) if err != nil { t.Fatal("Unable to establish connection") } @@ -54,7 +54,7 @@ func TestConnect(t *testing.T) { } func TestConnectWithInvalidUser(t *testing.T) { - _, err := Connect(map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "invalid_user", "database": "pgx_test"}) + _, err := Connect(ConnectionParameters{socket: "/private/tmp/.s.PGSQL.5432", user: "invalid_user", database: "pgx_test"}) pgErr := err.(PgError) if pgErr.Code != "28000" { t.Fatal("Did not receive expected error when connecting with invalid user") @@ -62,7 +62,7 @@ func TestConnectWithInvalidUser(t *testing.T) { } func TestConnectWithPlainTextPassword(t *testing.T) { - conn, err := Connect(map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_pw", "password": "secret", "database": "pgx_test"}) + conn, err := Connect(ConnectionParameters{socket: "/private/tmp/.s.PGSQL.5432", user: "pgx_pw", password: "secret", database: "pgx_test"}) if err != nil { t.Fatal("Unable to establish connection: " + err.Error()) } @@ -74,7 +74,7 @@ func TestConnectWithPlainTextPassword(t *testing.T) { } func TestConnectWithMD5Password(t *testing.T) { - conn, err := Connect(map[string]string{"socket": "/private/tmp/.s.PGSQL.5432", "user": "pgx_md5", "password": "secret", "database": "pgx_test"}) + conn, err := Connect(ConnectionParameters{socket: "/private/tmp/.s.PGSQL.5432", user: "pgx_md5", password: "secret", database: "pgx_test"}) if err != nil { t.Fatal("Unable to establish connection: " + err.Error()) }