mirror of https://github.com/jackc/pgx.git
139 lines
5.0 KiB
Go
139 lines
5.0 KiB
Go
package pgx_test
|
|
|
|
import (
|
|
"context"
|
|
"os"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgxtest"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
var defaultConnTestRunner pgxtest.ConnTestRunner
|
|
|
|
func init() {
|
|
defaultConnTestRunner = pgxtest.DefaultConnTestRunner()
|
|
defaultConnTestRunner.CreateConfig = func(ctx context.Context, t testing.TB) *pgx.ConnConfig {
|
|
config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
return config
|
|
}
|
|
}
|
|
|
|
func mustConnectString(t testing.TB, connString string) *pgx.Conn {
|
|
conn, err := pgx.Connect(context.Background(), connString)
|
|
if err != nil {
|
|
t.Fatalf("Unable to establish connection: %v", err)
|
|
}
|
|
return conn
|
|
}
|
|
|
|
func mustParseConfig(t testing.TB, connString string) *pgx.ConnConfig {
|
|
config, err := pgx.ParseConfig(connString)
|
|
require.Nil(t, err)
|
|
return config
|
|
}
|
|
|
|
func mustConnect(t testing.TB, config *pgx.ConnConfig) *pgx.Conn {
|
|
conn, err := pgx.ConnectConfig(context.Background(), config)
|
|
if err != nil {
|
|
t.Fatalf("Unable to establish connection: %v", err)
|
|
}
|
|
return conn
|
|
}
|
|
|
|
func closeConn(t testing.TB, conn *pgx.Conn) {
|
|
err := conn.Close(context.Background())
|
|
if err != nil {
|
|
t.Fatalf("conn.Close unexpectedly failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func mustExec(t testing.TB, conn *pgx.Conn, sql string, arguments ...any) (commandTag pgconn.CommandTag) {
|
|
var err error
|
|
if commandTag, err = conn.Exec(context.Background(), sql, arguments...); err != nil {
|
|
t.Fatalf("Exec unexpectedly failed with %v: %v", sql, err)
|
|
}
|
|
return
|
|
}
|
|
|
|
// Do a simple query to ensure the connection is still usable
|
|
func ensureConnValid(t testing.TB, conn *pgx.Conn) {
|
|
var sum, rowCount int32
|
|
|
|
rows, err := conn.Query(context.Background(), "select generate_series(1,$1)", 10)
|
|
if err != nil {
|
|
t.Fatalf("conn.Query failed: %v", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var n int32
|
|
rows.Scan(&n)
|
|
sum += n
|
|
rowCount++
|
|
}
|
|
|
|
if rows.Err() != nil {
|
|
t.Fatalf("conn.Query failed: %v", rows.Err())
|
|
}
|
|
|
|
if rowCount != 10 {
|
|
t.Error("Select called onDataRow wrong number of times")
|
|
}
|
|
if sum != 55 {
|
|
t.Error("Wrong values returned")
|
|
}
|
|
}
|
|
|
|
func assertConfigsEqual(t *testing.T, expected, actual *pgx.ConnConfig, testName string) {
|
|
if !assert.NotNil(t, expected) {
|
|
return
|
|
}
|
|
if !assert.NotNil(t, actual) {
|
|
return
|
|
}
|
|
|
|
assert.Equalf(t, expected.Tracer, actual.Tracer, "%s - Tracer", testName)
|
|
assert.Equalf(t, expected.ConnString(), actual.ConnString(), "%s - ConnString", testName)
|
|
assert.Equalf(t, expected.StatementCacheCapacity, actual.StatementCacheCapacity, "%s - StatementCacheCapacity", testName)
|
|
assert.Equalf(t, expected.DescriptionCacheCapacity, actual.DescriptionCacheCapacity, "%s - DescriptionCacheCapacity", testName)
|
|
assert.Equalf(t, expected.DefaultQueryExecMode, actual.DefaultQueryExecMode, "%s - DefaultQueryExecMode", testName)
|
|
assert.Equalf(t, expected.Host, actual.Host, "%s - Host", testName)
|
|
assert.Equalf(t, expected.Database, actual.Database, "%s - Database", testName)
|
|
assert.Equalf(t, expected.Port, actual.Port, "%s - Port", testName)
|
|
assert.Equalf(t, expected.User, actual.User, "%s - User", testName)
|
|
assert.Equalf(t, expected.Password, actual.Password, "%s - Password", testName)
|
|
assert.Equalf(t, expected.ConnectTimeout, actual.ConnectTimeout, "%s - ConnectTimeout", testName)
|
|
assert.Equalf(t, expected.RuntimeParams, actual.RuntimeParams, "%s - RuntimeParams", testName)
|
|
|
|
// Can't test function equality, so just test that they are set or not.
|
|
assert.Equalf(t, expected.ValidateConnect == nil, actual.ValidateConnect == nil, "%s - ValidateConnect", testName)
|
|
assert.Equalf(t, expected.AfterConnect == nil, actual.AfterConnect == nil, "%s - AfterConnect", testName)
|
|
|
|
if assert.Equalf(t, expected.TLSConfig == nil, actual.TLSConfig == nil, "%s - TLSConfig", testName) {
|
|
if expected.TLSConfig != nil {
|
|
assert.Equalf(t, expected.TLSConfig.InsecureSkipVerify, actual.TLSConfig.InsecureSkipVerify, "%s - TLSConfig InsecureSkipVerify", testName)
|
|
assert.Equalf(t, expected.TLSConfig.ServerName, actual.TLSConfig.ServerName, "%s - TLSConfig ServerName", testName)
|
|
}
|
|
}
|
|
|
|
if assert.Equalf(t, len(expected.Fallbacks), len(actual.Fallbacks), "%s - Fallbacks", testName) {
|
|
for i := range expected.Fallbacks {
|
|
assert.Equalf(t, expected.Fallbacks[i].Host, actual.Fallbacks[i].Host, "%s - Fallback %d - Host", testName, i)
|
|
assert.Equalf(t, expected.Fallbacks[i].Port, actual.Fallbacks[i].Port, "%s - Fallback %d - Port", testName, i)
|
|
|
|
if assert.Equalf(t, expected.Fallbacks[i].TLSConfig == nil, actual.Fallbacks[i].TLSConfig == nil, "%s - Fallback %d - TLSConfig", testName, i) {
|
|
if expected.Fallbacks[i].TLSConfig != nil {
|
|
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.InsecureSkipVerify, actual.Fallbacks[i].TLSConfig.InsecureSkipVerify, "%s - Fallback %d - TLSConfig InsecureSkipVerify", testName)
|
|
assert.Equalf(t, expected.Fallbacks[i].TLSConfig.ServerName, actual.Fallbacks[i].TLSConfig.ServerName, "%s - Fallback %d - TLSConfig ServerName", testName)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|