mirror of
https://github.com/jackc/pgx.git
synced 2025-04-28 05:37:41 +00:00
Arguably, PGError might have been better. But since the precedent is long since established it is better to be consistent.
3392 lines
95 KiB
Go
3392 lines
95 KiB
Go
package pgconn_test
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"math"
|
|
"net"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/internal/pgio"
|
|
"github.com/jackc/pgx/v5/internal/pgmock"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgproto3"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
)
|
|
|
|
const pgbouncerConnStringEnvVar = "PGX_TEST_PGBOUNCER_CONN_STRING"
|
|
|
|
func TestConnect(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
env string
|
|
}{
|
|
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
|
|
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
|
|
{"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
|
|
{"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
|
|
{"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv(tt.env)
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", tt.env)
|
|
}
|
|
|
|
conn, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
|
|
closeConn(t, conn)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConnectWithOptions(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
env string
|
|
}{
|
|
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
|
|
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
|
|
{"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
|
|
{"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
|
|
{"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv(tt.env)
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", tt.env)
|
|
}
|
|
var sslOptions pgconn.ParseConfigOptions
|
|
sslOptions.GetSSLPassword = GetSSLPassword
|
|
conn, err := pgconn.ConnectWithOptions(ctx, connString, sslOptions)
|
|
require.NoError(t, err)
|
|
|
|
closeConn(t, conn)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure
|
|
// connection.
|
|
func TestConnectTLS(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv("PGX_TEST_TLS_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
|
|
}
|
|
|
|
conn, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
|
|
result := conn.ExecParams(ctx, `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
require.Len(t, result.Rows, 1)
|
|
require.Len(t, result.Rows[0], 1)
|
|
require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection")
|
|
|
|
closeConn(t, conn)
|
|
}
|
|
|
|
func TestConnectTLSPasswordProtectedClientCertWithSSLPassword(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv("PGX_TEST_TLS_CLIENT_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CLIENT_CONN_STRING")
|
|
}
|
|
if os.Getenv("PGX_SSL_PASSWORD") == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_SSL_PASSWORD")
|
|
}
|
|
|
|
connString += " sslpassword=" + os.Getenv("PGX_SSL_PASSWORD")
|
|
|
|
conn, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
|
|
result := conn.ExecParams(ctx, `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
require.Len(t, result.Rows, 1)
|
|
require.Len(t, result.Rows[0], 1)
|
|
require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection")
|
|
|
|
closeConn(t, conn)
|
|
}
|
|
|
|
func TestConnectTLSPasswordProtectedClientCertWithGetSSLPasswordConfigOption(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv("PGX_TEST_TLS_CLIENT_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CLIENT_CONN_STRING")
|
|
}
|
|
if os.Getenv("PGX_SSL_PASSWORD") == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_SSL_PASSWORD")
|
|
}
|
|
|
|
var sslOptions pgconn.ParseConfigOptions
|
|
sslOptions.GetSSLPassword = GetSSLPassword
|
|
config, err := pgconn.ParseConfigWithOptions(connString, sslOptions)
|
|
require.Nil(t, err)
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
|
|
result := conn.ExecParams(ctx, `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
require.Len(t, result.Rows, 1)
|
|
require.Len(t, result.Rows[0], 1)
|
|
require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection")
|
|
|
|
closeConn(t, conn)
|
|
}
|
|
|
|
type pgmockWaitStep time.Duration
|
|
|
|
func (s pgmockWaitStep) Step(*pgproto3.Backend) error {
|
|
time.Sleep(time.Duration(s))
|
|
return nil
|
|
}
|
|
|
|
func TestConnectTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
tests := []struct {
|
|
name string
|
|
connect func(connStr string) error
|
|
}{
|
|
{
|
|
name: "via context that times out",
|
|
connect: func(connStr string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
|
|
defer cancel()
|
|
_, err := pgconn.Connect(ctx, connStr)
|
|
return err
|
|
},
|
|
},
|
|
{
|
|
name: "via config ConnectTimeout",
|
|
connect: func(connStr string) error {
|
|
conf, err := pgconn.ParseConfig(connStr)
|
|
require.NoError(t, err)
|
|
conf.ConnectTimeout = time.Microsecond * 50
|
|
_, err = pgconn.ConnectConfig(context.Background(), conf)
|
|
return err
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
script := &pgmock.Script{
|
|
Steps: []pgmock.Step{
|
|
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
|
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
|
|
pgmockWaitStep(time.Millisecond * 500),
|
|
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
|
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
|
},
|
|
}
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
serverErrChan := make(chan error, 1)
|
|
go func() {
|
|
defer close(serverErrChan)
|
|
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
|
|
err = script.Run(pgproto3.NewBackend(conn, conn))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
}()
|
|
|
|
host, port, _ := strings.Cut(ln.Addr().String(), ":")
|
|
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
|
tooLate := time.Now().Add(time.Millisecond * 500)
|
|
|
|
err = tt.connect(connStr)
|
|
require.True(t, pgconn.Timeout(err), err)
|
|
require.True(t, time.Now().Before(tooLate))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConnectTimeoutStuckOnTLSHandshake(t *testing.T) {
|
|
t.Parallel()
|
|
tests := []struct {
|
|
name string
|
|
connect func(connStr string) error
|
|
}{
|
|
{
|
|
name: "via context that times out",
|
|
connect: func(connStr string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10)
|
|
defer cancel()
|
|
_, err := pgconn.Connect(ctx, connStr)
|
|
return err
|
|
},
|
|
},
|
|
{
|
|
name: "via config ConnectTimeout",
|
|
connect: func(connStr string) error {
|
|
conf, err := pgconn.ParseConfig(connStr)
|
|
require.NoError(t, err)
|
|
conf.ConnectTimeout = time.Millisecond * 10
|
|
_, err = pgconn.ConnectConfig(context.Background(), conf)
|
|
return err
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
serverErrChan := make(chan error, 1)
|
|
go func() {
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
var buf []byte
|
|
_, err = conn.Read(buf)
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
|
|
// Sleeping to hang the TLS handshake.
|
|
time.Sleep(time.Minute)
|
|
}()
|
|
|
|
host, port, _ := strings.Cut(ln.Addr().String(), ":")
|
|
connStr := fmt.Sprintf("host=%s port=%s", host, port)
|
|
|
|
errChan := make(chan error)
|
|
go func() {
|
|
err := tt.connect(connStr)
|
|
errChan <- err
|
|
}()
|
|
|
|
select {
|
|
case err = <-errChan:
|
|
require.True(t, pgconn.Timeout(err), err)
|
|
case err = <-serverErrChan:
|
|
t.Fatalf("server failed with error: %s", err)
|
|
case <-time.After(time.Millisecond * 500):
|
|
t.Fatal("exceeded connection timeout without erroring out")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConnectInvalidUser(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
|
|
}
|
|
|
|
config, err := pgconn.ParseConfig(connString)
|
|
require.NoError(t, err)
|
|
|
|
config.User = "pgxinvalidusertest"
|
|
|
|
_, err = pgconn.ConnectConfig(ctx, config)
|
|
require.Error(t, err)
|
|
pgErr, ok := errors.Unwrap(err).(*pgconn.PgError)
|
|
if !ok {
|
|
t.Fatalf("Expected to receive a wrapped PgError, instead received: %v", err)
|
|
}
|
|
if pgErr.Code != "28000" && pgErr.Code != "28P01" {
|
|
t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr)
|
|
}
|
|
}
|
|
|
|
func TestConnectWithConnectionRefused(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
// Presumably nothing is listening on 127.0.0.1:1
|
|
conn, err := pgconn.Connect(ctx, "host=127.0.0.1 port=1")
|
|
if err == nil {
|
|
conn.Close(ctx)
|
|
t.Fatal("Expected error establishing connection to bad port")
|
|
}
|
|
}
|
|
|
|
func TestConnectCustomDialer(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
dialed := false
|
|
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
dialed = true
|
|
return net.Dial(network, address)
|
|
}
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
require.True(t, dialed)
|
|
closeConn(t, conn)
|
|
}
|
|
|
|
func TestConnectCustomLookup(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
|
|
}
|
|
|
|
config, err := pgconn.ParseConfig(connString)
|
|
require.NoError(t, err)
|
|
|
|
looked := false
|
|
config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) {
|
|
looked = true
|
|
return net.LookupHost(host)
|
|
}
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
require.True(t, looked)
|
|
closeConn(t, conn)
|
|
}
|
|
|
|
func TestConnectCustomLookupWithPort(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
|
|
}
|
|
|
|
config, err := pgconn.ParseConfig(connString)
|
|
require.NoError(t, err)
|
|
|
|
origPort := config.Port
|
|
// Change the config an invalid port so it will fail if used
|
|
config.Port = 0
|
|
|
|
looked := false
|
|
config.LookupFunc = func(ctx context.Context, host string) ([]string, error) {
|
|
looked = true
|
|
addrs, err := net.LookupHost(host)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range addrs {
|
|
addrs[i] = net.JoinHostPort(addrs[i], strconv.FormatUint(uint64(origPort), 10))
|
|
}
|
|
return addrs, nil
|
|
}
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
require.True(t, looked)
|
|
closeConn(t, conn)
|
|
}
|
|
|
|
func TestConnectWithRuntimeParams(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
config.RuntimeParams = map[string]string{
|
|
"application_name": "pgxtest",
|
|
"search_path": "myschema",
|
|
}
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, conn)
|
|
|
|
result := conn.ExecParams(ctx, "show application_name", nil, nil, nil, nil).Read()
|
|
require.Nil(t, result.Err)
|
|
assert.Equal(t, 1, len(result.Rows))
|
|
assert.Equal(t, "pgxtest", string(result.Rows[0][0]))
|
|
|
|
result = conn.ExecParams(ctx, "show search_path", nil, nil, nil, nil).Read()
|
|
require.Nil(t, result.Err)
|
|
assert.Equal(t, 1, len(result.Rows))
|
|
assert.Equal(t, "myschema", string(result.Rows[0][0]))
|
|
}
|
|
|
|
func TestConnectWithFallback(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
// Prepend current primary config to fallbacks
|
|
config.Fallbacks = append([]*pgconn.FallbackConfig{
|
|
{
|
|
Host: config.Host,
|
|
Port: config.Port,
|
|
TLSConfig: config.TLSConfig,
|
|
},
|
|
}, config.Fallbacks...)
|
|
|
|
// Make primary config bad
|
|
config.Host = "localhost"
|
|
config.Port = 1 // presumably nothing listening here
|
|
|
|
// Prepend bad first fallback
|
|
config.Fallbacks = append([]*pgconn.FallbackConfig{
|
|
{
|
|
Host: "localhost",
|
|
Port: 1,
|
|
TLSConfig: config.TLSConfig,
|
|
},
|
|
}, config.Fallbacks...)
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
closeConn(t, conn)
|
|
}
|
|
|
|
func TestConnectWithValidateConnect(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
dialCount := 0
|
|
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
dialCount++
|
|
return net.Dial(network, address)
|
|
}
|
|
|
|
acceptConnCount := 0
|
|
config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
|
|
acceptConnCount++
|
|
if acceptConnCount < 2 {
|
|
return errors.New("reject first conn")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Append current primary config to fallbacks
|
|
config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{
|
|
Host: config.Host,
|
|
Port: config.Port,
|
|
TLSConfig: config.TLSConfig,
|
|
})
|
|
|
|
// Repeat fallbacks
|
|
config.Fallbacks = append(config.Fallbacks, config.Fallbacks...)
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
closeConn(t, conn)
|
|
|
|
assert.True(t, dialCount > 1)
|
|
assert.True(t, acceptConnCount > 1)
|
|
}
|
|
|
|
func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite
|
|
config.RuntimeParams["default_transaction_read_only"] = "on"
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
if !assert.NotNil(t, err) {
|
|
conn.Close(ctx)
|
|
}
|
|
}
|
|
|
|
func TestConnectWithAfterConnect(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
|
|
_, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll()
|
|
return err
|
|
}
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
|
|
results, err := conn.Exec(ctx, "show search_path;").ReadAll()
|
|
require.NoError(t, err)
|
|
defer closeConn(t, conn)
|
|
|
|
assert.Equal(t, []byte("foobar"), results[0].Rows[0][0])
|
|
}
|
|
|
|
func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config := &pgconn.Config{}
|
|
|
|
require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(ctx, config) })
|
|
}
|
|
|
|
func TestConnPrepareSyntaxError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
psd, err := pgConn.Prepare(ctx, "ps1", "SYNTAX ERROR", nil)
|
|
require.Nil(t, psd)
|
|
require.NotNil(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnPrepareContextPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
cancel()
|
|
|
|
psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil)
|
|
assert.Nil(t, psd)
|
|
assert.Error(t, err)
|
|
assert.True(t, errors.Is(err, context.Canceled))
|
|
assert.True(t, pgconn.SafeToRetry(err))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnDeallocate(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Prepare(ctx, "ps1", "select 1", nil)
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
|
|
require.NoError(t, err)
|
|
|
|
err = pgConn.Deallocate(ctx, "ps1")
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
|
|
require.Error(t, err)
|
|
var pgErr *pgconn.PgError
|
|
require.ErrorAs(t, err, &pgErr)
|
|
require.Equal(t, "26000", pgErr.Code)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnDeallocateSucceedsInAbortedTransaction(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
err = pgConn.Exec(ctx, "begin").Close()
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.Prepare(ctx, "ps1", "select 1", nil)
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
|
|
require.NoError(t, err)
|
|
|
|
err = pgConn.Exec(ctx, "select 1/0").Close() // break transaction with divide by 0 error
|
|
require.Error(t, err)
|
|
var pgErr *pgconn.PgError
|
|
require.ErrorAs(t, err, &pgErr)
|
|
require.Equal(t, "22012", pgErr.Code)
|
|
|
|
err = pgConn.Deallocate(ctx, "ps1")
|
|
require.NoError(t, err)
|
|
|
|
err = pgConn.Exec(ctx, "rollback").Close()
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
|
|
require.Error(t, err)
|
|
require.ErrorAs(t, err, &pgErr)
|
|
require.Equal(t, "26000", pgErr.Code)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnDeallocateNonExistantStatementSucceeds(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
err = pgConn.Deallocate(ctx, "ps1")
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExec(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
results, err := pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
|
assert.NoError(t, err)
|
|
|
|
assert.Len(t, results, 1)
|
|
assert.Nil(t, results[0].Err)
|
|
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
|
|
assert.Len(t, results[0].Rows, 1)
|
|
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecEmpty(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
multiResult := pgConn.Exec(ctx, ";")
|
|
|
|
resultCount := 0
|
|
for multiResult.NextResult() {
|
|
resultCount++
|
|
multiResult.ResultReader().Close()
|
|
}
|
|
assert.Equal(t, 0, resultCount)
|
|
err = multiResult.Close()
|
|
assert.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecMultipleQueries(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
results, err := pgConn.Exec(ctx, "select 'Hello, world'; select 1").ReadAll()
|
|
assert.NoError(t, err)
|
|
|
|
assert.Len(t, results, 2)
|
|
|
|
assert.Nil(t, results[0].Err)
|
|
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
|
|
assert.Len(t, results[0].Rows, 1)
|
|
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
|
|
|
|
assert.Nil(t, results[1].Err)
|
|
assert.Equal(t, "SELECT 1", results[1].CommandTag.String())
|
|
assert.Len(t, results[1].Rows, 1)
|
|
assert.Equal(t, "1", string(results[1].Rows[0][0]))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
mrr := pgConn.Exec(ctx, "select 'Hello, world' as msg; select 1 as num")
|
|
|
|
require.True(t, mrr.NextResult())
|
|
require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
|
|
assert.Equal(t, "msg", mrr.ResultReader().FieldDescriptions()[0].Name)
|
|
_, err = mrr.ResultReader().Close()
|
|
require.NoError(t, err)
|
|
|
|
require.True(t, mrr.NextResult())
|
|
require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
|
|
assert.Equal(t, "num", mrr.ResultReader().FieldDescriptions()[0].Name)
|
|
_, err = mrr.ResultReader().Close()
|
|
require.NoError(t, err)
|
|
|
|
require.False(t, mrr.NextResult())
|
|
|
|
require.NoError(t, mrr.Close())
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecMultipleQueriesError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
results, err := pgConn.Exec(ctx, "select 1; select 1/0; select 1").ReadAll()
|
|
require.NotNil(t, err)
|
|
if pgErr, ok := err.(*pgconn.PgError); ok {
|
|
assert.Equal(t, "22012", pgErr.Code)
|
|
} else {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
// CockroachDB starts the second query result set and then sends the divide by zero error.
|
|
require.Len(t, results, 2)
|
|
assert.Len(t, results[0].Rows, 1)
|
|
assert.Equal(t, "1", string(results[0].Rows[0][0]))
|
|
assert.Len(t, results[1].Rows, 0)
|
|
} else {
|
|
// PostgreSQL sends the divide by zero and never sends the second query result set.
|
|
require.Len(t, results, 1)
|
|
assert.Len(t, results[0].Rows, 1)
|
|
assert.Equal(t, "1", string(results[0].Rows[0][0]))
|
|
}
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecDeferredError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
|
|
}
|
|
|
|
setupSQL := `create temporary table t (
|
|
id text primary key,
|
|
n int not null,
|
|
unique (n) deferrable initially deferred
|
|
);
|
|
|
|
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
|
|
|
|
_, err = pgConn.Exec(ctx, setupSQL).ReadAll()
|
|
assert.NoError(t, err)
|
|
|
|
_, err = pgConn.Exec(ctx, `update t set n=n+1 where id='b' returning *`).ReadAll()
|
|
require.NotNil(t, err)
|
|
|
|
var pgErr *pgconn.PgError
|
|
require.True(t, errors.As(err, &pgErr))
|
|
require.Equal(t, "23505", pgErr.Code)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecContextCanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
cancel()
|
|
|
|
ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(1)")
|
|
|
|
for multiResult.NextResult() {
|
|
}
|
|
err = multiResult.Close()
|
|
assert.True(t, pgconn.Timeout(err))
|
|
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
|
assert.True(t, pgConn.IsClosed())
|
|
select {
|
|
case <-pgConn.CleanupDone():
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Connection cleanup exceeded maximum time")
|
|
}
|
|
}
|
|
|
|
func TestConnExecContextPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
cancel()
|
|
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
|
assert.Error(t, err)
|
|
assert.True(t, errors.Is(err, context.Canceled))
|
|
assert.True(t, pgconn.SafeToRetry(err))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecParams(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
result := pgConn.ExecParams(ctx, "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
|
|
require.Len(t, result.FieldDescriptions(), 1)
|
|
assert.Equal(t, "msg", result.FieldDescriptions()[0].Name)
|
|
|
|
rowCount := 0
|
|
for result.NextRow() {
|
|
rowCount += 1
|
|
assert.Equal(t, "Hello, world", string(result.Values()[0]))
|
|
}
|
|
assert.Equal(t, 1, rowCount)
|
|
commandTag, err := result.Close()
|
|
assert.Equal(t, "SELECT 1", commandTag.String())
|
|
assert.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecParamsDeferredError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
|
|
}
|
|
|
|
setupSQL := `create temporary table t (
|
|
id text primary key,
|
|
n int not null,
|
|
unique (n) deferrable initially deferred
|
|
);
|
|
|
|
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
|
|
|
|
_, err = pgConn.Exec(ctx, setupSQL).ReadAll()
|
|
assert.NoError(t, err)
|
|
|
|
result := pgConn.ExecParams(ctx, `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read()
|
|
require.NotNil(t, result.Err)
|
|
var pgErr *pgconn.PgError
|
|
require.True(t, errors.As(result.Err, &pgErr))
|
|
require.Equal(t, "23505", pgErr.Code)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecParamsMaxNumberOfParams(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
paramCount := math.MaxUint16
|
|
params := make([]string, 0, paramCount)
|
|
args := make([][]byte, 0, paramCount)
|
|
for i := 0; i < paramCount; i++ {
|
|
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
|
args = append(args, []byte(strconv.Itoa(i)))
|
|
}
|
|
sql := "values" + strings.Join(params, ", ")
|
|
|
|
result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
require.Len(t, result.Rows, paramCount)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecParamsTooManyParams(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
paramCount := math.MaxUint16 + 1
|
|
params := make([]string, 0, paramCount)
|
|
args := make([][]byte, 0, paramCount)
|
|
for i := 0; i < paramCount; i++ {
|
|
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
|
args = append(args, []byte(strconv.Itoa(i)))
|
|
}
|
|
sql := "values" + strings.Join(params, ", ")
|
|
|
|
result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read()
|
|
require.Error(t, result.Err)
|
|
require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error())
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecParamsCanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
|
|
defer cancel()
|
|
result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil)
|
|
rowCount := 0
|
|
for result.NextRow() {
|
|
rowCount += 1
|
|
}
|
|
assert.Equal(t, 0, rowCount)
|
|
commandTag, err := result.Close()
|
|
assert.Equal(t, pgconn.CommandTag{}, commandTag)
|
|
assert.True(t, pgconn.Timeout(err))
|
|
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
|
|
|
assert.True(t, pgConn.IsClosed())
|
|
select {
|
|
case <-pgConn.CleanupDone():
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Connection cleanup exceeded maximum time")
|
|
}
|
|
}
|
|
|
|
func TestConnExecParamsPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
cancel()
|
|
result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
|
|
require.Error(t, result.Err)
|
|
assert.True(t, errors.Is(result.Err, context.Canceled))
|
|
assert.True(t, pgconn.SafeToRetry(result.Err))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecParamsEmptySQL(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read()
|
|
assert.Equal(t, pgconn.CommandTag{}, result.CommandTag)
|
|
assert.Len(t, result.Rows, 0)
|
|
assert.NoError(t, result.Err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/859
|
|
func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
result := pgConn.ExecParams(ctx, "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
|
|
require.Len(t, result.FieldDescriptions(), 1)
|
|
assert.Equal(t, "msg", result.FieldDescriptions()[0].Name)
|
|
|
|
rowCount := 0
|
|
for result.NextRow() {
|
|
rowCount += 1
|
|
assert.Equal(t, "Hello, world", string(result.Values()[0]))
|
|
assert.Equal(t, len(result.Values()[0]), cap(result.Values()[0]))
|
|
}
|
|
assert.Equal(t, 1, rowCount)
|
|
commandTag, err := result.Close()
|
|
assert.Equal(t, "SELECT 1", commandTag.String())
|
|
assert.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecPrepared(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
psd, err := pgConn.Prepare(ctx, "ps1", "select $1::text as msg", nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, psd)
|
|
assert.Len(t, psd.ParamOIDs, 1)
|
|
assert.Len(t, psd.Fields, 1)
|
|
|
|
result := pgConn.ExecPrepared(ctx, "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
|
|
require.Len(t, result.FieldDescriptions(), 1)
|
|
assert.Equal(t, "msg", result.FieldDescriptions()[0].Name)
|
|
|
|
rowCount := 0
|
|
for result.NextRow() {
|
|
rowCount += 1
|
|
assert.Equal(t, "Hello, world", string(result.Values()[0]))
|
|
}
|
|
assert.Equal(t, 1, rowCount)
|
|
commandTag, err := result.Close()
|
|
assert.Equal(t, "SELECT 1", commandTag.String())
|
|
assert.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecPreparedMaxNumberOfParams(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
paramCount := math.MaxUint16
|
|
params := make([]string, 0, paramCount)
|
|
args := make([][]byte, 0, paramCount)
|
|
for i := 0; i < paramCount; i++ {
|
|
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
|
args = append(args, []byte(strconv.Itoa(i)))
|
|
}
|
|
sql := "values" + strings.Join(params, ", ")
|
|
|
|
psd, err := pgConn.Prepare(ctx, "ps1", sql, nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, psd)
|
|
assert.Len(t, psd.ParamOIDs, paramCount)
|
|
assert.Len(t, psd.Fields, 1)
|
|
|
|
result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
require.Len(t, result.Rows, paramCount)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecPreparedTooManyParams(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
paramCount := math.MaxUint16 + 1
|
|
params := make([]string, 0, paramCount)
|
|
args := make([][]byte, 0, paramCount)
|
|
for i := 0; i < paramCount; i++ {
|
|
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
|
args = append(args, []byte(strconv.Itoa(i)))
|
|
}
|
|
sql := "values" + strings.Join(params, ", ")
|
|
|
|
psd, err := pgConn.Prepare(ctx, "ps1", sql, nil)
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
// CockroachDB rejects preparing a statement with more than 65535 parameters.
|
|
require.EqualError(t, err, "ERROR: more than 65535 arguments to prepared statement: 65536 (SQLSTATE 08P01)")
|
|
} else {
|
|
// PostgreSQL accepts preparing a statement with more than 65535 parameters and only fails when executing it through the extended protocol.
|
|
require.NoError(t, err)
|
|
require.NotNil(t, psd)
|
|
assert.Len(t, psd.ParamOIDs, paramCount)
|
|
assert.Len(t, psd.Fields, 1)
|
|
|
|
result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read()
|
|
require.EqualError(t, result.Err, "extended protocol limited to 65535 parameters")
|
|
}
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecPreparedCanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Prepare(ctx, "ps1", "select current_database(), pg_sleep(1)", nil)
|
|
require.NoError(t, err)
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
|
|
defer cancel()
|
|
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil)
|
|
rowCount := 0
|
|
for result.NextRow() {
|
|
rowCount += 1
|
|
}
|
|
assert.Equal(t, 0, rowCount)
|
|
commandTag, err := result.Close()
|
|
assert.Equal(t, pgconn.CommandTag{}, commandTag)
|
|
assert.True(t, pgconn.Timeout(err))
|
|
assert.True(t, pgConn.IsClosed())
|
|
select {
|
|
case <-pgConn.CleanupDone():
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Connection cleanup exceeded maximum time")
|
|
}
|
|
}
|
|
|
|
func TestConnExecPreparedPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Prepare(ctx, "ps1", "select current_database(), pg_sleep(1)", nil)
|
|
require.NoError(t, err)
|
|
|
|
cancel()
|
|
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
|
|
require.Error(t, result.Err)
|
|
assert.True(t, errors.Is(result.Err, context.Canceled))
|
|
assert.True(t, pgconn.SafeToRetry(result.Err))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecPreparedEmptySQL(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Prepare(ctx, "ps1", "", nil)
|
|
require.NoError(t, err)
|
|
|
|
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
|
|
assert.Equal(t, pgconn.CommandTag{}, result.CommandTag)
|
|
assert.Len(t, result.Rows, 0)
|
|
assert.NoError(t, result.Err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecBatch(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil)
|
|
require.NoError(t, err)
|
|
|
|
batch := &pgconn.Batch{}
|
|
|
|
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
|
|
batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil)
|
|
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil)
|
|
results, err := pgConn.ExecBatch(ctx, batch).ReadAll()
|
|
require.NoError(t, err)
|
|
require.Len(t, results, 3)
|
|
|
|
require.Len(t, results[0].Rows, 1)
|
|
require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0]))
|
|
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
|
|
|
|
require.Len(t, results[1].Rows, 1)
|
|
require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0]))
|
|
assert.Equal(t, "SELECT 1", results[1].CommandTag.String())
|
|
|
|
require.Len(t, results[2].Rows, 1)
|
|
require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0]))
|
|
assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
|
|
}
|
|
|
|
func TestConnExecBatchDeferredError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
|
|
}
|
|
|
|
setupSQL := `create temporary table t (
|
|
id text primary key,
|
|
n int not null,
|
|
unique (n) deferrable initially deferred
|
|
);
|
|
|
|
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
|
|
|
|
_, err = pgConn.Exec(ctx, setupSQL).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
batch := &pgconn.Batch{}
|
|
|
|
batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil)
|
|
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
|
|
require.NotNil(t, err)
|
|
var pgErr *pgconn.PgError
|
|
require.True(t, errors.As(err, &pgErr))
|
|
require.Equal(t, "23505", pgErr.Code)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecBatchPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil)
|
|
require.NoError(t, err)
|
|
|
|
batch := &pgconn.Batch{}
|
|
|
|
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
|
|
batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil)
|
|
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil)
|
|
|
|
cancel()
|
|
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
|
|
require.Error(t, err)
|
|
assert.True(t, errors.Is(err, context.Canceled))
|
|
assert.True(t, pgconn.SafeToRetry(err))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
// Without concurrent reading and writing large batches can deadlock.
|
|
//
|
|
// See https://github.com/jackc/pgx/issues/374.
|
|
func TestConnExecBatchHuge(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping test in short mode.")
|
|
}
|
|
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
batch := &pgconn.Batch{}
|
|
|
|
queryCount := 100000
|
|
args := make([]string, queryCount)
|
|
|
|
for i := range args {
|
|
args[i] = strconv.Itoa(i)
|
|
batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil)
|
|
}
|
|
|
|
results, err := pgConn.ExecBatch(ctx, batch).ReadAll()
|
|
require.NoError(t, err)
|
|
require.Len(t, results, queryCount)
|
|
|
|
for i := range args {
|
|
require.Len(t, results[i].Rows, 1)
|
|
require.Equal(t, args[i], string(results[i].Rows[0][0]))
|
|
assert.Equal(t, "SELECT 1", results[i].CommandTag.String())
|
|
}
|
|
}
|
|
|
|
func TestConnExecBatchImplicitTransaction(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/44803)")
|
|
}
|
|
|
|
_, err = pgConn.Exec(ctx, "create temporary table t(id int)").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
batch := &pgconn.Batch{}
|
|
|
|
batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil)
|
|
batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil)
|
|
batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil)
|
|
batch.ExecParams("select 1/0", nil, nil, nil, nil)
|
|
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
|
|
require.Error(t, err)
|
|
|
|
result := pgConn.ExecParams(ctx, "select count(*) from t", nil, nil, nil, nil).Read()
|
|
require.Equal(t, "0", string(result.Rows[0][0]))
|
|
}
|
|
|
|
func TestConnLocking(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
mrr := pgConn.Exec(ctx, "select 'Hello, world'")
|
|
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
|
assert.Error(t, err)
|
|
assert.Equal(t, "conn busy", err.Error())
|
|
assert.True(t, pgconn.SafeToRetry(err))
|
|
|
|
results, err := mrr.ReadAll()
|
|
assert.NoError(t, err)
|
|
assert.Len(t, results, 1)
|
|
assert.Nil(t, results[0].Err)
|
|
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
|
|
assert.Len(t, results[0].Rows, 1)
|
|
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnOnNotice(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
var msg string
|
|
config.OnNotice = func(c *pgconn.PgConn, notice *pgconn.Notice) {
|
|
msg = notice.Message
|
|
}
|
|
config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the message we expect.
|
|
|
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support PL/PGSQL (https://github.com/cockroachdb/cockroach/issues/17511)")
|
|
}
|
|
|
|
multiResult := pgConn.Exec(ctx, `do $$
|
|
begin
|
|
raise notice 'hello, world';
|
|
end$$;`)
|
|
err = multiResult.Close()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "hello, world", msg)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnOnNotification(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
var msg string
|
|
config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
|
|
msg = n.Payload
|
|
}
|
|
|
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
|
|
}
|
|
|
|
_, err = pgConn.Exec(ctx, "listen foo").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
notifier, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, notifier)
|
|
_, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.Exec(ctx, "select 1").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, "bar", msg)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnWaitForNotification(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
var msg string
|
|
config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
|
|
msg = n.Payload
|
|
}
|
|
|
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
|
|
}
|
|
|
|
_, err = pgConn.Exec(ctx, "listen foo").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
notifier, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, notifier)
|
|
_, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
err = pgConn.WaitForNotification(ctx)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, "bar", msg)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnWaitForNotificationPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
cancel()
|
|
err = pgConn.WaitForNotification(ctx)
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnWaitForNotificationTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, 5*time.Millisecond)
|
|
err = pgConn.WaitForNotification(ctx)
|
|
cancel()
|
|
assert.True(t, pgconn.Timeout(err))
|
|
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyToSmall(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does support COPY TO")
|
|
}
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int2,
|
|
b int4,
|
|
c int8,
|
|
d varchar,
|
|
e text,
|
|
f date,
|
|
g json
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.Exec(ctx, `insert into foo values (null, null, null, null, null, null, null)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" +
|
|
"\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
|
|
|
|
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
|
|
|
|
res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout")
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, int64(2), res.RowsAffected())
|
|
assert.Equal(t, inputBytes, outputWriter.Bytes())
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyToLarge(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does support COPY TO")
|
|
}
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int2,
|
|
b int4,
|
|
c int8,
|
|
d varchar,
|
|
e text,
|
|
f date,
|
|
g json,
|
|
h bytea
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
inputBytes := make([]byte, 0)
|
|
|
|
for i := 0; i < 1000; i++ {
|
|
_, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll()
|
|
require.NoError(t, err)
|
|
inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...)
|
|
}
|
|
|
|
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
|
|
|
|
res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout")
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, int64(1000), res.RowsAffected())
|
|
assert.Equal(t, inputBytes, outputWriter.Bytes())
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyToQueryError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
outputWriter := bytes.NewBuffer(make([]byte, 0))
|
|
|
|
res, err := pgConn.CopyTo(ctx, outputWriter, "cropy foo to stdout")
|
|
require.Error(t, err)
|
|
assert.IsType(t, &pgconn.PgError{}, err)
|
|
assert.Equal(t, int64(0), res.RowsAffected())
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyToCanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)")
|
|
}
|
|
|
|
outputWriter := &bytes.Buffer{}
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
|
|
defer cancel()
|
|
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
|
|
assert.Error(t, err)
|
|
assert.Equal(t, pgconn.CommandTag{}, res)
|
|
|
|
assert.True(t, pgConn.IsClosed())
|
|
select {
|
|
case <-pgConn.CleanupDone():
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Connection cleanup exceeded maximum time")
|
|
}
|
|
}
|
|
|
|
func TestConnCopyToPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
outputWriter := &bytes.Buffer{}
|
|
|
|
cancel()
|
|
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
|
|
require.Error(t, err)
|
|
assert.True(t, errors.Is(err, context.Canceled))
|
|
assert.True(t, pgconn.SafeToRetry(err))
|
|
assert.Equal(t, pgconn.CommandTag{}, res)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyFrom(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int4,
|
|
b varchar
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
srcBuf := &bytes.Buffer{}
|
|
|
|
inputRows := [][][]byte{}
|
|
for i := 0; i < 1000; i++ {
|
|
a := strconv.Itoa(i)
|
|
b := "foo " + a + " bar"
|
|
inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
|
|
_, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
copySql := "COPY foo FROM STDIN WITH (FORMAT csv)"
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
copySql = "COPY foo FROM STDIN WITH CSV"
|
|
}
|
|
ct, err := pgConn.CopyFrom(ctx, srcBuf, copySql)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
|
|
|
|
result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
|
|
assert.Equal(t, inputRows, result.Rows)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyFromBinary(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int4,
|
|
b varchar
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
buf := []byte{}
|
|
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
|
|
buf = pgio.AppendInt32(buf, 0)
|
|
buf = pgio.AppendInt32(buf, 0)
|
|
|
|
inputRows := [][][]byte{}
|
|
for i := 0; i < 1000; i++ {
|
|
// Number of elements in the tuple
|
|
buf = pgio.AppendInt16(buf, int16(2))
|
|
a := i
|
|
|
|
// Length of element for column `a int4`
|
|
buf = pgio.AppendInt32(buf, 4)
|
|
buf, err = pgtype.NewMap().Encode(pgtype.Int4OID, pgx.BinaryFormatCode, a, buf)
|
|
require.NoError(t, err)
|
|
|
|
b := "foo " + strconv.Itoa(a) + " bar"
|
|
lenB := int32(len([]byte(b)))
|
|
// Length of element for column `b varchar`
|
|
buf = pgio.AppendInt32(buf, lenB)
|
|
buf, err = pgtype.NewMap().Encode(pgtype.VarcharOID, pgx.BinaryFormatCode, b, buf)
|
|
require.NoError(t, err)
|
|
|
|
inputRows = append(inputRows, [][]byte{[]byte(strconv.Itoa(a)), []byte(b)})
|
|
}
|
|
|
|
srcBuf := &bytes.Buffer{}
|
|
srcBuf.Write(buf)
|
|
ct, err := pgConn.CopyFrom(ctx, srcBuf, "COPY foo (a, b) FROM STDIN BINARY;")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
|
|
|
|
result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
|
|
assert.Equal(t, inputRows, result.Rows)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyFromCanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int4,
|
|
b varchar
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
r, w := io.Pipe()
|
|
go func() {
|
|
for i := 0; i < 1000000; i++ {
|
|
a := strconv.Itoa(i)
|
|
b := "foo " + a + " bar"
|
|
_, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
|
if err != nil {
|
|
return
|
|
}
|
|
time.Sleep(time.Microsecond)
|
|
}
|
|
}()
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
|
|
copySql := "COPY foo FROM STDIN WITH (FORMAT csv)"
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
copySql = "COPY foo FROM STDIN WITH CSV"
|
|
}
|
|
ct, err := pgConn.CopyFrom(ctx, r, copySql)
|
|
cancel()
|
|
assert.Equal(t, int64(0), ct.RowsAffected())
|
|
assert.Error(t, err)
|
|
|
|
assert.True(t, pgConn.IsClosed())
|
|
select {
|
|
case <-pgConn.CleanupDone():
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Connection cleanup exceeded maximum time")
|
|
}
|
|
}
|
|
|
|
func TestConnCopyFromPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int4,
|
|
b varchar
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
r, w := io.Pipe()
|
|
go func() {
|
|
for i := 0; i < 1000000; i++ {
|
|
a := strconv.Itoa(i)
|
|
b := "foo " + a + " bar"
|
|
_, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
|
if err != nil {
|
|
return
|
|
}
|
|
time.Sleep(time.Microsecond)
|
|
}
|
|
}()
|
|
|
|
ctx, cancel = context.WithCancel(ctx)
|
|
cancel()
|
|
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
|
require.Error(t, err)
|
|
assert.True(t, errors.Is(err, context.Canceled))
|
|
assert.True(t, pgconn.SafeToRetry(err))
|
|
assert.Equal(t, pgconn.CommandTag{}, ct)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyFromGzipReader(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)")
|
|
}
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int4,
|
|
b varchar
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
f, err := os.CreateTemp("", "*")
|
|
require.NoError(t, err)
|
|
|
|
gw := gzip.NewWriter(f)
|
|
|
|
inputRows := [][][]byte{}
|
|
for i := 0; i < 1000; i++ {
|
|
a := strconv.Itoa(i)
|
|
b := "foo " + a + " bar"
|
|
inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
|
|
_, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
err = gw.Close()
|
|
require.NoError(t, err)
|
|
|
|
_, err = f.Seek(0, 0)
|
|
require.NoError(t, err)
|
|
|
|
gr, err := gzip.NewReader(f)
|
|
require.NoError(t, err)
|
|
|
|
copySql := "COPY foo FROM STDIN WITH (FORMAT csv)"
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
copySql = "COPY foo FROM STDIN WITH CSV"
|
|
}
|
|
ct, err := pgConn.CopyFrom(ctx, gr, copySql)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
|
|
|
|
err = gr.Close()
|
|
require.NoError(t, err)
|
|
|
|
err = f.Close()
|
|
require.NoError(t, err)
|
|
|
|
err = os.Remove(f.Name())
|
|
require.NoError(t, err)
|
|
|
|
result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
|
|
assert.Equal(t, inputRows, result.Rows)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyFromQuerySyntaxError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int4,
|
|
b varchar
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
srcBuf := &bytes.Buffer{}
|
|
|
|
// Send data even though the COPY FROM command will be rejected with a syntax error. This ensures that this does not
|
|
// break the connection. See https://github.com/jackc/pgconn/pull/127 for context.
|
|
inputRows := [][][]byte{}
|
|
for i := 0; i < 1000; i++ {
|
|
a := strconv.Itoa(i)
|
|
b := "foo " + a + " bar"
|
|
inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
|
|
_, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
res, err := pgConn.CopyFrom(ctx, srcBuf, "cropy foo FROM STDIN WITH (FORMAT csv)")
|
|
require.Error(t, err)
|
|
assert.IsType(t, &pgconn.PgError{}, err)
|
|
assert.Equal(t, int64(0), res.RowsAffected())
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyFromQueryNoTableError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
srcBuf := &bytes.Buffer{}
|
|
|
|
res, err := pgConn.CopyFrom(ctx, srcBuf, "copy foo to stdout")
|
|
require.Error(t, err)
|
|
assert.IsType(t, &pgconn.PgError{}, err)
|
|
assert.Equal(t, int64(0), res.RowsAffected())
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
// https://github.com/jackc/pgconn/issues/21
|
|
func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support triggers (https://github.com/cockroachdb/cockroach/issues/28296)")
|
|
}
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table sentences(
|
|
t text,
|
|
ts tsvector
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.Exec(ctx, `create function pg_temp.sentences_trigger() returns trigger as $$
|
|
begin
|
|
new.ts := to_tsvector(new.t);
|
|
return new;
|
|
end
|
|
$$ language plpgsql;`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.Exec(ctx, `create trigger sentences_update before insert on sentences for each row execute procedure pg_temp.sentences_trigger();`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
longString := make([]byte, 10001)
|
|
for i := range longString {
|
|
longString[i] = 'x'
|
|
}
|
|
|
|
buf := &bytes.Buffer{}
|
|
for i := 0; i < 1000; i++ {
|
|
buf.Write([]byte(fmt.Sprintf("%s\n", string(longString))))
|
|
}
|
|
|
|
_, err = pgConn.CopyFrom(ctx, buf, "COPY sentences(t) FROM STDIN WITH (FORMAT csv)")
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
type delayedReader struct {
|
|
r io.Reader
|
|
}
|
|
|
|
func (d delayedReader) Read(p []byte) (int, error) {
|
|
// W/o sleep test passes, with sleep it fails.
|
|
time.Sleep(time.Millisecond)
|
|
return d.r.Read(p)
|
|
}
|
|
|
|
// https://github.com/jackc/pgconn/issues/128
|
|
func TestConnCopyFromDataWriteAfterErrorAndReturn(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv("PGX_TEST_DATABASE")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_DATABASE")
|
|
}
|
|
|
|
config, err := pgconn.ParseConfig(connString)
|
|
require.NoError(t, err)
|
|
|
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not fully support COPY FROM")
|
|
}
|
|
|
|
setupSQL := `create temporary table t (
|
|
id text primary key,
|
|
n int not null
|
|
);`
|
|
|
|
_, err = pgConn.Exec(ctx, setupSQL).ReadAll()
|
|
assert.NoError(t, err)
|
|
|
|
r1 := delayedReader{r: strings.NewReader(`id 0\n`)}
|
|
// Generate an error with a bogus COPY command
|
|
_, err = pgConn.CopyFrom(ctx, r1, "COPY nosuchtable FROM STDIN ")
|
|
assert.Error(t, err)
|
|
|
|
r2 := delayedReader{r: strings.NewReader(`id 0\n`)}
|
|
_, err = pgConn.CopyFrom(ctx, r2, "COPY t FROM STDIN")
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestConnEscapeString(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
tests := []struct {
|
|
in string
|
|
out string
|
|
}{
|
|
{in: "", out: ""},
|
|
{in: "42", out: "42"},
|
|
{in: "'", out: "''"},
|
|
{in: "hi'there", out: "hi''there"},
|
|
{in: "'hi there'", out: "''hi there''"},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
value, err := pgConn.EscapeString(tt.in)
|
|
if assert.NoErrorf(t, err, "%d.", i) {
|
|
assert.Equalf(t, tt.out, value, "%d.", i)
|
|
}
|
|
}
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCancelRequest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)")
|
|
}
|
|
|
|
multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(25)")
|
|
|
|
errChan := make(chan error)
|
|
go func() {
|
|
// The query is actually sent when multiResult.NextResult() is called. So wait to ensure it is sent.
|
|
// Once Flush is available this could use that instead.
|
|
time.Sleep(1 * time.Second)
|
|
|
|
err := pgConn.CancelRequest(ctx)
|
|
errChan <- err
|
|
}()
|
|
|
|
for multiResult.NextResult() {
|
|
}
|
|
err = multiResult.Close()
|
|
|
|
require.IsType(t, &pgconn.PgError{}, err)
|
|
require.Equal(t, "57014", err.(*pgconn.PgError).Code)
|
|
|
|
err = <-errChan
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/659
|
|
func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("postgres", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testConnContextCanceledCancelsRunningQueryOnServer(t, os.Getenv("PGX_TEST_DATABASE"), "postgres")
|
|
})
|
|
|
|
t.Run("pgbouncer", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
connString := os.Getenv(pgbouncerConnStringEnvVar)
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", pgbouncerConnStringEnvVar)
|
|
}
|
|
|
|
testConnContextCanceledCancelsRunningQueryOnServer(t, connString, "pgbouncer")
|
|
})
|
|
}
|
|
|
|
func testConnContextCanceledCancelsRunningQueryOnServer(t *testing.T, connString, dbType string) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
// Getting the actual PostgreSQL server process ID (PID) from a query executed through pgbouncer is not straightforward
|
|
// because pgbouncer abstracts the underlying database connections, and it doesn't expose the PID of the PostgreSQL
|
|
// server process to clients. However, we can check if the query is running by checking the generated query ID.
|
|
queryID := fmt.Sprintf("%s testConnContextCanceled %d", dbType, time.Now().UnixNano())
|
|
|
|
multiResult := pgConn.Exec(ctx, fmt.Sprintf(`
|
|
-- %v
|
|
select 'Hello, world', pg_sleep(30)
|
|
`, queryID))
|
|
|
|
for multiResult.NextResult() {
|
|
}
|
|
err = multiResult.Close()
|
|
assert.True(t, pgconn.Timeout(err))
|
|
assert.True(t, pgConn.IsClosed())
|
|
select {
|
|
case <-pgConn.CleanupDone():
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Connection cleanup exceeded maximum time")
|
|
}
|
|
|
|
ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
otherConn, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, otherConn)
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, time.Second*5)
|
|
defer cancel()
|
|
|
|
for {
|
|
result := otherConn.ExecParams(ctx,
|
|
`select 1 from pg_stat_activity where query like $1`,
|
|
[][]byte{[]byte("%" + queryID + "%")},
|
|
nil,
|
|
nil,
|
|
nil,
|
|
).Read()
|
|
require.NoError(t, result.Err)
|
|
|
|
if len(result.Rows) == 0 {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestHijackAndConstruct(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
origConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
err = origConn.SyncConn(ctx)
|
|
require.NoError(t, err)
|
|
|
|
hc, err := origConn.Hijack()
|
|
require.NoError(t, err)
|
|
|
|
_, err = origConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
|
require.Error(t, err)
|
|
|
|
newConn, err := pgconn.Construct(hc)
|
|
require.NoError(t, err)
|
|
|
|
defer closeConn(t, newConn)
|
|
|
|
results, err := newConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
|
assert.NoError(t, err)
|
|
|
|
assert.Len(t, results, 1)
|
|
assert.Nil(t, results[0].Err)
|
|
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
|
|
assert.Len(t, results[0].Rows, 1)
|
|
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
|
|
|
|
ensureConnValid(t, newConn)
|
|
}
|
|
|
|
func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
pgConn.Exec(ctx, "select n from generate_series(1,10) n")
|
|
|
|
closeCtx, _ := context.WithCancel(ctx)
|
|
pgConn.Close(closeCtx)
|
|
select {
|
|
case <-pgConn.CleanupDone():
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Connection cleanup exceeded maximum time")
|
|
}
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/800
|
|
func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
steps := pgmock.AcceptUnauthenticatedConnRequestSteps()
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Bind{}))
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Execute{}))
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Sync{}))
|
|
steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
|
|
{Name: []byte("mock")},
|
|
}}))
|
|
steps = append(steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 0")}))
|
|
steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"}))
|
|
|
|
script := &pgmock.Script{Steps: steps}
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
serverErrChan := make(chan error, 1)
|
|
go func() {
|
|
defer close(serverErrChan)
|
|
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
err = conn.SetDeadline(time.Now().Add(5 * time.Second))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
|
|
err = script.Run(pgproto3.NewBackend(conn, conn))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
}()
|
|
|
|
host, port, _ := strings.Cut(ln.Addr().String(), ":")
|
|
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, 5*time.Second)
|
|
defer cancel()
|
|
conn, err := pgconn.Connect(ctx, connStr)
|
|
require.NoError(t, err)
|
|
|
|
rr := conn.ExecParams(ctx, "mocked...", nil, nil, nil, nil)
|
|
|
|
for rr.NextRow() {
|
|
}
|
|
|
|
_, err = rr.Close()
|
|
require.Error(t, err)
|
|
}
|
|
|
|
// https://github.com/jackc/pgconn/issues/27
|
|
func TestConnLargeResponseWhileWritingDoesNotDeadlock(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Exec(ctx, "set client_min_messages = debug5").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
// The actual contents of this test aren't important. What's important is a large amount of data to be written and
|
|
// because of client_min_messages = debug5 the server will return a large amount of data.
|
|
|
|
paramCount := math.MaxUint16
|
|
params := make([]string, 0, paramCount)
|
|
args := make([][]byte, 0, paramCount)
|
|
for i := 0; i < paramCount; i++ {
|
|
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
|
args = append(args, []byte(strconv.Itoa(i)))
|
|
}
|
|
sql := "values" + strings.Join(params, ", ")
|
|
|
|
result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
require.Len(t, result.Rows, paramCount)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCheckConn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
// Intentionally using TCP connection for more predictable close behavior. (Not sure if Unix domain sockets would behave subtly different.)
|
|
|
|
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
|
|
}
|
|
|
|
c1, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
defer c1.Close(ctx)
|
|
|
|
if c1.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
|
|
}
|
|
|
|
err = c1.CheckConn()
|
|
require.NoError(t, err)
|
|
|
|
c2, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
defer c2.Close(ctx)
|
|
|
|
_, err = c2.Exec(ctx, fmt.Sprintf("select pg_terminate_backend(%d)", c1.PID())).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
// It may take a while for the server to kill the backend. Retry until the error is detected or the test context is
|
|
// canceled.
|
|
for err == nil && ctx.Err() == nil {
|
|
time.Sleep(50 * time.Millisecond)
|
|
err = c1.CheckConn()
|
|
}
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestConnPing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
// Intentionally using TCP connection for more predictable close behavior. (Not sure if Unix domain sockets would behave subtly different.)
|
|
|
|
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
|
|
}
|
|
|
|
c1, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
defer c1.Close(ctx)
|
|
|
|
if c1.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
|
|
}
|
|
|
|
err = c1.Exec(ctx, "set log_statement = 'all'").Close()
|
|
require.NoError(t, err)
|
|
|
|
err = c1.Ping(ctx)
|
|
require.NoError(t, err)
|
|
|
|
c2, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
defer c2.Close(ctx)
|
|
|
|
_, err = c2.Exec(ctx, fmt.Sprintf("select pg_terminate_backend(%d)", c1.PID())).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
// Give a little time for the signal to actually kill the backend.
|
|
time.Sleep(500 * time.Millisecond)
|
|
|
|
err = c1.Ping(ctx)
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestPipelinePrepare(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
result := pgConn.ExecParams(ctx, `create temporary table t (id text primary key)`, nil, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil)
|
|
pipeline.SendPrepare("selectText", "select $1::text as b", nil)
|
|
pipeline.SendPrepare("selectNoParams", "select 42 as c", nil)
|
|
pipeline.SendPrepare("insertNoResults", "insert into t (id) values ($1)", nil)
|
|
pipeline.SendPrepare("insertNoParamsOrResults", "insert into t (id) values ('foo')", nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok := results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 1)
|
|
require.Equal(t, string(sd.Fields[0].Name), "a")
|
|
require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok = results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 1)
|
|
require.Equal(t, string(sd.Fields[0].Name), "b")
|
|
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok = results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 1)
|
|
require.Equal(t, string(sd.Fields[0].Name), "c")
|
|
require.Equal(t, []uint32{}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok = results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 0)
|
|
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok = results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 0)
|
|
require.Len(t, sd.ParamOIDs, 0)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelinePrepareError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil)
|
|
pipeline.SendPrepare("selectError", "bad", nil)
|
|
pipeline.SendPrepare("selectText", "select $1::text as b", nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok := results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 1)
|
|
require.Equal(t, string(sd.Fields[0].Name), "a")
|
|
require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
var pgErr *pgconn.PgError
|
|
require.ErrorAs(t, err, &pgErr)
|
|
require.Nil(t, results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelinePrepareAndDeallocate(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil)
|
|
pipeline.SendDeallocate("selectInt")
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok := results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 1)
|
|
require.Equal(t, string(sd.Fields[0].Name), "a")
|
|
require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.CloseComplete)
|
|
require.Truef(t, ok, "expected CloseComplete, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelineQuery(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok := results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult := rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "2", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "3", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "4", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "5", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelinePrepareQuery(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendPrepare("ps", "select $1::text as msg", nil)
|
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil)
|
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok := results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 1)
|
|
require.Equal(t, string(sd.Fields[0].Name), "msg")
|
|
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok := results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult := rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "hello", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "goodbye", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelineQueryErrorBetweenSyncs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 6`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok := results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult := rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "2", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "3", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
var pgErr *pgconn.PgError
|
|
require.ErrorAs(t, readResult.Err, &pgErr)
|
|
require.Equal(t, "22012", pgErr.Code)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "5", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "6", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelineCloseReadsUnreadResults(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok := results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult := rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok := results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult := rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
|
|
|
err = pipeline.Close()
|
|
require.EqualError(t, err, "pipeline has unsynced requests")
|
|
}
|
|
|
|
func TestConnOnPgError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
config.OnPgError = func(c *pgconn.PgConn, pgErr *pgconn.PgError) bool {
|
|
require.NotNil(t, c)
|
|
require.NotNil(t, pgErr)
|
|
// close connection on undefined tables only
|
|
if pgErr.Code == "42P01" {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
|
assert.NoError(t, err)
|
|
assert.False(t, pgConn.IsClosed())
|
|
|
|
_, err = pgConn.Exec(ctx, "select 1/0").ReadAll()
|
|
assert.Error(t, err)
|
|
assert.False(t, pgConn.IsClosed())
|
|
|
|
_, err = pgConn.Exec(ctx, "select * from non_existant_table").ReadAll()
|
|
assert.Error(t, err)
|
|
assert.True(t, pgConn.IsClosed())
|
|
}
|
|
|
|
func Example() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
defer pgConn.Close(ctx)
|
|
|
|
result := pgConn.ExecParams(ctx, "select generate_series(1,3)", nil, nil, nil, nil).Read()
|
|
if result.Err != nil {
|
|
log.Fatalln(result.Err)
|
|
}
|
|
|
|
for _, row := range result.Rows {
|
|
fmt.Println(string(row[0]))
|
|
}
|
|
|
|
fmt.Println(result.CommandTag)
|
|
// Output:
|
|
// 1
|
|
// 2
|
|
// 3
|
|
// SELECT 3
|
|
}
|
|
|
|
func GetSSLPassword(ctx context.Context) string {
|
|
connString := os.Getenv("PGX_SSL_PASSWORD")
|
|
return connString
|
|
}
|
|
|
|
var rsaCertPEM = `-----BEGIN CERTIFICATE-----
|
|
MIIDCTCCAfGgAwIBAgIUQDlN1g1bzxIJ8KWkayNcQY5gzMEwDQYJKoZIhvcNAQEL
|
|
BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMDgxNTIxNDgyNloXDTIzMDgx
|
|
NTIxNDgyNlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
|
|
AAOCAQ8AMIIBCgKCAQEA0vOppiT8zE+076acRORzD5JVbRYKMK3XlWLVrHua4+ct
|
|
Rm54WyP+3XsYU4JGGGKgb8E+u2UosGJYcSM+b+U1/5XPTcpuumS+pCiD9WP++A39
|
|
tsukYwR7m65cgpiI4dlLEZI3EWpAW+Bb3230KiYW4sAmQ0Ih4PrN+oPvzcs86F4d
|
|
9Y03CqVUxRKLBLaClZQAg8qz2Pawwj1FKKjDX7u2fRVR0wgOugpCMOBJMcCgz9pp
|
|
0HSa4x3KZDHEZY7Pah5XwWrCfAEfRWsSTGcNaoN8gSxGFM1JOEJa8SAuPGjFcYIv
|
|
MmVWdw0FXCgYlSDL02fzLE0uyvXBDibzSqOk770JhQIDAQABo1MwUTAdBgNVHQ4E
|
|
FgQUiJ8JLENJ+2k1Xl4o6y2Lc/qHTh0wHwYDVR0jBBgwFoAUiJ8JLENJ+2k1Xl4o
|
|
6y2Lc/qHTh0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAwjn2
|
|
gnNAhFvh58VqLIjU6ftvn6rhz5B9dg2+XyY8sskLhhkO1nL9339BVZsRt+eI3a7I
|
|
81GNIm9qHVM3MUAcQv3SZy+0UPVUT8DNH2LwHT3CHnYTBP8U+8n8TDNGSTMUhIBB
|
|
Rx+6KwODpwLdI79VGT3IkbU9bZwuepB9I9nM5t/tt5kS4gHmJFlO0aLJFCTO4Scf
|
|
hp/WLPv4XQUH+I3cPfaJRxz2j0Kc8iOzMhFmvl1XOGByjX6X33LnOzY/LVeTSGyS
|
|
VgC32BGtnMwuy5XZYgFAeUx9HKy4tG4OH2Ux6uPF/WAhsug6PXSjV7BK6wYT5i27
|
|
MlascjupnaptKX/wMA==
|
|
-----END CERTIFICATE-----
|
|
`
|
|
|
|
var rsaKeyPEM = testingKey(`-----BEGIN TESTING KEY-----
|
|
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDS86mmJPzMT7Tv
|
|
ppxE5HMPklVtFgowrdeVYtWse5rj5y1GbnhbI/7dexhTgkYYYqBvwT67ZSiwYlhx
|
|
Iz5v5TX/lc9Nym66ZL6kKIP1Y/74Df22y6RjBHubrlyCmIjh2UsRkjcRakBb4Fvf
|
|
bfQqJhbiwCZDQiHg+s36g+/NyzzoXh31jTcKpVTFEosEtoKVlACDyrPY9rDCPUUo
|
|
qMNfu7Z9FVHTCA66CkIw4EkxwKDP2mnQdJrjHcpkMcRljs9qHlfBasJ8AR9FaxJM
|
|
Zw1qg3yBLEYUzUk4QlrxIC48aMVxgi8yZVZ3DQVcKBiVIMvTZ/MsTS7K9cEOJvNK
|
|
o6TvvQmFAgMBAAECggEAKzTK54Ol33bn2TnnwdiElIjlRE2CUswYXrl6iDRc2hbs
|
|
WAOiVRB/T/+5UMla7/2rXJhY7+rdNZs/ABU24ZYxxCJ77jPrD/Q4c8j0lhsgCtBa
|
|
ycjV543wf0dsHTd+ubtWu8eVzdRUUD0YtB+CJevdPh4a+CWgaMMV0xyYzi61T+Yv
|
|
Z7Uc3awIAiT4Kw9JRmJiTnyMJg5vZqW3BBAX4ZIvS/54ipwEU+9sWLcuH7WmCR0B
|
|
QCTqS6hfJDLm//dGC89Iyno57zfYuiT3PYCWH5crr/DH3LqnwlNaOGSBkhkXuIL+
|
|
QvOaUMe2i0pjqxDrkBx05V554vyy9jEvK7i330HL4QKBgQDUJmouEr0+o7EMBApC
|
|
CPPu58K04qY5t9aGciG/pOurN42PF99yNZ1CnynH6DbcnzSl8rjc6Y65tzTlWods
|
|
bjwVfcmcokG7sPcivJvVjrjKpSQhL8xdZwSAjcqjN4yoJ/+ghm9w+SRmZr6oCQZ3
|
|
1jREfJKT+PGiWTEjYcExPWUD2QKBgQD+jdgq4c3tFavU8Hjnlf75xbStr5qu+fp2
|
|
SGLRRbX+msQwVbl2ZM9AJLoX9MTCl7D9zaI3ONhheMmfJ77lDTa3VMFtr3NevGA6
|
|
MxbiCEfRtQpNkJnsqCixLckx3bskj5+IF9BWzw7y7nOzdhoWVFv/+TltTm3RB51G
|
|
McdlmmVjjQKBgQDSFAw2/YV6vtu2O1XxGC591/Bd8MaMBziev+wde3GHhaZfGVPC
|
|
I8dLTpMwCwowpFKdNeLLl1gnHX161I+f1vUWjw4TVjVjaBUBx+VEr2Tb/nXtiwiD
|
|
QV0a883CnGJjreAblKRMKdpasMmBWhaWmn39h6Iad3zHuCzJjaaiXNpn2QKBgQCf
|
|
k1Q8LanmQnuh1c41f7aD5gjKCRezMUpt9BrejhD1NxheJJ9LNQ8nat6uPedLBcUS
|
|
lmJms+AR2qKqf0QQWyQ98YgAtshgTz8TvQtPT1mWgSOgVFHqJdC8obNK63FyDgc4
|
|
TZVxlgQNDqbBjfv0m5XA9f+mIlB9hYR2iKYzb4K30QKBgQC+LEJYZh00zsXttGHr
|
|
5wU1RzbgDIEsNuu+nZ4MxsaCik8ILNRHNXdeQbnADKuo6ATfhdmDIQMVZLG8Mivi
|
|
UwnwLd1GhizvqvLHa3ULnFphRyMGFxaLGV48axTT2ADoMX67ILrIY/yjycLqRZ3T
|
|
z3w+CgS20UrbLIR1YXfqUXge1g==
|
|
-----END TESTING KEY-----
|
|
`)
|
|
|
|
func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
|
|
|
|
func TestSNISupport(t *testing.T) {
|
|
t.Parallel()
|
|
tests := []struct {
|
|
name string
|
|
sni_param string
|
|
sni_set bool
|
|
}{
|
|
{
|
|
name: "SNI is passed by default",
|
|
sni_param: "",
|
|
sni_set: true,
|
|
},
|
|
{
|
|
name: "SNI is passed when asked for",
|
|
sni_param: "sslsni=1",
|
|
sni_set: true,
|
|
},
|
|
{
|
|
name: "SNI is not passed when disabled",
|
|
sni_param: "sslsni=0",
|
|
sni_set: false,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
serverErrChan := make(chan error, 1)
|
|
serverSNINameChan := make(chan string, 1)
|
|
defer close(serverErrChan)
|
|
defer close(serverSNINameChan)
|
|
|
|
go func() {
|
|
var sniHost string
|
|
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
err = conn.SetDeadline(time.Now().Add(5 * time.Second))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
|
|
backend := pgproto3.NewBackend(conn, conn)
|
|
startupMessage, err := backend.ReceiveStartupMessage()
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
|
|
switch startupMessage.(type) {
|
|
case *pgproto3.SSLRequest:
|
|
_, err = conn.Write([]byte("S"))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
default:
|
|
serverErrChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
|
|
return
|
|
}
|
|
|
|
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
|
|
srv := tls.Server(conn, &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
sniHost = argHello.ServerName
|
|
return nil, nil
|
|
},
|
|
})
|
|
defer srv.Close()
|
|
|
|
if err := srv.Handshake(); err != nil {
|
|
serverErrChan <- fmt.Errorf("handshake: %w", err)
|
|
return
|
|
}
|
|
|
|
srv.Write((&pgproto3.AuthenticationOk{}).Encode(nil))
|
|
srv.Write((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil))
|
|
srv.Write((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil))
|
|
|
|
serverSNINameChan <- sniHost
|
|
}()
|
|
|
|
_, port, _ := strings.Cut(ln.Addr().String(), ":")
|
|
connStr := fmt.Sprintf("sslmode=require host=localhost port=%s %s", port, tt.sni_param)
|
|
_, err = pgconn.Connect(ctx, connStr)
|
|
|
|
select {
|
|
case sniHost := <-serverSNINameChan:
|
|
if tt.sni_set {
|
|
require.Equal(t, sniHost, "localhost")
|
|
} else {
|
|
require.Equal(t, sniHost, "")
|
|
}
|
|
case err = <-serverErrChan:
|
|
t.Fatalf("server failed with error: %+v", err)
|
|
case <-time.After(time.Millisecond * 100):
|
|
t.Fatal("exceeded connection timeout without erroring out")
|
|
}
|
|
})
|
|
}
|
|
}
|