mirror of https://github.com/jackc/pgx.git
4069 lines
116 KiB
Go
4069 lines
116 KiB
Go
package pgconn_test
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/gzip"
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"math"
|
|
"net"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/internal/pgio"
|
|
"github.com/jackc/pgx/v5/internal/pgmock"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgconn/ctxwatch"
|
|
"github.com/jackc/pgx/v5/pgproto3"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
)
|
|
|
|
const pgbouncerConnStringEnvVar = "PGX_TEST_PGBOUNCER_CONN_STRING"
|
|
|
|
func TestConnect(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
env string
|
|
}{
|
|
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
|
|
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
|
|
{"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
|
|
{"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
|
|
{"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv(tt.env)
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", tt.env)
|
|
}
|
|
|
|
conn, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
|
|
closeConn(t, conn)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConnectWithOptions(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
env string
|
|
}{
|
|
{"Unix socket", "PGX_TEST_UNIX_SOCKET_CONN_STRING"},
|
|
{"TCP", "PGX_TEST_TCP_CONN_STRING"},
|
|
{"Plain password", "PGX_TEST_PLAIN_PASSWORD_CONN_STRING"},
|
|
{"MD5 password", "PGX_TEST_MD5_PASSWORD_CONN_STRING"},
|
|
{"SCRAM password", "PGX_TEST_SCRAM_PASSWORD_CONN_STRING"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv(tt.env)
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", tt.env)
|
|
}
|
|
var sslOptions pgconn.ParseConfigOptions
|
|
sslOptions.GetSSLPassword = GetSSLPassword
|
|
conn, err := pgconn.ConnectWithOptions(ctx, connString, sslOptions)
|
|
require.NoError(t, err)
|
|
|
|
closeConn(t, conn)
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestConnectTLS is separate from other connect tests because it has an additional test to ensure it really is a secure
|
|
// connection.
|
|
func TestConnectTLS(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv("PGX_TEST_TLS_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CONN_STRING")
|
|
}
|
|
|
|
conn, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
|
|
result := conn.ExecParams(ctx, `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
require.Len(t, result.Rows, 1)
|
|
require.Len(t, result.Rows[0], 1)
|
|
require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection")
|
|
|
|
closeConn(t, conn)
|
|
}
|
|
|
|
func TestConnectTLSPasswordProtectedClientCertWithSSLPassword(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv("PGX_TEST_TLS_CLIENT_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CLIENT_CONN_STRING")
|
|
}
|
|
if os.Getenv("PGX_SSL_PASSWORD") == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_SSL_PASSWORD")
|
|
}
|
|
|
|
connString += " sslpassword=" + os.Getenv("PGX_SSL_PASSWORD")
|
|
|
|
conn, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
|
|
result := conn.ExecParams(ctx, `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
require.Len(t, result.Rows, 1)
|
|
require.Len(t, result.Rows[0], 1)
|
|
require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection")
|
|
|
|
closeConn(t, conn)
|
|
}
|
|
|
|
func TestConnectTLSPasswordProtectedClientCertWithGetSSLPasswordConfigOption(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv("PGX_TEST_TLS_CLIENT_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TLS_CLIENT_CONN_STRING")
|
|
}
|
|
if os.Getenv("PGX_SSL_PASSWORD") == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_SSL_PASSWORD")
|
|
}
|
|
|
|
var sslOptions pgconn.ParseConfigOptions
|
|
sslOptions.GetSSLPassword = GetSSLPassword
|
|
config, err := pgconn.ParseConfigWithOptions(connString, sslOptions)
|
|
require.Nil(t, err)
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
|
|
result := conn.ExecParams(ctx, `select ssl from pg_stat_ssl where pg_backend_pid() = pid;`, nil, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
require.Len(t, result.Rows, 1)
|
|
require.Len(t, result.Rows[0], 1)
|
|
require.Equalf(t, "t", string(result.Rows[0][0]), "not a TLS connection")
|
|
|
|
closeConn(t, conn)
|
|
}
|
|
|
|
type pgmockWaitStep time.Duration
|
|
|
|
func (s pgmockWaitStep) Step(*pgproto3.Backend) error {
|
|
time.Sleep(time.Duration(s))
|
|
return nil
|
|
}
|
|
|
|
func TestConnectTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
tests := []struct {
|
|
name string
|
|
connect func(connStr string) error
|
|
}{
|
|
{
|
|
name: "via context that times out",
|
|
connect: func(connStr string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*50)
|
|
defer cancel()
|
|
_, err := pgconn.Connect(ctx, connStr)
|
|
return err
|
|
},
|
|
},
|
|
{
|
|
name: "via config ConnectTimeout",
|
|
connect: func(connStr string) error {
|
|
conf, err := pgconn.ParseConfig(connStr)
|
|
require.NoError(t, err)
|
|
conf.ConnectTimeout = time.Microsecond * 50
|
|
_, err = pgconn.ConnectConfig(context.Background(), conf)
|
|
return err
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
script := &pgmock.Script{
|
|
Steps: []pgmock.Step{
|
|
pgmock.ExpectAnyMessage(&pgproto3.StartupMessage{ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: map[string]string{}}),
|
|
pgmock.SendMessage(&pgproto3.AuthenticationOk{}),
|
|
pgmockWaitStep(time.Millisecond * 500),
|
|
pgmock.SendMessage(&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}),
|
|
pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
|
|
},
|
|
}
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
serverErrChan := make(chan error, 1)
|
|
go func() {
|
|
defer close(serverErrChan)
|
|
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
err = conn.SetDeadline(time.Now().Add(time.Millisecond * 450))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
|
|
err = script.Run(pgproto3.NewBackend(conn, conn))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
}()
|
|
|
|
host, port, _ := strings.Cut(ln.Addr().String(), ":")
|
|
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
|
tooLate := time.Now().Add(time.Millisecond * 500)
|
|
|
|
err = tt.connect(connStr)
|
|
require.True(t, pgconn.Timeout(err), err)
|
|
require.True(t, time.Now().Before(tooLate))
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConnectTimeoutStuckOnTLSHandshake(t *testing.T) {
|
|
t.Parallel()
|
|
tests := []struct {
|
|
name string
|
|
connect func(connStr string) error
|
|
}{
|
|
{
|
|
name: "via context that times out",
|
|
connect: func(connStr string) error {
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10)
|
|
defer cancel()
|
|
_, err := pgconn.Connect(ctx, connStr)
|
|
return err
|
|
},
|
|
},
|
|
{
|
|
name: "via config ConnectTimeout",
|
|
connect: func(connStr string) error {
|
|
conf, err := pgconn.ParseConfig(connStr)
|
|
require.NoError(t, err)
|
|
conf.ConnectTimeout = time.Millisecond * 10
|
|
_, err = pgconn.ConnectConfig(context.Background(), conf)
|
|
return err
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
serverErrChan := make(chan error, 1)
|
|
go func() {
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
var buf []byte
|
|
_, err = conn.Read(buf)
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
|
|
// Sleeping to hang the TLS handshake.
|
|
time.Sleep(time.Minute)
|
|
}()
|
|
|
|
host, port, _ := strings.Cut(ln.Addr().String(), ":")
|
|
connStr := fmt.Sprintf("host=%s port=%s", host, port)
|
|
|
|
errChan := make(chan error)
|
|
go func() {
|
|
err := tt.connect(connStr)
|
|
errChan <- err
|
|
}()
|
|
|
|
select {
|
|
case err = <-errChan:
|
|
require.True(t, pgconn.Timeout(err), err)
|
|
case err = <-serverErrChan:
|
|
t.Fatalf("server failed with error: %s", err)
|
|
case <-time.After(time.Millisecond * 500):
|
|
t.Fatal("exceeded connection timeout without erroring out")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestConnectInvalidUser(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
|
|
}
|
|
|
|
config, err := pgconn.ParseConfig(connString)
|
|
require.NoError(t, err)
|
|
|
|
config.User = "pgxinvalidusertest"
|
|
|
|
_, err = pgconn.ConnectConfig(ctx, config)
|
|
require.Error(t, err)
|
|
var pgErr *pgconn.PgError
|
|
require.ErrorAs(t, err, &pgErr)
|
|
if pgErr.Code != "28000" && pgErr.Code != "28P01" {
|
|
t.Fatalf("Expected to receive a PgError with code 28000 or 28P01, instead received: %v", pgErr)
|
|
}
|
|
}
|
|
|
|
func TestConnectWithConnectionRefused(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
// Presumably nothing is listening on 127.0.0.1:1
|
|
conn, err := pgconn.Connect(ctx, "host=127.0.0.1 port=1")
|
|
if err == nil {
|
|
conn.Close(ctx)
|
|
t.Fatal("Expected error establishing connection to bad port")
|
|
}
|
|
}
|
|
|
|
func TestConnectCustomDialer(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
dialed := false
|
|
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
dialed = true
|
|
return net.Dial(network, address)
|
|
}
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
require.True(t, dialed)
|
|
closeConn(t, conn)
|
|
}
|
|
|
|
func TestConnectCustomLookup(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
|
|
}
|
|
|
|
config, err := pgconn.ParseConfig(connString)
|
|
require.NoError(t, err)
|
|
|
|
looked := false
|
|
config.LookupFunc = func(ctx context.Context, host string) (addrs []string, err error) {
|
|
looked = true
|
|
return net.LookupHost(host)
|
|
}
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
require.True(t, looked)
|
|
closeConn(t, conn)
|
|
}
|
|
|
|
func TestConnectCustomLookupWithPort(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
|
|
}
|
|
|
|
config, err := pgconn.ParseConfig(connString)
|
|
require.NoError(t, err)
|
|
|
|
origPort := config.Port
|
|
// Change the config an invalid port so it will fail if used
|
|
config.Port = 0
|
|
|
|
looked := false
|
|
config.LookupFunc = func(ctx context.Context, host string) ([]string, error) {
|
|
looked = true
|
|
addrs, err := net.LookupHost(host)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
for i := range addrs {
|
|
addrs[i] = net.JoinHostPort(addrs[i], strconv.FormatUint(uint64(origPort), 10))
|
|
}
|
|
return addrs, nil
|
|
}
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
require.True(t, looked)
|
|
closeConn(t, conn)
|
|
}
|
|
|
|
func TestConnectWithRuntimeParams(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
config.RuntimeParams = map[string]string{
|
|
"application_name": "pgxtest",
|
|
"search_path": "myschema",
|
|
}
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, conn)
|
|
|
|
result := conn.ExecParams(ctx, "show application_name", nil, nil, nil, nil).Read()
|
|
require.Nil(t, result.Err)
|
|
assert.Equal(t, 1, len(result.Rows))
|
|
assert.Equal(t, "pgxtest", string(result.Rows[0][0]))
|
|
|
|
result = conn.ExecParams(ctx, "show search_path", nil, nil, nil, nil).Read()
|
|
require.Nil(t, result.Err)
|
|
assert.Equal(t, 1, len(result.Rows))
|
|
assert.Equal(t, "myschema", string(result.Rows[0][0]))
|
|
}
|
|
|
|
func TestConnectWithFallback(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
// Prepend current primary config to fallbacks
|
|
config.Fallbacks = append([]*pgconn.FallbackConfig{
|
|
{
|
|
Host: config.Host,
|
|
Port: config.Port,
|
|
TLSConfig: config.TLSConfig,
|
|
},
|
|
}, config.Fallbacks...)
|
|
|
|
// Make primary config bad
|
|
config.Host = "localhost"
|
|
config.Port = 1 // presumably nothing listening here
|
|
|
|
// Prepend bad first fallback
|
|
config.Fallbacks = append([]*pgconn.FallbackConfig{
|
|
{
|
|
Host: "localhost",
|
|
Port: 1,
|
|
TLSConfig: config.TLSConfig,
|
|
},
|
|
}, config.Fallbacks...)
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
closeConn(t, conn)
|
|
}
|
|
|
|
func TestConnectFailsWithResolveFailureAndFailedConnectionAttempts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
conn, err := pgconn.Connect(ctx, "host=localhost,127.0.0.1,foo.invalid port=1,2,3 sslmode=disable")
|
|
require.Error(t, err)
|
|
require.Nil(t, conn)
|
|
|
|
require.ErrorContains(t, err, "lookup foo.invalid")
|
|
// Not testing the entire string as depending on IPv4 or IPv6 support localhost may resolve to 127.0.0.1 or ::1.
|
|
require.ErrorContains(t, err, ":1 (localhost): dial error:")
|
|
require.ErrorContains(t, err, ":2 (127.0.0.1): dial error:")
|
|
}
|
|
|
|
func TestConnectWithValidateConnect(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
dialCount := 0
|
|
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
dialCount++
|
|
return net.Dial(network, address)
|
|
}
|
|
|
|
acceptConnCount := 0
|
|
config.ValidateConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
|
|
acceptConnCount++
|
|
if acceptConnCount < 2 {
|
|
return errors.New("reject first conn")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Append current primary config to fallbacks
|
|
config.Fallbacks = append(config.Fallbacks, &pgconn.FallbackConfig{
|
|
Host: config.Host,
|
|
Port: config.Port,
|
|
TLSConfig: config.TLSConfig,
|
|
})
|
|
|
|
// Repeat fallbacks
|
|
config.Fallbacks = append(config.Fallbacks, config.Fallbacks...)
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
closeConn(t, conn)
|
|
|
|
assert.True(t, dialCount > 1)
|
|
assert.True(t, acceptConnCount > 1)
|
|
}
|
|
|
|
func TestConnectWithValidateConnectTargetSessionAttrsReadWrite(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
config.ValidateConnect = pgconn.ValidateConnectTargetSessionAttrsReadWrite
|
|
config.RuntimeParams["default_transaction_read_only"] = "on"
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
if !assert.NotNil(t, err) {
|
|
conn.Close(ctx)
|
|
}
|
|
}
|
|
|
|
func TestConnectWithAfterConnect(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
config.AfterConnect = func(ctx context.Context, conn *pgconn.PgConn) error {
|
|
_, err := conn.Exec(ctx, "set search_path to foobar;").ReadAll()
|
|
return err
|
|
}
|
|
|
|
conn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
|
|
results, err := conn.Exec(ctx, "show search_path;").ReadAll()
|
|
require.NoError(t, err)
|
|
defer closeConn(t, conn)
|
|
|
|
assert.Equal(t, []byte("foobar"), results[0].Rows[0][0])
|
|
}
|
|
|
|
func TestConnectConfigRequiresConfigFromParseConfig(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config := &pgconn.Config{}
|
|
|
|
require.PanicsWithValue(t, "config must be created by ParseConfig", func() { pgconn.ConnectConfig(ctx, config) })
|
|
}
|
|
|
|
func TestConnPrepareSyntaxError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
psd, err := pgConn.Prepare(ctx, "ps1", "SYNTAX ERROR", nil)
|
|
require.Nil(t, psd)
|
|
require.NotNil(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnPrepareContextPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
cancel()
|
|
|
|
psd, err := pgConn.Prepare(ctx, "ps1", "select 1", nil)
|
|
assert.Nil(t, psd)
|
|
assert.Error(t, err)
|
|
assert.True(t, errors.Is(err, context.Canceled))
|
|
assert.True(t, pgconn.SafeToRetry(err))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnDeallocate(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Prepare(ctx, "ps1", "select 1", nil)
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
|
|
require.NoError(t, err)
|
|
|
|
err = pgConn.Deallocate(ctx, "ps1")
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
|
|
require.Error(t, err)
|
|
var pgErr *pgconn.PgError
|
|
require.ErrorAs(t, err, &pgErr)
|
|
require.Equal(t, "26000", pgErr.Code)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnDeallocateSucceedsInAbortedTransaction(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
err = pgConn.Exec(ctx, "begin").Close()
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.Prepare(ctx, "ps1", "select 1", nil)
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
|
|
require.NoError(t, err)
|
|
|
|
err = pgConn.Exec(ctx, "select 1/0").Close() // break transaction with divide by 0 error
|
|
require.Error(t, err)
|
|
var pgErr *pgconn.PgError
|
|
require.ErrorAs(t, err, &pgErr)
|
|
require.Equal(t, "22012", pgErr.Code)
|
|
|
|
err = pgConn.Deallocate(ctx, "ps1")
|
|
require.NoError(t, err)
|
|
|
|
err = pgConn.Exec(ctx, "rollback").Close()
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Close()
|
|
require.Error(t, err)
|
|
require.ErrorAs(t, err, &pgErr)
|
|
require.Equal(t, "26000", pgErr.Code)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnDeallocateNonExistantStatementSucceeds(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
err = pgConn.Deallocate(ctx, "ps1")
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExec(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
results, err := pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
|
assert.NoError(t, err)
|
|
|
|
assert.Len(t, results, 1)
|
|
assert.Nil(t, results[0].Err)
|
|
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
|
|
assert.Len(t, results[0].Rows, 1)
|
|
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecEmpty(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
multiResult := pgConn.Exec(ctx, ";")
|
|
|
|
resultCount := 0
|
|
for multiResult.NextResult() {
|
|
resultCount++
|
|
multiResult.ResultReader().Close()
|
|
}
|
|
assert.Equal(t, 0, resultCount)
|
|
err = multiResult.Close()
|
|
assert.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecMultipleQueries(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
results, err := pgConn.Exec(ctx, "select 'Hello, world'; select 1").ReadAll()
|
|
assert.NoError(t, err)
|
|
|
|
assert.Len(t, results, 2)
|
|
|
|
assert.Nil(t, results[0].Err)
|
|
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
|
|
assert.Len(t, results[0].Rows, 1)
|
|
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
|
|
|
|
assert.Nil(t, results[1].Err)
|
|
assert.Equal(t, "SELECT 1", results[1].CommandTag.String())
|
|
assert.Len(t, results[1].Rows, 1)
|
|
assert.Equal(t, "1", string(results[1].Rows[0][0]))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecMultipleQueriesEagerFieldDescriptions(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
mrr := pgConn.Exec(ctx, "select 'Hello, world' as msg; select 1 as num")
|
|
|
|
require.True(t, mrr.NextResult())
|
|
require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
|
|
assert.Equal(t, "msg", mrr.ResultReader().FieldDescriptions()[0].Name)
|
|
_, err = mrr.ResultReader().Close()
|
|
require.NoError(t, err)
|
|
|
|
require.True(t, mrr.NextResult())
|
|
require.Len(t, mrr.ResultReader().FieldDescriptions(), 1)
|
|
assert.Equal(t, "num", mrr.ResultReader().FieldDescriptions()[0].Name)
|
|
_, err = mrr.ResultReader().Close()
|
|
require.NoError(t, err)
|
|
|
|
require.False(t, mrr.NextResult())
|
|
|
|
require.NoError(t, mrr.Close())
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecMultipleQueriesError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
results, err := pgConn.Exec(ctx, "select 1; select 1/0; select 1").ReadAll()
|
|
require.NotNil(t, err)
|
|
if pgErr, ok := err.(*pgconn.PgError); ok {
|
|
assert.Equal(t, "22012", pgErr.Code)
|
|
} else {
|
|
t.Errorf("unexpected error: %v", err)
|
|
}
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
// CockroachDB starts the second query result set and then sends the divide by zero error.
|
|
require.Len(t, results, 2)
|
|
assert.Len(t, results[0].Rows, 1)
|
|
assert.Equal(t, "1", string(results[0].Rows[0][0]))
|
|
assert.Len(t, results[1].Rows, 0)
|
|
} else {
|
|
// PostgreSQL sends the divide by zero and never sends the second query result set.
|
|
require.Len(t, results, 1)
|
|
assert.Len(t, results[0].Rows, 1)
|
|
assert.Equal(t, "1", string(results[0].Rows[0][0]))
|
|
}
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecDeferredError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
|
|
}
|
|
|
|
setupSQL := `create temporary table t (
|
|
id text primary key,
|
|
n int not null,
|
|
unique (n) deferrable initially deferred
|
|
);
|
|
|
|
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
|
|
|
|
_, err = pgConn.Exec(ctx, setupSQL).ReadAll()
|
|
assert.NoError(t, err)
|
|
|
|
_, err = pgConn.Exec(ctx, `update t set n=n+1 where id='b' returning *`).ReadAll()
|
|
require.NotNil(t, err)
|
|
|
|
var pgErr *pgconn.PgError
|
|
require.True(t, errors.As(err, &pgErr))
|
|
require.Equal(t, "23505", pgErr.Code)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecContextCanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
cancel()
|
|
|
|
ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(1)")
|
|
|
|
for multiResult.NextResult() {
|
|
}
|
|
err = multiResult.Close()
|
|
assert.True(t, pgconn.Timeout(err))
|
|
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
|
assert.True(t, pgConn.IsClosed())
|
|
select {
|
|
case <-pgConn.CleanupDone():
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Connection cleanup exceeded maximum time")
|
|
}
|
|
}
|
|
|
|
func TestConnExecContextPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
cancel()
|
|
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
|
assert.Error(t, err)
|
|
assert.True(t, errors.Is(err, context.Canceled))
|
|
assert.True(t, pgconn.SafeToRetry(err))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecParams(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
result := pgConn.ExecParams(ctx, "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
|
|
require.Len(t, result.FieldDescriptions(), 1)
|
|
assert.Equal(t, "msg", result.FieldDescriptions()[0].Name)
|
|
|
|
rowCount := 0
|
|
for result.NextRow() {
|
|
rowCount += 1
|
|
assert.Equal(t, "Hello, world", string(result.Values()[0]))
|
|
}
|
|
assert.Equal(t, 1, rowCount)
|
|
commandTag, err := result.Close()
|
|
assert.Equal(t, "SELECT 1", commandTag.String())
|
|
assert.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecParamsDeferredError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
|
|
}
|
|
|
|
setupSQL := `create temporary table t (
|
|
id text primary key,
|
|
n int not null,
|
|
unique (n) deferrable initially deferred
|
|
);
|
|
|
|
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
|
|
|
|
_, err = pgConn.Exec(ctx, setupSQL).ReadAll()
|
|
assert.NoError(t, err)
|
|
|
|
result := pgConn.ExecParams(ctx, `update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil).Read()
|
|
require.NotNil(t, result.Err)
|
|
var pgErr *pgconn.PgError
|
|
require.True(t, errors.As(result.Err, &pgErr))
|
|
require.Equal(t, "23505", pgErr.Code)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecParamsMaxNumberOfParams(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
paramCount := math.MaxUint16
|
|
params := make([]string, 0, paramCount)
|
|
args := make([][]byte, 0, paramCount)
|
|
for i := 0; i < paramCount; i++ {
|
|
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
|
args = append(args, []byte(strconv.Itoa(i)))
|
|
}
|
|
sql := "values" + strings.Join(params, ", ")
|
|
|
|
result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
require.Len(t, result.Rows, paramCount)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecParamsTooManyParams(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
paramCount := math.MaxUint16 + 1
|
|
params := make([]string, 0, paramCount)
|
|
args := make([][]byte, 0, paramCount)
|
|
for i := 0; i < paramCount; i++ {
|
|
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
|
args = append(args, []byte(strconv.Itoa(i)))
|
|
}
|
|
sql := "values" + strings.Join(params, ", ")
|
|
|
|
result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read()
|
|
require.Error(t, result.Err)
|
|
require.Equal(t, "extended protocol limited to 65535 parameters", result.Err.Error())
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecParamsCanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
|
|
defer cancel()
|
|
result := pgConn.ExecParams(ctx, "select current_database(), pg_sleep(1)", nil, nil, nil, nil)
|
|
rowCount := 0
|
|
for result.NextRow() {
|
|
rowCount += 1
|
|
}
|
|
assert.Equal(t, 0, rowCount)
|
|
commandTag, err := result.Close()
|
|
assert.Equal(t, pgconn.CommandTag{}, commandTag)
|
|
assert.True(t, pgconn.Timeout(err))
|
|
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
|
|
|
assert.True(t, pgConn.IsClosed())
|
|
select {
|
|
case <-pgConn.CleanupDone():
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Connection cleanup exceeded maximum time")
|
|
}
|
|
}
|
|
|
|
func TestConnExecParamsPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
cancel()
|
|
result := pgConn.ExecParams(ctx, "select $1::text", [][]byte{[]byte("Hello, world")}, nil, nil, nil).Read()
|
|
require.Error(t, result.Err)
|
|
assert.True(t, errors.Is(result.Err, context.Canceled))
|
|
assert.True(t, pgconn.SafeToRetry(result.Err))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecParamsEmptySQL(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
result := pgConn.ExecParams(ctx, "", nil, nil, nil, nil).Read()
|
|
assert.Equal(t, pgconn.CommandTag{}, result.CommandTag)
|
|
assert.Len(t, result.Rows, 0)
|
|
assert.NoError(t, result.Err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/859
|
|
func TestResultReaderValuesHaveSameCapacityAsLength(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
result := pgConn.ExecParams(ctx, "select $1::text as msg", [][]byte{[]byte("Hello, world")}, nil, nil, nil)
|
|
require.Len(t, result.FieldDescriptions(), 1)
|
|
assert.Equal(t, "msg", result.FieldDescriptions()[0].Name)
|
|
|
|
rowCount := 0
|
|
for result.NextRow() {
|
|
rowCount += 1
|
|
assert.Equal(t, "Hello, world", string(result.Values()[0]))
|
|
assert.Equal(t, len(result.Values()[0]), cap(result.Values()[0]))
|
|
}
|
|
assert.Equal(t, 1, rowCount)
|
|
commandTag, err := result.Close()
|
|
assert.Equal(t, "SELECT 1", commandTag.String())
|
|
assert.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/1987
|
|
func TestResultReaderReadNil(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
result := pgConn.ExecParams(ctx, "select null::text", nil, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
require.Nil(t, result.Rows[0][0])
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecPrepared(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
psd, err := pgConn.Prepare(ctx, "ps1", "select $1::text as msg", nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, psd)
|
|
assert.Len(t, psd.ParamOIDs, 1)
|
|
assert.Len(t, psd.Fields, 1)
|
|
|
|
result := pgConn.ExecPrepared(ctx, "ps1", [][]byte{[]byte("Hello, world")}, nil, nil)
|
|
require.Len(t, result.FieldDescriptions(), 1)
|
|
assert.Equal(t, "msg", result.FieldDescriptions()[0].Name)
|
|
|
|
rowCount := 0
|
|
for result.NextRow() {
|
|
rowCount += 1
|
|
assert.Equal(t, "Hello, world", string(result.Values()[0]))
|
|
}
|
|
assert.Equal(t, 1, rowCount)
|
|
commandTag, err := result.Close()
|
|
assert.Equal(t, "SELECT 1", commandTag.String())
|
|
assert.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecPreparedMaxNumberOfParams(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
paramCount := math.MaxUint16
|
|
params := make([]string, 0, paramCount)
|
|
args := make([][]byte, 0, paramCount)
|
|
for i := 0; i < paramCount; i++ {
|
|
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
|
args = append(args, []byte(strconv.Itoa(i)))
|
|
}
|
|
sql := "values" + strings.Join(params, ", ")
|
|
|
|
psd, err := pgConn.Prepare(ctx, "ps1", sql, nil)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, psd)
|
|
assert.Len(t, psd.ParamOIDs, paramCount)
|
|
assert.Len(t, psd.Fields, 1)
|
|
|
|
result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
require.Len(t, result.Rows, paramCount)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecPreparedTooManyParams(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
paramCount := math.MaxUint16 + 1
|
|
params := make([]string, 0, paramCount)
|
|
args := make([][]byte, 0, paramCount)
|
|
for i := 0; i < paramCount; i++ {
|
|
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
|
args = append(args, []byte(strconv.Itoa(i)))
|
|
}
|
|
sql := "values" + strings.Join(params, ", ")
|
|
|
|
psd, err := pgConn.Prepare(ctx, "ps1", sql, nil)
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
// CockroachDB rejects preparing a statement with more than 65535 parameters.
|
|
require.EqualError(t, err, "ERROR: more than 65535 arguments to prepared statement: 65536 (SQLSTATE 08P01)")
|
|
} else {
|
|
// PostgreSQL accepts preparing a statement with more than 65535 parameters and only fails when executing it through the extended protocol.
|
|
require.NoError(t, err)
|
|
require.NotNil(t, psd)
|
|
assert.Len(t, psd.ParamOIDs, paramCount)
|
|
assert.Len(t, psd.Fields, 1)
|
|
|
|
result := pgConn.ExecPrepared(ctx, "ps1", args, nil, nil).Read()
|
|
require.EqualError(t, result.Err, "extended protocol limited to 65535 parameters")
|
|
}
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecPreparedCanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Prepare(ctx, "ps1", "select current_database(), pg_sleep(1)", nil)
|
|
require.NoError(t, err)
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
|
|
defer cancel()
|
|
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil)
|
|
rowCount := 0
|
|
for result.NextRow() {
|
|
rowCount += 1
|
|
}
|
|
assert.Equal(t, 0, rowCount)
|
|
commandTag, err := result.Close()
|
|
assert.Equal(t, pgconn.CommandTag{}, commandTag)
|
|
assert.True(t, pgconn.Timeout(err))
|
|
assert.True(t, pgConn.IsClosed())
|
|
select {
|
|
case <-pgConn.CleanupDone():
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Connection cleanup exceeded maximum time")
|
|
}
|
|
}
|
|
|
|
func TestConnExecPreparedPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Prepare(ctx, "ps1", "select current_database(), pg_sleep(1)", nil)
|
|
require.NoError(t, err)
|
|
|
|
cancel()
|
|
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
|
|
require.Error(t, result.Err)
|
|
assert.True(t, errors.Is(result.Err, context.Canceled))
|
|
assert.True(t, pgconn.SafeToRetry(result.Err))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecPreparedEmptySQL(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Prepare(ctx, "ps1", "", nil)
|
|
require.NoError(t, err)
|
|
|
|
result := pgConn.ExecPrepared(ctx, "ps1", nil, nil, nil).Read()
|
|
assert.Equal(t, pgconn.CommandTag{}, result.CommandTag)
|
|
assert.Len(t, result.Rows, 0)
|
|
assert.NoError(t, result.Err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecBatch(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil)
|
|
require.NoError(t, err)
|
|
|
|
batch := &pgconn.Batch{}
|
|
|
|
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
|
|
batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil)
|
|
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil)
|
|
results, err := pgConn.ExecBatch(ctx, batch).ReadAll()
|
|
require.NoError(t, err)
|
|
require.Len(t, results, 3)
|
|
|
|
require.Len(t, results[0].Rows, 1)
|
|
require.Equal(t, "ExecParams 1", string(results[0].Rows[0][0]))
|
|
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
|
|
|
|
require.Len(t, results[1].Rows, 1)
|
|
require.Equal(t, "ExecPrepared 1", string(results[1].Rows[0][0]))
|
|
assert.Equal(t, "SELECT 1", results[1].CommandTag.String())
|
|
|
|
require.Len(t, results[2].Rows, 1)
|
|
require.Equal(t, "ExecParams 2", string(results[2].Rows[0][0]))
|
|
assert.Equal(t, "SELECT 1", results[2].CommandTag.String())
|
|
}
|
|
|
|
type mockConnection struct {
|
|
net.Conn
|
|
writeLatency *time.Duration
|
|
}
|
|
|
|
func (m mockConnection) Write(b []byte) (n int, err error) {
|
|
time.Sleep(*m.writeLatency)
|
|
return m.Conn.Write(b)
|
|
}
|
|
|
|
func TestConnExecBatchWriteError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
var mockConn mockConnection
|
|
writeLatency := 0 * time.Second
|
|
config.DialFunc = func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
conn, err := net.Dial(network, address)
|
|
mockConn = mockConnection{conn, &writeLatency}
|
|
return mockConn, err
|
|
}
|
|
|
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
batch := &pgconn.Batch{}
|
|
pgConn.Conn()
|
|
|
|
ctx2, cancel2 := context.WithTimeout(context.Background(), 1*time.Second)
|
|
defer cancel2()
|
|
|
|
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
|
|
writeLatency = 2 * time.Second
|
|
mrr := pgConn.ExecBatch(ctx2, batch)
|
|
err = mrr.Close()
|
|
require.Error(t, err)
|
|
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
|
require.True(t, pgConn.IsClosed())
|
|
}
|
|
|
|
func TestConnExecBatchDeferredError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support deferred constraint (https://github.com/cockroachdb/cockroach/issues/31632)")
|
|
}
|
|
|
|
setupSQL := `create temporary table t (
|
|
id text primary key,
|
|
n int not null,
|
|
unique (n) deferrable initially deferred
|
|
);
|
|
|
|
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`
|
|
|
|
_, err = pgConn.Exec(ctx, setupSQL).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
batch := &pgconn.Batch{}
|
|
|
|
batch.ExecParams(`update t set n=n+1 where id='b' returning *`, nil, nil, nil, nil)
|
|
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
|
|
require.NotNil(t, err)
|
|
var pgErr *pgconn.PgError
|
|
require.True(t, errors.As(err, &pgErr))
|
|
require.Equal(t, "23505", pgErr.Code)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnExecBatchPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Prepare(ctx, "ps1", "select $1::text", nil)
|
|
require.NoError(t, err)
|
|
|
|
batch := &pgconn.Batch{}
|
|
|
|
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 1")}, nil, nil, nil)
|
|
batch.ExecPrepared("ps1", [][]byte{[]byte("ExecPrepared 1")}, nil, nil)
|
|
batch.ExecParams("select $1::text", [][]byte{[]byte("ExecParams 2")}, nil, nil, nil)
|
|
|
|
cancel()
|
|
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
|
|
require.Error(t, err)
|
|
assert.True(t, errors.Is(err, context.Canceled))
|
|
assert.True(t, pgconn.SafeToRetry(err))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
// Without concurrent reading and writing large batches can deadlock.
|
|
//
|
|
// See https://github.com/jackc/pgx/issues/374.
|
|
func TestConnExecBatchHuge(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping test in short mode.")
|
|
}
|
|
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
batch := &pgconn.Batch{}
|
|
|
|
queryCount := 100000
|
|
args := make([]string, queryCount)
|
|
|
|
for i := range args {
|
|
args[i] = strconv.Itoa(i)
|
|
batch.ExecParams("select $1::text", [][]byte{[]byte(args[i])}, nil, nil, nil)
|
|
}
|
|
|
|
results, err := pgConn.ExecBatch(ctx, batch).ReadAll()
|
|
require.NoError(t, err)
|
|
require.Len(t, results, queryCount)
|
|
|
|
for i := range args {
|
|
require.Len(t, results[i].Rows, 1)
|
|
require.Equal(t, args[i], string(results[i].Rows[0][0]))
|
|
assert.Equal(t, "SELECT 1", results[i].CommandTag.String())
|
|
}
|
|
}
|
|
|
|
func TestConnExecBatchImplicitTransaction(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Skipping due to known server issue: (https://github.com/cockroachdb/cockroach/issues/44803)")
|
|
}
|
|
|
|
_, err = pgConn.Exec(ctx, "create temporary table t(id int)").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
batch := &pgconn.Batch{}
|
|
|
|
batch.ExecParams("insert into t(id) values(1)", nil, nil, nil, nil)
|
|
batch.ExecParams("insert into t(id) values(2)", nil, nil, nil, nil)
|
|
batch.ExecParams("insert into t(id) values(3)", nil, nil, nil, nil)
|
|
batch.ExecParams("select 1/0", nil, nil, nil, nil)
|
|
_, err = pgConn.ExecBatch(ctx, batch).ReadAll()
|
|
require.Error(t, err)
|
|
|
|
result := pgConn.ExecParams(ctx, "select count(*) from t", nil, nil, nil, nil).Read()
|
|
require.Equal(t, "0", string(result.Rows[0][0]))
|
|
}
|
|
|
|
func TestConnLocking(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
mrr := pgConn.Exec(ctx, "select 'Hello, world'")
|
|
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
|
assert.Error(t, err)
|
|
assert.Equal(t, "conn busy", err.Error())
|
|
assert.True(t, pgconn.SafeToRetry(err))
|
|
|
|
results, err := mrr.ReadAll()
|
|
assert.NoError(t, err)
|
|
assert.Len(t, results, 1)
|
|
assert.Nil(t, results[0].Err)
|
|
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
|
|
assert.Len(t, results[0].Rows, 1)
|
|
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnOnNotice(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
var notice *pgconn.Notice
|
|
config.OnNotice = func(c *pgconn.PgConn, n *pgconn.Notice) {
|
|
notice = n
|
|
}
|
|
config.RuntimeParams["client_min_messages"] = "notice" // Ensure we only get the message we expect.
|
|
|
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support PL/PGSQL (https://github.com/cockroachdb/cockroach/issues/17511)")
|
|
}
|
|
|
|
multiResult := pgConn.Exec(ctx, `do $$
|
|
begin
|
|
raise notice 'hello, world';
|
|
end$$;`)
|
|
err = multiResult.Close()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "NOTICE", notice.SeverityUnlocalized)
|
|
assert.Equal(t, "hello, world", notice.Message)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnOnNotification(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
var msg string
|
|
config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
|
|
msg = n.Payload
|
|
}
|
|
|
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
|
|
}
|
|
|
|
_, err = pgConn.Exec(ctx, "listen foo").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
notifier, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, notifier)
|
|
_, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.Exec(ctx, "select 1").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, "bar", msg)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnWaitForNotification(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
var msg string
|
|
config.OnNotification = func(c *pgconn.PgConn, n *pgconn.Notification) {
|
|
msg = n.Payload
|
|
}
|
|
|
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support LISTEN / NOTIFY (https://github.com/cockroachdb/cockroach/issues/41522)")
|
|
}
|
|
|
|
_, err = pgConn.Exec(ctx, "listen foo").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
notifier, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, notifier)
|
|
_, err = notifier.Exec(ctx, "notify foo, 'bar'").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
err = pgConn.WaitForNotification(ctx)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, "bar", msg)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnWaitForNotificationPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
cancel()
|
|
err = pgConn.WaitForNotification(ctx)
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnWaitForNotificationTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, 5*time.Millisecond)
|
|
err = pgConn.WaitForNotification(ctx)
|
|
cancel()
|
|
assert.True(t, pgconn.Timeout(err))
|
|
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyToSmall(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does support COPY TO")
|
|
}
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int2,
|
|
b int4,
|
|
c int8,
|
|
d varchar,
|
|
e text,
|
|
f date,
|
|
g json
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}')`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.Exec(ctx, `insert into foo values (null, null, null, null, null, null, null)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
inputBytes := []byte("0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\n" +
|
|
"\\N\t\\N\t\\N\t\\N\t\\N\t\\N\t\\N\n")
|
|
|
|
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
|
|
|
|
res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout")
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, int64(2), res.RowsAffected())
|
|
assert.Equal(t, inputBytes, outputWriter.Bytes())
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyToLarge(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does support COPY TO")
|
|
}
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int2,
|
|
b int4,
|
|
c int8,
|
|
d varchar,
|
|
e text,
|
|
f date,
|
|
g json,
|
|
h bytea
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
inputBytes := make([]byte, 0)
|
|
|
|
for i := 0; i < 1000; i++ {
|
|
_, err = pgConn.Exec(ctx, `insert into foo values (0, 1, 2, 'abc', 'efg', '2000-01-01', '{"abc":"def","foo":"bar"}', 'oooo')`).ReadAll()
|
|
require.NoError(t, err)
|
|
inputBytes = append(inputBytes, "0\t1\t2\tabc\tefg\t2000-01-01\t{\"abc\":\"def\",\"foo\":\"bar\"}\t\\\\x6f6f6f6f\n"...)
|
|
}
|
|
|
|
outputWriter := bytes.NewBuffer(make([]byte, 0, len(inputBytes)))
|
|
|
|
res, err := pgConn.CopyTo(ctx, outputWriter, "copy foo to stdout")
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, int64(1000), res.RowsAffected())
|
|
assert.Equal(t, inputBytes, outputWriter.Bytes())
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyToQueryError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
outputWriter := bytes.NewBuffer(make([]byte, 0))
|
|
|
|
res, err := pgConn.CopyTo(ctx, outputWriter, "cropy foo to stdout")
|
|
require.Error(t, err)
|
|
assert.IsType(t, &pgconn.PgError{}, err)
|
|
assert.Equal(t, int64(0), res.RowsAffected())
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyToCanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)")
|
|
}
|
|
|
|
outputWriter := &bytes.Buffer{}
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
|
|
defer cancel()
|
|
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select *, pg_sleep(0.01) from generate_series(1,1000)) to stdout")
|
|
assert.Error(t, err)
|
|
assert.Equal(t, pgconn.CommandTag{}, res)
|
|
|
|
assert.True(t, pgConn.IsClosed())
|
|
select {
|
|
case <-pgConn.CleanupDone():
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Connection cleanup exceeded maximum time")
|
|
}
|
|
}
|
|
|
|
func TestConnCopyToPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
outputWriter := &bytes.Buffer{}
|
|
|
|
cancel()
|
|
res, err := pgConn.CopyTo(ctx, outputWriter, "copy (select * from generate_series(1,1000)) to stdout")
|
|
require.Error(t, err)
|
|
assert.True(t, errors.Is(err, context.Canceled))
|
|
assert.True(t, pgconn.SafeToRetry(err))
|
|
assert.Equal(t, pgconn.CommandTag{}, res)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyFrom(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int4,
|
|
b varchar
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
srcBuf := &bytes.Buffer{}
|
|
|
|
inputRows := [][][]byte{}
|
|
for i := 0; i < 1000; i++ {
|
|
a := strconv.Itoa(i)
|
|
b := "foo " + a + " bar"
|
|
inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
|
|
_, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
copySql := "COPY foo FROM STDIN WITH (FORMAT csv)"
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
copySql = "COPY foo FROM STDIN WITH CSV"
|
|
}
|
|
ct, err := pgConn.CopyFrom(ctx, srcBuf, copySql)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
|
|
|
|
result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
|
|
assert.Equal(t, inputRows, result.Rows)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyFromBinary(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int4,
|
|
b varchar
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
buf := []byte{}
|
|
buf = append(buf, "PGCOPY\n\377\r\n\000"...)
|
|
buf = pgio.AppendInt32(buf, 0)
|
|
buf = pgio.AppendInt32(buf, 0)
|
|
|
|
inputRows := [][][]byte{}
|
|
for i := 0; i < 1000; i++ {
|
|
// Number of elements in the tuple
|
|
buf = pgio.AppendInt16(buf, int16(2))
|
|
a := i
|
|
|
|
// Length of element for column `a int4`
|
|
buf = pgio.AppendInt32(buf, 4)
|
|
buf, err = pgtype.NewMap().Encode(pgtype.Int4OID, pgx.BinaryFormatCode, a, buf)
|
|
require.NoError(t, err)
|
|
|
|
b := "foo " + strconv.Itoa(a) + " bar"
|
|
lenB := int32(len([]byte(b)))
|
|
// Length of element for column `b varchar`
|
|
buf = pgio.AppendInt32(buf, lenB)
|
|
buf, err = pgtype.NewMap().Encode(pgtype.VarcharOID, pgx.BinaryFormatCode, b, buf)
|
|
require.NoError(t, err)
|
|
|
|
inputRows = append(inputRows, [][]byte{[]byte(strconv.Itoa(a)), []byte(b)})
|
|
}
|
|
|
|
srcBuf := &bytes.Buffer{}
|
|
srcBuf.Write(buf)
|
|
ct, err := pgConn.CopyFrom(ctx, srcBuf, "COPY foo (a, b) FROM STDIN BINARY;")
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
|
|
|
|
result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
|
|
assert.Equal(t, inputRows, result.Rows)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyFromCanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int4,
|
|
b varchar
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
r, w := io.Pipe()
|
|
go func() {
|
|
for i := 0; i < 1000000; i++ {
|
|
a := strconv.Itoa(i)
|
|
b := "foo " + a + " bar"
|
|
_, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
|
if err != nil {
|
|
return
|
|
}
|
|
time.Sleep(time.Microsecond)
|
|
}
|
|
}()
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
|
|
copySql := "COPY foo FROM STDIN WITH (FORMAT csv)"
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
copySql = "COPY foo FROM STDIN WITH CSV"
|
|
}
|
|
ct, err := pgConn.CopyFrom(ctx, r, copySql)
|
|
cancel()
|
|
assert.Equal(t, int64(0), ct.RowsAffected())
|
|
assert.Error(t, err)
|
|
|
|
assert.True(t, pgConn.IsClosed())
|
|
select {
|
|
case <-pgConn.CleanupDone():
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Connection cleanup exceeded maximum time")
|
|
}
|
|
}
|
|
|
|
func TestConnCopyFromPrecanceled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int4,
|
|
b varchar
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
r, w := io.Pipe()
|
|
go func() {
|
|
for i := 0; i < 1000000; i++ {
|
|
a := strconv.Itoa(i)
|
|
b := "foo " + a + " bar"
|
|
_, err := w.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
|
if err != nil {
|
|
return
|
|
}
|
|
time.Sleep(time.Microsecond)
|
|
}
|
|
}()
|
|
|
|
ctx, cancel = context.WithCancel(ctx)
|
|
cancel()
|
|
ct, err := pgConn.CopyFrom(ctx, r, "COPY foo FROM STDIN WITH (FORMAT csv)")
|
|
require.Error(t, err)
|
|
assert.True(t, errors.Is(err, context.Canceled))
|
|
assert.True(t, pgconn.SafeToRetry(err))
|
|
assert.Equal(t, pgconn.CommandTag{}, ct)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyFromGzipReader(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not fully support COPY FROM (https://www.cockroachlabs.com/docs/v20.2/copy-from.html)")
|
|
}
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int4,
|
|
b varchar
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
f, err := os.CreateTemp(t.TempDir(), "*")
|
|
require.NoError(t, err)
|
|
defer f.Close()
|
|
|
|
gw := gzip.NewWriter(f)
|
|
|
|
inputRows := [][][]byte{}
|
|
for i := 0; i < 1000; i++ {
|
|
a := strconv.Itoa(i)
|
|
b := "foo " + a + " bar"
|
|
inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
|
|
_, err = gw.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
err = gw.Close()
|
|
require.NoError(t, err)
|
|
|
|
_, err = f.Seek(0, 0)
|
|
require.NoError(t, err)
|
|
|
|
gr, err := gzip.NewReader(f)
|
|
require.NoError(t, err)
|
|
|
|
copySql := "COPY foo FROM STDIN WITH (FORMAT csv)"
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
copySql = "COPY foo FROM STDIN WITH CSV"
|
|
}
|
|
ct, err := pgConn.CopyFrom(ctx, gr, copySql)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, int64(len(inputRows)), ct.RowsAffected())
|
|
|
|
err = gr.Close()
|
|
require.NoError(t, err)
|
|
|
|
result := pgConn.ExecParams(ctx, "select * from foo", nil, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
|
|
assert.Equal(t, inputRows, result.Rows)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyFromQuerySyntaxError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table foo(
|
|
a int4,
|
|
b varchar
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
srcBuf := &bytes.Buffer{}
|
|
|
|
// Send data even though the COPY FROM command will be rejected with a syntax error. This ensures that this does not
|
|
// break the connection. See https://github.com/jackc/pgconn/pull/127 for context.
|
|
inputRows := [][][]byte{}
|
|
for i := 0; i < 1000; i++ {
|
|
a := strconv.Itoa(i)
|
|
b := "foo " + a + " bar"
|
|
inputRows = append(inputRows, [][]byte{[]byte(a), []byte(b)})
|
|
_, err = srcBuf.Write([]byte(fmt.Sprintf("%s,\"%s\"\n", a, b)))
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
res, err := pgConn.CopyFrom(ctx, srcBuf, "cropy foo FROM STDIN WITH (FORMAT csv)")
|
|
require.Error(t, err)
|
|
assert.IsType(t, &pgconn.PgError{}, err)
|
|
assert.Equal(t, int64(0), res.RowsAffected())
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCopyFromQueryNoTableError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
srcBuf := &bytes.Buffer{}
|
|
|
|
res, err := pgConn.CopyFrom(ctx, srcBuf, "copy foo to stdout")
|
|
require.Error(t, err)
|
|
assert.IsType(t, &pgconn.PgError{}, err)
|
|
assert.Equal(t, int64(0), res.RowsAffected())
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
// https://github.com/jackc/pgconn/issues/21
|
|
func TestConnCopyFromNoticeResponseReceivedMidStream(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support triggers (https://github.com/cockroachdb/cockroach/issues/28296)")
|
|
}
|
|
|
|
_, err = pgConn.Exec(ctx, `create temporary table sentences(
|
|
t text,
|
|
ts tsvector
|
|
)`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.Exec(ctx, `create function pg_temp.sentences_trigger() returns trigger as $$
|
|
begin
|
|
new.ts := to_tsvector(new.t);
|
|
return new;
|
|
end
|
|
$$ language plpgsql;`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
_, err = pgConn.Exec(ctx, `create trigger sentences_update before insert on sentences for each row execute procedure pg_temp.sentences_trigger();`).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
longString := make([]byte, 10001)
|
|
for i := range longString {
|
|
longString[i] = 'x'
|
|
}
|
|
|
|
buf := &bytes.Buffer{}
|
|
for i := 0; i < 1000; i++ {
|
|
buf.Write([]byte(fmt.Sprintf("%s\n", string(longString))))
|
|
}
|
|
|
|
_, err = pgConn.CopyFrom(ctx, buf, "COPY sentences(t) FROM STDIN WITH (FORMAT csv)")
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
type delayedReader struct {
|
|
r io.Reader
|
|
}
|
|
|
|
func (d delayedReader) Read(p []byte) (int, error) {
|
|
// W/o sleep test passes, with sleep it fails.
|
|
time.Sleep(time.Millisecond)
|
|
return d.r.Read(p)
|
|
}
|
|
|
|
// https://github.com/jackc/pgconn/issues/128
|
|
func TestConnCopyFromDataWriteAfterErrorAndReturn(t *testing.T) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
connString := os.Getenv("PGX_TEST_DATABASE")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_DATABASE")
|
|
}
|
|
|
|
config, err := pgconn.ParseConfig(connString)
|
|
require.NoError(t, err)
|
|
|
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not fully support COPY FROM")
|
|
}
|
|
|
|
setupSQL := `create temporary table t (
|
|
id text primary key,
|
|
n int not null
|
|
);`
|
|
|
|
_, err = pgConn.Exec(ctx, setupSQL).ReadAll()
|
|
assert.NoError(t, err)
|
|
|
|
r1 := delayedReader{r: strings.NewReader(`id 0\n`)}
|
|
// Generate an error with a bogus COPY command
|
|
_, err = pgConn.CopyFrom(ctx, r1, "COPY nosuchtable FROM STDIN ")
|
|
assert.Error(t, err)
|
|
|
|
r2 := delayedReader{r: strings.NewReader(`id 0\n`)}
|
|
_, err = pgConn.CopyFrom(ctx, r2, "COPY t FROM STDIN")
|
|
assert.NoError(t, err)
|
|
}
|
|
|
|
func TestConnEscapeString(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
tests := []struct {
|
|
in string
|
|
out string
|
|
}{
|
|
{in: "", out: ""},
|
|
{in: "42", out: "42"},
|
|
{in: "'", out: "''"},
|
|
{in: "hi'there", out: "hi''there"},
|
|
{in: "'hi there'", out: "''hi there''"},
|
|
}
|
|
|
|
for i, tt := range tests {
|
|
value, err := pgConn.EscapeString(tt.in)
|
|
if assert.NoErrorf(t, err, "%d.", i) {
|
|
assert.Equalf(t, tt.out, value, "%d.", i)
|
|
}
|
|
}
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCancelRequest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
if pgConn.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support query cancellation (https://github.com/cockroachdb/cockroach/issues/41335)")
|
|
}
|
|
|
|
multiResult := pgConn.Exec(ctx, "select 'Hello, world', pg_sleep(25)")
|
|
|
|
errChan := make(chan error)
|
|
go func() {
|
|
// The query is actually sent when multiResult.NextResult() is called. So wait to ensure it is sent.
|
|
// Once Flush is available this could use that instead.
|
|
time.Sleep(1 * time.Second)
|
|
|
|
err := pgConn.CancelRequest(ctx)
|
|
errChan <- err
|
|
}()
|
|
|
|
for multiResult.NextResult() {
|
|
}
|
|
err = multiResult.Close()
|
|
|
|
require.IsType(t, &pgconn.PgError{}, err)
|
|
require.Equal(t, "57014", err.(*pgconn.PgError).Code)
|
|
|
|
err = <-errChan
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/659
|
|
func TestConnContextCanceledCancelsRunningQueryOnServer(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("postgres", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testConnContextCanceledCancelsRunningQueryOnServer(t, os.Getenv("PGX_TEST_DATABASE"), "postgres")
|
|
})
|
|
|
|
t.Run("pgbouncer", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
connString := os.Getenv(pgbouncerConnStringEnvVar)
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", pgbouncerConnStringEnvVar)
|
|
}
|
|
|
|
testConnContextCanceledCancelsRunningQueryOnServer(t, connString, "pgbouncer")
|
|
})
|
|
}
|
|
|
|
func testConnContextCanceledCancelsRunningQueryOnServer(t *testing.T, connString, dbType string) {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
// Getting the actual PostgreSQL server process ID (PID) from a query executed through pgbouncer is not straightforward
|
|
// because pgbouncer abstracts the underlying database connections, and it doesn't expose the PID of the PostgreSQL
|
|
// server process to clients. However, we can check if the query is running by checking the generated query ID.
|
|
queryID := fmt.Sprintf("%s testConnContextCanceled %d", dbType, time.Now().UnixNano())
|
|
|
|
multiResult := pgConn.Exec(ctx, fmt.Sprintf(`
|
|
-- %v
|
|
select 'Hello, world', pg_sleep(30)
|
|
`, queryID))
|
|
|
|
for multiResult.NextResult() {
|
|
}
|
|
err = multiResult.Close()
|
|
assert.True(t, pgconn.Timeout(err))
|
|
assert.True(t, pgConn.IsClosed())
|
|
select {
|
|
case <-pgConn.CleanupDone():
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Connection cleanup exceeded maximum time")
|
|
}
|
|
|
|
ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second)
|
|
defer cancel()
|
|
|
|
otherConn, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, otherConn)
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, time.Second*5)
|
|
defer cancel()
|
|
|
|
for {
|
|
result := otherConn.ExecParams(ctx,
|
|
`select 1 from pg_stat_activity where query like $1`,
|
|
[][]byte{[]byte("%" + queryID + "%")},
|
|
nil,
|
|
nil,
|
|
nil,
|
|
).Read()
|
|
require.NoError(t, result.Err)
|
|
|
|
if len(result.Rows) == 0 {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestHijackAndConstruct(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
origConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
err = origConn.SyncConn(ctx)
|
|
require.NoError(t, err)
|
|
|
|
hc, err := origConn.Hijack()
|
|
require.NoError(t, err)
|
|
|
|
_, err = origConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
|
require.Error(t, err)
|
|
|
|
newConn, err := pgconn.Construct(hc)
|
|
require.NoError(t, err)
|
|
|
|
defer closeConn(t, newConn)
|
|
|
|
results, err := newConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
|
assert.NoError(t, err)
|
|
|
|
assert.Len(t, results, 1)
|
|
assert.Nil(t, results[0].Err)
|
|
assert.Equal(t, "SELECT 1", results[0].CommandTag.String())
|
|
assert.Len(t, results[0].Rows, 1)
|
|
assert.Equal(t, "Hello, world", string(results[0].Rows[0][0]))
|
|
|
|
ensureConnValid(t, newConn)
|
|
}
|
|
|
|
func TestConnCloseWhileCancellableQueryInProgress(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
|
|
pgConn.Exec(ctx, "select n from generate_series(1,10) n")
|
|
|
|
closeCtx, _ := context.WithCancel(ctx)
|
|
pgConn.Close(closeCtx)
|
|
select {
|
|
case <-pgConn.CleanupDone():
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Connection cleanup exceeded maximum time")
|
|
}
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/800
|
|
func TestFatalErrorReceivedAfterCommandComplete(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
steps := pgmock.AcceptUnauthenticatedConnRequestSteps()
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Bind{}))
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Execute{}))
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Sync{}))
|
|
steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
|
|
{Name: []byte("mock")},
|
|
}}))
|
|
steps = append(steps, pgmock.SendMessage(&pgproto3.CommandComplete{CommandTag: []byte("SELECT 0")}))
|
|
steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"}))
|
|
|
|
script := &pgmock.Script{Steps: steps}
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
serverErrChan := make(chan error, 1)
|
|
go func() {
|
|
defer close(serverErrChan)
|
|
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
err = conn.SetDeadline(time.Now().Add(5 * time.Second))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
|
|
err = script.Run(pgproto3.NewBackend(conn, conn))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
}()
|
|
|
|
host, port, _ := strings.Cut(ln.Addr().String(), ":")
|
|
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, 5*time.Second)
|
|
defer cancel()
|
|
conn, err := pgconn.Connect(ctx, connStr)
|
|
require.NoError(t, err)
|
|
|
|
rr := conn.ExecParams(ctx, "mocked...", nil, nil, nil, nil)
|
|
|
|
for rr.NextRow() {
|
|
}
|
|
|
|
_, err = rr.Close()
|
|
require.Error(t, err)
|
|
}
|
|
|
|
// https://github.com/jackc/pgconn/issues/27
|
|
func TestConnLargeResponseWhileWritingDoesNotDeadlock(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Exec(ctx, "set client_min_messages = debug5").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
// The actual contents of this test aren't important. What's important is a large amount of data to be written and
|
|
// because of client_min_messages = debug5 the server will return a large amount of data.
|
|
|
|
paramCount := math.MaxUint16
|
|
params := make([]string, 0, paramCount)
|
|
args := make([][]byte, 0, paramCount)
|
|
for i := 0; i < paramCount; i++ {
|
|
params = append(params, fmt.Sprintf("($%d::text)", i+1))
|
|
args = append(args, []byte(strconv.Itoa(i)))
|
|
}
|
|
sql := "values" + strings.Join(params, ", ")
|
|
|
|
result := pgConn.ExecParams(ctx, sql, args, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
require.Len(t, result.Rows, paramCount)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestConnCheckConn(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
// Intentionally using TCP connection for more predictable close behavior. (Not sure if Unix domain sockets would behave subtly different.)
|
|
|
|
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
|
|
}
|
|
|
|
c1, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
defer c1.Close(ctx)
|
|
|
|
if c1.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
|
|
}
|
|
|
|
err = c1.CheckConn()
|
|
require.NoError(t, err)
|
|
|
|
c2, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
defer c2.Close(ctx)
|
|
|
|
_, err = c2.Exec(ctx, fmt.Sprintf("select pg_terminate_backend(%d)", c1.PID())).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
// It may take a while for the server to kill the backend. Retry until the error is detected or the test context is
|
|
// canceled.
|
|
for err == nil && ctx.Err() == nil {
|
|
time.Sleep(50 * time.Millisecond)
|
|
err = c1.CheckConn()
|
|
}
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestConnPing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
// Intentionally using TCP connection for more predictable close behavior. (Not sure if Unix domain sockets would behave subtly different.)
|
|
|
|
connString := os.Getenv("PGX_TEST_TCP_CONN_STRING")
|
|
if connString == "" {
|
|
t.Skipf("Skipping due to missing environment variable %v", "PGX_TEST_TCP_CONN_STRING")
|
|
}
|
|
|
|
c1, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
defer c1.Close(ctx)
|
|
|
|
if c1.ParameterStatus("crdb_version") != "" {
|
|
t.Skip("Server does not support pg_terminate_backend() (https://github.com/cockroachdb/cockroach/issues/35897)")
|
|
}
|
|
|
|
err = c1.Exec(ctx, "set log_statement = 'all'").Close()
|
|
require.NoError(t, err)
|
|
|
|
err = c1.Ping(ctx)
|
|
require.NoError(t, err)
|
|
|
|
c2, err := pgconn.Connect(ctx, connString)
|
|
require.NoError(t, err)
|
|
defer c2.Close(ctx)
|
|
|
|
_, err = c2.Exec(ctx, fmt.Sprintf("select pg_terminate_backend(%d)", c1.PID())).ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
// Give a little time for the signal to actually kill the backend.
|
|
time.Sleep(500 * time.Millisecond)
|
|
|
|
err = c1.Ping(ctx)
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestPipelinePrepare(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
result := pgConn.ExecParams(ctx, `create temporary table t (id text primary key)`, nil, nil, nil, nil).Read()
|
|
require.NoError(t, result.Err)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil)
|
|
pipeline.SendPrepare("selectText", "select $1::text as b", nil)
|
|
pipeline.SendPrepare("selectNoParams", "select 42 as c", nil)
|
|
pipeline.SendPrepare("insertNoResults", "insert into t (id) values ($1)", nil)
|
|
pipeline.SendPrepare("insertNoParamsOrResults", "insert into t (id) values ('foo')", nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok := results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 1)
|
|
require.Equal(t, "a", string(sd.Fields[0].Name))
|
|
require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok = results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 1)
|
|
require.Equal(t, "b", string(sd.Fields[0].Name))
|
|
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok = results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 1)
|
|
require.Equal(t, "c", string(sd.Fields[0].Name))
|
|
require.Equal(t, []uint32{}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok = results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 0)
|
|
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok = results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 0)
|
|
require.Len(t, sd.ParamOIDs, 0)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelinePrepareError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil)
|
|
pipeline.SendPrepare("selectError", "bad", nil)
|
|
pipeline.SendPrepare("selectText", "select $1::text as b", nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok := results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 1)
|
|
require.Equal(t, "a", string(sd.Fields[0].Name))
|
|
require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
var pgErr *pgconn.PgError
|
|
require.ErrorAs(t, err, &pgErr)
|
|
require.Nil(t, results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelinePrepareAndDeallocate(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendPrepare("selectInt", "select $1::bigint as a", nil)
|
|
pipeline.SendDeallocate("selectInt")
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok := results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 1)
|
|
require.Equal(t, "a", string(sd.Fields[0].Name))
|
|
require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.CloseComplete)
|
|
require.Truef(t, ok, "expected CloseComplete, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelineQuery(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok := results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult := rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "2", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "3", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "4", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "5", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelinePrepareQuery(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendPrepare("ps", "select $1::text as msg", nil)
|
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil)
|
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("goodbye")}, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok := results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 1)
|
|
require.Equal(t, "msg", string(sd.Fields[0].Name))
|
|
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok := results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult := rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "hello", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "goodbye", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelineQueryErrorBetweenSyncs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 6`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok := results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult := rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "2", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "3", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
var pgErr *pgconn.PgError
|
|
require.ErrorAs(t, readResult.Err, &pgErr)
|
|
require.Equal(t, "22012", pgErr.Code)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "5", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "6", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelineFlushForSingleRequests(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
|
|
pipeline.SendPrepare("ps", "select $1::text as msg", nil)
|
|
pipeline.SendFlushRequest()
|
|
err = pipeline.Flush()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok := results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 1)
|
|
require.Equal(t, "msg", string(sd.Fields[0].Name))
|
|
require.Equal(t, []uint32{pgtype.TextOID}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("hello")}, nil, nil)
|
|
pipeline.SendFlushRequest()
|
|
err = pipeline.Flush()
|
|
require.NoError(t, err)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok := results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult := rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "hello", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
pipeline.SendDeallocate("ps")
|
|
pipeline.SendFlushRequest()
|
|
err = pipeline.Flush()
|
|
require.NoError(t, err)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.CloseComplete)
|
|
require.Truef(t, ok, "expected CloseComplete, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
|
pipeline.SendFlushRequest()
|
|
err = pipeline.Flush()
|
|
require.NoError(t, err)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelineFlushForRequestSeries(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendPrepare("ps", "select $1::bigint as num", nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
sd, ok := results.(*pgconn.StatementDescription)
|
|
require.Truef(t, ok, "expected StatementDescription, got: %#v", results)
|
|
require.Len(t, sd.Fields, 1)
|
|
require.Equal(t, "num", string(sd.Fields[0].Name))
|
|
require.Equal(t, []uint32{pgtype.Int8OID}, sd.ParamOIDs)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("1")}, nil, nil)
|
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("2")}, nil, nil)
|
|
pipeline.SendFlushRequest()
|
|
err = pipeline.Flush()
|
|
require.NoError(t, err)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok := results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult := rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "2", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("3")}, nil, nil)
|
|
err = pipeline.Flush()
|
|
require.NoError(t, err)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("4")}, nil, nil)
|
|
pipeline.SendFlushRequest()
|
|
err = pipeline.Flush()
|
|
require.NoError(t, err)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "3", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "4", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("5")}, nil, nil)
|
|
pipeline.SendFlushRequest()
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
pipeline.SendQueryPrepared(`ps`, [][]byte{[]byte("6")}, nil, nil)
|
|
pipeline.SendFlushRequest()
|
|
err = pipeline.Flush()
|
|
require.NoError(t, err)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "5", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "6", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelineFlushWithError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 1/(3-n) from generate_series(1,10) n`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
|
pipeline.SendFlushRequest()
|
|
err = pipeline.Flush()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok := results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult := rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
var pgErr *pgconn.PgError
|
|
require.ErrorAs(t, readResult.Err, &pgErr)
|
|
require.Equal(t, "22012", pgErr.Code)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
|
pipeline.SendFlushRequest()
|
|
err = pipeline.Flush()
|
|
require.NoError(t, err)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
|
pipeline.SendPipelineSync()
|
|
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
|
pipeline.SendFlushRequest()
|
|
err = pipeline.Flush()
|
|
require.NoError(t, err)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok = results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult = rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "5", string(readResult.Rows[0][0]))
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
require.Nil(t, results)
|
|
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, ok = results.(*pgconn.PipelineSync)
|
|
require.Truef(t, ok, "expected PipelineSync, got: %#v", results)
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelineCloseReadsUnreadResults(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok := results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult := rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
|
|
|
err = pipeline.Close()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func TestPipelineCloseDetectsUnsyncedRequests(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pipeline := pgConn.StartPipeline(ctx)
|
|
pipeline.SendQueryParams(`select 1`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 2`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 3`, nil, nil, nil, nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
pipeline.SendQueryParams(`select 4`, nil, nil, nil, nil)
|
|
pipeline.SendQueryParams(`select 5`, nil, nil, nil, nil)
|
|
|
|
results, err := pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
rr, ok := results.(*pgconn.ResultReader)
|
|
require.Truef(t, ok, "expected ResultReader, got: %#v", results)
|
|
readResult := rr.Read()
|
|
require.NoError(t, readResult.Err)
|
|
require.Len(t, readResult.Rows, 1)
|
|
require.Len(t, readResult.Rows[0], 1)
|
|
require.Equal(t, "1", string(readResult.Rows[0][0]))
|
|
|
|
err = pipeline.Close()
|
|
require.EqualError(t, err, "pipeline has unsynced requests")
|
|
}
|
|
|
|
func TestConnOnPgError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
config.OnPgError = func(c *pgconn.PgConn, pgErr *pgconn.PgError) bool {
|
|
require.NotNil(t, c)
|
|
require.NotNil(t, pgErr)
|
|
// close connection on undefined tables only
|
|
if pgErr.Code == "42P01" {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
pgConn, err := pgconn.ConnectConfig(ctx, config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
_, err = pgConn.Exec(ctx, "select 'Hello, world'").ReadAll()
|
|
assert.NoError(t, err)
|
|
assert.False(t, pgConn.IsClosed())
|
|
|
|
_, err = pgConn.Exec(ctx, "select 1/0").ReadAll()
|
|
assert.Error(t, err)
|
|
assert.False(t, pgConn.IsClosed())
|
|
|
|
_, err = pgConn.Exec(ctx, "select * from non_existant_table").ReadAll()
|
|
assert.Error(t, err)
|
|
assert.True(t, pgConn.IsClosed())
|
|
}
|
|
|
|
func TestConnCustomData(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
pgConn.CustomData()["foo"] = "bar"
|
|
assert.Equal(t, "bar", pgConn.CustomData()["foo"])
|
|
|
|
ensureConnValid(t, pgConn)
|
|
}
|
|
|
|
func Example() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
pgConn, err := pgconn.Connect(ctx, os.Getenv("PGX_TEST_DATABASE"))
|
|
if err != nil {
|
|
log.Fatalln(err)
|
|
}
|
|
defer pgConn.Close(ctx)
|
|
|
|
result := pgConn.ExecParams(ctx, "select generate_series(1,3)", nil, nil, nil, nil).Read()
|
|
if result.Err != nil {
|
|
log.Fatalln(result.Err)
|
|
}
|
|
|
|
for _, row := range result.Rows {
|
|
fmt.Println(string(row[0]))
|
|
}
|
|
|
|
fmt.Println(result.CommandTag)
|
|
// Output:
|
|
// 1
|
|
// 2
|
|
// 3
|
|
// SELECT 3
|
|
}
|
|
|
|
func GetSSLPassword(ctx context.Context) string {
|
|
connString := os.Getenv("PGX_SSL_PASSWORD")
|
|
return connString
|
|
}
|
|
|
|
var rsaCertPEM = `-----BEGIN CERTIFICATE-----
|
|
MIIDCTCCAfGgAwIBAgIUQDlN1g1bzxIJ8KWkayNcQY5gzMEwDQYJKoZIhvcNAQEL
|
|
BQAwFDESMBAGA1UEAwwJbG9jYWxob3N0MB4XDTIyMDgxNTIxNDgyNloXDTIzMDgx
|
|
NTIxNDgyNlowFDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEF
|
|
AAOCAQ8AMIIBCgKCAQEA0vOppiT8zE+076acRORzD5JVbRYKMK3XlWLVrHua4+ct
|
|
Rm54WyP+3XsYU4JGGGKgb8E+u2UosGJYcSM+b+U1/5XPTcpuumS+pCiD9WP++A39
|
|
tsukYwR7m65cgpiI4dlLEZI3EWpAW+Bb3230KiYW4sAmQ0Ih4PrN+oPvzcs86F4d
|
|
9Y03CqVUxRKLBLaClZQAg8qz2Pawwj1FKKjDX7u2fRVR0wgOugpCMOBJMcCgz9pp
|
|
0HSa4x3KZDHEZY7Pah5XwWrCfAEfRWsSTGcNaoN8gSxGFM1JOEJa8SAuPGjFcYIv
|
|
MmVWdw0FXCgYlSDL02fzLE0uyvXBDibzSqOk770JhQIDAQABo1MwUTAdBgNVHQ4E
|
|
FgQUiJ8JLENJ+2k1Xl4o6y2Lc/qHTh0wHwYDVR0jBBgwFoAUiJ8JLENJ+2k1Xl4o
|
|
6y2Lc/qHTh0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAwjn2
|
|
gnNAhFvh58VqLIjU6ftvn6rhz5B9dg2+XyY8sskLhhkO1nL9339BVZsRt+eI3a7I
|
|
81GNIm9qHVM3MUAcQv3SZy+0UPVUT8DNH2LwHT3CHnYTBP8U+8n8TDNGSTMUhIBB
|
|
Rx+6KwODpwLdI79VGT3IkbU9bZwuepB9I9nM5t/tt5kS4gHmJFlO0aLJFCTO4Scf
|
|
hp/WLPv4XQUH+I3cPfaJRxz2j0Kc8iOzMhFmvl1XOGByjX6X33LnOzY/LVeTSGyS
|
|
VgC32BGtnMwuy5XZYgFAeUx9HKy4tG4OH2Ux6uPF/WAhsug6PXSjV7BK6wYT5i27
|
|
MlascjupnaptKX/wMA==
|
|
-----END CERTIFICATE-----
|
|
`
|
|
|
|
var rsaKeyPEM = testingKey(`-----BEGIN TESTING KEY-----
|
|
MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDS86mmJPzMT7Tv
|
|
ppxE5HMPklVtFgowrdeVYtWse5rj5y1GbnhbI/7dexhTgkYYYqBvwT67ZSiwYlhx
|
|
Iz5v5TX/lc9Nym66ZL6kKIP1Y/74Df22y6RjBHubrlyCmIjh2UsRkjcRakBb4Fvf
|
|
bfQqJhbiwCZDQiHg+s36g+/NyzzoXh31jTcKpVTFEosEtoKVlACDyrPY9rDCPUUo
|
|
qMNfu7Z9FVHTCA66CkIw4EkxwKDP2mnQdJrjHcpkMcRljs9qHlfBasJ8AR9FaxJM
|
|
Zw1qg3yBLEYUzUk4QlrxIC48aMVxgi8yZVZ3DQVcKBiVIMvTZ/MsTS7K9cEOJvNK
|
|
o6TvvQmFAgMBAAECggEAKzTK54Ol33bn2TnnwdiElIjlRE2CUswYXrl6iDRc2hbs
|
|
WAOiVRB/T/+5UMla7/2rXJhY7+rdNZs/ABU24ZYxxCJ77jPrD/Q4c8j0lhsgCtBa
|
|
ycjV543wf0dsHTd+ubtWu8eVzdRUUD0YtB+CJevdPh4a+CWgaMMV0xyYzi61T+Yv
|
|
Z7Uc3awIAiT4Kw9JRmJiTnyMJg5vZqW3BBAX4ZIvS/54ipwEU+9sWLcuH7WmCR0B
|
|
QCTqS6hfJDLm//dGC89Iyno57zfYuiT3PYCWH5crr/DH3LqnwlNaOGSBkhkXuIL+
|
|
QvOaUMe2i0pjqxDrkBx05V554vyy9jEvK7i330HL4QKBgQDUJmouEr0+o7EMBApC
|
|
CPPu58K04qY5t9aGciG/pOurN42PF99yNZ1CnynH6DbcnzSl8rjc6Y65tzTlWods
|
|
bjwVfcmcokG7sPcivJvVjrjKpSQhL8xdZwSAjcqjN4yoJ/+ghm9w+SRmZr6oCQZ3
|
|
1jREfJKT+PGiWTEjYcExPWUD2QKBgQD+jdgq4c3tFavU8Hjnlf75xbStr5qu+fp2
|
|
SGLRRbX+msQwVbl2ZM9AJLoX9MTCl7D9zaI3ONhheMmfJ77lDTa3VMFtr3NevGA6
|
|
MxbiCEfRtQpNkJnsqCixLckx3bskj5+IF9BWzw7y7nOzdhoWVFv/+TltTm3RB51G
|
|
McdlmmVjjQKBgQDSFAw2/YV6vtu2O1XxGC591/Bd8MaMBziev+wde3GHhaZfGVPC
|
|
I8dLTpMwCwowpFKdNeLLl1gnHX161I+f1vUWjw4TVjVjaBUBx+VEr2Tb/nXtiwiD
|
|
QV0a883CnGJjreAblKRMKdpasMmBWhaWmn39h6Iad3zHuCzJjaaiXNpn2QKBgQCf
|
|
k1Q8LanmQnuh1c41f7aD5gjKCRezMUpt9BrejhD1NxheJJ9LNQ8nat6uPedLBcUS
|
|
lmJms+AR2qKqf0QQWyQ98YgAtshgTz8TvQtPT1mWgSOgVFHqJdC8obNK63FyDgc4
|
|
TZVxlgQNDqbBjfv0m5XA9f+mIlB9hYR2iKYzb4K30QKBgQC+LEJYZh00zsXttGHr
|
|
5wU1RzbgDIEsNuu+nZ4MxsaCik8ILNRHNXdeQbnADKuo6ATfhdmDIQMVZLG8Mivi
|
|
UwnwLd1GhizvqvLHa3ULnFphRyMGFxaLGV48axTT2ADoMX67ILrIY/yjycLqRZ3T
|
|
z3w+CgS20UrbLIR1YXfqUXge1g==
|
|
-----END TESTING KEY-----
|
|
`)
|
|
|
|
func testingKey(s string) string { return strings.ReplaceAll(s, "TESTING KEY", "PRIVATE KEY") }
|
|
|
|
func TestSNISupport(t *testing.T) {
|
|
t.Parallel()
|
|
tests := []struct {
|
|
name string
|
|
sni_param string
|
|
sni_set bool
|
|
}{
|
|
{
|
|
name: "SNI is passed by default",
|
|
sni_param: "",
|
|
sni_set: true,
|
|
},
|
|
{
|
|
name: "SNI is passed when asked for",
|
|
sni_param: "sslsni=1",
|
|
sni_set: true,
|
|
},
|
|
{
|
|
name: "SNI is not passed when disabled",
|
|
sni_param: "sslsni=0",
|
|
sni_set: false,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
serverErrChan := make(chan error, 1)
|
|
serverSNINameChan := make(chan string, 1)
|
|
defer close(serverErrChan)
|
|
defer close(serverSNINameChan)
|
|
|
|
go func() {
|
|
var sniHost string
|
|
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
err = conn.SetDeadline(time.Now().Add(5 * time.Second))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
|
|
backend := pgproto3.NewBackend(conn, conn)
|
|
startupMessage, err := backend.ReceiveStartupMessage()
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
|
|
switch startupMessage.(type) {
|
|
case *pgproto3.SSLRequest:
|
|
_, err = conn.Write([]byte("S"))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
default:
|
|
serverErrChan <- fmt.Errorf("unexpected startup message: %#v", startupMessage)
|
|
return
|
|
}
|
|
|
|
cert, err := tls.X509KeyPair([]byte(rsaCertPEM), []byte(rsaKeyPEM))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
|
|
srv := tls.Server(conn, &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
sniHost = argHello.ServerName
|
|
return nil, nil
|
|
},
|
|
})
|
|
defer srv.Close()
|
|
|
|
if err := srv.Handshake(); err != nil {
|
|
serverErrChan <- fmt.Errorf("handshake: %w", err)
|
|
return
|
|
}
|
|
|
|
srv.Write(mustEncode((&pgproto3.AuthenticationOk{}).Encode(nil)))
|
|
srv.Write(mustEncode((&pgproto3.BackendKeyData{ProcessID: 0, SecretKey: 0}).Encode(nil)))
|
|
srv.Write(mustEncode((&pgproto3.ReadyForQuery{TxStatus: 'I'}).Encode(nil)))
|
|
|
|
serverSNINameChan <- sniHost
|
|
}()
|
|
|
|
_, port, _ := strings.Cut(ln.Addr().String(), ":")
|
|
connStr := fmt.Sprintf("sslmode=require host=localhost port=%s %s", port, tt.sni_param)
|
|
_, err = pgconn.Connect(ctx, connStr)
|
|
|
|
select {
|
|
case sniHost := <-serverSNINameChan:
|
|
if tt.sni_set {
|
|
require.Equal(t, "localhost", sniHost)
|
|
} else {
|
|
require.Equal(t, "", sniHost)
|
|
}
|
|
case err = <-serverErrChan:
|
|
t.Fatalf("server failed with error: %+v", err)
|
|
case <-time.After(time.Millisecond * 100):
|
|
t.Fatal("exceeded connection timeout without erroring out")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// https://github.com/jackc/pgx/issues/1920
|
|
func TestFatalErrorReceivedInPipelineMode(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
|
|
defer cancel()
|
|
|
|
steps := pgmock.AcceptUnauthenticatedConnRequestSteps()
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Parse{}))
|
|
steps = append(steps, pgmock.ExpectAnyMessage(&pgproto3.Describe{}))
|
|
steps = append(steps, pgmock.SendMessage(&pgproto3.RowDescription{Fields: []pgproto3.FieldDescription{
|
|
{Name: []byte("mock")},
|
|
}}))
|
|
steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"}))
|
|
// We shouldn't get anything after the first fatal error. But the reported issue was with PgBouncer so maybe that
|
|
// causes the issue. Anyway, a FATAL error after the connection had already been killed could cause a panic.
|
|
steps = append(steps, pgmock.SendMessage(&pgproto3.ErrorResponse{Severity: "FATAL", Code: "57P01"}))
|
|
|
|
script := &pgmock.Script{Steps: steps}
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:")
|
|
require.NoError(t, err)
|
|
defer ln.Close()
|
|
|
|
serverKeepAlive := make(chan struct{})
|
|
defer close(serverKeepAlive)
|
|
|
|
serverErrChan := make(chan error, 1)
|
|
go func() {
|
|
defer close(serverErrChan)
|
|
|
|
conn, err := ln.Accept()
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
err = conn.SetDeadline(time.Now().Add(59 * time.Second))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
|
|
err = script.Run(pgproto3.NewBackend(conn, conn))
|
|
if err != nil {
|
|
serverErrChan <- err
|
|
return
|
|
}
|
|
|
|
<-serverKeepAlive
|
|
}()
|
|
|
|
parts := strings.Split(ln.Addr().String(), ":")
|
|
host := parts[0]
|
|
port := parts[1]
|
|
connStr := fmt.Sprintf("sslmode=disable host=%s port=%s", host, port)
|
|
|
|
ctx, cancel = context.WithTimeout(ctx, 59*time.Second)
|
|
defer cancel()
|
|
conn, err := pgconn.Connect(ctx, connStr)
|
|
require.NoError(t, err)
|
|
|
|
pipeline := conn.StartPipeline(ctx)
|
|
pipeline.SendPrepare("s1", "select 1", nil)
|
|
pipeline.SendPrepare("s2", "select 2", nil)
|
|
pipeline.SendPrepare("s3", "select 3", nil)
|
|
err = pipeline.Sync()
|
|
require.NoError(t, err)
|
|
|
|
_, err = pipeline.GetResults()
|
|
require.NoError(t, err)
|
|
_, err = pipeline.GetResults()
|
|
require.Error(t, err)
|
|
|
|
err = pipeline.Close()
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func mustEncode(buf []byte, err error) []byte {
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return buf
|
|
}
|
|
|
|
func TestDeadlineContextWatcherHandler(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("DeadlineExceeded with zero DeadlineDelay", func(t *testing.T) {
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler {
|
|
return &pgconn.DeadlineContextWatcherHandler{Conn: conn.Conn()}
|
|
}
|
|
config.ConnectTimeout = 5 * time.Second
|
|
|
|
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
defer cancel()
|
|
|
|
_, err = pgConn.Exec(ctx, "select 1, pg_sleep(1)").ReadAll()
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
|
require.True(t, pgConn.IsClosed())
|
|
})
|
|
|
|
t.Run("DeadlineExceeded with DeadlineDelay", func(t *testing.T) {
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler {
|
|
return &pgconn.DeadlineContextWatcherHandler{Conn: conn.Conn(), DeadlineDelay: 500 * time.Millisecond}
|
|
}
|
|
config.ConnectTimeout = 5 * time.Second
|
|
|
|
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
_, err = pgConn.Exec(ctx, "select 1, pg_sleep(0.250)").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
})
|
|
}
|
|
|
|
func TestCancelRequestContextWatcherHandler(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("DeadlineExceeded cancels request after CancelRequestDelay", func(t *testing.T) {
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler {
|
|
return &pgconn.CancelRequestContextWatcherHandler{
|
|
Conn: conn,
|
|
CancelRequestDelay: 250 * time.Millisecond,
|
|
DeadlineDelay: 5000 * time.Millisecond,
|
|
}
|
|
}
|
|
config.ConnectTimeout = 5 * time.Second
|
|
|
|
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
defer cancel()
|
|
|
|
_, err = pgConn.Exec(ctx, "select 1, pg_sleep(3)").ReadAll()
|
|
require.Error(t, err)
|
|
var pgErr *pgconn.PgError
|
|
require.ErrorAs(t, err, &pgErr)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
})
|
|
|
|
t.Run("DeadlineExceeded - do not send cancel request when query finishes in grace period", func(t *testing.T) {
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler {
|
|
return &pgconn.CancelRequestContextWatcherHandler{
|
|
Conn: conn,
|
|
CancelRequestDelay: 1000 * time.Millisecond,
|
|
DeadlineDelay: 5000 * time.Millisecond,
|
|
}
|
|
}
|
|
config.ConnectTimeout = 5 * time.Second
|
|
|
|
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
defer cancel()
|
|
|
|
_, err = pgConn.Exec(ctx, "select 1, pg_sleep(0.250)").ReadAll()
|
|
require.NoError(t, err)
|
|
|
|
ensureConnValid(t, pgConn)
|
|
})
|
|
|
|
t.Run("DeadlineExceeded sets conn deadline with DeadlineDelay", func(t *testing.T) {
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler {
|
|
return &pgconn.CancelRequestContextWatcherHandler{
|
|
Conn: conn,
|
|
CancelRequestDelay: 5000 * time.Millisecond, // purposely setting this higher than DeadlineDelay to ensure the cancel request never happens.
|
|
DeadlineDelay: 250 * time.Millisecond,
|
|
}
|
|
}
|
|
config.ConnectTimeout = 5 * time.Second
|
|
|
|
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
defer cancel()
|
|
|
|
_, err = pgConn.Exec(ctx, "select 1, pg_sleep(1)").ReadAll()
|
|
require.Error(t, err)
|
|
require.ErrorIs(t, err, context.DeadlineExceeded)
|
|
require.True(t, pgConn.IsClosed())
|
|
})
|
|
|
|
for i := 0; i < 10; i++ {
|
|
t.Run(fmt.Sprintf("Stress %d", i), func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
config, err := pgconn.ParseConfig(os.Getenv("PGX_TEST_DATABASE"))
|
|
require.NoError(t, err)
|
|
config.BuildContextWatcherHandler = func(conn *pgconn.PgConn) ctxwatch.Handler {
|
|
return &pgconn.CancelRequestContextWatcherHandler{
|
|
Conn: conn,
|
|
CancelRequestDelay: 5 * time.Millisecond,
|
|
DeadlineDelay: 1000 * time.Millisecond,
|
|
}
|
|
}
|
|
config.ConnectTimeout = 5 * time.Second
|
|
|
|
pgConn, err := pgconn.ConnectConfig(context.Background(), config)
|
|
require.NoError(t, err)
|
|
defer closeConn(t, pgConn)
|
|
|
|
for i := 0; i < 20; i++ {
|
|
func() {
|
|
ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond)
|
|
defer cancel()
|
|
pgConn.Exec(ctx, "select 1, pg_sleep(0.010)").ReadAll()
|
|
time.Sleep(100 * time.Millisecond) // ensure a cancel request that was a little late doesn't interrupt ensureConnValid.
|
|
ensureConnValid(t, pgConn)
|
|
}()
|
|
}
|
|
})
|
|
}
|
|
}
|