pgx/pgconn/pgconn_test.go
Alejandro Do Nascimento Mora c4ac6d810f Use DefaultQueryExecMode in CopyFrom
CopyFrom had to create a prepared statement to get the OIDs of the data
types that were going to be copied into the table. Every COPY operation
required an extra round trips to retrieve the type information. There
was no way to customize this behavior.

By leveraging the QueryExecMode feature, like in `Conn.Query`, users can
specify if they want to cache the prepared statements, execute
them on every request (like the old behavior), or bypass the prepared
statement relying on the pgtype.Map to get the type information.

The `QueryExecMode` behave exactly like in `Conn.Query` in the way the
data type OIDs are fetched, meaning that:

- `QueryExecModeCacheStatement`: caches the statement.
- `QueryExecModeCacheDescribe`: caches the statement and assumes they do
  not change.
- `QueryExecModeDescribeExec`: gets the statement description on every
  execution. This is like to the old behavior of `CopyFrom`.
- `QueryExecModeExec` and `QueryExecModeSimpleProtocol`: maintain the
  same behavior as before, which is the same as `QueryExecModeDescribeExec`.
  It will keep getting the statement description on every execution

The `QueryExecMode` can only be set via
`ConnConfig.DefaultQueryExecMode`, unlike `Conn.Query` there's no
support for specifying the `QueryExecMode` via optional arguments
in the function signature.
2022-12-23 13:22:26 -06:00

2966 lines
86 KiB
Go

