diff --git a/conn_test.go b/conn_test.go index e5e607b0..6527f112 100644 --- a/conn_test.go +++ b/conn_test.go @@ -98,11 +98,10 @@ func TestConnectWithPreferSimpleProtocol(t *testing.T) { } func TestConnectConfigRequiresConnConfigFromParseConfig(t *testing.T) { - t.Parallel() - config := &pgx.ConnConfig{} - - require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgx.ConnectConfig(context.Background(), config) }) + require.PanicsWithValue(t, "config must be created by ParseConfig", func() { + pgx.ConnectConfig(context.Background(), config) + }) } func TestParseConfigExtractsStatementCacheOptions(t *testing.T) { @@ -140,131 +139,121 @@ func TestParseConfigExtractsStatementCacheOptions(t *testing.T) { func TestExec(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); string(results) != "CREATE TABLE" { + t.Error("Unexpected results from Exec") + } - if results := mustExec(t, conn, "create temporary table foo(id integer primary key);"); string(results) != "CREATE TABLE" { - t.Error("Unexpected results from Exec") - } + // Accept parameters + if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); string(results) != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } - // Accept parameters - if results := mustExec(t, conn, "insert into foo(id) values($1)", 1); string(results) != "INSERT 0 1" { - t.Errorf("Unexpected results from Exec: %v", results) - } + if results := mustExec(t, conn, "drop table foo;"); string(results) != "DROP TABLE" { + t.Error("Unexpected results from Exec") + } - if results := mustExec(t, conn, "drop table foo;"); string(results) != "DROP TABLE" { - t.Error("Unexpected results from Exec") - } + // Multiple statements can be executed -- last command tag is returned + if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); string(results) != "DROP TABLE" { + t.Error("Unexpected results from Exec") + } - // Multiple statements can be executed -- last command tag is returned - if results := mustExec(t, conn, "create temporary table foo(id serial primary key); drop table foo;"); string(results) != "DROP TABLE" { - t.Error("Unexpected results from Exec") - } + // Can execute longer SQL strings than sharedBufferSize + if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); string(results) != "SELECT 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } - // Can execute longer SQL strings than sharedBufferSize - if results := mustExec(t, conn, strings.Repeat("select 42; ", 1000)); string(results) != "SELECT 1" { - t.Errorf("Unexpected results from Exec: %v", results) - } - - // Exec no-op which does not return a command tag - if results := mustExec(t, conn, "--;"); string(results) != "" { - t.Errorf("Unexpected results from Exec: %v", results) - } + // Exec no-op which does not return a command tag + if results := mustExec(t, conn, "--;"); string(results) != "" { + t.Errorf("Unexpected results from Exec: %v", results) + } + }) } func TestExecFailure(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + if _, err := conn.Exec(context.Background(), "selct;"); err == nil { + t.Fatal("Expected SQL syntax error") + } - if _, err := conn.Exec(context.Background(), "selct;"); err == nil { - t.Fatal("Expected SQL syntax error") - } - - rows, _ := conn.Query(context.Background(), "select 1") - rows.Close() - if rows.Err() != nil { - t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err()) - } + rows, _ := conn.Query(context.Background(), "select 1") + rows.Close() + if rows.Err() != nil { + t.Fatalf("Exec failure appears to have broken connection: %v", rows.Err()) + } + }) } func TestExecFailureWithArguments(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + _, err := conn.Exec(context.Background(), "selct $1;", 1) + if err == nil { + t.Fatal("Expected SQL syntax error") + } + assert.False(t, pgconn.SafeToRetry(err)) - _, err := conn.Exec(context.Background(), "selct $1;", 1) - if err == nil { - t.Fatal("Expected SQL syntax error") - } - assert.False(t, pgconn.SafeToRetry(err)) - - _, err = conn.Exec(context.Background(), "select $1::varchar(1);", "1", "2") - if err == nil { - t.Fatal("Expected pgx arguments count error", err) - } - assert.Equal(t, "expected 1 arguments, got 2", err.Error()) + _, err = conn.Exec(context.Background(), "select $1::varchar(1);", "1", "2") + require.Error(t, err) + }) } func TestExecContextWithoutCancelation(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - commandTag, err := conn.Exec(ctx, "create temporary table foo(id integer primary key);") - if err != nil { - t.Fatal(err) - } - if string(commandTag) != "CREATE TABLE" { - t.Fatalf("Unexpected results from Exec: %v", commandTag) - } - assert.False(t, pgconn.SafeToRetry(err)) + commandTag, err := conn.Exec(ctx, "create temporary table foo(id integer primary key);") + if err != nil { + t.Fatal(err) + } + if string(commandTag) != "CREATE TABLE" { + t.Fatalf("Unexpected results from Exec: %v", commandTag) + } + assert.False(t, pgconn.SafeToRetry(err)) + }) } func TestExecContextFailureWithoutCancelation(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - _, err := conn.Exec(ctx, "selct;") - if err == nil { - t.Fatal("Expected SQL syntax error") - } - assert.False(t, pgconn.SafeToRetry(err)) - - rows, _ := conn.Query(context.Background(), "select 1") - rows.Close() - if rows.Err() != nil { - t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err()) - } - assert.False(t, pgconn.SafeToRetry(err)) + _, err := conn.Exec(ctx, "selct;") + if err == nil { + t.Fatal("Expected SQL syntax error") + } + assert.False(t, pgconn.SafeToRetry(err)) + rows, _ := conn.Query(context.Background(), "select 1") + rows.Close() + if rows.Err() != nil { + t.Fatalf("ExecEx failure appears to have broken connection: %v", rows.Err()) + } + assert.False(t, pgconn.SafeToRetry(err)) + }) } func TestExecContextFailureWithoutCancelationWithArguments(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - _, err := conn.Exec(ctx, "selct $1;", 1) - if err == nil { - t.Fatal("Expected SQL syntax error") - } - assert.False(t, pgconn.SafeToRetry(err)) + _, err := conn.Exec(ctx, "selct $1;", 1) + if err == nil { + t.Fatal("Expected SQL syntax error") + } + assert.False(t, pgconn.SafeToRetry(err)) + }) } func TestExecFailureCloseBefore(t *testing.T) { @@ -278,38 +267,6 @@ func TestExecFailureCloseBefore(t *testing.T) { assert.True(t, pgconn.SafeToRetry(err)) } -func TestExecExtendedProtocol(t *testing.T) { - t.Parallel() - - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - commandTag, err := conn.Exec(ctx, "create temporary table foo(name varchar primary key);") - if err != nil { - t.Fatal(err) - } - if string(commandTag) != "CREATE TABLE" { - t.Fatalf("Unexpected results from Exec: %v", commandTag) - } - - commandTag, err = conn.Exec( - ctx, - "insert into foo(name) values($1);", - "bar", - ) - if err != nil { - t.Fatal(err) - } - if string(commandTag) != "INSERT 0 1" { - t.Fatalf("Unexpected results from ExecEx: %v", commandTag) - } - - ensureConnValid(t, conn) -} - func TestExecStatementCacheModes(t *testing.T) { t.Parallel() @@ -360,7 +317,7 @@ func TestExecStatementCacheModes(t *testing.T) { } } -func TestExecSimpleProtocol(t *testing.T) { +func TestExecPerQuerySimpleProtocol(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) @@ -692,33 +649,31 @@ func TestFatalTxError(t *testing.T) { func TestInsertBoolArray(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); string(results) != "CREATE TABLE" { + t.Error("Unexpected results from Exec") + } - if results := mustExec(t, conn, "create temporary table foo(spice bool[]);"); string(results) != "CREATE TABLE" { - t.Error("Unexpected results from Exec") - } - - // Accept parameters - if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); string(results) != "INSERT 0 1" { - t.Errorf("Unexpected results from Exec: %v", results) - } + // Accept parameters + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []bool{true, false, true}); string(results) != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } + }) } func TestInsertTimestampArray(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); string(results) != "CREATE TABLE" { + t.Error("Unexpected results from Exec") + } - if results := mustExec(t, conn, "create temporary table foo(spice timestamp[]);"); string(results) != "CREATE TABLE" { - t.Error("Unexpected results from Exec") - } - - // Accept parameters - if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); string(results) != "INSERT 0 1" { - t.Errorf("Unexpected results from Exec: %v", results) - } + // Accept parameters + if results := mustExec(t, conn, "insert into foo(spice) values($1)", []time.Time{time.Unix(1419143667, 0), time.Unix(1419143672, 0)}); string(results) != "INSERT 0 1" { + t.Errorf("Unexpected results from Exec: %v", results) + } + }) } type testLog struct { @@ -833,71 +788,61 @@ func TestConnInitConnInfo(t *testing.T) { } func TestUnregisteredTypeUsableAsStringArgumentAndBaseResult(t *testing.T) { - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + var n uint64 + err := conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n) + if err != nil { + t.Fatal(err) + } - var n uint64 - err := conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n) - if err != nil { - t.Fatal(err) - } - - if n != 42 { - t.Fatalf("Expected n to be 42, but was %v", n) - } - - ensureConnValid(t, conn) + if n != 42 { + t.Fatalf("Expected n to be 42, but was %v", n) + } + }) } func TestDomainType(t *testing.T) { - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + var n uint64 - var err error - var n uint64 + // Domain type uint64 is a PostgreSQL domain of underlying type numeric. - // Domain type uint64 is a PostgreSQL domain of underlying type numeric. + err := conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) + require.NoError(t, err) - // Since it is not registered, pgx does not know how to encode Go uint64 argument. - err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) - if err == nil { - t.Fatal("expected error encoding uint64 into unregistered domain") - } + // A string can be used. But a string cannot be the result because the describe result from the PostgreSQL server gives + // the underlying type of numeric. + err = conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n) + if err != nil { + t.Fatal(err) + } + if n != 42 { + t.Fatalf("Expected n to be 42, but was %v", n) + } - // A string can be used. But a string cannot be the result because the describe result from the PostgreSQL server gives - // the underlying type of numeric. - err = conn.QueryRow(context.Background(), "select $1::uint64", "42").Scan(&n) - if err != nil { - t.Fatal(err) - } - if n != 42 { - t.Fatalf("Expected n to be 42, but was %v", n) - } + var uint64OID uint32 + err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID) + if err != nil { + t.Fatalf("did not find uint64 OID, %v", err) + } + conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.Numeric{}, Name: "uint64", OID: uint64OID}) - var uint64OID uint32 - err = conn.QueryRow(context.Background(), "select t.oid from pg_type t where t.typname='uint64';").Scan(&uint64OID) - if err != nil { - t.Fatalf("did not find uint64 OID, %v", err) - } - conn.ConnInfo().RegisterDataType(pgtype.DataType{Value: &pgtype.Numeric{}, Name: "uint64", OID: uint64OID}) + // String is still an acceptable argument after registration + err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n) + if err != nil { + t.Fatal(err) + } + if n != 7 { + t.Fatalf("Expected n to be 7, but was %v", n) + } - // String is still an acceptable argument after registration - err = conn.QueryRow(context.Background(), "select $1::uint64", "7").Scan(&n) - if err != nil { - t.Fatal(err) - } - if n != 7 { - t.Fatalf("Expected n to be 7, but was %v", n) - } - - // But a uint64 is acceptable - err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) - if err != nil { - t.Fatal(err) - } - if n != 24 { - t.Fatalf("Expected n to be 24, but was %v", n) - } - - ensureConnValid(t, conn) + // But a uint64 is acceptable + err = conn.QueryRow(context.Background(), "select $1::uint64", uint64(24)).Scan(&n) + if err != nil { + t.Fatal(err) + } + if n != 24 { + t.Fatalf("Expected n to be 24, but was %v", n) + } + }) } diff --git a/extended_query_builder.go b/extended_query_builder.go index 26eca505..b6a85a9e 100644 --- a/extended_query_builder.go +++ b/extended_query_builder.go @@ -134,6 +134,28 @@ func (eqb *extendedQueryBuilder) encodeExtendedParamValue(ci *pgtype.ConnInfo, o return eqb.encodeExtendedParamValue(ci, oid, value) } + // There is no data type registered for the destination OID, but maybe there is data type registered for the arg + // type. If so use it's text encoder (if available). + if dt, ok := ci.DataTypeForValue(arg); ok { + value := dt.Value + if textEncoder, ok := value.(pgtype.TextEncoder); ok { + err := value.Set(arg) + if err != nil { + return nil, err + } + + buf, err = textEncoder.EncodeText(ci, eqb.paramValueBytes) + if err != nil { + return nil, err + } + if buf == nil { + return nil, nil + } + eqb.paramValueBytes = buf + return eqb.paramValueBytes[pos:], nil + } + } + if strippedArg, ok := stripNamedType(&refVal); ok { return eqb.encodeExtendedParamValue(ci, oid, strippedArg) } diff --git a/go.mod b/go.mod index 4be08952..17804c68 100644 --- a/go.mod +++ b/go.mod @@ -4,16 +4,18 @@ go 1.12 require ( github.com/cockroachdb/apd v1.1.0 + github.com/go-stack/stack v1.8.0 // indirect github.com/gofrs/uuid v3.2.0+incompatible github.com/jackc/pgconn v1.5.0 github.com/jackc/pgio v1.0.0 github.com/jackc/pgproto3/v2 v2.0.1 - github.com/jackc/pgtype v1.3.1-0.20200505182314-3b7c47a2a7da + github.com/jackc/pgtype v1.3.1-0.20200508211315-97bbe6ae20e2 github.com/jackc/puddle v1.1.1 github.com/rs/zerolog v1.15.0 github.com/shopspring/decimal v0.0.0-20200227202807-02e2044944cc github.com/sirupsen/logrus v1.4.2 github.com/stretchr/testify v1.5.1 + go.uber.org/multierr v1.5.0 // indirect go.uber.org/zap v1.10.0 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec diff --git a/go.sum b/go.sum index e30f7b46..6cdc2fe3 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,4 @@ +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -11,6 +12,7 @@ github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/gofrs/uuid v3.2.0+incompatible h1:y12jRkkFxsd7GpqdSZ+/KCs/fJbqpEXSGd4+jfEaewE= github.com/gofrs/uuid v3.2.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/jackc/chunkreader v1.0.0 h1:4s39bBR8ByfqH+DKm8rQA3E1LHZWB9XWcrz8fqaZbe0= github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= github.com/jackc/chunkreader/v2 v2.0.0 h1:DUwgMQuuPnS0rhMXenUtZpqZqrR/30NWY+qQvTpSvEs= @@ -53,6 +55,8 @@ github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrU github.com/jackc/pgtype v1.2.0/go.mod h1:5m2OfMh1wTK7x+Fk952IDmI4nw3nPrvtQdM0ZT4WpC0= github.com/jackc/pgtype v1.3.1-0.20200505182314-3b7c47a2a7da h1:ZbfsOjqJ1nHsryU03mdXZy6ZEsymYvihkXxN9tUx1YU= github.com/jackc/pgtype v1.3.1-0.20200505182314-3b7c47a2a7da/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= +github.com/jackc/pgtype v1.3.1-0.20200508211315-97bbe6ae20e2 h1:Y6cErz3hUojOwnjUEWoZPRCBQcB7avM9ntGiYkB0wJo= +github.com/jackc/pgtype v1.3.1-0.20200508211315-97bbe6ae20e2/go.mod h1:vaogEUkALtxZMCH411K+tKzNpwzCKU+AnPzBKZ+I+Po= github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= @@ -65,6 +69,7 @@ github.com/jackc/puddle v1.1.0 h1:musOWczZC/rSbqut475Vfcczg7jJsdUQf0D6oKPLgNU= github.com/jackc/puddle v1.1.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= github.com/jackc/puddle v1.1.1 h1:PJAw7H/9hoWC4Kf3J8iNmL1SwA6E8vfsLqBiL+F6CtI= github.com/jackc/puddle v1.1.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/konsorten/go-windows-terminal-sequences v1.0.1 h1:mweAR1A6xJ3oS2pRaGiHgQ4OO8tzTaLawm8vnODuwDk= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= @@ -97,6 +102,7 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.13.0 h1:hSNcYHyxDWycfePW7pUI8swuFkcSMPKh3E63Pokg1Hk= github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= @@ -127,8 +133,13 @@ go.uber.org/atomic v1.3.2 h1:2Oa65PReHzfn29GpvgsYwloV9AVFHPDk8tYxt2c2tr4= go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU= go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.6.0 h1:Ezj3JGmsOnG1MoRWQkPBsKLe9DwWD9QeXzTRzzldNVk= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI= go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.5.0 h1:KCa4XfM8CWFCpxXRGok+Q0SS/0XBhMDbHHGABQLvD2A= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= go.uber.org/zap v1.9.1 h1:XCJQEf3W6eZaVwhRBof6ImoYGJSITeKWsyeh3HFu/5o= go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= go.uber.org/zap v1.10.0 h1:ORx85nbTijNz8ljznvCMR1ZBIPKFn3jQrag10X2AsuM= @@ -136,11 +147,14 @@ go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a h1:Igim7XhdOpBnWPuYJ70XcNpq8q3BCACtVgNfoJxOV7g= golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586 h1:7KByu05hhLed2MO29w7p1XfZvZ13m8mub3shuVftRs0= golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59 h1:3zb4D3T4G8jdExgVU/95+vQXfpEPiMdCaZgmGVxjNHM= golang.org/x/crypto v0.0.0-20200323165209-0ec3e9974c59/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -163,8 +177,12 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373 h1:PPwnA7z1Pjf7XYaBP9GL1VAMZmcIWyFz7QCMSIIa3Bg= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= @@ -176,7 +194,9 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec h1:RlWgLqCMMIYYEVcAR5MDsuHlVkaIPDAF+5Dehzg8L5A= gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/helper_test.go b/helper_test.go index 6b357577..fde4cbfa 100644 --- a/helper_test.go +++ b/helper_test.go @@ -2,6 +2,7 @@ package pgx_test import ( "context" + "os" "testing" "github.com/jackc/pgconn" @@ -9,6 +10,45 @@ import ( "github.com/stretchr/testify/require" ) +func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, conn *pgx.Conn)) { + t.Run("SimpleProto", + func(t *testing.T) { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.PreferSimpleProtocol = true + conn, err := pgx.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer func() { + err := conn.Close(context.Background()) + require.NoError(t, err) + }() + + f(t, conn) + + ensureConnValid(t, conn) + }, + ) + + t.Run("DefaultProto", + func(t *testing.T) { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + conn, err := pgx.ConnectConfig(context.Background(), config) + require.NoError(t, err) + defer func() { + err := conn.Close(context.Background()) + require.NoError(t, err) + }() + + f(t, conn) + + ensureConnValid(t, conn) + }, + ) +} + func mustConnectString(t testing.TB, connString string) *pgx.Conn { conn, err := pgx.Connect(context.Background(), connString) if err != nil { diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 8acce14e..f17560f7 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -30,6 +30,43 @@ func closeDB(t testing.TB, db *sql.DB) { require.NoError(t, err) } +func testWithAndWithoutPreferSimpleProtocol(t *testing.T, f func(t *testing.T, db *sql.DB)) { + t.Run("SimpleProto", + func(t *testing.T) { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + config.PreferSimpleProtocol = true + db := stdlib.OpenDB(*config) + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + f(t, db) + + ensureDBValid(t, db) + }, + ) + + t.Run("DefaultProto", + func(t *testing.T) { + config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) + require.NoError(t, err) + + db := stdlib.OpenDB(*config) + defer func() { + err := db.Close() + require.NoError(t, err) + }() + + f(t, db) + + ensureDBValid(t, db) + }, + ) +} + // Do a simple query to ensure the DB is still usable. This is of less use in stdlib as the connection pool should // cover an broken connections. func ensureDBValid(t testing.TB, db *sql.DB) { @@ -182,172 +219,151 @@ func TestQueryCloseRowsEarly(t *testing.T) { } func TestConnExec(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + _, err := db.Exec("create temporary table t(a varchar not null)") + require.NoError(t, err) - _, err := db.Exec("create temporary table t(a varchar not null)") - require.NoError(t, err) + result, err := db.Exec("insert into t values('hey')") + require.NoError(t, err) - result, err := db.Exec("insert into t values('hey')") - require.NoError(t, err) - - n, err := result.RowsAffected() - require.NoError(t, err) - require.EqualValues(t, 1, n) - - ensureDBValid(t, db) + n, err := result.RowsAffected() + require.NoError(t, err) + require.EqualValues(t, 1, n) + }) } func TestConnQuery(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10)) - require.NoError(t, err) - - rowCount := int64(0) - - for rows.Next() { - rowCount++ - - var s string - var n int64 - err := rows.Scan(&s, &n) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10)) require.NoError(t, err) - if s != "foo" { - t.Errorf(`Expected "foo", received "%v"`, s) - } - if n != rowCount { - t.Errorf("Expected %d, received %d", rowCount, n) - } - } - require.NoError(t, rows.Err()) - require.EqualValues(t, 10, rowCount) - err = rows.Close() - require.NoError(t, err) + rowCount := int64(0) - ensureDBValid(t, db) + for rows.Next() { + rowCount++ + + var s string + var n int64 + err := rows.Scan(&s, &n) + require.NoError(t, err) + if s != "foo" { + t.Errorf(`Expected "foo", received "%v"`, s) + } + if n != rowCount { + t.Errorf("Expected %d, received %d", rowCount, n) + } + } + require.NoError(t, rows.Err()) + require.EqualValues(t, 10, rowCount) + + err = rows.Close() + require.NoError(t, err) + }) } func TestConnQueryNull(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - rows, err := db.Query("select $1::int", nil) - require.NoError(t, err) - - rowCount := int64(0) - - for rows.Next() { - rowCount++ - - var n sql.NullInt64 - err := rows.Scan(&n) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + rows, err := db.Query("select $1::int", nil) require.NoError(t, err) - if n.Valid != false { - t.Errorf("Expected n to be null, but it was %v", n) + + rowCount := int64(0) + + for rows.Next() { + rowCount++ + + var n sql.NullInt64 + err := rows.Scan(&n) + require.NoError(t, err) + if n.Valid != false { + t.Errorf("Expected n to be null, but it was %v", n) + } } - } - require.NoError(t, rows.Err()) - require.EqualValues(t, 1, rowCount) + require.NoError(t, rows.Err()) + require.EqualValues(t, 1, rowCount) - err = rows.Close() - require.NoError(t, err) - - ensureDBValid(t, db) + err = rows.Close() + require.NoError(t, err) + }) } func TestConnQueryRowByteSlice(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + expected := []byte{222, 173, 190, 239} + var actual []byte - expected := []byte{222, 173, 190, 239} - var actual []byte - - err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - - ensureDBValid(t, db) + err := db.QueryRow(`select E'\\xdeadbeef'::bytea`).Scan(&actual) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + }) } func TestConnQueryFailure(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - _, err := db.Query("select 'foo") - require.Error(t, err) - require.IsType(t, new(pgconn.PgError), err) - - ensureDBValid(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + _, err := db.Query("select 'foo") + require.Error(t, err) + require.IsType(t, new(pgconn.PgError), err) + }) } // Test type that pgx would handle natively in binary, but since it is not a // database/sql native type should be passed through as a string func TestConnQueryRowPgxBinary(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + sql := "select $1::int4[]" + expected := "{1,2,3}" + var actual string - sql := "select $1::int4[]" - expected := "{1,2,3}" - var actual string - - err := db.QueryRow(sql, expected).Scan(&actual) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - - ensureDBValid(t, db) + err := db.QueryRow(sql, expected).Scan(&actual) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + }) } func TestConnQueryRowUnknownType(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + sql := "select $1::point" + expected := "(1,2)" + var actual string - sql := "select $1::point" - expected := "(1,2)" - var actual string - - err := db.QueryRow(sql, expected).Scan(&actual) - require.NoError(t, err) - require.EqualValues(t, expected, actual) - - ensureDBValid(t, db) + err := db.QueryRow(sql, expected).Scan(&actual) + require.NoError(t, err) + require.EqualValues(t, expected, actual) + }) } func TestConnQueryJSONIntoByteSlice(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - _, err := db.Exec(` + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + _, err := db.Exec(` create temporary table docs( body json not null ); insert into docs(body) values('{"foo":"bar"}'); `) - require.NoError(t, err) + require.NoError(t, err) - sql := `select * from docs` - expected := []byte(`{"foo":"bar"}`) - var actual []byte + sql := `select * from docs` + expected := []byte(`{"foo":"bar"}`) + var actual []byte - err = db.QueryRow(sql).Scan(&actual) - if err != nil { - t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) - } + err = db.QueryRow(sql).Scan(&actual) + if err != nil { + t.Errorf("Unexpected failure: %v (sql -> %v)", err, sql) + } - if bytes.Compare(actual, expected) != 0 { - t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql) - } + if bytes.Compare(actual, expected) != 0 { + t.Errorf(`Expected "%v", got "%v" (sql -> %v)`, string(expected), string(actual), sql) + } - _, err = db.Exec(`drop table docs`) - require.NoError(t, err) - - ensureDBValid(t, db) + _, err = db.Exec(`drop table docs`) + require.NoError(t, err) + }) } func TestConnExecInsertByteSliceIntoJSON(t *testing.T) { + // Not testing with simple protocol because there is no way for that to work. A []byte will be considered binary data + // that needs to escape. No way to know whether the destination is really a text compatible or a bytea. + db := openDB(t) defer closeDB(t, db) @@ -373,329 +389,294 @@ func TestConnExecInsertByteSliceIntoJSON(t *testing.T) { _, err = db.Exec(`drop table docs`) require.NoError(t, err) - - ensureDBValid(t, db) } func TestTransactionLifeCycle(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + _, err := db.Exec("create temporary table t(a varchar not null)") + require.NoError(t, err) - _, err := db.Exec("create temporary table t(a varchar not null)") - require.NoError(t, err) + tx, err := db.Begin() + require.NoError(t, err) - tx, err := db.Begin() - require.NoError(t, err) + _, err = tx.Exec("insert into t values('hi')") + require.NoError(t, err) - _, err = tx.Exec("insert into t values('hi')") - require.NoError(t, err) + err = tx.Rollback() + require.NoError(t, err) - err = tx.Rollback() - require.NoError(t, err) + var n int64 + err = db.QueryRow("select count(*) from t").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 0, n) - var n int64 - err = db.QueryRow("select count(*) from t").Scan(&n) - require.NoError(t, err) - require.EqualValues(t, 0, n) + tx, err = db.Begin() + require.NoError(t, err) - tx, err = db.Begin() - require.NoError(t, err) + _, err = tx.Exec("insert into t values('hi')") + require.NoError(t, err) - _, err = tx.Exec("insert into t values('hi')") - require.NoError(t, err) + err = tx.Commit() + require.NoError(t, err) - err = tx.Commit() - require.NoError(t, err) - - err = db.QueryRow("select count(*) from t").Scan(&n) - require.NoError(t, err) - require.EqualValues(t, 1, n) - - ensureDBValid(t, db) + err = db.QueryRow("select count(*) from t").Scan(&n) + require.NoError(t, err) + require.EqualValues(t, 1, n) + }) } func TestConnBeginTxIsolation(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + var defaultIsoLevel string + err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel) + require.NoError(t, err) - var defaultIsoLevel string - err := db.QueryRow("show transaction_isolation").Scan(&defaultIsoLevel) - require.NoError(t, err) - - supportedTests := []struct { - sqlIso sql.IsolationLevel - pgIso string - }{ - {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel}, - {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"}, - {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"}, - {sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"}, - {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"}, - {sqlIso: sql.LevelSerializable, pgIso: "serializable"}, - } - for i, tt := range supportedTests { - func() { - tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) - if err != nil { - t.Errorf("%d. BeginTx failed: %v", i, err) - return - } - defer tx.Rollback() - - var pgIso string - err = tx.QueryRow("show transaction_isolation").Scan(&pgIso) - if err != nil { - t.Errorf("%d. QueryRow failed: %v", i, err) - } - - if pgIso != tt.pgIso { - t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso) - } - }() - } - - unsupportedTests := []struct { - sqlIso sql.IsolationLevel - }{ - {sqlIso: sql.LevelWriteCommitted}, - {sqlIso: sql.LevelLinearizable}, - } - for i, tt := range unsupportedTests { - tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) - if err == nil { - t.Errorf("%d. BeginTx should have failed", i) - tx.Rollback() + supportedTests := []struct { + sqlIso sql.IsolationLevel + pgIso string + }{ + {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel}, + {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"}, + {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"}, + {sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"}, + {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"}, + {sqlIso: sql.LevelSerializable, pgIso: "serializable"}, } - } + for i, tt := range supportedTests { + func() { + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) + if err != nil { + t.Errorf("%d. BeginTx failed: %v", i, err) + return + } + defer tx.Rollback() - ensureDBValid(t, db) + var pgIso string + err = tx.QueryRow("show transaction_isolation").Scan(&pgIso) + if err != nil { + t.Errorf("%d. QueryRow failed: %v", i, err) + } + + if pgIso != tt.pgIso { + t.Errorf("%d. pgIso => %s, want %s", i, pgIso, tt.pgIso) + } + }() + } + + unsupportedTests := []struct { + sqlIso sql.IsolationLevel + }{ + {sqlIso: sql.LevelWriteCommitted}, + {sqlIso: sql.LevelLinearizable}, + } + for i, tt := range unsupportedTests { + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{Isolation: tt.sqlIso}) + if err == nil { + t.Errorf("%d. BeginTx should have failed", i) + tx.Rollback() + } + } + }) } func TestConnBeginTxReadOnly(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) + require.NoError(t, err) + defer tx.Rollback() - tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) - require.NoError(t, err) - defer tx.Rollback() + var pgReadOnly string + err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly) + if err != nil { + t.Errorf("QueryRow failed: %v", err) + } - var pgReadOnly string - err = tx.QueryRow("show transaction_read_only").Scan(&pgReadOnly) - if err != nil { - t.Errorf("QueryRow failed: %v", err) - } - - if pgReadOnly != "on" { - t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on") - } - - ensureDBValid(t, db) + if pgReadOnly != "on" { + t.Errorf("pgReadOnly => %s, want %s", pgReadOnly, "on") + } + }) } func TestBeginTxContextCancel(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + _, err := db.Exec("drop table if exists t") + require.NoError(t, err) - _, err := db.Exec("drop table if exists t") - require.NoError(t, err) + ctx, cancelFn := context.WithCancel(context.Background()) - ctx, cancelFn := context.WithCancel(context.Background()) + tx, err := db.BeginTx(ctx, nil) + require.NoError(t, err) - tx, err := db.BeginTx(ctx, nil) - require.NoError(t, err) + _, err = tx.Exec("create table t(id serial)") + require.NoError(t, err) - _, err = tx.Exec("create table t(id serial)") - require.NoError(t, err) + cancelFn() - cancelFn() + err = tx.Commit() + if err != context.Canceled && err != sql.ErrTxDone { + t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone) + } - err = tx.Commit() - if err != context.Canceled && err != sql.ErrTxDone { - t.Fatalf("err => %v, want %v or %v", err, context.Canceled, sql.ErrTxDone) - } - - var n int - err = db.QueryRow("select count(*) from t").Scan(&n) - if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" { - t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err) - } - - ensureDBValid(t, db) + var n int + err = db.QueryRow("select count(*) from t").Scan(&n) + if pgErr, ok := err.(*pgconn.PgError); !ok || pgErr.Code != "42P01" { + t.Fatalf(`err => %v, want PgError{Code: "42P01"}`, err) + } + }) } func TestAcquireConn(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + var conns []*pgx.Conn - var conns []*pgx.Conn + for i := 1; i < 6; i++ { + conn, err := stdlib.AcquireConn(db) + if err != nil { + t.Errorf("%d. AcquireConn failed: %v", i, err) + continue + } - for i := 1; i < 6; i++ { - conn, err := stdlib.AcquireConn(db) - if err != nil { - t.Errorf("%d. AcquireConn failed: %v", i, err) - continue + var n int32 + err = conn.QueryRow(context.Background(), "select 1").Scan(&n) + if err != nil { + t.Errorf("%d. QueryRow failed: %v", i, err) + } + if n != 1 { + t.Errorf("%d. n => %d, want %d", i, n, 1) + } + + stats := db.Stats() + if stats.OpenConnections != i { + t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i) + } + + conns = append(conns, conn) } - var n int32 - err = conn.QueryRow(context.Background(), "select 1").Scan(&n) - if err != nil { - t.Errorf("%d. QueryRow failed: %v", i, err) + for i, conn := range conns { + if err := stdlib.ReleaseConn(db, conn); err != nil { + t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err) + } } - if n != 1 { - t.Errorf("%d. n => %d, want %d", i, n, 1) - } - - stats := db.Stats() - if stats.OpenConnections != i { - t.Errorf("%d. stats.OpenConnections => %d, want %d", i, stats.OpenConnections, i) - } - - conns = append(conns, conn) - } - - for i, conn := range conns { - if err := stdlib.ReleaseConn(db, conn); err != nil { - t.Errorf("%d. stdlib.ReleaseConn failed: %v", i, err) - } - } - - ensureDBValid(t, db) + }) } // https://github.com/jackc/pgx/issues/673 func TestReleaseConnWithTxInProgress(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + c1, err := stdlib.AcquireConn(db) + require.NoError(t, err) - c1, err := stdlib.AcquireConn(db) - require.NoError(t, err) + _, err = c1.Exec(context.Background(), "begin") + require.NoError(t, err) - _, err = c1.Exec(context.Background(), "begin") - require.NoError(t, err) + c1PID := c1.PgConn().PID() - c1PID := c1.PgConn().PID() + err = stdlib.ReleaseConn(db, c1) + require.NoError(t, err) - err = stdlib.ReleaseConn(db, c1) - require.NoError(t, err) + c2, err := stdlib.AcquireConn(db) + require.NoError(t, err) - c2, err := stdlib.AcquireConn(db) - require.NoError(t, err) + c2PID := c2.PgConn().PID() - c2PID := c2.PgConn().PID() + err = stdlib.ReleaseConn(db, c2) + require.NoError(t, err) - err = stdlib.ReleaseConn(db, c2) - require.NoError(t, err) + require.NotEqual(t, c1PID, c2PID) - require.NotEqual(t, c1PID, c2PID) - - // Releasing a conn with a tx in progress should close the connection - stats := db.Stats() - require.Equal(t, 1, stats.OpenConnections) - - ensureDBValid(t, db) + // Releasing a conn with a tx in progress should close the connection + stats := db.Stats() + require.Equal(t, 1, stats.OpenConnections) + }) } func TestConnPingContextSuccess(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - err := db.PingContext(context.Background()) - require.NoError(t, err) - - ensureDBValid(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + err := db.PingContext(context.Background()) + require.NoError(t, err) + }) } func TestConnPrepareContextSuccess(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - stmt, err := db.PrepareContext(context.Background(), "select now()") - require.NoError(t, err) - stmt.Close() - - ensureDBValid(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + stmt, err := db.PrepareContext(context.Background(), "select now()") + require.NoError(t, err) + err = stmt.Close() + require.NoError(t, err) + }) } func TestConnExecContextSuccess(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)") - require.NoError(t, err) - - ensureDBValid(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)") + require.NoError(t, err) + }) } func TestConnExecContextFailureRetry(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - // we get a connection, immediately close it, and then get it back - { - conn, err := stdlib.AcquireConn(db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + // we get a connection, immediately close it, and then get it back + { + conn, err := stdlib.AcquireConn(db) + require.NoError(t, err) + conn.Close(context.Background()) + stdlib.ReleaseConn(db, conn) + } + conn, err := db.Conn(context.Background()) require.NoError(t, err) - conn.Close(context.Background()) - stdlib.ReleaseConn(db, conn) - } - conn, err := db.Conn(context.Background()) - require.NoError(t, err) - _, err = conn.ExecContext(context.Background(), "select 1") - require.EqualValues(t, driver.ErrBadConn, err) + _, err = conn.ExecContext(context.Background(), "select 1") + require.EqualValues(t, driver.ErrBadConn, err) + }) } func TestConnQueryContextSuccess(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") - require.NoError(t, err) - - for rows.Next() { - var n int64 - err := rows.Scan(&n) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") require.NoError(t, err) - } - require.NoError(t, rows.Err()) - ensureDBValid(t, db) + for rows.Next() { + var n int64 + err := rows.Scan(&n) + require.NoError(t, err) + } + require.NoError(t, rows.Err()) + }) } func TestConnQueryContextFailureRetry(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) - - // we get a connection, immediately close it, and then get it back - { - conn, err := stdlib.AcquireConn(db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + // we get a connection, immediately close it, and then get it back + { + conn, err := stdlib.AcquireConn(db) + require.NoError(t, err) + conn.Close(context.Background()) + stdlib.ReleaseConn(db, conn) + } + conn, err := db.Conn(context.Background()) require.NoError(t, err) - conn.Close(context.Background()) - stdlib.ReleaseConn(db, conn) - } - conn, err := db.Conn(context.Background()) - require.NoError(t, err) - _, err = conn.QueryContext(context.Background(), "select 1") - require.EqualValues(t, driver.ErrBadConn, err) + _, err = conn.QueryContext(context.Background(), "select 1") + require.EqualValues(t, driver.ErrBadConn, err) + }) } func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + rows, err := db.Query("select * from generate_series(1,10) n") + require.NoError(t, err) - rows, err := db.Query("select * from generate_series(1,10) n") - require.NoError(t, err) + columnTypes, err := rows.ColumnTypes() + require.NoError(t, err) + require.Len(t, columnTypes, 1) - columnTypes, err := rows.ColumnTypes() - require.NoError(t, err) - require.Len(t, columnTypes, 1) + if columnTypes[0].DatabaseTypeName() != "INT4" { + t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT4") + } - if columnTypes[0].DatabaseTypeName() != "INT4" { - t.Errorf("columnTypes[0].DatabaseTypeName() => %v, want %v", columnTypes[0].DatabaseTypeName(), "INT4") - } - - rows.Close() - - ensureDBValid(t, db) + err = rows.Close() + require.NoError(t, err) + }) } func TestStmtExecContextSuccess(t *testing.T) { @@ -763,191 +744,180 @@ func TestStmtQueryContextSuccess(t *testing.T) { } func TestRowsColumnTypes(t *testing.T) { - columnTypesTests := []struct { - Name string - TypeName string - Length struct { - Len int64 - OK bool - } - DecimalSize struct { - Precision int64 - Scale int64 - OK bool - } - ScanType reflect.Type - }{ - { - Name: "a", - TypeName: "INT4", - Length: struct { + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + columnTypesTests := []struct { + Name string + TypeName string + Length struct { Len int64 OK bool - }{ - Len: 0, - OK: false, - }, - DecimalSize: struct { + } + DecimalSize struct { Precision int64 Scale int64 OK bool - }{ - Precision: 0, - Scale: 0, - OK: false, + } + ScanType reflect.Type + }{ + { + Name: "a", + TypeName: "INT4", + Length: struct { + Len int64 + OK bool + }{ + Len: 0, + OK: false, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 0, + Scale: 0, + OK: false, + }, + ScanType: reflect.TypeOf(int32(0)), + }, { + Name: "bar", + TypeName: "TEXT", + Length: struct { + Len int64 + OK bool + }{ + Len: math.MaxInt64, + OK: true, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 0, + Scale: 0, + OK: false, + }, + ScanType: reflect.TypeOf(""), + }, { + Name: "dec", + TypeName: "NUMERIC", + Length: struct { + Len int64 + OK bool + }{ + Len: 0, + OK: false, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 9, + Scale: 2, + OK: true, + }, + ScanType: reflect.TypeOf(float64(0)), }, - ScanType: reflect.TypeOf(int32(0)), - }, { - Name: "bar", - TypeName: "TEXT", - Length: struct { - Len int64 - OK bool - }{ - Len: math.MaxInt64, - OK: true, - }, - DecimalSize: struct { - Precision int64 - Scale int64 - OK bool - }{ - Precision: 0, - Scale: 0, - OK: false, - }, - ScanType: reflect.TypeOf(""), - }, { - Name: "dec", - TypeName: "NUMERIC", - Length: struct { - Len int64 - OK bool - }{ - Len: 0, - OK: false, - }, - DecimalSize: struct { - Precision int64 - Scale int64 - OK bool - }{ - Precision: 9, - Scale: 2, - OK: true, - }, - ScanType: reflect.TypeOf(float64(0)), - }, - } - - db := openDB(t) - defer closeDB(t, db) - - rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec") - require.NoError(t, err) - - columns, err := rows.ColumnTypes() - require.NoError(t, err) - if len(columns) != 3 { - t.Errorf("expected 3 columns found %d", len(columns)) - } - - for i, tt := range columnTypesTests { - c := columns[i] - if c.Name() != tt.Name { - t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name) } - if c.DatabaseTypeName() != tt.TypeName { - t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName) - } - l, ok := c.Length() - if l != tt.Length.Len { - t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len) - } - if ok != tt.Length.OK { - t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK) - } - p, s, ok := c.DecimalSize() - if p != tt.DecimalSize.Precision { - t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision) - } - if s != tt.DecimalSize.Scale { - t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale) - } - if ok != tt.DecimalSize.OK { - t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK) - } - if c.ScanType() != tt.ScanType { - t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType) - } - } -} -func TestSimpleQueryLifeCycle(t *testing.T) { - config, err := pgx.ParseConfig(os.Getenv("PGX_TEST_DATABASE")) - require.NoError(t, err) - config.PreferSimpleProtocol = true - - db := stdlib.OpenDB(*config) - defer closeDB(t, db) - - rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3) - require.NoError(t, err) - - rowCount := int64(0) - - for rows.Next() { - rowCount++ - var ( - s string - n int64 - ) - - err := rows.Scan(&s, &n) + rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec") require.NoError(t, err) - if s != "foo" { - t.Errorf(`Expected "foo", received "%v"`, s) + columns, err := rows.ColumnTypes() + require.NoError(t, err) + if len(columns) != 3 { + t.Errorf("expected 3 columns found %d", len(columns)) } - if n != rowCount { - t.Errorf("Expected %d, received %d", rowCount, n) + for i, tt := range columnTypesTests { + c := columns[i] + if c.Name() != tt.Name { + t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name) + } + if c.DatabaseTypeName() != tt.TypeName { + t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName) + } + l, ok := c.Length() + if l != tt.Length.Len { + t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len) + } + if ok != tt.Length.OK { + t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK) + } + p, s, ok := c.DecimalSize() + if p != tt.DecimalSize.Precision { + t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision) + } + if s != tt.DecimalSize.Scale { + t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale) + } + if ok != tt.DecimalSize.OK { + t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK) + } + if c.ScanType() != tt.ScanType { + t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType) + } } - } - require.NoError(t, rows.Err()) + }) +} - err = rows.Close() - require.NoError(t, err) +func TestQueryLifeCycle(t *testing.T) { + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + rows, err := db.Query("SELECT 'foo', n FROM generate_series($1::int, $2::int) n WHERE 3 = $3", 1, 10, 3) + require.NoError(t, err) - rows, err = db.Query("select 1 where false") - require.NoError(t, err) + rowCount := int64(0) - rowCount = int64(0) + for rows.Next() { + rowCount++ + var ( + s string + n int64 + ) - for rows.Next() { - rowCount++ - } - require.NoError(t, rows.Err()) - require.EqualValues(t, 0, rowCount) + err := rows.Scan(&s, &n) + require.NoError(t, err) - err = rows.Close() - require.NoError(t, err) + if s != "foo" { + t.Errorf(`Expected "foo", received "%v"`, s) + } - ensureDBValid(t, db) + if n != rowCount { + t.Errorf("Expected %d, received %d", rowCount, n) + } + } + require.NoError(t, rows.Err()) + + err = rows.Close() + require.NoError(t, err) + + rows, err = db.Query("select 1 where false") + require.NoError(t, err) + + rowCount = int64(0) + + for rows.Next() { + rowCount++ + } + require.NoError(t, rows.Err()) + require.EqualValues(t, 0, rowCount) + + err = rows.Close() + require.NoError(t, err) + }) } // https://github.com/jackc/pgx/issues/409 func TestScanJSONIntoJSONRawMessage(t *testing.T) { - db := openDB(t) - defer closeDB(t, db) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, db *sql.DB) { + var msg json.RawMessage - var msg json.RawMessage - - err := db.QueryRow("select '{}'::json").Scan(&msg) - require.NoError(t, err) - require.EqualValues(t, []byte("{}"), []byte(msg)) - - ensureDBValid(t, db) + err := db.QueryRow("select '{}'::json").Scan(&msg) + require.NoError(t, err) + require.EqualValues(t, []byte("{}"), []byte(msg)) + }) } type testLog struct { diff --git a/values.go b/values.go index afb1c46f..355510b6 100644 --- a/values.go +++ b/values.go @@ -30,22 +30,9 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e return nil, nil } - setArgAndEncodeTextToString := func(t interface { - pgtype.Value - pgtype.TextEncoder - }) (interface{}, error) { - err := t.Set(arg) - if err != nil { - return nil, err - } - buf, err := t.EncodeText(ci, nil) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), nil + refVal := reflect.ValueOf(arg) + if refVal.Kind() == reflect.Ptr && refVal.IsNil() { + return nil, nil } switch arg := arg.(type) { @@ -124,36 +111,25 @@ func convertSimpleArgument(ci *pgtype.ConnInfo, arg interface{}) (interface{}, e return nil, errors.Errorf("arg too big for int64: %v", arg) } return int64(arg), nil - case []string: - return setArgAndEncodeTextToString(&pgtype.TextArray{}) - case []float32: - return setArgAndEncodeTextToString(&pgtype.Float4Array{}) - case []float64: - return setArgAndEncodeTextToString(&pgtype.Float8Array{}) - case []int16: - return setArgAndEncodeTextToString(&pgtype.Int8Array{}) - case []int32: - return setArgAndEncodeTextToString(&pgtype.Int8Array{}) - case []int64: - return setArgAndEncodeTextToString(&pgtype.Int8Array{}) - case []int: - return setArgAndEncodeTextToString(&pgtype.Int8Array{}) - case []uint16: - return setArgAndEncodeTextToString(&pgtype.Int8Array{}) - case []uint32: - return setArgAndEncodeTextToString(&pgtype.Int8Array{}) - case []uint64: - return setArgAndEncodeTextToString(&pgtype.Int8Array{}) - case []uint: - return setArgAndEncodeTextToString(&pgtype.Int8Array{}) } - refVal := reflect.ValueOf(arg) - - if refVal.Kind() == reflect.Ptr { - if refVal.IsNil() { + if dt, found := ci.DataTypeForValue(arg); found { + v := dt.Value + err := v.Set(arg) + if err != nil { + return nil, err + } + buf, err := v.(pgtype.TextEncoder).EncodeText(ci, nil) + if err != nil { + return nil, err + } + if buf == nil { return nil, nil } + return string(buf), nil + } + + if refVal.Kind() == reflect.Ptr { arg = refVal.Elem().Interface() return convertSimpleArgument(ci, arg) } diff --git a/values_test.go b/values_test.go index 282045d9..f15457fa 100644 --- a/values_test.go +++ b/values_test.go @@ -10,65 +10,65 @@ import ( "time" "github.com/jackc/pgx/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestDateTranscode(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - dates := []time.Time{ - time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(1600, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(1700, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), - time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(2001, 1, 2, 0, 0, 0, 0, time.UTC), - time.Date(2004, 2, 29, 0, 0, 0, 0, time.UTC), - time.Date(2013, 7, 4, 0, 0, 0, 0, time.UTC), - time.Date(2013, 12, 25, 0, 0, 0, 0, time.UTC), - time.Date(2029, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(2081, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(2096, 2, 29, 0, 0, 0, 0, time.UTC), - time.Date(2550, 1, 1, 0, 0, 0, 0, time.UTC), - time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC), - } - - for _, actualDate := range dates { - var d time.Time - - err := conn.QueryRow(context.Background(), "select $1::date", actualDate).Scan(&d) - if err != nil { - t.Fatalf("Unexpected failure on QueryRow Scan: %v", err) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + dates := []time.Time{ + time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1600, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1700, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC), + time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2001, 1, 2, 0, 0, 0, 0, time.UTC), + time.Date(2004, 2, 29, 0, 0, 0, 0, time.UTC), + time.Date(2013, 7, 4, 0, 0, 0, 0, time.UTC), + time.Date(2013, 12, 25, 0, 0, 0, 0, time.UTC), + time.Date(2029, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2081, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(2096, 2, 29, 0, 0, 0, 0, time.UTC), + time.Date(2550, 1, 1, 0, 0, 0, 0, time.UTC), + time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC), } - if !actualDate.Equal(d) { - t.Errorf("Did not transcode date successfully: %v is not %v", d, actualDate) + + for _, actualDate := range dates { + var d time.Time + + err := conn.QueryRow(context.Background(), "select $1::date", actualDate).Scan(&d) + if err != nil { + t.Fatalf("Unexpected failure on QueryRow Scan: %v", err) + } + if !actualDate.Equal(d) { + t.Errorf("Did not transcode date successfully: %v is not %v", d, actualDate) + } } - } + }) } func TestTimestampTzTranscode(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local) - inputTime := time.Date(2013, 1, 2, 3, 4, 5, 6000, time.Local) + var outputTime time.Time - var outputTime time.Time - - err := conn.QueryRow(context.Background(), "select $1::timestamptz", inputTime).Scan(&outputTime) - if err != nil { - t.Fatalf("QueryRow Scan failed: %v", err) - } - if !inputTime.Equal(outputTime) { - t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime) - } + err := conn.QueryRow(context.Background(), "select $1::timestamptz", inputTime).Scan(&outputTime) + if err != nil { + t.Fatalf("QueryRow Scan failed: %v", err) + } + if !inputTime.Equal(outputTime) { + t.Errorf("Did not transcode time successfully: %v is not %v", outputTime, inputTime) + } + }) } // TODO - move these tests to pgtype @@ -76,6 +76,21 @@ func TestTimestampTzTranscode(t *testing.T) { func TestJSONAndJSONBTranscode(t *testing.T) { t.Parallel() + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + for _, typename := range []string{"json", "jsonb"} { + if _, ok := conn.ConnInfo().DataTypeForName(typename); !ok { + continue // No JSON/JSONB type -- must be running against old PostgreSQL + } + + testJSONString(t, conn, typename) + testJSONStringPointer(t, conn, typename) + } + }) +} + +func TestJSONAndJSONBTranscodeExtendedOnly(t *testing.T) { + t.Parallel() + conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) defer closeConn(t, conn) @@ -83,9 +98,6 @@ func TestJSONAndJSONBTranscode(t *testing.T) { if _, ok := conn.ConnInfo().DataTypeForName(typename); !ok { continue // No JSON/JSONB type -- must be running against old PostgreSQL } - - testJSONString(t, conn, typename) - testJSONStringPointer(t, conn, typename) testJSONSingleLevelStringMap(t, conn, typename) testJSONNestedMap(t, conn, typename) testJSONStringArray(t, conn, typename) @@ -93,6 +105,7 @@ func TestJSONAndJSONBTranscode(t *testing.T) { testJSONInt16ArrayFailureDueToOverflow(t, conn, typename) testJSONStruct(t, conn, typename) } + } func testJSONString(t *testing.T, conn *pgx.Conn, typename string) { @@ -231,683 +244,670 @@ func mustParseCIDR(t *testing.T, s string) *net.IPNet { func TestStringToNotTextTypeTranscode(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + input := "01086ee0-4963-4e35-9116-30c173a8d0bd" - input := "01086ee0-4963-4e35-9116-30c173a8d0bd" + var output string + err := conn.QueryRow(context.Background(), "select $1::uuid", input).Scan(&output) + if err != nil { + t.Fatal(err) + } + if input != output { + t.Errorf("uuid: Did not transcode string successfully: %s is not %s", input, output) + } - var output string - err := conn.QueryRow(context.Background(), "select $1::uuid", input).Scan(&output) - if err != nil { - t.Fatal(err) - } - if input != output { - t.Errorf("uuid: Did not transcode string successfully: %s is not %s", input, output) - } - - err = conn.QueryRow(context.Background(), "select $1::uuid", &input).Scan(&output) - if err != nil { - t.Fatal(err) - } - if input != output { - t.Errorf("uuid: Did not transcode pointer to string successfully: %s is not %s", input, output) - } + err = conn.QueryRow(context.Background(), "select $1::uuid", &input).Scan(&output) + if err != nil { + t.Fatal(err) + } + if input != output { + t.Errorf("uuid: Did not transcode pointer to string successfully: %s is not %s", input, output) + } + }) } func TestInetCIDRTranscodeIPNet(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - tests := []struct { - sql string - value *net.IPNet - }{ - {"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")}, - {"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")}, - {"select $1::inet", mustParseCIDR(t, "12.34.56.0/32")}, - {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, - {"select $1::inet", mustParseCIDR(t, "255.0.0.0/8")}, - {"select $1::inet", mustParseCIDR(t, "255.255.255.255/32")}, - {"select $1::inet", mustParseCIDR(t, "::/128")}, - {"select $1::inet", mustParseCIDR(t, "::/0")}, - {"select $1::inet", mustParseCIDR(t, "::1/128")}, - {"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, - {"select $1::cidr", mustParseCIDR(t, "0.0.0.0/32")}, - {"select $1::cidr", mustParseCIDR(t, "127.0.0.1/32")}, - {"select $1::cidr", mustParseCIDR(t, "12.34.56.0/32")}, - {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, - {"select $1::cidr", mustParseCIDR(t, "255.0.0.0/8")}, - {"select $1::cidr", mustParseCIDR(t, "255.255.255.255/32")}, - {"select $1::cidr", mustParseCIDR(t, "::/128")}, - {"select $1::cidr", mustParseCIDR(t, "::/0")}, - {"select $1::cidr", mustParseCIDR(t, "::1/128")}, - {"select $1::cidr", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, - } - - for i, tt := range tests { - var actual net.IPNet - - err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) - continue + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + tests := []struct { + sql string + value *net.IPNet + }{ + {"select $1::inet", mustParseCIDR(t, "0.0.0.0/32")}, + {"select $1::inet", mustParseCIDR(t, "127.0.0.1/32")}, + {"select $1::inet", mustParseCIDR(t, "12.34.56.0/32")}, + {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::inet", mustParseCIDR(t, "255.0.0.0/8")}, + {"select $1::inet", mustParseCIDR(t, "255.255.255.255/32")}, + {"select $1::inet", mustParseCIDR(t, "::/128")}, + {"select $1::inet", mustParseCIDR(t, "::/0")}, + {"select $1::inet", mustParseCIDR(t, "::1/128")}, + {"select $1::inet", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, + {"select $1::cidr", mustParseCIDR(t, "0.0.0.0/32")}, + {"select $1::cidr", mustParseCIDR(t, "127.0.0.1/32")}, + {"select $1::cidr", mustParseCIDR(t, "12.34.56.0/32")}, + {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::cidr", mustParseCIDR(t, "255.0.0.0/8")}, + {"select $1::cidr", mustParseCIDR(t, "255.255.255.255/32")}, + {"select $1::cidr", mustParseCIDR(t, "::/128")}, + {"select $1::cidr", mustParseCIDR(t, "::/0")}, + {"select $1::cidr", mustParseCIDR(t, "::1/128")}, + {"select $1::cidr", mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128")}, } - if actual.String() != tt.value.String() { - t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) - } + for i, tt := range tests { + var actual net.IPNet - ensureConnValid(t, conn) - } + err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + if actual.String() != tt.value.String() { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + } + }) } func TestInetCIDRTranscodeIP(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - tests := []struct { - sql string - value net.IP - }{ - {"select $1::inet", net.ParseIP("0.0.0.0")}, - {"select $1::inet", net.ParseIP("127.0.0.1")}, - {"select $1::inet", net.ParseIP("12.34.56.0")}, - {"select $1::inet", net.ParseIP("255.255.255.255")}, - {"select $1::inet", net.ParseIP("::1")}, - {"select $1::inet", net.ParseIP("2607:f8b0:4009:80b::200e")}, - {"select $1::cidr", net.ParseIP("0.0.0.0")}, - {"select $1::cidr", net.ParseIP("127.0.0.1")}, - {"select $1::cidr", net.ParseIP("12.34.56.0")}, - {"select $1::cidr", net.ParseIP("255.255.255.255")}, - {"select $1::cidr", net.ParseIP("::1")}, - {"select $1::cidr", net.ParseIP("2607:f8b0:4009:80b::200e")}, - } - - for i, tt := range tests { - var actual net.IP - - err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) - continue + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + tests := []struct { + sql string + value net.IP + }{ + {"select $1::inet", net.ParseIP("0.0.0.0")}, + {"select $1::inet", net.ParseIP("127.0.0.1")}, + {"select $1::inet", net.ParseIP("12.34.56.0")}, + {"select $1::inet", net.ParseIP("255.255.255.255")}, + {"select $1::inet", net.ParseIP("::1")}, + {"select $1::inet", net.ParseIP("2607:f8b0:4009:80b::200e")}, + {"select $1::cidr", net.ParseIP("0.0.0.0")}, + {"select $1::cidr", net.ParseIP("127.0.0.1")}, + {"select $1::cidr", net.ParseIP("12.34.56.0")}, + {"select $1::cidr", net.ParseIP("255.255.255.255")}, + {"select $1::cidr", net.ParseIP("::1")}, + {"select $1::cidr", net.ParseIP("2607:f8b0:4009:80b::200e")}, } - if !actual.Equal(tt.value) { - t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + for i, tt := range tests { + var actual net.IP + + err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + if !actual.Equal(tt.value) { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + + ensureConnValid(t, conn) } - ensureConnValid(t, conn) - } - - failTests := []struct { - sql string - value *net.IPNet - }{ - {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, - {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, - } - for i, tt := range failTests { - var actual net.IP - - err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) - if err == nil { - t.Errorf("%d. Expected failure but got none", i) - continue + failTests := []struct { + sql string + value *net.IPNet + }{ + {"select $1::inet", mustParseCIDR(t, "192.168.1.0/24")}, + {"select $1::cidr", mustParseCIDR(t, "192.168.1.0/24")}, } + for i, tt := range failTests { + var actual net.IP - ensureConnValid(t, conn) - } + err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) + if err == nil { + t.Errorf("%d. Expected failure but got none", i) + continue + } + + ensureConnValid(t, conn) + } + }) } func TestInetCIDRArrayTranscodeIPNet(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - tests := []struct { - sql string - value []*net.IPNet - }{ - { - "select $1::inet[]", - []*net.IPNet{ - mustParseCIDR(t, "0.0.0.0/32"), - mustParseCIDR(t, "127.0.0.1/32"), - mustParseCIDR(t, "12.34.56.0/32"), - mustParseCIDR(t, "192.168.1.0/24"), - mustParseCIDR(t, "255.0.0.0/8"), - mustParseCIDR(t, "255.255.255.255/32"), - mustParseCIDR(t, "::/128"), - mustParseCIDR(t, "::/0"), - mustParseCIDR(t, "::1/128"), - mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + tests := []struct { + sql string + value []*net.IPNet + }{ + { + "select $1::inet[]", + []*net.IPNet{ + mustParseCIDR(t, "0.0.0.0/32"), + mustParseCIDR(t, "127.0.0.1/32"), + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + mustParseCIDR(t, "255.0.0.0/8"), + mustParseCIDR(t, "255.255.255.255/32"), + mustParseCIDR(t, "::/128"), + mustParseCIDR(t, "::/0"), + mustParseCIDR(t, "::1/128"), + mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), + }, }, - }, - { - "select $1::cidr[]", - []*net.IPNet{ - mustParseCIDR(t, "0.0.0.0/32"), - mustParseCIDR(t, "127.0.0.1/32"), - mustParseCIDR(t, "12.34.56.0/32"), - mustParseCIDR(t, "192.168.1.0/24"), - mustParseCIDR(t, "255.0.0.0/8"), - mustParseCIDR(t, "255.255.255.255/32"), - mustParseCIDR(t, "::/128"), - mustParseCIDR(t, "::/0"), - mustParseCIDR(t, "::1/128"), - mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), + { + "select $1::cidr[]", + []*net.IPNet{ + mustParseCIDR(t, "0.0.0.0/32"), + mustParseCIDR(t, "127.0.0.1/32"), + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + mustParseCIDR(t, "255.0.0.0/8"), + mustParseCIDR(t, "255.255.255.255/32"), + mustParseCIDR(t, "::/128"), + mustParseCIDR(t, "::/0"), + mustParseCIDR(t, "::1/128"), + mustParseCIDR(t, "2607:f8b0:4009:80b::200e/128"), + }, }, - }, - } - - for i, tt := range tests { - var actual []*net.IPNet - - err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) - continue } - if !reflect.DeepEqual(actual, tt.value) { - t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) - } + for i, tt := range tests { + var actual []*net.IPNet - ensureConnValid(t, conn) - } + err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + if !reflect.DeepEqual(actual, tt.value) { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + + ensureConnValid(t, conn) + } + }) } func TestInetCIDRArrayTranscodeIP(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - tests := []struct { - sql string - value []net.IP - }{ - { - "select $1::inet[]", - []net.IP{ - net.ParseIP("0.0.0.0"), - net.ParseIP("127.0.0.1"), - net.ParseIP("12.34.56.0"), - net.ParseIP("255.255.255.255"), - net.ParseIP("2607:f8b0:4009:80b::200e"), + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + tests := []struct { + sql string + value []net.IP + }{ + { + "select $1::inet[]", + []net.IP{ + net.ParseIP("0.0.0.0"), + net.ParseIP("127.0.0.1"), + net.ParseIP("12.34.56.0"), + net.ParseIP("255.255.255.255"), + net.ParseIP("2607:f8b0:4009:80b::200e"), + }, }, - }, - { - "select $1::cidr[]", - []net.IP{ - net.ParseIP("0.0.0.0"), - net.ParseIP("127.0.0.1"), - net.ParseIP("12.34.56.0"), - net.ParseIP("255.255.255.255"), - net.ParseIP("2607:f8b0:4009:80b::200e"), + { + "select $1::cidr[]", + []net.IP{ + net.ParseIP("0.0.0.0"), + net.ParseIP("127.0.0.1"), + net.ParseIP("12.34.56.0"), + net.ParseIP("255.255.255.255"), + net.ParseIP("2607:f8b0:4009:80b::200e"), + }, }, - }, - } - - for i, tt := range tests { - var actual []net.IP - - err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) - continue } - if !reflect.DeepEqual(actual, tt.value) { - t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + for i, tt := range tests { + var actual []net.IP + + err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + assert.Equal(t, len(tt.value), len(actual), "%d", i) + for j := range actual { + assert.True(t, actual[j].Equal(tt.value[j]), "%d", i) + } + + ensureConnValid(t, conn) } - ensureConnValid(t, conn) - } - - failTests := []struct { - sql string - value []*net.IPNet - }{ - { - "select $1::inet[]", - []*net.IPNet{ - mustParseCIDR(t, "12.34.56.0/32"), - mustParseCIDR(t, "192.168.1.0/24"), + failTests := []struct { + sql string + value []*net.IPNet + }{ + { + "select $1::inet[]", + []*net.IPNet{ + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + }, }, - }, - { - "select $1::cidr[]", - []*net.IPNet{ - mustParseCIDR(t, "12.34.56.0/32"), - mustParseCIDR(t, "192.168.1.0/24"), + { + "select $1::cidr[]", + []*net.IPNet{ + mustParseCIDR(t, "12.34.56.0/32"), + mustParseCIDR(t, "192.168.1.0/24"), + }, }, - }, - } - - for i, tt := range failTests { - var actual []net.IP - - err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) - if err == nil { - t.Errorf("%d. Expected failure but got none", i) - continue } - ensureConnValid(t, conn) - } + for i, tt := range failTests { + var actual []net.IP + + err := conn.QueryRow(context.Background(), tt.sql, tt.value).Scan(&actual) + if err == nil { + t.Errorf("%d. Expected failure but got none", i) + continue + } + + ensureConnValid(t, conn) + } + }) } func TestInetCIDRTranscodeWithJustIP(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - tests := []struct { - sql string - value string - }{ - {"select $1::inet", "0.0.0.0/32"}, - {"select $1::inet", "127.0.0.1/32"}, - {"select $1::inet", "12.34.56.0/32"}, - {"select $1::inet", "255.255.255.255/32"}, - {"select $1::inet", "::/128"}, - {"select $1::inet", "2607:f8b0:4009:80b::200e/128"}, - {"select $1::cidr", "0.0.0.0/32"}, - {"select $1::cidr", "127.0.0.1/32"}, - {"select $1::cidr", "12.34.56.0/32"}, - {"select $1::cidr", "255.255.255.255/32"}, - {"select $1::cidr", "::/128"}, - {"select $1::cidr", "2607:f8b0:4009:80b::200e/128"}, - } - - for i, tt := range tests { - expected := mustParseCIDR(t, tt.value) - var actual net.IPNet - - err := conn.QueryRow(context.Background(), tt.sql, expected.IP).Scan(&actual) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) - continue + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + tests := []struct { + sql string + value string + }{ + {"select $1::inet", "0.0.0.0/32"}, + {"select $1::inet", "127.0.0.1/32"}, + {"select $1::inet", "12.34.56.0/32"}, + {"select $1::inet", "255.255.255.255/32"}, + {"select $1::inet", "::/128"}, + {"select $1::inet", "2607:f8b0:4009:80b::200e/128"}, + {"select $1::cidr", "0.0.0.0/32"}, + {"select $1::cidr", "127.0.0.1/32"}, + {"select $1::cidr", "12.34.56.0/32"}, + {"select $1::cidr", "255.255.255.255/32"}, + {"select $1::cidr", "::/128"}, + {"select $1::cidr", "2607:f8b0:4009:80b::200e/128"}, } - if actual.String() != expected.String() { - t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) - } + for i, tt := range tests { + expected := mustParseCIDR(t, tt.value) + var actual net.IPNet - ensureConnValid(t, conn) - } + err := conn.QueryRow(context.Background(), tt.sql, expected.IP).Scan(&actual) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, value -> %v)", i, err, tt.sql, tt.value) + continue + } + + if actual.String() != expected.String() { + t.Errorf("%d. Expected %v, got %v (sql -> %v)", i, tt.value, actual, tt.sql) + } + + ensureConnValid(t, conn) + } + }) } func TestArrayDecoding(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - tests := []struct { - sql string - query interface{} - scan interface{} - assert func(*testing.T, interface{}, interface{}) - }{ - { - "select $1::bool[]", []bool{true, false, true}, &[]bool{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]bool))) { - t.Errorf("failed to encode bool[]") - } - }, - }, - { - "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]int16))) { - t.Errorf("failed to encode smallint[]") - } - }, - }, - { - "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]uint16))) { - t.Errorf("failed to encode smallint[]") - } - }, - }, - { - "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]int32))) { - t.Errorf("failed to encode int[]") - } - }, - }, - { - "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]uint32))) { - t.Errorf("failed to encode int[]") - } - }, - }, - { - "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]int64))) { - t.Errorf("failed to encode bigint[]") - } - }, - }, - { - "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]uint64))) { - t.Errorf("failed to encode bigint[]") - } - }, - }, - { - "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]string))) { - t.Errorf("failed to encode text[]") - } - }, - }, - { - "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, - func(t *testing.T, query, scan interface{}) { - if !reflect.DeepEqual(query, *(scan.(*[]time.Time))) { - t.Errorf("failed to encode time.Time[] to timestamptz[]") - } - }, - }, - { - "select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{}, - func(t *testing.T, query, scan interface{}) { - queryBytesSliceSlice := query.([][]byte) - scanBytesSliceSlice := *(scan.(*[][]byte)) - if len(queryBytesSliceSlice) != len(scanBytesSliceSlice) { - t.Errorf("failed to encode byte[][] to bytea[]: expected %d to equal %d", len(queryBytesSliceSlice), len(scanBytesSliceSlice)) - } - for i := range queryBytesSliceSlice { - qb := queryBytesSliceSlice[i] - sb := scanBytesSliceSlice[i] - if !bytes.Equal(qb, sb) { - t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + tests := []struct { + sql string + query interface{} + scan interface{} + assert func(*testing.T, interface{}, interface{}) + }{ + { + "select $1::bool[]", []bool{true, false, true}, &[]bool{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]bool))) { + t.Errorf("failed to encode bool[]") } - } + }, + }, + { + "select $1::smallint[]", []int16{2, 4, 484, 32767}, &[]int16{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]int16))) { + t.Errorf("failed to encode smallint[]") + } + }, + }, + { + "select $1::smallint[]", []uint16{2, 4, 484, 32767}, &[]uint16{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]uint16))) { + t.Errorf("failed to encode smallint[]") + } + }, + }, + { + "select $1::int[]", []int32{2, 4, 484}, &[]int32{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]int32))) { + t.Errorf("failed to encode int[]") + } + }, + }, + { + "select $1::int[]", []uint32{2, 4, 484, 2147483647}, &[]uint32{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]uint32))) { + t.Errorf("failed to encode int[]") + } + }, + }, + { + "select $1::bigint[]", []int64{2, 4, 484, 9223372036854775807}, &[]int64{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]int64))) { + t.Errorf("failed to encode bigint[]") + } + }, + }, + { + "select $1::bigint[]", []uint64{2, 4, 484, 9223372036854775807}, &[]uint64{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]uint64))) { + t.Errorf("failed to encode bigint[]") + } + }, + }, + { + "select $1::text[]", []string{"it's", "over", "9000!"}, &[]string{}, + func(t *testing.T, query, scan interface{}) { + if !reflect.DeepEqual(query, *(scan.(*[]string))) { + t.Errorf("failed to encode text[]") + } + }, + }, + { + "select $1::timestamptz[]", []time.Time{time.Unix(323232, 0), time.Unix(3239949334, 00)}, &[]time.Time{}, + func(t *testing.T, query, scan interface{}) { + queryTimeSlice := query.([]time.Time) + scanTimeSlice := *(scan.(*[]time.Time)) + require.Equal(t, len(queryTimeSlice), len(scanTimeSlice)) + for i := range queryTimeSlice { + assert.Truef(t, queryTimeSlice[i].Equal(scanTimeSlice[i]), "%d", i) + } + }, + }, + { + "select $1::bytea[]", [][]byte{{0, 1, 2, 3}, {4, 5, 6, 7}}, &[][]byte{}, + func(t *testing.T, query, scan interface{}) { + queryBytesSliceSlice := query.([][]byte) + scanBytesSliceSlice := *(scan.(*[][]byte)) + if len(queryBytesSliceSlice) != len(scanBytesSliceSlice) { + t.Errorf("failed to encode byte[][] to bytea[]: expected %d to equal %d", len(queryBytesSliceSlice), len(scanBytesSliceSlice)) + } + for i := range queryBytesSliceSlice { + qb := queryBytesSliceSlice[i] + sb := scanBytesSliceSlice[i] + if !bytes.Equal(qb, sb) { + t.Errorf("failed to encode byte[][] to bytea[]: expected %v to equal %v", qb, sb) + } + } + }, }, - }, - } - - for i, tt := range tests { - err := conn.QueryRow(context.Background(), tt.sql, tt.query).Scan(tt.scan) - if err != nil { - t.Errorf(`%d. error reading array: %v`, i, err) - continue } - tt.assert(t, tt.query, tt.scan) - ensureConnValid(t, conn) - } + + for i, tt := range tests { + err := conn.QueryRow(context.Background(), tt.sql, tt.query).Scan(tt.scan) + if err != nil { + t.Errorf(`%d. error reading array: %v`, i, err) + continue + } + tt.assert(t, tt.query, tt.scan) + ensureConnValid(t, conn) + } + }) } func TestEmptyArrayDecoding(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + var val []string - var val []string - - err := conn.QueryRow(context.Background(), "select array[]::text[]").Scan(&val) - if err != nil { - t.Errorf(`error reading array: %v`, err) - } - if len(val) != 0 { - t.Errorf("Expected 0 values, got %d", len(val)) - } - - var n, m int32 - - err = conn.QueryRow(context.Background(), "select 1::integer, array[]::text[], 42::integer").Scan(&n, &val, &m) - if err != nil { - t.Errorf(`error reading array: %v`, err) - } - if len(val) != 0 { - t.Errorf("Expected 0 values, got %d", len(val)) - } - if n != 1 { - t.Errorf("Expected n to be 1, but it was %d", n) - } - if m != 42 { - t.Errorf("Expected n to be 42, but it was %d", n) - } - - rows, err := conn.Query(context.Background(), "select 1::integer, array['test']::text[] union select 2::integer, array[]::text[] union select 3::integer, array['test']::text[]") - if err != nil { - t.Errorf(`error retrieving rows with array: %v`, err) - } - defer rows.Close() - - for rows.Next() { - err = rows.Scan(&n, &val) + err := conn.QueryRow(context.Background(), "select array[]::text[]").Scan(&val) if err != nil { t.Errorf(`error reading array: %v`, err) } - } + if len(val) != 0 { + t.Errorf("Expected 0 values, got %d", len(val)) + } - ensureConnValid(t, conn) + var n, m int32 + + err = conn.QueryRow(context.Background(), "select 1::integer, array[]::text[], 42::integer").Scan(&n, &val, &m) + if err != nil { + t.Errorf(`error reading array: %v`, err) + } + if len(val) != 0 { + t.Errorf("Expected 0 values, got %d", len(val)) + } + if n != 1 { + t.Errorf("Expected n to be 1, but it was %d", n) + } + if m != 42 { + t.Errorf("Expected n to be 42, but it was %d", n) + } + + rows, err := conn.Query(context.Background(), "select 1::integer, array['test']::text[] union select 2::integer, array[]::text[] union select 3::integer, array['test']::text[]") + if err != nil { + t.Errorf(`error retrieving rows with array: %v`, err) + } + defer rows.Close() + + for rows.Next() { + err = rows.Scan(&n, &val) + if err != nil { + t.Errorf(`error reading array: %v`, err) + } + } + }) } func TestPointerPointer(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) - - type allTypes struct { - s *string - i16 *int16 - i32 *int32 - i64 *int64 - f32 *float32 - f64 *float64 - b *bool - t *time.Time - } - - var actual, zero, expected allTypes - - { - s := "foo" - expected.s = &s - i16 := int16(1) - expected.i16 = &i16 - i32 := int32(1) - expected.i32 = &i32 - i64 := int64(1) - expected.i64 = &i64 - f32 := float32(1.23) - expected.f32 = &f32 - f64 := float64(1.23) - expected.f64 = &f64 - b := true - expected.b = &b - t := time.Unix(123, 5000) - expected.t = &t - } - - tests := []struct { - sql string - queryArgs []interface{} - scanArgs []interface{} - expected allTypes - }{ - {"select $1::text", []interface{}{expected.s}, []interface{}{&actual.s}, allTypes{s: expected.s}}, - {"select $1::text", []interface{}{zero.s}, []interface{}{&actual.s}, allTypes{}}, - {"select $1::int2", []interface{}{expected.i16}, []interface{}{&actual.i16}, allTypes{i16: expected.i16}}, - {"select $1::int2", []interface{}{zero.i16}, []interface{}{&actual.i16}, allTypes{}}, - {"select $1::int4", []interface{}{expected.i32}, []interface{}{&actual.i32}, allTypes{i32: expected.i32}}, - {"select $1::int4", []interface{}{zero.i32}, []interface{}{&actual.i32}, allTypes{}}, - {"select $1::int8", []interface{}{expected.i64}, []interface{}{&actual.i64}, allTypes{i64: expected.i64}}, - {"select $1::int8", []interface{}{zero.i64}, []interface{}{&actual.i64}, allTypes{}}, - {"select $1::float4", []interface{}{expected.f32}, []interface{}{&actual.f32}, allTypes{f32: expected.f32}}, - {"select $1::float4", []interface{}{zero.f32}, []interface{}{&actual.f32}, allTypes{}}, - {"select $1::float8", []interface{}{expected.f64}, []interface{}{&actual.f64}, allTypes{f64: expected.f64}}, - {"select $1::float8", []interface{}{zero.f64}, []interface{}{&actual.f64}, allTypes{}}, - {"select $1::bool", []interface{}{expected.b}, []interface{}{&actual.b}, allTypes{b: expected.b}}, - {"select $1::bool", []interface{}{zero.b}, []interface{}{&actual.b}, allTypes{}}, - {"select $1::timestamptz", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}}, - {"select $1::timestamptz", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}}, - } - - for i, tt := range tests { - actual = zero - - err := conn.QueryRow(context.Background(), tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) - if err != nil { - t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + type allTypes struct { + s *string + i16 *int16 + i32 *int32 + i64 *int64 + f32 *float32 + f64 *float64 + b *bool + t *time.Time } - if !reflect.DeepEqual(actual, tt.expected) { - t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs) + var actual, zero, expected allTypes + + { + s := "foo" + expected.s = &s + i16 := int16(1) + expected.i16 = &i16 + i32 := int32(1) + expected.i32 = &i32 + i64 := int64(1) + expected.i64 = &i64 + f32 := float32(1.23) + expected.f32 = &f32 + f64 := float64(1.23) + expected.f64 = &f64 + b := true + expected.b = &b + t := time.Unix(123, 5000) + expected.t = &t } - ensureConnValid(t, conn) - } + tests := []struct { + sql string + queryArgs []interface{} + scanArgs []interface{} + expected allTypes + }{ + {"select $1::text", []interface{}{expected.s}, []interface{}{&actual.s}, allTypes{s: expected.s}}, + {"select $1::text", []interface{}{zero.s}, []interface{}{&actual.s}, allTypes{}}, + {"select $1::int2", []interface{}{expected.i16}, []interface{}{&actual.i16}, allTypes{i16: expected.i16}}, + {"select $1::int2", []interface{}{zero.i16}, []interface{}{&actual.i16}, allTypes{}}, + {"select $1::int4", []interface{}{expected.i32}, []interface{}{&actual.i32}, allTypes{i32: expected.i32}}, + {"select $1::int4", []interface{}{zero.i32}, []interface{}{&actual.i32}, allTypes{}}, + {"select $1::int8", []interface{}{expected.i64}, []interface{}{&actual.i64}, allTypes{i64: expected.i64}}, + {"select $1::int8", []interface{}{zero.i64}, []interface{}{&actual.i64}, allTypes{}}, + {"select $1::float4", []interface{}{expected.f32}, []interface{}{&actual.f32}, allTypes{f32: expected.f32}}, + {"select $1::float4", []interface{}{zero.f32}, []interface{}{&actual.f32}, allTypes{}}, + {"select $1::float8", []interface{}{expected.f64}, []interface{}{&actual.f64}, allTypes{f64: expected.f64}}, + {"select $1::float8", []interface{}{zero.f64}, []interface{}{&actual.f64}, allTypes{}}, + {"select $1::bool", []interface{}{expected.b}, []interface{}{&actual.b}, allTypes{b: expected.b}}, + {"select $1::bool", []interface{}{zero.b}, []interface{}{&actual.b}, allTypes{}}, + {"select $1::timestamptz", []interface{}{expected.t}, []interface{}{&actual.t}, allTypes{t: expected.t}}, + {"select $1::timestamptz", []interface{}{zero.t}, []interface{}{&actual.t}, allTypes{}}, + } + + for i, tt := range tests { + actual = zero + + err := conn.QueryRow(context.Background(), tt.sql, tt.queryArgs...).Scan(tt.scanArgs...) + if err != nil { + t.Errorf("%d. Unexpected failure: %v (sql -> %v, queryArgs -> %v)", i, err, tt.sql, tt.queryArgs) + } + + if !reflect.DeepEqual(actual, tt.expected) { + t.Errorf("%d. Expected %v, got %v (sql -> %v, queryArgs -> %v)", i, tt.expected, actual, tt.sql, tt.queryArgs) + } + + ensureConnValid(t, conn) + } + }) } func TestPointerPointerNonZero(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + f := "foo" + dest := &f - f := "foo" - dest := &f - - err := conn.QueryRow(context.Background(), "select $1::text", nil).Scan(&dest) - if err != nil { - t.Errorf("Unexpected failure scanning: %v", err) - } - if dest != nil { - t.Errorf("Expected dest to be nil, got %#v", dest) - } + err := conn.QueryRow(context.Background(), "select $1::text", nil).Scan(&dest) + if err != nil { + t.Errorf("Unexpected failure scanning: %v", err) + } + if dest != nil { + t.Errorf("Expected dest to be nil, got %#v", dest) + } + }) } func TestEncodeTypeRename(t *testing.T) { t.Parallel() - conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE")) - defer closeConn(t, conn) + testWithAndWithoutPreferSimpleProtocol(t, func(t *testing.T, conn *pgx.Conn) { + type _int int + inInt := _int(1) + var outInt _int - type _int int - inInt := _int(1) - var outInt _int + type _int8 int8 + inInt8 := _int8(2) + var outInt8 _int8 - type _int8 int8 - inInt8 := _int8(2) - var outInt8 _int8 + type _int16 int16 + inInt16 := _int16(3) + var outInt16 _int16 - type _int16 int16 - inInt16 := _int16(3) - var outInt16 _int16 + type _int32 int32 + inInt32 := _int32(4) + var outInt32 _int32 - type _int32 int32 - inInt32 := _int32(4) - var outInt32 _int32 + type _int64 int64 + inInt64 := _int64(5) + var outInt64 _int64 - type _int64 int64 - inInt64 := _int64(5) - var outInt64 _int64 + type _uint uint + inUint := _uint(6) + var outUint _uint - type _uint uint - inUint := _uint(6) - var outUint _uint + type _uint8 uint8 + inUint8 := _uint8(7) + var outUint8 _uint8 - type _uint8 uint8 - inUint8 := _uint8(7) - var outUint8 _uint8 + type _uint16 uint16 + inUint16 := _uint16(8) + var outUint16 _uint16 - type _uint16 uint16 - inUint16 := _uint16(8) - var outUint16 _uint16 + type _uint32 uint32 + inUint32 := _uint32(9) + var outUint32 _uint32 - type _uint32 uint32 - inUint32 := _uint32(9) - var outUint32 _uint32 + type _uint64 uint64 + inUint64 := _uint64(10) + var outUint64 _uint64 - type _uint64 uint64 - inUint64 := _uint64(10) - var outUint64 _uint64 + type _string string + inString := _string("foo") + var outString _string - type _string string - inString := _string("foo") - var outString _string + err := conn.QueryRow(context.Background(), "select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text", + inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString, + ).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString) + if err != nil { + t.Fatalf("Failed with type rename: %v", err) + } - err := conn.QueryRow(context.Background(), "select $1::int, $2::int, $3::int2, $4::int4, $5::int8, $6::int, $7::int, $8::int, $9::int, $10::int, $11::text", - inInt, inInt8, inInt16, inInt32, inInt64, inUint, inUint8, inUint16, inUint32, inUint64, inString, - ).Scan(&outInt, &outInt8, &outInt16, &outInt32, &outInt64, &outUint, &outUint8, &outUint16, &outUint32, &outUint64, &outString) - if err != nil { - t.Fatalf("Failed with type rename: %v", err) - } + if inInt != outInt { + t.Errorf("int rename: expected %v, got %v", inInt, outInt) + } - if inInt != outInt { - t.Errorf("int rename: expected %v, got %v", inInt, outInt) - } + if inInt8 != outInt8 { + t.Errorf("int8 rename: expected %v, got %v", inInt8, outInt8) + } - if inInt8 != outInt8 { - t.Errorf("int8 rename: expected %v, got %v", inInt8, outInt8) - } + if inInt16 != outInt16 { + t.Errorf("int16 rename: expected %v, got %v", inInt16, outInt16) + } - if inInt16 != outInt16 { - t.Errorf("int16 rename: expected %v, got %v", inInt16, outInt16) - } + if inInt32 != outInt32 { + t.Errorf("int32 rename: expected %v, got %v", inInt32, outInt32) + } - if inInt32 != outInt32 { - t.Errorf("int32 rename: expected %v, got %v", inInt32, outInt32) - } + if inInt64 != outInt64 { + t.Errorf("int64 rename: expected %v, got %v", inInt64, outInt64) + } - if inInt64 != outInt64 { - t.Errorf("int64 rename: expected %v, got %v", inInt64, outInt64) - } + if inUint != outUint { + t.Errorf("uint rename: expected %v, got %v", inUint, outUint) + } - if inUint != outUint { - t.Errorf("uint rename: expected %v, got %v", inUint, outUint) - } + if inUint8 != outUint8 { + t.Errorf("uint8 rename: expected %v, got %v", inUint8, outUint8) + } - if inUint8 != outUint8 { - t.Errorf("uint8 rename: expected %v, got %v", inUint8, outUint8) - } + if inUint16 != outUint16 { + t.Errorf("uint16 rename: expected %v, got %v", inUint16, outUint16) + } - if inUint16 != outUint16 { - t.Errorf("uint16 rename: expected %v, got %v", inUint16, outUint16) - } + if inUint32 != outUint32 { + t.Errorf("uint32 rename: expected %v, got %v", inUint32, outUint32) + } - if inUint32 != outUint32 { - t.Errorf("uint32 rename: expected %v, got %v", inUint32, outUint32) - } + if inUint64 != outUint64 { + t.Errorf("uint64 rename: expected %v, got %v", inUint64, outUint64) + } - if inUint64 != outUint64 { - t.Errorf("uint64 rename: expected %v, got %v", inUint64, outUint64) - } - - if inString != outString { - t.Errorf("string rename: expected %v, got %v", inString, outString) - } - - ensureConnValid(t, conn) + if inString != outString { + t.Errorf("string rename: expected %v, got %v", inString, outString) + } + }) } -func TestRowDecode(t *testing.T) { +func TestRowDecodeBinary(t *testing.T) { t.Parallel() conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))