diff --git a/.travis.yml b/.travis.yml index 6d4b3cd2..060fac8a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -11,6 +11,11 @@ before_install: env: global: - PGX_TEST_DATABASE=postgres://pgx_md5:secret@127.0.0.1/pgx_test + - PGX_TEST_UNIX_SOCKET_CONN_STRING="host=/var/run/postgresql database=pgx_test" + - PGX_TEST_TCP_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test + - PGX_TEST_TLS_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test?sslmode=require + - PGX_TEST_MD5_PASSWORD_CONN_STRING=postgres://pgx_md5:secret@127.0.0.1/pgx_test + - PGX_TEST_PLAIN_PASSWORD_CONN_STRING=postgres://pgx_pw:secret@127.0.0.1/pgx_test matrix: - CRATEVERSION=2.1 - PGVERSION=10 diff --git a/conn_config_test.go.example b/conn_config_test.go.example index 096e1354..bbe14438 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -14,66 +14,7 @@ import ( var defaultConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // To skip tests for specific connection / authentication types set that connection param to nil -var tcpConnConfig *pgx.ConnConfig = nil -var unixSocketConnConfig *pgx.ConnConfig = nil -var md5ConnConfig *pgx.ConnConfig = nil -var plainPasswordConnConfig *pgx.ConnConfig = nil -var invalidUserConnConfig *pgx.ConnConfig = nil -var tlsConnConfig *pgx.ConnConfig = nil -var customDialerConnConfig *pgx.ConnConfig = nil var replicationConnConfig *pgx.ConnConfig = nil var cratedbConnConfig *pgx.ConnConfig = nil -// var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} -// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"} -// var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} -// var plainPasswordConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"} -// var invalidUserConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"} -// var customDialerConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // var replicationConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_replication", Password: "secret", Database: "pgx_test"} - -// var tlsConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}} -// -//// or to test client certs: -// -// var tlsConnConfig *pgx.ConnConfig -// -// func init() { -// homeDir := build.Default.GOPATH -// tlsConnConfig = &pgx.ConnConfig{ -// Host: "127.0.0.1", -// User: "pgx_md5", -// Password: "secret", -// Database: "pgx_test", -// TLSConfig: &tls.Config{ -// InsecureSkipVerify: true, -// }, -// } -// caCertPool := x509.NewCertPool() -// -// caPath := path.Join(homeDir, "/src/github.com/jackc/pgx/rootCA.pem") -// caCert, err := ioutil.ReadFile(caPath) -// if err != nil { -// panic(fmt.Sprintf("unable to read CA file: %v", err)) -// } -// -// if !caCertPool.AppendCertsFromPEM(caCert) { -// panic("unable to add CA to cert pool") -// } -// -// tlsConnConfig.TLSConfig.RootCAs = caCertPool -// tlsConnConfig.TLSConfig.ClientCAs = caCertPool -// -// sslCert := path.Join(homeDir, "/src/github.com/jackc/pgx/pg_md5.crt") -// sslKey := path.Join(homeDir, "/src/github.com/jackc/pgx/pg_md5.key") -// if (sslCert != "" && sslKey == "") || (sslCert == "" && sslKey != "") { -// panic(`both "sslcert" and "sslkey" are required`) -// } -// -// cert, err := tls.LoadX509KeyPair(sslCert, sslKey) -// if err != nil { -// panic(fmt.Sprintf("unable to read cert: %v", err)) -// } -// -// tlsConnConfig.TLSConfig.Certificates = []tls.Certificate{cert} -// } diff --git a/conn_config_test.go.travis b/conn_config_test.go.travis index cf29a743..d67f6887 100644 --- a/conn_config_test.go.travis +++ b/conn_config_test.go.travis @@ -8,13 +8,6 @@ import ( ) var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} -var tcpConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} -var unixSocketConnConfig = &pgx.ConnConfig{Host: "/var/run/postgresql", User: "postgres", Database: "pgx_test"} -var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} -var plainPasswordConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_pw", Password: "secret", Database: "pgx_test"} -var invalidUserConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "invalid", Database: "pgx_test"} -var tlsConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_ssl", Password: "secret", Database: "pgx_test", TLSConfig: &tls.Config{InsecureSkipVerify: true}} -var customDialerConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} var replicationConnConfig *pgx.ConnConfig = nil var cratedbConnConfig *pgx.ConnConfig = nil diff --git a/conn_test.go b/conn_test.go index 90da4a7d..4b4c8562 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,7 +3,6 @@ package pgx_test import ( "context" "fmt" - "net" "strconv" "strings" "sync" @@ -81,128 +80,6 @@ func TestConnect(t *testing.T) { } } -func TestConnectWithUnixSocketDirectory(t *testing.T) { - t.Parallel() - - // /.s.PGSQL.5432 - if unixSocketConnConfig == nil { - t.Skip("Skipping due to undefined unixSocketConnConfig") - } - - conn, err := pgx.ConnectConfig(context.Background(), unixSocketConnConfig) - if err != nil { - t.Fatalf("Unable to establish connection: %v", err) - } - - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") - } -} - -func TestConnectWithTcp(t *testing.T) { - t.Parallel() - - if tcpConnConfig == nil { - t.Skip("Skipping due to undefined tcpConnConfig") - } - - conn, err := pgx.ConnectConfig(context.Background(), tcpConnConfig) - if err != nil { - t.Fatal("Unable to establish connection: " + err.Error()) - } - - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") - } -} - -func TestConnectWithTLS(t *testing.T) { - t.Parallel() - - if tlsConnConfig == nil { - t.Skip("Skipping due to undefined tlsConnConfig") - } - - conn, err := pgx.ConnectConfig(context.Background(), tlsConnConfig) - if err != nil { - t.Fatal("Unable to establish connection: " + err.Error()) - } - - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") - } -} - -func TestConnectWithInvalidUser(t *testing.T) { - t.Parallel() - - if invalidUserConnConfig == nil { - t.Skip("Skipping due to undefined invalidUserConnConfig") - } - - _, err := pgx.ConnectConfig(context.Background(), invalidUserConnConfig) - pgErr, ok := err.(pgx.PgError) - if !ok { - t.Fatalf("Expected to receive a PgError with code 28000, instead received: %v", err) - } - if pgErr.Code != "28000" && pgErr.Code != "28P01" { - t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) - } -} - -func TestConnectWithPlainTextPassword(t *testing.T) { - t.Parallel() - - if plainPasswordConnConfig == nil { - t.Skip("Skipping due to undefined plainPasswordConnConfig") - } - - conn, err := pgx.ConnectConfig(context.Background(), plainPasswordConnConfig) - if err != nil { - t.Fatal("Unable to establish connection: " + err.Error()) - } - - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") - } -} - -func TestConnectWithMD5Password(t *testing.T) { - t.Parallel() - - if md5ConnConfig == nil { - t.Skip("Skipping due to undefined md5ConnConfig") - } - - conn, err := pgx.ConnectConfig(context.Background(), md5ConnConfig) - if err != nil { - t.Fatal("Unable to establish connection: " + err.Error()) - } - - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") - } -} - -func TestConnectWithConnectionRefused(t *testing.T) { - t.Parallel() - - // Presumably nothing is listening on 127.0.0.1:1 - bad := *defaultConnConfig - bad.Host = "127.0.0.1" - bad.Port = 1 - - _, err := pgx.ConnectConfig(context.Background(), &bad) - if err == nil { - t.Fatal("Expected error establishing connection to bad port") - } -} - func TestConnectWithPreferSimpleProtocol(t *testing.T) { t.Parallel() @@ -228,67 +105,6 @@ func TestConnectWithPreferSimpleProtocol(t *testing.T) { ensureConnValid(t, conn) } -func TestConnectCustomDialer(t *testing.T) { - t.Parallel() - - if customDialerConnConfig == nil { - t.Skip("Skipping due to undefined customDialerConnConfig") - } - - dialled := false - conf := *customDialerConnConfig - conf.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { - dialled = true - return net.Dial(network, address) - } - - conn, err := pgx.ConnectConfig(context.Background(), &conf) - if err != nil { - t.Fatalf("Unable to establish connection: %s", err) - } - if !dialled { - t.Fatal("Connect did not use custom dialer") - } - - err = conn.Close() - if err != nil { - t.Fatal("Unable to close connection") - } -} - -func TestConnectWithRuntimeParams(t *testing.T) { - t.Parallel() - - connConfig := *defaultConnConfig - connConfig.RuntimeParams = map[string]string{ - "application_name": "pgxtest", - "search_path": "myschema", - } - - conn, err := pgx.ConnectConfig(context.Background(), &connConfig) - if err != nil { - t.Fatalf("Unable to establish connection: %v", err) - } - defer conn.Close() - - var s string - err = conn.QueryRow("show application_name").Scan(&s) - if err != nil { - t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) - } - if s != "pgxtest" { - t.Errorf("Expected application_name to be %s, but it was %s", "pgxtest", s) - } - - err = conn.QueryRow("show search_path").Scan(&s) - if err != nil { - t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) - } - if s != "myschema" { - t.Errorf("Expected search_path to be %s, but it was %s", "myschema", s) - } -} - func TestExec(t *testing.T) { t.Parallel() diff --git a/pgconn/helper_test.go b/pgconn/helper_test.go new file mode 100644 index 00000000..e6a7c73b --- /dev/null +++ b/pgconn/helper_test.go @@ -0,0 +1,13 @@ +package pgconn_test + +import ( + "testing" + + "github.com/jackc/pgx/pgconn" + + "github.com/stretchr/testify/require" +) + +func closeConn(t testing.TB, conn *pgconn.PgConn) { + require.Nil(t, conn.Close()) +} diff --git a/pgconn/pgconn_test.go b/pgconn/pgconn_test.go index f165786e..9e16e925 100644 --- a/pgconn/pgconn_test.go +++ b/pgconn/pgconn_test.go @@ -2,15 +2,161 @@ package pgconn_test import ( "context" + "crypto/tls" + "net" "os" "testing" + "github.com/jackc/pgx" "github.com/jackc/pgx/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) +func TestConnect(t *testing.T) { + tests := []struct { + name string + env string + }{ + {"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"}, + {"TCP", "PGX_TEST_TCP_CONN_STRING"}, + {"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"}, + {"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + connString := os.Getenv(tt.env) + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", tt.env) + } + + conn, err := pgconn.Connect(context.Background(), connString) + require.Nil(t, err) + + err = conn.Close() + require.Nil(t, err) + }) + } +} + +// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure +// connection. +func TestConnectTLS(t *testing.T) { + connString := os.Getenv("PGX_TEST_TLS_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING") + } + + conn, err := pgconn.Connect(context.Background(), connString) + require.Nil(t, err) + + if _, ok := conn.NetConn.(*tls.Conn); !ok { + t.Error("not a TLS connection") + } + + err = conn.Close() + require.Nil(t, err) +} + +func TestConnectInvalidUser(t *testing.T) { + connString := os.Getenv("PGX_TEST_TCP_CONN_STRING") + if connString == "" { + t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING") + } + + config, err := pgconn.ParseConfig(connString) + require.Nil(t, err) + + config.User = "pgxinvalidusertest" + + conn, err := pgconn.ConnectConfig(context.Background(), config) + if err == nil { + conn.Close() + t.Fatal("expected err but got none") + } + pgErr, ok := err.(pgx.PgError) + if !ok { + t.Fatalf("Expected to receive a PgError, instead received: %v", err) + } + if pgErr.Code != "28000" && pgErr.Code != "28P01" { + t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr) + } +} + +func TestConnectWithConnectionRefused(t *testing.T) { + t.Parallel() + + // Presumably nothing is listening on 127.0.0.1:1 + conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1") + if err == nil { + conn.Close() + t.Fatal("Expected error establishing connection to bad port") + } +} + +func TestConnectCustomDialer(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + dialed := false + config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) { + dialed = true + return net.Dial(network, address) + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + require.True(t, dialed) + conn.Close() +} + +func TestConnectWithRuntimeParams(t *testing.T) { + config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.Nil(t, err) + + config.RuntimeParams = map[string]string{ + "application_name": "pgxtest", + "search_path": "myschema", + } + + conn, err := pgconn.ConnectConfig(context.Background(), config) + require.Nil(t, err) + defer closeConn(t, conn) + + // TODO - refactor these selects once there are higher level query functions + + conn.SendExec("show application_name") + conn.SendExec("show search_path") + err = conn.Flush() + require.Nil(t, err) + + result := conn.GetResult() + require.NotNil(t, result) + + rowFound := result.NextRow() + assert.True(t, rowFound) + if rowFound { + assert.Equal(t, "pgxtest", string(result.Value(0))) + } + + _, err = result.Close() + assert.Nil(t, err) + + result = conn.GetResult() + require.NotNil(t, result) + + rowFound = result.NextRow() + assert.True(t, rowFound) + if rowFound { + assert.Equal(t, "myschema", string(result.Value(0))) + } + + _, err = result.Close() + assert.Nil(t, err) +} + func TestSimple(t *testing.T) { pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE")) require.Nil(t, err)