package pgconn_test
import (
"bytes"
"compress/gzip"
"context"
"crypto/tls"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"math"
"net"
"os"
"strconv"
"strings"
"testing"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/internal/pgio"
"github.com/jackc/pgx/v5/internal/pgmock"
"github.com/jackc/pgx/v5/pgconn"
"github.com/jackc/pgx/v5/pgproto3"
"github.com/jackc/pgx/v5/pgtype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestConnect(t *testing.T) {
tests := []struct {
name string
env string
}{
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
{"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
{"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
{"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
connString := os.Getenv(tt.env)
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", tt.env)
}
conn, err := pgconn.Connect(context.Background(), connString)
require.NoError(t, err)
closeConn(t, conn)
})
}
}
func TestConnectWithOptions(t *testing.T) {
tests := []struct {
name string
env string
}{
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
{"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
{"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
{"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
connString := os.Getenv(tt.env)
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", tt.env)
}
var sslOptions pgconn.ParseConfigOptions
sslOptions.GetSSLPassword = GetSSLPassword
conn, err := pgconn.ConnectWithOptions(context.Background(), connString, sslOptions)
require.NoError(t, err)
closeConn(t, conn)
})
}
}
// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure
// connection.
func TestConnectTLS(t *testing.T) {
t.Parallel()
connString := os.Getenv("PGX_TEST_TLS_CONN_STRING")
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
}
conn, err := pgconn.Connect(context.Background(), connString)
require.NoError(t, err)
result := conn.ExecParams(context.Background(), `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
require.Len(t, result.Rows, 1)
require.Len(t, result.Rows[0], 1)
require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection")
closeConn(t, conn)
}
func TestConnectTLSPasswordProtectedClientCertWithSSLPassword(t *testing.T) {
t.Parallel()
connString := os.Getenv("PGX_TEST_TLS_CLIENT_CONN_STRING")
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CLIENT_CONN_STRING")
}
if os.Getenv("PGX_SSL_PASSWORD") == "" {
t.Skipf("Skipping due to missing environment variable %v", "PGX_SSL_PASSWORD")
}
connString += " sslpassword=" + os.Getenv("PGX_SSL_PASSWORD")
conn, err := pgconn.Connect(context.Background(), connString)
require.NoError(t, err)
result := conn.ExecParams(context.Background(), `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
require.Len(t, result.Rows, 1)
require.Len(t, result.Rows[0], 1)
require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection")
closeConn(t, conn)
}
func TestConnectTLSPasswordProtectedClientCertWithGetSSLPasswordConfigOption(t *testing.T) {
t.Parallel()
connString := os.Getenv("PGX_TEST_TLS_CLIENT_CONN_STRING")
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CLIENT_CONN_STRING")
}
if os.Getenv("PGX_SSL_PASSWORD") == "" {
t.Skipf("Skipping due to missing environment variable %v", "PGX_SSL_PASSWORD")
}
var sslOptions pgconn.ParseConfigOptions
sslOptions.GetSSLPassword = GetSSLPassword
config, err := pgconn.ParseConfigWithOptions(connString, sslOptions)
require.Nil(t, err)
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
result := conn.ExecParams(context.Background(), `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
require.Len(t, result.Rows, 1)
require.Len(t, result.Rows[0], 1)
require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection")
closeConn(t, conn)
}
type pgmockWaitStep time.Duration
func (s pgmockWaitStep) Step(*pgproto3.Backend) error {
time.Sleep(time.Duration(s))
return nil
}
func TestConnectTimeout(t *testing.T) {
t.Parallel()
tests := []struct {
name string
connect func(connStr string) error
}{
{
name: "via context that times out",
connect: func(connStr string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
defer cancel()
_, err := pgconn.Connect(ctx, connStr)
return err
},
},
{
name: "via config ConnectTimeout",
connect: func(connStr string) error {
conf, err := pgconn.ParseConfig(connStr)
require.NoError(t, err)
conf.ConnectTimeout = time.Microsecond * 50
_, err = pgconn.ConnectConfig(context.Background(), conf)
return err
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
script := &pgmock.Script{
Steps: []pgmock.Step{
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
pgmockWaitStep(time.Millisecond * 500),
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
},
}
ln, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
defer ln.Close()
serverErrChan := make(chan error, 1)
go func() {
defer close(serverErrChan)
conn, err := ln.Accept()
if err != nil {
serverErrChan <- err
return
}
defer conn.Close()
err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
if err != nil {
serverErrChan <- err
return
}
err = script.Run(pgproto3.NewBackend(conn, conn))
if err != nil {
serverErrChan <- err
return
}
}()
parts := strings.Split(ln.Addr().String(), ":")
host := parts[0]
port := parts[1]
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
tooLate := time.Now().Add(time.Millisecond * 500)
err = tt.connect(connStr)
require.True(t, pgconn.Timeout(err), err)
require.True(t, time.Now().Before(tooLate))
})
}
}
func TestConnectTimeoutStuckOnTLSHandshake(t *testing.T) {
t.Parallel()
tests := []struct {
name string
connect func(connStr string) error
}{
{
name: "via context that times out",
connect: func(connStr string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10)
defer cancel()
_, err := pgconn.Connect(ctx, connStr)
return err
},
},
{
name: "via config ConnectTimeout",
connect: func(connStr string) error {
conf, err := pgconn.ParseConfig(connStr)
require.NoError(t, err)
conf.ConnectTimeout = time.Millisecond * 10
_, err = pgconn.ConnectConfig(context.Background(), conf)
return err
},
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ln, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
defer ln.Close()
serverErrChan := make(chan error, 1)
go func() {
conn, err := ln.Accept()
if err != nil {
serverErrChan <- err
return
}
defer conn.Close()
var buf []byte
_, err = conn.Read(buf)
if err != nil {
serverErrChan <- err
return
}
// Sleeping to hang the TLS handshake.
time.Sleep(time.Minute)
}()
parts := strings.Split(ln.Addr().String(), ":")
host := parts[0]
port := parts[1]
connStr := fmt.Sprintf("host=%s port=%s", host, port)
errChan := make(chan error)
go func() {
err := tt.connect(connStr)
errChan <- err
}()
select {
case err = <-errChan:
require.True(t, pgconn.Timeout(err), err)
case err = <-serverErrChan:
t.Fatalf("server failed with error: %s", err)
case <-time.After(time.Millisecond * 100):
t.Fatal("exceeded connection timeout without erroring out")
}
})
}
}
func TestConnectInvalidUser(t *testing.T) {
t.Parallel()
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
}
config, err := pgconn.ParseConfig(connString)
require.NoError(t, err)
config.User = "pgxinvalidusertest"
_, err = pgconn.ConnectConfig(context.Background(), config)
require.Error(t, err)
pgErr, ok := errors.Unwrap(err).(*pgconn.PgError)
if !ok {
t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err)
}
if pgErr.Code != "28000" && pgErr.Code != "28P01" {
t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr)
}
}
func TestConnectWithConnectionRefused(t *testing.T) {
t.Parallel()
// Presumably nothing is listening on 127.0.0.1:1
conn, err := pgconn.Connect(context.Background(), "host=127.0.0.1 port=1")
if err == nil {
conn.Close(context.Background())
t.Fatal("Expected error establishing connection to bad port")
}
}
func TestConnectCustomDialer(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
dialed := false
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
dialed = true
return net.Dial(network, address)
}
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
require.True(t, dialed)
closeConn(t, conn)
}
func TestConnectCustomLookup(t *testing.T) {
t.Parallel()
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
}
config, err := pgconn.ParseConfig(connString)
require.NoError(t, err)
looked := false
config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) {
looked = true
return net.LookupHost(host)
}
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
require.True(t, looked)
closeConn(t, conn)
}
func TestConnectCustomLookupWithPort(t *testing.T) {
t.Parallel()
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
}
config, err := pgconn.ParseConfig(connString)
require.NoError(t, err)
origPort := config.Port
// Chnage the config an invalid port so it will fail if used
config.Port = 0
looked := false
config.LookupFunc = func(ctx context.Context, host string) ([]string, error) {
looked = true
addrs, err := net.LookupHost(host)
if err != nil {
return nil, err
}
for i := range addrs {
addrs[i] = net.JoinHostPort(addrs[i], strconv.FormatUint(uint64(origPort), 10))
}
return addrs, nil
}
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
require.True(t, looked)
closeConn(t, conn)
}
func TestConnectWithRuntimeParams(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.RuntimeParams = map[string]string{
"application_name": "pgxtest",
"search_path": "myschema",
}
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer closeConn(t, conn)
result := conn.ExecParams(context.Background(), "show application_name", nil, nil, nil, nil).Read()
require.Nil(t, result.Err)
assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "pgxtest", string(result.Rows[0][0]))
result = conn.ExecParams(context.Background(), "show search_path", nil, nil, nil, nil).Read()
require.Nil(t, result.Err)
assert.Equal(t, 1, len(result.Rows))
assert.Equal(t, "myschema", string(result.Rows[0][0]))
}
func TestConnectWithFallback(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
// Prepend current primary config to fallbacks
config.Fallbacks = append([]*pgconn.FallbackConfig{
&pgconn.FallbackConfig{
Host: config.Host,
Port: config.Port,
TLSConfig: config.TLSConfig,
},
}, config.Fallbacks...)
// Make primary config bad
config.Host = "localhost"
config.Port = 1 // presumably nothing listening here
// Prepend bad first fallback
config.Fallbacks = append([]*pgconn.FallbackConfig{
&pgconn.FallbackConfig{
Host: "localhost",
Port: 1,
TLSConfig: config.TLSConfig,
},
}, config.Fallbacks...)
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
closeConn(t, conn)
}
func TestConnectWithValidateConnect(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
dialCount := 0
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
dialCount++
return net.Dial(network, address)
}
acceptConnCount := 0
config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
acceptConnCount++
if acceptConnCount < 2 {
return errors.New("reject first conn")
}
return nil
}
// Append current primary config to fallbacks
config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{
Host: config.Host,
Port: config.Port,
TLSConfig: config.TLSConfig,
})
// Repeat fallbacks
config.Fallbacks = append(config.Fallbacks, config.Fallbacks...)
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
closeConn(t, conn)
assert.True(t, dialCount > 1)
assert.True(t, acceptConnCount > 1)
}
func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite
config.RuntimeParams["default_transaction_read_only"] = "on"
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
conn, err := pgconn.ConnectConfig(ctx, config)
if !assert.NotNil(t, err) {
conn.Close(ctx)
}
}
func TestConnectWithAfterConnect(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
_, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll()
return err
}
conn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
results, err := conn.Exec(context.Background(), "show search_path;").ReadAll()
require.NoError(t, err)
defer closeConn(t, conn)
assert.Equal(t, []byte("foobar"), results[0].Rows[0][0])
}
func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) {
t.Parallel()
config := &pgconn.Config{}
require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(context.Background(), config) })
}
func TestConnPrepareSyntaxError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
psd, err := pgConn.Prepare(context.Background(), "ps1", "SYNTAX ERROR", nil)
require.Nil(t, psd)
require.NotNil(t, err)
ensureConnValid(t, pgConn)
}
func TestConnPrepareContextPrecanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithCancel(context.Background())
cancel()
psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil)
assert.Nil(t, psd)
assert.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, pgconn.SafeToRetry(err))
ensureConnValid(t, pgConn)
}
func TestConnExec(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
results, err := pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
assert.NoError(t, err)
assert.Len(t, results, 1)
assert.Nil(t, results[0].Err)
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
ensureConnValid(t, pgConn)
}
func TestConnExecEmpty(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
multiResult := pgConn.Exec(context.Background(), ";")
resultCount := 0
for multiResult.NextResult() {
resultCount++
multiResult.ResultReader().Close()
}
assert.Equal(t, 0, resultCount)
err = multiResult.Close()
assert.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestConnExecMultipleQueries(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
results, err := pgConn.Exec(context.Background(), "select 'Hello, world'; select 1").ReadAll()
assert.NoError(t, err)
assert.Len(t, results, 2)
assert.Nil(t, results[0].Err)
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
assert.Nil(t, results[1].Err)
assert.Equal(t, "SELECT 1", results[1].CommandTag.String())
assert.Len(t, results[1].Rows, 1)
assert.Equal(t, "1", string(results[1].Rows[0][0]))
ensureConnValid(t, pgConn)
}
func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
mrr := pgConn.Exec(context.Background(), "select 'Hello, world' as msg; select 1 as num")
require.True(t, mrr.NextResult())
require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
assert.Equal(t, "msg", mrr.ResultReader().FieldDescriptions()[0].Name)
_, err = mrr.ResultReader().Close()
require.NoError(t, err)
require.True(t, mrr.NextResult())
require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
assert.Equal(t, "num", mrr.ResultReader().FieldDescriptions()[0].Name)
_, err = mrr.ResultReader().Close()
require.NoError(t, err)
require.False(t, mrr.NextResult())
require.NoError(t, mrr.Close())
ensureConnValid(t, pgConn)
}
func TestConnExecMultipleQueriesError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
results, err := pgConn.Exec(context.Background(), "select 1; select 1/0; select 1").ReadAll()
require.NotNil(t, err)
if pgErr, ok := err.(*pgconn.PgError); ok {
assert.Equal(t, "22012", pgErr.Code)
} else {
t.Errorf("unexpected error: %v", err)
}
if pgConn.ParameterStatus("crdb_version") != "" {
// CockroachDB starts the second query result set and then sends the divide by zero error.
require.Len(t, results, 2)
assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "1", string(results[0].Rows[0][0]))
assert.Len(t, results[1].Rows, 0)
} else {
// PostgreSQL sends the divide by zero and never sends the second query result set.
require.Len(t, results, 1)
assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "1", string(results[0].Rows[0][0]))
}
ensureConnValid(t, pgConn)
}
func TestConnExecDeferredError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
}
setupSQL := `create temporary table t (
id text primary key,
n int not null,
unique (n) deferrable initially deferred
);
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
assert.NoError(t, err)
_, err = pgConn.Exec(context.Background(), `update t set n=n+1 where id='b' returning *`).ReadAll()
require.NotNil(t, err)
var pgErr *pgconn.PgError
require.True(t, errors.As(err, &pgErr))
require.Equal(t, "23505", pgErr.Code)
ensureConnValid(t, pgConn)
}
func TestConnExecContextCanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(1)")
for multiResult.NextResult() {
}
err = multiResult.Close()
assert.True(t, pgconn.Timeout(err))
assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.True(t, pgConn.IsClosed())
select {
case <-pgConn.CleanupDone():
case <-time.After(5 * time.Second):
t.Fatal("Connection cleanup exceeded maximum time")
}
}
func TestConnExecContextPrecanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
assert.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, pgconn.SafeToRetry(err))
ensureConnValid(t, pgConn)
}
func TestConnExecParams(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
require.Len(t, result.FieldDescriptions(), 1)
assert.Equal(t, "msg", result.FieldDescriptions()[0].Name)
rowCount := 0
for result.NextRow() {
rowCount += 1
assert.Equal(t, "Hello, world", string(result.Values()[0]))
}
assert.Equal(t, 1, rowCount)
commandTag, err := result.Close()
assert.Equal(t, "SELECT 1", commandTag.String())
assert.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestConnExecParamsDeferredError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
}
setupSQL := `create temporary table t (
id text primary key,
n int not null,
unique (n) deferrable initially deferred
);
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
assert.NoError(t, err)
result := pgConn.ExecParams(context.Background(), `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read()
require.NotNil(t, result.Err)
var pgErr *pgconn.PgError
require.True(t, errors.As(result.Err, &pgErr))
require.Equal(t, "23505", pgErr.Code)
ensureConnValid(t, pgConn)
}
func TestConnExecParamsMaxNumberOfParams(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
paramCount := math.MaxUint16
params := make([]string, 0, paramCount)
args := make([][]byte, 0, paramCount)
for i := 0; i < paramCount; i++ {
params = append(params, fmt.Sprintf("($%d::text)", i+1))
args = append(args, []byte(strconv.Itoa(i)))
}
sql := "values" + strings.Join(params, ", ")
result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read()
require.NoError(t, result.Err)
require.Len(t, result.Rows, paramCount)
ensureConnValid(t, pgConn)
}
func TestConnExecParamsTooManyParams(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
paramCount := math.MaxUint16 + 1
params := make([]string, 0, paramCount)
args := make([][]byte, 0, paramCount)
for i := 0; i < paramCount; i++ {
params = append(params, fmt.Sprintf("($%d::text)", i+1))
args = append(args, []byte(strconv.Itoa(i)))
}
sql := "values" + strings.Join(params, ", ")
result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read()
require.Error(t, result.Err)
require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error())
ensureConnValid(t, pgConn)
}
func TestConnExecParamsCanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil)
rowCount := 0
for result.NextRow() {
rowCount += 1
}
assert.Equal(t, 0, rowCount)
commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag{}, commandTag)
assert.True(t, pgconn.Timeout(err))
assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.True(t, pgConn.IsClosed())
select {
case <-pgConn.CleanupDone():
case <-time.After(5 * time.Second):
t.Fatal("Connection cleanup exceeded maximum time")
}
}
func TestConnExecParamsPrecanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithCancel(context.Background())
cancel()
result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
require.Error(t, result.Err)
assert.True(t, errors.Is(result.Err, context.Canceled))
assert.True(t, pgconn.SafeToRetry(result.Err))
ensureConnValid(t, pgConn)
}
func TestConnExecParamsEmptySQL(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read()
assert.Equal(t, pgconn.CommandTag{}, result.CommandTag)
assert.Len(t, result.Rows, 0)
assert.NoError(t, result.Err)
ensureConnValid(t, pgConn)
}
// https://github.com/jackc/pgx/issues/859
func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
result := pgConn.ExecParams(context.Background(), "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
require.Len(t, result.FieldDescriptions(), 1)
assert.Equal(t, "msg", result.FieldDescriptions()[0].Name)
rowCount := 0
for result.NextRow() {
rowCount += 1
assert.Equal(t, "Hello, world", string(result.Values()[0]))
assert.Equal(t, len(result.Values()[0]), cap(result.Values()[0]))
}
assert.Equal(t, 1, rowCount)
commandTag, err := result.Close()
assert.Equal(t, "SELECT 1", commandTag.String())
assert.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestConnExecPrepared(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
psd, err := pgConn.Prepare(context.Background(), "ps1", "select $1::text as msg", nil)
require.NoError(t, err)
require.NotNil(t, psd)
assert.Len(t, psd.ParamOIDs, 1)
assert.Len(t, psd.Fields, 1)
result := pgConn.ExecPrepared(context.Background(), "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
require.Len(t, result.FieldDescriptions(), 1)
assert.Equal(t, "msg", result.FieldDescriptions()[0].Name)
rowCount := 0
for result.NextRow() {
rowCount += 1
assert.Equal(t, "Hello, world", string(result.Values()[0]))
}
assert.Equal(t, 1, rowCount)
commandTag, err := result.Close()
assert.Equal(t, "SELECT 1", commandTag.String())
assert.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestConnExecPreparedMaxNumberOfParams(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
paramCount := math.MaxUint16
params := make([]string, 0, paramCount)
args := make([][]byte, 0, paramCount)
for i := 0; i < paramCount; i++ {
params = append(params, fmt.Sprintf("($%d::text)", i+1))
args = append(args, []byte(strconv.Itoa(i)))
}
sql := "values" + strings.Join(params, ", ")
psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil)
require.NoError(t, err)
require.NotNil(t, psd)
assert.Len(t, psd.ParamOIDs, paramCount)
assert.Len(t, psd.Fields, 1)
result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read()
require.NoError(t, result.Err)
require.Len(t, result.Rows, paramCount)
ensureConnValid(t, pgConn)
}
func TestConnExecPreparedTooManyParams(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
paramCount := math.MaxUint16 + 1
params := make([]string, 0, paramCount)
args := make([][]byte, 0, paramCount)
for i := 0; i < paramCount; i++ {
params = append(params, fmt.Sprintf("($%d::text)", i+1))
args = append(args, []byte(strconv.Itoa(i)))
}
sql := "values" + strings.Join(params, ", ")
psd, err := pgConn.Prepare(context.Background(), "ps1", sql, nil)
if pgConn.ParameterStatus("crdb_version") != "" {
// CockroachDB rejects preparing a statement with more than 65535 parameters.
require.EqualError(t, err, "ERROR: more than 65535 arguments to prepared statement: 65536 (SQLSTATE 08P01)")
} else {
// PostgreSQL accepts preparing a statement with more than 65535 parameters and only fails when executing it through the extended protocol.
require.NoError(t, err)
require.NotNil(t, psd)
assert.Len(t, psd.ParamOIDs, paramCount)
assert.Len(t, psd.Fields, 1)
result := pgConn.ExecPrepared(context.Background(), "ps1", args, nil, nil).Read()
require.EqualError(t, result.Err, "extended protocol limited to 65535 parameters")
}
ensureConnValid(t, pgConn)
}
func TestConnExecPreparedCanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil)
rowCount := 0
for result.NextRow() {
rowCount += 1
}
assert.Equal(t, 0, rowCount)
commandTag, err := result.Close()
assert.Equal(t, pgconn.CommandTag{}, commandTag)
assert.True(t, pgconn.Timeout(err))
assert.True(t, pgConn.IsClosed())
select {
case <-pgConn.CleanupDone():
case <-time.After(5 * time.Second):
t.Fatal("Connection cleanup exceeded maximum time")
}
}
func TestConnExecPreparedPrecanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Prepare(context.Background(), "ps1", "select current_database(), pg_sleep(1)", nil)
require.NoError(t, err)
ctx, cancel := context.WithCancel(context.Background())
cancel()
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
require.Error(t, result.Err)
assert.True(t, errors.Is(result.Err, context.Canceled))
assert.True(t, pgconn.SafeToRetry(result.Err))
ensureConnValid(t, pgConn)
}
func TestConnExecPreparedEmptySQL(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Prepare(ctx, "ps1", "", nil)
require.NoError(t, err)
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
assert.Equal(t, pgconn.CommandTag{}, result.CommandTag)
assert.Len(t, result.Rows, 0)
assert.NoError(t, result.Err)
ensureConnValid(t, pgConn)
}
func TestConnExecBatch(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
require.NoError(t, err)
batch := &pgconn.Batch{}
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil)
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil)
results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll()
require.NoError(t, err)
require.Len(t, results, 3)
require.Len(t, results[0].Rows, 1)
require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0]))
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
require.Len(t, results[1].Rows, 1)
require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0]))
assert.Equal(t, "SELECT 1", results[1].CommandTag.String())
require.Len(t, results[2].Rows, 1)
require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0]))
assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
}
func TestConnExecBatchDeferredError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
}
setupSQL := `create temporary table t (
id text primary key,
n int not null,
unique (n) deferrable initially deferred
);
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
require.NoError(t, err)
batch := &pgconn.Batch{}
batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil)
_, err = pgConn.ExecBatch(context.Background(), batch).ReadAll()
require.NotNil(t, err)
var pgErr *pgconn.PgError
require.True(t, errors.As(err, &pgErr))
require.Equal(t, "23505", pgErr.Code)
ensureConnValid(t, pgConn)
}
func TestConnExecBatchPrecanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Prepare(context.Background(), "ps1", "select $1::text", nil)
require.NoError(t, err)
batch := &pgconn.Batch{}
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil)
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil)
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, pgconn.SafeToRetry(err))
ensureConnValid(t, pgConn)
}
// Without concurrent reading and writing large batches can deadlock.
//
// See https://github.com/jackc/pgx/issues/374.
func TestConnExecBatchHuge(t *testing.T) {
if testing.Short() {
t.Skip("skipping test in short mode.")
}
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
batch := &pgconn.Batch{}
queryCount := 100000
args := make([]string, queryCount)
for i := range args {
args[i] = strconv.Itoa(i)
batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil)
}
results, err := pgConn.ExecBatch(context.Background(), batch).ReadAll()
require.NoError(t, err)
require.Len(t, results, queryCount)
for i := range args {
require.Len(t, results[i].Rows, 1)
require.Equal(t, args[i], string(results[i].Rows[0][0]))
assert.Equal(t, "SELECT 1", results[i].CommandTag.String())
}
}
func TestConnExecBatchImplicitTransaction(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/44803)")
}
_, err = pgConn.Exec(context.Background(), "create temporary table t(id int)").ReadAll()
require.NoError(t, err)
batch := &pgconn.Batch{}
batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil)
batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil)
batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil)
batch.ExecParams("select 1/0", nil, nil, nil, nil)
_, err = pgConn.ExecBatch(context.Background(), batch).ReadAll()
require.Error(t, err)
result := pgConn.ExecParams(context.Background(), "select count(*) from t", nil, nil, nil, nil).Read()
require.Equal(t, "0", string(result.Rows[0][0]))
}
func TestConnLocking(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
mrr := pgConn.Exec(context.Background(), "select 'Hello, world'")
_, err = pgConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
assert.Error(t, err)
assert.Equal(t, "conn busy", err.Error())
assert.True(t, pgconn.SafeToRetry(err))
results, err := mrr.ReadAll()
assert.NoError(t, err)
assert.Len(t, results, 1)
assert.Nil(t, results[0].Err)
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
ensureConnValid(t, pgConn)
}
func TestConnOnNotice(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
var msg string
config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) {
msg = notice.Message
}
config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the message we expect.
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer closeConn(t, pgConn)
if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server does not support PL/PGSQL (https://github.com/cockroachdb/cockroach/issues/17511)")
}
multiResult := pgConn.Exec(context.Background(), `do $$
begin
raise notice 'hello, world';
end$$;`)
err = multiResult.Close()
require.NoError(t, err)
assert.Equal(t, "hello, world", msg)
ensureConnValid(t, pgConn)
}
func TestConnOnNotification(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
var msg string
config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
msg = n.Payload
}
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer closeConn(t, pgConn)
if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
}
_, err = pgConn.Exec(context.Background(), "listen foo").ReadAll()
require.NoError(t, err)
notifier, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer closeConn(t, notifier)
_, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll()
require.NoError(t, err)
_, err = pgConn.Exec(context.Background(), "select 1").ReadAll()
require.NoError(t, err)
assert.Equal(t, "bar", msg)
ensureConnValid(t, pgConn)
}
func TestConnWaitForNotification(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
var msg string
config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
msg = n.Payload
}
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer closeConn(t, pgConn)
if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
}
_, err = pgConn.Exec(context.Background(), "listen foo").ReadAll()
require.NoError(t, err)
notifier, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer closeConn(t, notifier)
_, err = notifier.Exec(context.Background(), "notify foo, 'bar'").ReadAll()
require.NoError(t, err)
err = pgConn.WaitForNotification(context.Background())
require.NoError(t, err)
assert.Equal(t, "bar", msg)
ensureConnValid(t, pgConn)
}
func TestConnWaitForNotificationPrecanceled(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithCancel(context.Background())
cancel()
err = pgConn.WaitForNotification(ctx)
require.ErrorIs(t, err, context.Canceled)
ensureConnValid(t, pgConn)
}
func TestConnWaitForNotificationTimeout(t *testing.T) {
t.Parallel()
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
defer closeConn(t, pgConn)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
err = pgConn.WaitForNotification(ctx)
cancel()
assert.True(t, pgconn.Timeout(err))
assert.ErrorIs(t, err, context.DeadlineExceeded)
ensureConnValid(t, pgConn)
}
func TestConnCopyToSmall(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server does support COPY TO")
}
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int2,
b int4,
c int8,
d varchar,
e text,
f date,
g json
)`).ReadAll()
require.NoError(t, err)
_, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll()
require.NoError(t, err)
_, err = pgConn.Exec(context.Background(), `insert into foo values (null, null, null, null, null, null, null)`).ReadAll()
require.NoError(t, err)
inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" +
"\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout")
require.NoError(t, err)
assert.Equal(t, int64(2), res.RowsAffected())
assert.Equal(t, inputBytes, outputWriter.Bytes())
ensureConnValid(t, pgConn)
}
func TestConnCopyToLarge(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server does support COPY TO")
}
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int2,
b int4,
c int8,
d varchar,
e text,
f date,
g json,
h bytea
)`).ReadAll()
require.NoError(t, err)
inputBytes := make([]byte, 0)
for i := 0; i < 1000; i++ {
_, err = pgConn.Exec(context.Background(), `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll()
require.NoError(t, err)
inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...)
}
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
res, err := pgConn.CopyTo(context.Background(), outputWriter, "copy foo to stdout")
require.NoError(t, err)
assert.Equal(t, int64(1000), res.RowsAffected())
assert.Equal(t, inputBytes, outputWriter.Bytes())
ensureConnValid(t, pgConn)
}
func TestConnCopyToQueryError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
outputWriter := bytes.NewBuffer(make([]byte, 0))
res, err := pgConn.CopyTo(context.Background(), outputWriter, "cropy foo to stdout")
require.Error(t, err)
assert.IsType(t, &pgconn.PgError{}, err)
assert.Equal(t, int64(0), res.RowsAffected())
ensureConnValid(t, pgConn)
}
func TestConnCopyToCanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)")
}
outputWriter := &bytes.Buffer{}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
assert.Error(t, err)
assert.Equal(t, pgconn.CommandTag{}, res)
assert.True(t, pgConn.IsClosed())
select {
case <-pgConn.CleanupDone():
case <-time.After(5 * time.Second):
t.Fatal("Connection cleanup exceeded maximum time")
}
}
func TestConnCopyToPrecanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
outputWriter := &bytes.Buffer{}
ctx, cancel := context.WithCancel(context.Background())
cancel()
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, pgconn.SafeToRetry(err))
assert.Equal(t, pgconn.CommandTag{}, res)
ensureConnValid(t, pgConn)
}
func TestConnCopyFrom(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int4,
b varchar
)`).ReadAll()
require.NoError(t, err)
srcBuf := &bytes.Buffer{}
inputRows := [][][]byte{}
for i := 0; i < 1000; i++ {
a := strconv.Itoa(i)
b := "foo " + a + " bar"
inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
_, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
require.NoError(t, err)
}
copySql := "COPY foo FROM STDIN WITH (FORMAT csv)"
if pgConn.ParameterStatus("crdb_version") != "" {
copySql = "COPY foo FROM STDIN WITH CSV"
}
ct, err := pgConn.CopyFrom(context.Background(), srcBuf, copySql)
require.NoError(t, err)
assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
assert.Equal(t, inputRows, result.Rows)
ensureConnValid(t, pgConn)
}
func TestConnCopyFromBinary(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int4,
b varchar
)`).ReadAll()
require.NoError(t, err)
buf := []byte{}
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
buf = pgio.AppendInt32(buf, 0)
buf = pgio.AppendInt32(buf, 0)
inputRows := [][][]byte{}
for i := 0; i < 1000; i++ {
// Number of elements in the tuple
buf = pgio.AppendInt16(buf, int16(2))
a := i
// Length of element for column `a int4`
buf = pgio.AppendInt32(buf, 4)
buf, err = pgtype.NewMap().Encode(pgtype.Int4OID, pgx.BinaryFormatCode, a, buf)
require.NoError(t, err)
b := "foo " + strconv.Itoa(a) + " bar"
lenB := int32(len([]byte(b)))
// Length of element for column `b varchar`
buf = pgio.AppendInt32(buf, lenB)
buf, err = pgtype.NewMap().Encode(pgtype.VarcharOID, pgx.BinaryFormatCode, b, buf)
require.NoError(t, err)
inputRows = append(inputRows, [][]byte{[]byte(strconv.Itoa(a)), []byte(b)})
}
srcBuf := &bytes.Buffer{}
srcBuf.Write(buf)
ct, err := pgConn.CopyFrom(context.Background(), srcBuf, "COPY foo (a, b) FROM STDIN BINARY;")
require.NoError(t, err)
assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
assert.Equal(t, inputRows, result.Rows)
ensureConnValid(t, pgConn)
}
func TestConnCopyFromCanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int4,
b varchar
)`).ReadAll()
require.NoError(t, err)
r, w := io.Pipe()
go func() {
for i := 0; i < 1000000; i++ {
a := strconv.Itoa(i)
b := "foo " + a + " bar"
_, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
if err != nil {
return
}
time.Sleep(time.Microsecond)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
copySql := "COPY foo FROM STDIN WITH (FORMAT csv)"
if pgConn.ParameterStatus("crdb_version") != "" {
copySql = "COPY foo FROM STDIN WITH CSV"
}
ct, err := pgConn.CopyFrom(ctx, r, copySql)
cancel()
assert.Equal(t, int64(0), ct.RowsAffected())
assert.Error(t, err)
assert.True(t, pgConn.IsClosed())
select {
case <-pgConn.CleanupDone():
case <-time.After(5 * time.Second):
t.Fatal("Connection cleanup exceeded maximum time")
}
}
func TestConnCopyFromPrecanceled(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int4,
b varchar
)`).ReadAll()
require.NoError(t, err)
r, w := io.Pipe()
go func() {
for i := 0; i < 1000000; i++ {
a := strconv.Itoa(i)
b := "foo " + a + " bar"
_, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
if err != nil {
return
}
time.Sleep(time.Microsecond)
}
}()
ctx, cancel := context.WithCancel(context.Background())
cancel()
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
require.Error(t, err)
assert.True(t, errors.Is(err, context.Canceled))
assert.True(t, pgconn.SafeToRetry(err))
assert.Equal(t, pgconn.CommandTag{}, ct)
ensureConnValid(t, pgConn)
}
func TestConnCopyFromGzipReader(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)")
}
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int4,
b varchar
)`).ReadAll()
require.NoError(t, err)
f, err := ioutil.TempFile("", "*")
require.NoError(t, err)
gw := gzip.NewWriter(f)
inputRows := [][][]byte{}
for i := 0; i < 1000; i++ {
a := strconv.Itoa(i)
b := "foo " + a + " bar"
inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
_, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
require.NoError(t, err)
}
err = gw.Close()
require.NoError(t, err)
_, err = f.Seek(0, 0)
require.NoError(t, err)
gr, err := gzip.NewReader(f)
require.NoError(t, err)
copySql := "COPY foo FROM STDIN WITH (FORMAT csv)"
if pgConn.ParameterStatus("crdb_version") != "" {
copySql = "COPY foo FROM STDIN WITH CSV"
}
ct, err := pgConn.CopyFrom(context.Background(), gr, copySql)
require.NoError(t, err)
assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
err = gr.Close()
require.NoError(t, err)
err = f.Close()
require.NoError(t, err)
err = os.Remove(f.Name())
require.NoError(t, err)
result := pgConn.ExecParams(context.Background(), "select * from foo", nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
assert.Equal(t, inputRows, result.Rows)
ensureConnValid(t, pgConn)
}
func TestConnCopyFromQuerySyntaxError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Exec(context.Background(), `create temporary table foo(
a int4,
b varchar
)`).ReadAll()
require.NoError(t, err)
srcBuf := &bytes.Buffer{}
// Send data even though the COPY FROM command will be rejected with a syntax error. This ensures that this does not
// break the connection. See https://github.com/jackc/pgconn/pull/127 for context.
inputRows := [][][]byte{}
for i := 0; i < 1000; i++ {
a := strconv.Itoa(i)
b := "foo " + a + " bar"
inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
_, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
require.NoError(t, err)
}
res, err := pgConn.CopyFrom(context.Background(), srcBuf, "cropy foo FROM STDIN WITH (FORMAT csv)")
require.Error(t, err)
assert.IsType(t, &pgconn.PgError{}, err)
assert.Equal(t, int64(0), res.RowsAffected())
ensureConnValid(t, pgConn)
}
func TestConnCopyFromQueryNoTableError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
srcBuf := &bytes.Buffer{}
res, err := pgConn.CopyFrom(context.Background(), srcBuf, "copy foo to stdout")
require.Error(t, err)
assert.IsType(t, &pgconn.PgError{}, err)
assert.Equal(t, int64(0), res.RowsAffected())
ensureConnValid(t, pgConn)
}
// https://github.com/jackc/pgconn/issues/21
func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) {
t.Parallel()
ctx := context.Background()
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server does not support triggers (https://github.com/cockroachdb/cockroach/issues/28296)")
}
_, err = pgConn.Exec(ctx, `create temporary table sentences(
t text,
ts tsvector
)`).ReadAll()
require.NoError(t, err)
_, err = pgConn.Exec(ctx, `create function pg_temp.sentences_trigger() returns trigger as $$
begin
new.ts := to_tsvector(new.t);
return new;
end
$$ language plpgsql;`).ReadAll()
require.NoError(t, err)
_, err = pgConn.Exec(ctx, `create trigger sentences_update before insert on sentences for each row execute procedure pg_temp.sentences_trigger();`).ReadAll()
require.NoError(t, err)
longString := make([]byte, 10001)
for i := range longString {
longString[i] = 'x'
}
buf := &bytes.Buffer{}
for i := 0; i < 1000; i++ {
buf.Write([]byte(fmt.Sprintf("%s\n", string(longString))))
}
_, err = pgConn.CopyFrom(ctx, buf, "COPY sentences(t) FROM STDIN WITH (FORMAT csv)")
require.NoError(t, err)
}
type delayedReader struct {
r io.Reader
}
func (d delayedReader) Read(p []byte) (int, error) {
// W/o sleep test passes, with sleep it fails.
time.Sleep(time.Millisecond)
return d.r.Read(p)
}
// https://github.com/jackc/pgconn/issues/128
func TestConnCopyFromDataWriteAfterErrorAndReturn(t *testing.T) {
connString := os.Getenv("PGX_TEST_DATABASE")
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_DATABASE")
}
config, err := pgconn.ParseConfig(connString)
require.NoError(t, err)
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
require.NoError(t, err)
if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server does not fully support COPY FROM")
}
setupSQL := `create temporary table t (
id text primary key,
n int not null
);`
_, err = pgConn.Exec(context.Background(), setupSQL).ReadAll()
assert.NoError(t, err)
r1 := delayedReader{r: strings.NewReader(`id 0\n`)}
// Generate an error with a bogus COPY command
_, err = pgConn.CopyFrom(context.Background(), r1, "COPY nosuchtable FROM STDIN ")
assert.Error(t, err)
r2 := delayedReader{r: strings.NewReader(`id 0\n`)}
_, err = pgConn.CopyFrom(context.Background(), r2, "COPY t FROM STDIN")
assert.NoError(t, err)
}
func TestConnEscapeString(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
tests := []struct {
in string
out string
}{
{in: "", out: ""},
{in: "42", out: "42"},
{in: "'", out: "''"},
{in: "hi'there", out: "hi''there"},
{in: "'hi there'", out: "''hi there''"},
}
for i, tt := range tests {
value, err := pgConn.EscapeString(tt.in)
if assert.NoErrorf(t, err, "%d.", i) {
assert.Equalf(t, tt.out, value, "%d.", i)
}
}
ensureConnValid(t, pgConn)
}
func TestConnCancelRequest(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
if pgConn.ParameterStatus("crdb_version") != "" {
t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)")
}
multiResult := pgConn.Exec(context.Background(), "select 'Hello, world', pg_sleep(2)")
go func() {
// The query is actually sent when multiResult.NextResult() is called. So wait to ensure it is sent.
// Once Flush is available this could use that instead.
time.Sleep(500 * time.Millisecond)
err := pgConn.CancelRequest(context.Background())
require.NoError(t, err)
}()
for multiResult.NextResult() {
}
err = multiResult.Close()
require.IsType(t, &pgconn.PgError{}, err)
require.Equal(t, "57014", err.(*pgconn.PgError).Code)
ensureConnValid(t, pgConn)
}
// https://github.com/jackc/pgx/issues/659
func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
pid := pgConn.PID()
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(30)")
for multiResult.NextResult() {
}
err = multiResult.Close()
assert.True(t, pgconn.Timeout(err))
assert.True(t, pgConn.IsClosed())
select {
case <-pgConn.CleanupDone():
case <-time.After(5 * time.Second):
t.Fatal("Connection cleanup exceeded maximum time")
}
otherConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, otherConn)
ctx, cancel = context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
for {
result := otherConn.ExecParams(ctx,
`select 1 from pg_stat_activity where pid=$1`,
[][]byte{[]byte(strconv.FormatInt(int64(pid), 10))},
nil,
nil,
nil,
).Read()
require.NoError(t, result.Err)
if len(result.Rows) == 0 {
break
}
}
}
func TestHijackAndConstruct(t *testing.T) {
t.Parallel()
origConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
hc, err := origConn.Hijack()
require.NoError(t, err)
_, err = origConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
require.Error(t, err)
newConn, err := pgconn.Construct(hc)
require.NoError(t, err)
defer closeConn(t, newConn)
results, err := newConn.Exec(context.Background(), "select 'Hello, world'").ReadAll()
assert.NoError(t, err)
assert.Len(t, results, 1)
assert.Nil(t, results[0].Err)
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
assert.Len(t, results[0].Rows, 1)
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
ensureConnValid(t, newConn)
}
func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
ctx, _ := context.WithCancel(context.Background())
pgConn.Exec(ctx, "select n from generate_series(1,10) n")
closeCtx, _ := context.WithCancel(context.Background())
pgConn.Close(closeCtx)
select {
case <-pgConn.CleanupDone():
case <-time.After(5 * time.Second):
t.Fatal("Connection cleanup exceeded maximum time")
}
}
// https://github.com/jackc/pgx/issues/800
func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) {
t.Parallel()
steps := pgmock.AcceptUnauthenticatedConnRequestSteps()
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Bind{}))
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Execute{}))
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Sync{}))
steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
{Name: []byte("mock")},
}}))
steps = append(steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 0")}))
steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"}))
script := &pgmock.Script{Steps: steps}
ln, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
defer ln.Close()
serverErrChan := make(chan error, 1)
go func() {
defer close(serverErrChan)
conn, err := ln.Accept()
if err != nil {
serverErrChan <- err
return
}
defer conn.Close()
err = conn.SetDeadline(time.Now().Add(5 * time.Second))
if err != nil {
serverErrChan <- err
return
}
err = script.Run(pgproto3.NewBackend(conn, conn))
if err != nil {
serverErrChan <- err
return
}
}()
parts := strings.Split(ln.Addr().String(), ":")
host := parts[0]
port := parts[1]
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
conn, err := pgconn.Connect(ctx, connStr)
require.NoError(t, err)
rr := conn.ExecParams(ctx, "mocked...", nil, nil, nil, nil)
for rr.NextRow() {
}
_, err = rr.Close()
require.Error(t, err)
}
// https://github.com/jackc/pgconn/issues/27
func TestConnLargeResponseWhileWritingDoesNotDeadlock(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
_, err = pgConn.Exec(context.Background(), "set client_min_messages = debug5").ReadAll()
require.NoError(t, err)
// The actual contents of this test aren't important. What's important is a large amount of data to be written and
// because of client_min_messages = debug5 the server will return a large amount of data.
paramCount := math.MaxUint16
params := make([]string, 0, paramCount)
args := make([][]byte, 0, paramCount)
for i := 0; i < paramCount; i++ {
params = append(params, fmt.Sprintf("($%d::text)", i+1))
args = append(args, []byte(strconv.Itoa(i)))
}
sql := "values" + strings.Join(params, ", ")
result := pgConn.ExecParams(context.Background(), sql, args, nil, nil, nil).Read()
require.NoError(t, result.Err)
require.Len(t, result.Rows, paramCount)
ensureConnValid(t, pgConn)
}
func TestConnCheckConn(t *testing.T) {
t.Parallel()
// Intentionally using TCP connection for more predictable close behavior. (Not sure if Unix domain sockets would behave subtlely different.)
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
if connString == "" {
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
}
c1, err := pgconn.Connect(context.Background(), connString)
require.NoError(t, err)
defer c1.Close(context.Background())
if c1.ParameterStatus("crdb_version") != "" {
t.Skip("Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
}
err = c1.CheckConn()
require.NoError(t, err)
c2, err := pgconn.Connect(context.Background(), connString)
require.NoError(t, err)
defer c2.Close(context.Background())
_, err = c2.Exec(context.Background(), fmt.Sprintf("select pg_terminate_backend(%d)", c1.PID())).ReadAll()
require.NoError(t, err)
// Give a little time for the signal to actually kill the backend.
time.Sleep(500 * time.Millisecond)
err = c1.CheckConn()
require.Error(t, err)
}
func TestPipelinePrepare(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
result := pgConn.ExecParams(context.Background(), `create temporary table t (id text primary key)`, nil, nil, nil, nil).Read()
require.NoError(t, result.Err)
pipeline := pgConn.StartPipeline(context.Background())
pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil)
pipeline.SendPrepare("selectText", "select $1::text as b", nil)
pipeline.SendPrepare("selectNoParams", "select 42 as c", nil)
pipeline.SendPrepare("insertNoResults", "insert into t (id) values ($1)", nil)
pipeline.SendPrepare("insertNoParamsOrResults", "insert into t (id) values ('foo')", nil)
err = pipeline.Sync()
require.NoError(t, err)
results, err := pipeline.GetResults()
require.NoError(t, err)
sd, ok := results.(*pgconn.StatementDescription)
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
require.Len(t, sd.Fields, 1)
require.Equal(t, string(sd.Fields[0].Name), "a")
require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
results, err = pipeline.GetResults()
require.NoError(t, err)
sd, ok = results.(*pgconn.StatementDescription)
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
require.Len(t, sd.Fields, 1)
require.Equal(t, string(sd.Fields[0].Name), "b")
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
results, err = pipeline.GetResults()
require.NoError(t, err)
sd, ok = results.(*pgconn.StatementDescription)
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
require.Len(t, sd.Fields, 1)
require.Equal(t, string(sd.Fields[0].Name), "c")
require.Equal(t, []uint32{}, sd.ParamOIDs)
results, err = pipeline.GetResults()
require.NoError(t, err)
sd, ok = results.(*pgconn.StatementDescription)
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
require.Len(t, sd.Fields, 0)
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
results, err = pipeline.GetResults()
require.NoError(t, err)
sd, ok = results.(*pgconn.StatementDescription)
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
require.Len(t, sd.Fields, 0)
require.Len(t, sd.ParamOIDs, 0)
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
err = pipeline.Close()
require.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestPipelinePrepareError(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
pipeline := pgConn.StartPipeline(context.Background())
pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil)
pipeline.SendPrepare("selectError", "bad", nil)
pipeline.SendPrepare("selectText", "select $1::text as b", nil)
err = pipeline.Sync()
require.NoError(t, err)
results, err := pipeline.GetResults()
require.NoError(t, err)
sd, ok := results.(*pgconn.StatementDescription)
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
require.Len(t, sd.Fields, 1)
require.Equal(t, string(sd.Fields[0].Name), "a")
require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
results, err = pipeline.GetResults()
var pgErr *pgconn.PgError
require.ErrorAs(t, err, &pgErr)
require.Nil(t, results)
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
err = pipeline.Close()
require.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestPipelinePrepareAndDeallocate(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
pipeline := pgConn.StartPipeline(context.Background())
pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil)
pipeline.SendDeallocate("selectInt")
err = pipeline.Sync()
require.NoError(t, err)
results, err := pipeline.GetResults()
require.NoError(t, err)
sd, ok := results.(*pgconn.StatementDescription)
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
require.Len(t, sd.Fields, 1)
require.Equal(t, string(sd.Fields[0].Name), "a")
require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.CloseComplete)
require.Truef(t, ok, "expected CloseComplete, got: %#v", results)
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
err = pipeline.Close()
require.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestPipelineQuery(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
pipeline := pgConn.StartPipeline(context.Background())
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
err = pipeline.Sync()
require.NoError(t, err)
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
err = pipeline.Sync()
require.NoError(t, err)
results, err := pipeline.GetResults()
require.NoError(t, err)
rr, ok := results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult := rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "1", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "2", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "3", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "4", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "5", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
err = pipeline.Close()
require.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestPipelinePrepareQuery(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
pipeline := pgConn.StartPipeline(context.Background())
pipeline.SendPrepare("ps", "select $1::text as msg", nil)
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil)
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil)
err = pipeline.Sync()
require.NoError(t, err)
results, err := pipeline.GetResults()
require.NoError(t, err)
sd, ok := results.(*pgconn.StatementDescription)
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
require.Len(t, sd.Fields, 1)
require.Equal(t, string(sd.Fields[0].Name), "msg")
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok := results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult := rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "hello", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "goodbye", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
results, err = pipeline.GetResults()
require.NoError(t, err)
require.Nil(t, results)
err = pipeline.Close()
require.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestPipelineQueryErrorBetweenSyncs(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
pipeline := pgConn.StartPipeline(context.Background())
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
err = pipeline.Sync()
require.NoError(t, err)
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil)
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
err = pipeline.Sync()
require.NoError(t, err)
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
pipeline.SendQueryParams(`select 6`, nil, nil, nil, nil)
err = pipeline.Sync()
require.NoError(t, err)
results, err := pipeline.GetResults()
require.NoError(t, err)
rr, ok := results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult := rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "1", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "2", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "3", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
var pgErr *pgconn.PgError
require.ErrorAs(t, readResult.Err, &pgErr)
require.Equal(t, "22012", pgErr.Code)
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "5", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
rr, ok = results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult = rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "6", string(readResult.Rows[0][0]))
results, err = pipeline.GetResults()
require.NoError(t, err)
_, ok = results.(*pgconn.PipelineSync)
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
err = pipeline.Close()
require.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestPipelineCloseReadsUnreadResults(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
pipeline := pgConn.StartPipeline(context.Background())
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
err = pipeline.Sync()
require.NoError(t, err)
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
err = pipeline.Sync()
require.NoError(t, err)
results, err := pipeline.GetResults()
require.NoError(t, err)
rr, ok := results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult := rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "1", string(readResult.Rows[0][0]))
err = pipeline.Close()
require.NoError(t, err)
ensureConnValid(t, pgConn)
}
func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) {
t.Parallel()
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
require.NoError(t, err)
defer closeConn(t, pgConn)
pipeline := pgConn.StartPipeline(context.Background())
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
err = pipeline.Sync()
require.NoError(t, err)
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
results, err := pipeline.GetResults()
require.NoError(t, err)
rr, ok := results.(*pgconn.ResultReader)
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
readResult := rr.Read()
require.NoError(t, readResult.Err)
require.Len(t, readResult.Rows, 1)
require.Len(t, readResult.Rows[0], 1)
require.Equal(t, "1", string(readResult.Rows[0][0]))
err = pipeline.Close()
require.EqualError(t, err, "pipeline has unsynced requests")
}
func Example() {
pgConn, err := pgconn.Connect(context.Background(), os.Getenv("PGX_TEST_DATABASE"))
if err != nil {
log.Fatalln(err)
}
defer pgConn.Close(context.Background())
result := pgConn.ExecParams(context.Background(), "select generate_series(1,3)", nil, nil, nil, nil).Read()
if result.Err != nil {
log.Fatalln(result.Err)
}
for _, row := range result.Rows {
fmt.Println(string(row[0]))
}
fmt.Println(result.CommandTag)
// Output:
// 1
// 2
// 3
// SELECT 3
}
func GetSSLPassword(ctx context.Context) string {
connString := os.Getenv("PGX_SSL_PASSWORD")
return connString
}
var rsaCertPEM = `-----BEGIN CERTIFICATE-----
MIIDCTCCAfGgAwIBAgIUQDlN1g1bzxIJ8KWkayNcQY5gzMEwDQYJKoZIhvcNAQEL
BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMDgxNTIxNDgyNloXDTIzMDgx
NTIxNDgyNlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
AAOCAQ8AMIIBCgKCAQEA0vOppiT8zE+076acRORzD5JVbRYKMK3XlWLVrHua4+ct
Rm54WyP+3XsYU4JGGGKgb8E+u2UosGJYcSM+b+U1/5XPTcpuumS+pCiD9WP++A39
tsukYwR7m65cgpiI4dlLEZI3EWpAW+Bb3230KiYW4sAmQ0Ih4PrN+oPvzcs86F4d
9Y03CqVUxRKLBLaClZQAg8qz2Pawwj1FKKjDX7u2fRVR0wgOugpCMOBJMcCgz9pp
0HSa4x3KZDHEZY7Pah5XwWrCfAEfRWsSTGcNaoN8gSxGFM1JOEJa8SAuPGjFcYIv
MmVWdw0FXCgYlSDL02fzLE0uyvXBDibzSqOk770JhQIDAQABo1MwUTAdBgNVHQ4E
FgQUiJ8JLENJ+2k1Xl4o6y2Lc/qHTh0wHwYDVR0jBBgwFoAUiJ8JLENJ+2k1Xl4o
6y2Lc/qHTh0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAwjn2
gnNAhFvh58VqLIjU6ftvn6rhz5B9dg2+XyY8sskLhhkO1nL9339BVZsRt+eI3a7I
81GNIm9qHVM3MUAcQv3SZy+0UPVUT8DNH2LwHT3CHnYTBP8U+8n8TDNGSTMUhIBB
Rx+6KwODpwLdI79VGT3IkbU9bZwuepB9I9nM5t/tt5kS4gHmJFlO0aLJFCTO4Scf
hp/WLPv4XQUH+I3cPfaJRxz2j0Kc8iOzMhFmvl1XOGByjX6X33LnOzY/LVeTSGyS
VgC32BGtnMwuy5XZYgFAeUx9HKy4tG4OH2Ux6uPF/WAhsug6PXSjV7BK6wYT5i27
MlascjupnaptKX/wMA==
-----END CERTIFICATE-----
`
var rsaKeyPEM = testingKey(`-----BEGIN TESTING KEY-----
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDS86mmJPzMT7Tv
ppxE5HMPklVtFgowrdeVYtWse5rj5y1GbnhbI/7dexhTgkYYYqBvwT67ZSiwYlhx
Iz5v5TX/lc9Nym66ZL6kKIP1Y/74Df22y6RjBHubrlyCmIjh2UsRkjcRakBb4Fvf
bfQqJhbiwCZDQiHg+s36g+/NyzzoXh31jTcKpVTFEosEtoKVlACDyrPY9rDCPUUo
qMNfu7Z9FVHTCA66CkIw4EkxwKDP2mnQdJrjHcpkMcRljs9qHlfBasJ8AR9FaxJM
Zw1qg3yBLEYUzUk4QlrxIC48aMVxgi8yZVZ3DQVcKBiVIMvTZ/MsTS7K9cEOJvNK
o6TvvQmFAgMBAAECggEAKzTK54Ol33bn2TnnwdiElIjlRE2CUswYXrl6iDRc2hbs
WAOiVRB/T/+5UMla7/2rXJhY7+rdNZs/ABU24ZYxxCJ77jPrD/Q4c8j0lhsgCtBa
ycjV543wf0dsHTd+ubtWu8eVzdRUUD0YtB+CJevdPh4a+CWgaMMV0xyYzi61T+Yv
Z7Uc3awIAiT4Kw9JRmJiTnyMJg5vZqW3BBAX4ZIvS/54ipwEU+9sWLcuH7WmCR0B
QCTqS6hfJDLm//dGC89Iyno57zfYuiT3PYCWH5crr/DH3LqnwlNaOGSBkhkXuIL+
QvOaUMe2i0pjqxDrkBx05V554vyy9jEvK7i330HL4QKBgQDUJmouEr0+o7EMBApC
CPPu58K04qY5t9aGciG/pOurN42PF99yNZ1CnynH6DbcnzSl8rjc6Y65tzTlWods
bjwVfcmcokG7sPcivJvVjrjKpSQhL8xdZwSAjcqjN4yoJ/+ghm9w+SRmZr6oCQZ3
1jREfJKT+PGiWTEjYcExPWUD2QKBgQD+jdgq4c3tFavU8Hjnlf75xbStr5qu+fp2
SGLRRbX+msQwVbl2ZM9AJLoX9MTCl7D9zaI3ONhheMmfJ77lDTa3VMFtr3NevGA6
MxbiCEfRtQpNkJnsqCixLckx3bskj5+IF9BWzw7y7nOzdhoWVFv/+TltTm3RB51G
McdlmmVjjQKBgQDSFAw2/YV6vtu2O1XxGC591/Bd8MaMBziev+wde3GHhaZfGVPC
I8dLTpMwCwowpFKdNeLLl1gnHX161I+f1vUWjw4TVjVjaBUBx+VEr2Tb/nXtiwiD
QV0a883CnGJjreAblKRMKdpasMmBWhaWmn39h6Iad3zHuCzJjaaiXNpn2QKBgQCf
k1Q8LanmQnuh1c41f7aD5gjKCRezMUpt9BrejhD1NxheJJ9LNQ8nat6uPedLBcUS
lmJms+AR2qKqf0QQWyQ98YgAtshgTz8TvQtPT1mWgSOgVFHqJdC8obNK63FyDgc4
TZVxlgQNDqbBjfv0m5XA9f+mIlB9hYR2iKYzb4K30QKBgQC+LEJYZh00zsXttGHr
5wU1RzbgDIEsNuu+nZ4MxsaCik8ILNRHNXdeQbnADKuo6ATfhdmDIQMVZLG8Mivi
UwnwLd1GhizvqvLHa3ULnFphRyMGFxaLGV48axTT2ADoMX67ILrIY/yjycLqRZ3T
z3w+CgS20UrbLIR1YXfqUXge1g==
-----END TESTING KEY-----
`)
func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
func TestSNISupport(t *testing.T) {
t.Parallel()
tests := []struct {
name string
sni_param string
sni_set bool
}{
{
name: "SNI is passed by default",
sni_param: "",
sni_set: true,
},
{
name: "SNI is passed when asked for",
sni_param: "sslsni=1",
sni_set: true,
},
{
name: "SNI is not passed when disabled",
sni_param: "sslsni=0",
sni_set: false,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ln, err := net.Listen("tcp", "127.0.0.1:")
require.NoError(t, err)
defer ln.Close()
serverErrChan := make(chan error, 1)
serverSNINameChan := make(chan string, 1)
defer close(serverErrChan)
defer close(serverSNINameChan)
go func() {
var sniHost string
conn, err := ln.Accept()
if err != nil {
serverErrChan <- err
return
}
defer conn.Close()
err = conn.SetDeadline(time.Now().Add(5 * time.Second))
if err != nil {
serverErrChan <- err
return
}
backend := pgproto3.NewBackend(conn, conn)
startupMessage, err := backend.ReceiveStartupMessage()
if err != nil {
serverErrChan <- err
return
}
switch startupMessage.(type) {
case *pgproto3.SSLRequest:
_, err = conn.Write([]byte("S"))
if err != nil {
serverErrChan <- err
return
}
default:
serverErrChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
return
}
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
if err != nil {
serverErrChan <- err
return
}
srv := tls.Server(conn, &tls.Config{
Certificates: []tls.Certificate{cert},
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
sniHost = argHello.ServerName
return nil, nil
},
})
defer srv.Close()
if err := srv.Handshake(); err != nil {
serverErrChan <- fmt.Errorf("handshake: %v", err)
return
}
srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil))
srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))
srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))
serverSNINameChan <- sniHost
}()
port := strings.Split(ln.Addr().String(), ":")[1]
connStr := fmt.Sprintf("sslmode=require host=localhost port=%s %s", port, tt.sni_param)
_, err = pgconn.Connect(context.Background(), connStr)
select {
case sniHost := <-serverSNINameChan:
if tt.sni_set {
require.Equal(t, sniHost, "localhost")
} else {
require.Equal(t, sniHost, "")
}
case err = <-serverErrChan:
t.Fatalf("server failed with error: %+v", err)
case <-time.After(time.Millisecond * 100):
t.Fatal("exceeded connection timeout without erroring out")
}
})
}
}