mirror of
https://github.com/jackc/pgx.git
synced 2025-07-08 03:28:46 +00:00
Merge branch 'master' into composite
This commit is contained in:
parent
2d89e52d6f
commit
fed099f04a
@ -17,7 +17,8 @@ if err != nil {
|
||||
|
||||
## v4 Coming Soon
|
||||
|
||||
This is the current stable v3 version. v4 is currently is in prelease status. Consider using [v4](https://github.com/jackc/pgx/tree/v4) for new development or test upgrading existing applications.
|
||||
This is the current stable v3 version. v4 is currently is in release candidate status. Consider using
|
||||
[v4](https://github.com/jackc/pgx/tree/v4) for new development or test upgrading existing applications.
|
||||
|
||||
## Features
|
||||
|
||||
|
28
batch.go
28
batch.go
@ -135,7 +135,7 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error {
|
||||
|
||||
_, err = b.conn.conn.Write(buf)
|
||||
if err != nil {
|
||||
b.conn.die(err)
|
||||
b.die(err)
|
||||
return err
|
||||
}
|
||||
|
||||
@ -268,6 +268,23 @@ func (b *Batch) Close() (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
for b.conn.pendingReadyForQueryCount > 0 {
|
||||
msg, err := b.conn.rxMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch msg := msg.(type) {
|
||||
case *pgproto3.ErrorResponse:
|
||||
return b.conn.rxErrorResponse(msg)
|
||||
default:
|
||||
err = b.conn.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err = b.conn.ensureConnectionReadyForQuery(); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -281,10 +298,13 @@ func (b *Batch) die(err error) {
|
||||
}
|
||||
|
||||
b.err = err
|
||||
b.conn.die(err)
|
||||
if b.conn != nil {
|
||||
err = b.conn.termContext(err)
|
||||
b.conn.die(err)
|
||||
|
||||
if b.conn != nil && b.connPool != nil {
|
||||
b.connPool.Release(b.conn)
|
||||
if b.connPool != nil {
|
||||
b.connPool.Release(b.conn)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -701,3 +701,55 @@ func TestTxBeginBatchRollback(t *testing.T) {
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestConnBeginBatchDeferredError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table t (
|
||||
id text primary key,
|
||||
n int not null,
|
||||
unique (n) deferrable initially deferred
|
||||
);
|
||||
|
||||
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`)
|
||||
|
||||
batch := conn.BeginBatch()
|
||||
batch.Queue(`update t set n=n+1 where id='b' returning *`,
|
||||
nil,
|
||||
nil,
|
||||
[]int16{pgx.BinaryFormatCode},
|
||||
)
|
||||
|
||||
err := batch.Send(context.Background(), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rows, err := batch.QueryResults()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var id string
|
||||
var n int32
|
||||
err = rows.Scan(&id, &n)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err = batch.Close()
|
||||
if err == nil {
|
||||
t.Fatal("expected error 23505 but got none")
|
||||
}
|
||||
|
||||
if err, ok := err.(pgx.PgError); !ok || err.Code != "23505" {
|
||||
t.Fatalf("expected error 23505, got %v", err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
@ -28,7 +28,11 @@ func NewChunkReader(r io.Reader) *ChunkReader {
|
||||
|
||||
func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) {
|
||||
if options.MinBufLen == 0 {
|
||||
options.MinBufLen = 4096
|
||||
// By historical reasons Postgres currently has 8KB send buffer inside,
|
||||
// so here we want to have at least the same size buffer.
|
||||
// @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134
|
||||
// @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru
|
||||
options.MinBufLen = 8192
|
||||
}
|
||||
|
||||
return &ChunkReader{
|
||||
|
406
conn.go
406
conn.go
@ -10,6 +10,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
@ -61,10 +62,50 @@ type NoticeHandler func(*Conn, *Notice)
|
||||
// DialFunc is a function that can be used to connect to a PostgreSQL server
|
||||
type DialFunc func(network, addr string) (net.Conn, error)
|
||||
|
||||
// TargetSessionType represents target session attrs configuration parameter.
|
||||
type TargetSessionType string
|
||||
|
||||
// Block enumerates available values for TargetSessionType.
|
||||
const (
|
||||
AnyTargetSession = "any"
|
||||
ReadWriteTargetSession = "read-write"
|
||||
)
|
||||
|
||||
func (t TargetSessionType) isValid() error {
|
||||
switch t {
|
||||
case "", AnyTargetSession, ReadWriteTargetSession:
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"")
|
||||
}
|
||||
|
||||
func (t TargetSessionType) writableRequired() bool {
|
||||
return t == ReadWriteTargetSession
|
||||
}
|
||||
|
||||
// ConnConfig contains all the options used to establish a connection.
|
||||
type ConnConfig struct {
|
||||
Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp)
|
||||
Port uint16 // default: 5432
|
||||
// Name of host to connect to. (e.g. localhost)
|
||||
// If a host name begins with a slash, it specifies Unix-domain communication
|
||||
// rather than TCP/IP communication; the value is the name of the directory
|
||||
// in which the socket file is stored. (e.g. /private/tmp)
|
||||
// The default behavior when host is not specified, or is empty, is to connect to localhost.
|
||||
//
|
||||
// A comma-separated list of host names is also accepted,
|
||||
// in which case each host name in the list is tried in order;
|
||||
// an empty item in the list selects the default behavior as explained above.
|
||||
// @see https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS
|
||||
Host string
|
||||
|
||||
// Port number to connect to at the server host,
|
||||
// or socket file name extension for Unix-domain connections.
|
||||
// An empty or zero value, specifies the default port number — 5432.
|
||||
//
|
||||
// If multiple hosts were given in the Host parameter, then
|
||||
// this parameter may specify a single port number to be used for all hosts,
|
||||
// or for those that haven't port explicitly defined.
|
||||
Port uint16
|
||||
Database string
|
||||
User string // default: OS user name
|
||||
Password string
|
||||
@ -89,22 +130,94 @@ type ConnConfig struct {
|
||||
// used by default. The same functionality can be controlled on a per query
|
||||
// basis by setting QueryExOptions.SimpleProtocol.
|
||||
PreferSimpleProtocol bool
|
||||
|
||||
// TargetSessionAttr allows to specify which servers are accepted for this connection.
|
||||
// "any", meaning that any kind of servers can be accepted. This is as well the default value.
|
||||
// "read-write", to disallow connections to read-only servers, hot standbys for example.
|
||||
// @see https://www.postgresql.org/message-id/CAD__OuhqPRGpcsfwPHz_PDqAGkoqS1UvnUnOnAB-LBWBW=wu4A@mail.gmail.com
|
||||
// @see https://paquier.xyz/postgresql-2/postgres-10-libpq-read-write/
|
||||
//
|
||||
// The query SHOW transaction_read_only will be sent upon any successful connection;
|
||||
// if it returns on, the connection will be closed.
|
||||
// If multiple hosts were specified in the connection string,
|
||||
// any remaining servers will be tried just as if the connection attempt had failed.
|
||||
// The default value of this parameter, any, regards all connections as acceptable.
|
||||
TargetSessionAttrs TargetSessionType
|
||||
}
|
||||
|
||||
func (cc *ConnConfig) networkAddress() (network, address string) {
|
||||
network = "tcp"
|
||||
address = fmt.Sprintf("%s:%d", cc.Host, cc.Port)
|
||||
// See if host is a valid path, if yes connect with a socket
|
||||
if _, err := os.Stat(cc.Host); err == nil {
|
||||
// For backward compatibility accept socket file paths -- but directories are now preferred
|
||||
network = "unix"
|
||||
address = cc.Host
|
||||
if !strings.Contains(address, "/.s.PGSQL.") {
|
||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10)
|
||||
}
|
||||
// hostAddr represents network end point defined as hostname or IP + port.
|
||||
type hostAddr struct {
|
||||
Host string
|
||||
Port uint16
|
||||
}
|
||||
|
||||
// Network returns the address's network name, "tcp".
|
||||
func (a *hostAddr) Network() string { return "tcp" }
|
||||
|
||||
// String implements net.Addr String method.
|
||||
func (a *hostAddr) String() string {
|
||||
if a == nil {
|
||||
return "<nil>"
|
||||
}
|
||||
|
||||
return network, address
|
||||
return net.JoinHostPort(a.Host, strconv.Itoa(int(a.Port)))
|
||||
}
|
||||
|
||||
func (cc *ConnConfig) networkAddresses() ([]net.Addr, error) {
|
||||
// See if host is a valid path, if yes connect with a unix socket
|
||||
if _, err := os.Stat(cc.Host); err == nil {
|
||||
// For backward compatibility accept socket file paths -- but directories are now preferred
|
||||
network := "unix"
|
||||
address := cc.Host
|
||||
|
||||
if !strings.Contains(address, "/.s.PGSQL.") {
|
||||
address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatUint(uint64(cc.Port), 10)
|
||||
}
|
||||
|
||||
addrs := []net.Addr{
|
||||
&net.UnixAddr{Name: address, Net: network},
|
||||
}
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
if cc.Host == "" {
|
||||
addrs := []net.Addr{
|
||||
&net.TCPAddr{Port: int(cc.Port)},
|
||||
}
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
var addrs []net.Addr
|
||||
|
||||
hostports := strings.Split(cc.Host, ",")
|
||||
for i, hostport := range hostports {
|
||||
if hostport == "" {
|
||||
return nil, fmt.Errorf("multi-host part %d is empty, at least host or port must be defined", i)
|
||||
}
|
||||
|
||||
// It's not possible to use net.TCPAddr here, cuz host may be hostname.
|
||||
addr := hostAddr{
|
||||
Host: hostport,
|
||||
Port: cc.Port,
|
||||
}
|
||||
|
||||
pos := strings.IndexByte(hostport, ':')
|
||||
if pos != -1 {
|
||||
p, err := strconv.ParseUint(hostport[pos+1:], 10, 16)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("multi-host part %d (%s) has invalid port format", i, hostport)
|
||||
}
|
||||
|
||||
addr.Host = hostport[:pos]
|
||||
addr.Port = uint16(p)
|
||||
}
|
||||
|
||||
addrs = append(addrs, &addr)
|
||||
}
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage.
|
||||
@ -145,6 +258,10 @@ type Conn struct {
|
||||
ConnInfo *pgtype.ConnInfo
|
||||
|
||||
frontend *pgproto3.Frontend
|
||||
|
||||
// In case of Multiple Hosts we need to know what addr was used to connect.
|
||||
// This address will be used to send a cancellation request.
|
||||
addr net.Addr
|
||||
}
|
||||
|
||||
// PreparedStatement is a description of a prepared statement
|
||||
@ -190,7 +307,8 @@ type Identifier []string
|
||||
func (ident Identifier) Sanitize() string {
|
||||
parts := make([]string, len(ident))
|
||||
for i := range ident {
|
||||
parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"`
|
||||
s := strings.Replace(ident[i], string([]byte{0}), "", -1)
|
||||
parts[i] = `"` + strings.Replace(s, `"`, `""`, -1) + `"`
|
||||
}
|
||||
return strings.Join(parts, ".")
|
||||
}
|
||||
@ -262,33 +380,123 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error)
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.config.TargetSessionAttrs.isValid(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.onNotice = config.OnNotice
|
||||
|
||||
network, address := c.config.networkAddress()
|
||||
if c.config.Dial == nil {
|
||||
d := defaultDialer()
|
||||
c.config.Dial = d.Dial
|
||||
}
|
||||
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address})
|
||||
}
|
||||
err = c.connect(config, network, address, config.TLSConfig)
|
||||
if err != nil && config.UseFallbackTLS {
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err})
|
||||
}
|
||||
err = c.connect(config, network, address, config.FallbackTLSConfig)
|
||||
}
|
||||
|
||||
addrs, err := c.config.networkAddresses()
|
||||
if err != nil {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
var errs []error
|
||||
for _, addr := range addrs {
|
||||
network, address := addr.Network(), addr.String()
|
||||
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{
|
||||
"network": network,
|
||||
"address": address,
|
||||
})
|
||||
}
|
||||
|
||||
err = c.connect(config, network, address, config.TLSConfig)
|
||||
if err != nil && config.UseFallbackTLS {
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{
|
||||
"err": err,
|
||||
"network": network,
|
||||
"address": address,
|
||||
})
|
||||
}
|
||||
err = c.connect(config, network, address, config.FallbackTLSConfig)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if c.shouldLog(LogLevelError) {
|
||||
c.log(LogLevelError, "connect failed", map[string]interface{}{
|
||||
"err": err,
|
||||
"network": network,
|
||||
"address": address,
|
||||
})
|
||||
}
|
||||
|
||||
// On any auth errors return immediately
|
||||
if pgErr, ok := err.(PgError); ok {
|
||||
switch pgErr.Code {
|
||||
// @see: https://www.postgresql.org/docs/current/errcodes-appendix.html
|
||||
case "28000", "28P01": // Invalid Authorization Specification
|
||||
return nil, pgErr
|
||||
}
|
||||
}
|
||||
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
|
||||
err = c.checkWritable()
|
||||
if err != nil {
|
||||
c.die(err)
|
||||
|
||||
if c.shouldLog(LogLevelInfo) {
|
||||
c.log(LogLevelInfo, "host is not writable", map[string]interface{}{
|
||||
"err": err,
|
||||
"network": network,
|
||||
"address": address,
|
||||
})
|
||||
}
|
||||
|
||||
errs = append(errs, err)
|
||||
continue
|
||||
}
|
||||
|
||||
c.addr = addr
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// To keep backwards compatibility, if specific error type expected.
|
||||
if len(errs) == 1 {
|
||||
return nil, errs[0]
|
||||
}
|
||||
|
||||
errmsgs := make([]string, len(errs))
|
||||
for i, err := range errs {
|
||||
errmsgs[i] = err.Error()
|
||||
}
|
||||
|
||||
return nil, errors.New(strings.Join(errmsgs, "; "))
|
||||
}
|
||||
|
||||
func (c *Conn) checkWritable() error {
|
||||
if !c.config.TargetSessionAttrs.writableRequired() {
|
||||
return nil
|
||||
}
|
||||
|
||||
var st string
|
||||
err := c.QueryRowEx(context.Background(), "SHOW transaction_read_only", &QueryExOptions{SimpleProtocol: true}).
|
||||
Scan(&st)
|
||||
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to fetch \"transaction_read_only\" state")
|
||||
}
|
||||
|
||||
switch st {
|
||||
case "on":
|
||||
return errors.New("writable transactions disabled by server")
|
||||
case "off":
|
||||
// If transaction_read_only = off, then connection is writable.
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("unexpected \"transaction_read_only\" status")
|
||||
}
|
||||
|
||||
func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) {
|
||||
@ -403,6 +611,7 @@ func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) {
|
||||
end
|
||||
from pg_type t
|
||||
left join pg_type base_type on t.typelem=base_type.oid
|
||||
left join pg_class base_cls ON base_type.typrelid = base_cls.oid
|
||||
left join pg_namespace nsp on t.typnamespace=nsp.oid
|
||||
where (
|
||||
t.typtype in('b', 'p', 'r', 'e')
|
||||
@ -750,6 +959,10 @@ func (old ConnConfig) Merge(other ConnConfig) ConnConfig {
|
||||
|
||||
cc.PreferSimpleProtocol = old.PreferSimpleProtocol || other.PreferSimpleProtocol
|
||||
|
||||
if other.TargetSessionAttrs != "" {
|
||||
cc.TargetSessionAttrs = other.TargetSessionAttrs
|
||||
}
|
||||
|
||||
cc.RuntimeParams = make(map[string]string)
|
||||
for k, v := range old.RuntimeParams {
|
||||
cc.RuntimeParams[k] = v
|
||||
@ -777,16 +990,26 @@ func ParseURI(uri string) (ConnConfig, error) {
|
||||
cp.Password, _ = url.User.Password()
|
||||
}
|
||||
|
||||
parts := strings.SplitN(url.Host, ":", 2)
|
||||
cp.Host = parts[0]
|
||||
if len(parts) == 2 {
|
||||
p, err := strconv.ParseUint(parts[1], 10, 16)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
hasMuliHosts := strings.IndexByte(url.Host, ',') != -1
|
||||
if !hasMuliHosts {
|
||||
parts := strings.SplitN(url.Host, ":", 2)
|
||||
cp.Host = parts[0]
|
||||
if len(parts) == 2 {
|
||||
p, err := strconv.ParseUint(parts[1], 10, 16)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
}
|
||||
cp.Port = uint16(p)
|
||||
}
|
||||
cp.Port = uint16(p)
|
||||
} else {
|
||||
cp.Host = url.Host
|
||||
}
|
||||
|
||||
cp.Database = strings.TrimLeft(url.Path, "/")
|
||||
cp.TargetSessionAttrs = TargetSessionType(url.Query().Get("target_session_attrs"))
|
||||
if err := cp.TargetSessionAttrs.isValid(); err != nil {
|
||||
return cp, err
|
||||
}
|
||||
|
||||
if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" {
|
||||
timeout, err := strconv.ParseInt(pgtimeout, 10, 64)
|
||||
@ -810,11 +1033,12 @@ func ParseURI(uri string) (ConnConfig, error) {
|
||||
}
|
||||
|
||||
ignoreKeys := map[string]struct{}{
|
||||
"connect_timeout": {},
|
||||
"sslcert": {},
|
||||
"sslkey": {},
|
||||
"sslmode": {},
|
||||
"sslrootcert": {},
|
||||
"connect_timeout": {},
|
||||
"sslcert": {},
|
||||
"sslkey": {},
|
||||
"sslmode": {},
|
||||
"sslrootcert": {},
|
||||
"target_session_attrs": {},
|
||||
}
|
||||
|
||||
cp.RuntimeParams = make(map[string]string)
|
||||
@ -834,6 +1058,7 @@ func ParseURI(uri string) (ConnConfig, error) {
|
||||
if cp.Password == "" {
|
||||
pgpass(&cp)
|
||||
}
|
||||
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
@ -859,6 +1084,7 @@ func ParseDSN(s string) (ConnConfig, error) {
|
||||
|
||||
cp.RuntimeParams = make(map[string]string)
|
||||
|
||||
var hostval, portval string
|
||||
for _, b := range m {
|
||||
switch b[1] {
|
||||
case "user":
|
||||
@ -866,13 +1092,9 @@ func ParseDSN(s string) (ConnConfig, error) {
|
||||
case "password":
|
||||
cp.Password = b[2]
|
||||
case "host":
|
||||
cp.Host = b[2]
|
||||
hostval = b[2]
|
||||
case "port":
|
||||
p, err := strconv.ParseUint(b[2], 10, 16)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
}
|
||||
cp.Port = uint16(p)
|
||||
portval = b[2]
|
||||
case "dbname":
|
||||
cp.Database = b[2]
|
||||
case "sslmode":
|
||||
@ -891,23 +1113,93 @@ func ParseDSN(s string) (ConnConfig, error) {
|
||||
d := defaultDialer()
|
||||
d.Timeout = time.Duration(timeout) * time.Second
|
||||
cp.Dial = d.Dial
|
||||
case "target_session_attrs":
|
||||
cp.TargetSessionAttrs = TargetSessionType(b[2])
|
||||
if err := cp.TargetSessionAttrs.isValid(); err != nil {
|
||||
return cp, err
|
||||
}
|
||||
default:
|
||||
cp.RuntimeParams[b[1]] = b[2]
|
||||
}
|
||||
}
|
||||
|
||||
err := configTLS(tlsArgs, &cp)
|
||||
host, port, err := parseHostPortDSN(hostval, portval)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
}
|
||||
|
||||
cp.Host, cp.Port = host, port
|
||||
|
||||
err = configTLS(tlsArgs, &cp)
|
||||
if err != nil {
|
||||
return cp, err
|
||||
}
|
||||
|
||||
if cp.Password == "" {
|
||||
pgpass(&cp)
|
||||
}
|
||||
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
// ParseConnectionString parses either a URI or a DSN connection string.
|
||||
// see ParseURI and ParseDSN for details.
|
||||
func parseHostPortDSN(hostval, portval string) (host string, port uint16, err error) {
|
||||
if portval == "" {
|
||||
return hostval, 0, nil
|
||||
}
|
||||
|
||||
hosts := strings.Split(hostval, ",")
|
||||
ports := strings.Split(portval, ",")
|
||||
|
||||
if len(ports) == 1 {
|
||||
port, err := parsePort(portval)
|
||||
if err != nil {
|
||||
return "", 0, errors.Errorf("invalid port: %v", err)
|
||||
}
|
||||
|
||||
return hostval, port, nil
|
||||
}
|
||||
|
||||
if len(hosts) != len(ports) {
|
||||
return "", 0, errors.New("the number of hosts and ports must be the same")
|
||||
}
|
||||
|
||||
hostports := make([]string, len(hosts))
|
||||
for i, host := range hosts {
|
||||
hostports[i] = host + ":" + ports[i]
|
||||
}
|
||||
|
||||
return strings.Join(hostports, ","), 0, nil
|
||||
}
|
||||
|
||||
func parsePort(s string) (uint16, error) {
|
||||
port, err := strconv.ParseUint(s, 10, 16)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if port < 1 || port > math.MaxUint16 {
|
||||
return 0, errors.New("outside range")
|
||||
}
|
||||
return uint16(port), nil
|
||||
}
|
||||
|
||||
// ParseConnectionString parses either a URI or a DSN connection string and builds ConnConfig.
|
||||
//
|
||||
// # Example DSN
|
||||
// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca
|
||||
//
|
||||
// # Example URL
|
||||
// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca
|
||||
//
|
||||
// ParseConnectionString supports specifying multiple hosts in similar manner to libpq.
|
||||
// Host and port may include comma separated values that will be tried in order.
|
||||
// This can be used as part of a high availability system.
|
||||
// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information.
|
||||
//
|
||||
// # Example URL
|
||||
// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb
|
||||
//
|
||||
// # Example DSN
|
||||
// user=jack password=secret host=host1,host2,host3 port=5432,5433,5434 dbname=mydb sslmode=verify-ca
|
||||
func ParseConnectionString(s string) (ConnConfig, error) {
|
||||
if u, err := url.Parse(s); err == nil && u.Scheme != "" {
|
||||
return ParseURI(s)
|
||||
@ -932,6 +1224,8 @@ func ParseConnectionString(s string) (ConnConfig, error) {
|
||||
// PGSSLROOTCERT
|
||||
// PGAPPNAME
|
||||
// PGCONNECT_TIMEOUT
|
||||
// PGTARGETSESSIONATTRS
|
||||
// @see: https://www.postgresql.org/docs/10/libpq-envars.html
|
||||
//
|
||||
// Important TLS Security Notes:
|
||||
// ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This
|
||||
@ -977,6 +1271,11 @@ func ParseEnvLibpq() (ConnConfig, error) {
|
||||
}
|
||||
}
|
||||
|
||||
cc.TargetSessionAttrs = TargetSessionType(os.Getenv("PGTARGETSESSIONATTRS"))
|
||||
if err := cc.TargetSessionAttrs.isValid(); err != nil {
|
||||
return cc, err
|
||||
}
|
||||
|
||||
tlsArgs := configTLSArgs{
|
||||
sslMode: os.Getenv("PGSSLMODE"),
|
||||
sslKey: os.Getenv("PGSSLKEY"),
|
||||
@ -1692,8 +1991,7 @@ func quoteIdentifier(s string) string {
|
||||
}
|
||||
|
||||
func doCancel(c *Conn) error {
|
||||
network, address := c.config.networkAddress()
|
||||
cancelConn, err := c.config.Dial(network, address)
|
||||
cancelConn, err := c.config.Dial(c.addr.Network(), c.addr.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -7,6 +7,8 @@ import (
|
||||
// "go/build"
|
||||
// "io/ioutil"
|
||||
// "path"
|
||||
// "net"
|
||||
// "time"
|
||||
|
||||
"github.com/jackc/pgx"
|
||||
)
|
||||
@ -14,6 +16,7 @@ import (
|
||||
var defaultConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
|
||||
// To skip tests for specific connection / authentication types set that connection param to nil
|
||||
var multihostConnConfig *pgx.ConnConfig = nil
|
||||
var tcpConnConfig *pgx.ConnConfig = nil
|
||||
var unixSocketConnConfig *pgx.ConnConfig = nil
|
||||
var md5ConnConfig *pgx.ConnConfig = nil
|
||||
@ -24,6 +27,7 @@ var customDialerConnConfig *pgx.ConnConfig = nil
|
||||
var replicationConnConfig *pgx.ConnConfig = nil
|
||||
var cratedbConnConfig *pgx.ConnConfig = nil
|
||||
|
||||
// var multihostConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "2.2.2.2:1,127.0.0.1,4.2.4.2", User: "pgx_md5", Password: "secret", Database: "pgx_test", Dial: (&net.Dialer{KeepAlive: 5 * time.Minute, Timeout: 100 * time.Millisecond}).Dial}
|
||||
// var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
// var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"}
|
||||
// var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
|
@ -5,9 +5,12 @@ import (
|
||||
"github.com/jackc/pgx"
|
||||
"os"
|
||||
"strconv"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
var multihostConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "2.2.2.2:1,127.0.0.1,4.2.4.2", User: "pgx_md5", Password: "secret", Database: "pgx_test", Dial: (&net.Dialer{KeepAlive: 5 * time.Minute, Timeout: 100 * time.Millisecond}).Dial}
|
||||
var tcpConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
var unixSocketConnConfig = &pgx.ConnConfig{Host: "/var/run/postgresql", User: "postgres", Database: "pgx_test"}
|
||||
var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"}
|
||||
|
24
conn_pool.go
24
conn_pool.go
@ -110,6 +110,25 @@ func (p *ConnPool) Acquire() (*Conn, error) {
|
||||
return c, err
|
||||
}
|
||||
|
||||
func (p *ConnPool) AcquireEx(ctx context.Context) (*Conn, error) {
|
||||
var deadline *time.Time
|
||||
|
||||
if p.acquireTimeout > 0 {
|
||||
tmp := time.Now().Add(p.acquireTimeout)
|
||||
deadline = &tmp
|
||||
}
|
||||
|
||||
ctxDeadline, ok := ctx.Deadline()
|
||||
if ok && (deadline == nil || ctxDeadline.Before(*deadline)) {
|
||||
deadline = &ctxDeadline
|
||||
}
|
||||
|
||||
p.cond.L.Lock()
|
||||
c, err := p.acquire(deadline)
|
||||
p.cond.L.Unlock()
|
||||
return c, err
|
||||
}
|
||||
|
||||
// deadlinePassed returns true if the given deadline has passed.
|
||||
func (p *ConnPool) deadlinePassed(deadline *time.Time) bool {
|
||||
return deadline != nil && time.Now().After(*deadline)
|
||||
@ -319,7 +338,7 @@ func (p *ConnPool) createConnection() (*Conn, error) {
|
||||
func (p *ConnPool) createConnectionUnlocked() (*Conn, error) {
|
||||
p.inProgressConnects++
|
||||
p.cond.L.Unlock()
|
||||
c, err := Connect(p.config)
|
||||
c, err := connect(p.config, p.connInfo.DeepCopy())
|
||||
p.cond.L.Lock()
|
||||
p.inProgressConnects--
|
||||
|
||||
@ -341,7 +360,8 @@ func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) {
|
||||
}
|
||||
|
||||
for _, ps := range p.preparedStatements {
|
||||
if _, err := c.Prepare(ps.Name, ps.SQL); err != nil {
|
||||
opts := &PrepareExOptions{ParameterOIDs: ps.ParameterOIDs}
|
||||
if _, err := c.PrepareEx(context.Background(), ps.Name, ps.SQL, opts); err != nil {
|
||||
c.die(err)
|
||||
return nil, err
|
||||
}
|
||||
|
@ -45,6 +45,12 @@ func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error)
|
||||
return c, time.Since(startTime), err
|
||||
}
|
||||
|
||||
func acquireExWithTimeTaken(pool *pgx.ConnPool, ctx context.Context) (*pgx.Conn, time.Duration, error) {
|
||||
startTime := time.Now()
|
||||
c, err := pool.AcquireEx(ctx)
|
||||
return c, time.Since(startTime), err
|
||||
}
|
||||
|
||||
func TestNewConnPool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -315,6 +321,144 @@ func TestPoolWithoutAcquireTimeoutSet(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolWithAcquireExContextTimeoutSet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := pgx.ConnPoolConfig{
|
||||
ConnConfig: *defaultConnConfig,
|
||||
MaxConnections: 1,
|
||||
}
|
||||
|
||||
pool, err := pgx.NewConnPool(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to create connection pool: %v", err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
// Consume all connections ...
|
||||
allConnections := acquireAllConnections(t, pool, config.MaxConnections)
|
||||
defer releaseAllConnections(pool, allConnections)
|
||||
|
||||
ctxTimeout := 2 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout)
|
||||
defer cancel()
|
||||
|
||||
// ... then try to consume 1 more. It should fail after a short timeout.
|
||||
_, timeTaken, err := acquireExWithTimeTaken(pool, ctx)
|
||||
|
||||
if err == nil || err != pgx.ErrAcquireTimeout {
|
||||
t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err)
|
||||
}
|
||||
if timeTaken < ctxTimeout {
|
||||
t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", ctxTimeout, timeTaken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolWithAcquireExPoolTimeoutLower(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
connAllocTimeout := 2 * time.Second
|
||||
config := pgx.ConnPoolConfig{
|
||||
ConnConfig: *defaultConnConfig,
|
||||
MaxConnections: 1,
|
||||
AcquireTimeout: connAllocTimeout,
|
||||
}
|
||||
|
||||
pool, err := pgx.NewConnPool(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to create connection pool: %v", err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
// Consume all connections ...
|
||||
allConnections := acquireAllConnections(t, pool, config.MaxConnections)
|
||||
defer releaseAllConnections(pool, allConnections)
|
||||
|
||||
ctxTimeout := 5 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout)
|
||||
defer cancel()
|
||||
|
||||
// ... then try to consume 1 more. It should fail after a short timeout.
|
||||
_, timeTaken, err := acquireExWithTimeTaken(pool, ctx)
|
||||
|
||||
if err == nil || err != pgx.ErrAcquireTimeout {
|
||||
t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err)
|
||||
}
|
||||
if timeTaken < connAllocTimeout {
|
||||
t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken)
|
||||
}
|
||||
if timeTaken > ctxTimeout {
|
||||
t.Fatalf("Expected connection allocation time to be less than %v, instead it was '%v'", ctxTimeout, timeTaken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolWithAcquireExPoolTimeoutHigher(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
connAllocTimeout := 5 * time.Second
|
||||
config := pgx.ConnPoolConfig{
|
||||
ConnConfig: *defaultConnConfig,
|
||||
MaxConnections: 1,
|
||||
AcquireTimeout: connAllocTimeout,
|
||||
}
|
||||
|
||||
pool, err := pgx.NewConnPool(config)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to create connection pool: %v", err)
|
||||
}
|
||||
defer pool.Close()
|
||||
|
||||
// Consume all connections ...
|
||||
allConnections := acquireAllConnections(t, pool, config.MaxConnections)
|
||||
defer releaseAllConnections(pool, allConnections)
|
||||
|
||||
ctxTimeout := 2 * time.Second
|
||||
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout)
|
||||
defer cancel()
|
||||
|
||||
// ... then try to consume 1 more. It should fail after a short timeout.
|
||||
_, timeTaken, err := acquireExWithTimeTaken(pool, ctx)
|
||||
|
||||
if err == nil || err != pgx.ErrAcquireTimeout {
|
||||
t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err)
|
||||
}
|
||||
if timeTaken < ctxTimeout {
|
||||
t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", ctxTimeout, timeTaken)
|
||||
}
|
||||
if timeTaken > connAllocTimeout {
|
||||
t.Fatalf("Expected connection allocation time to be less than %v, instead it was '%v'", connAllocTimeout, timeTaken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolWithoutAcquireExTimeoutSet(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
maxConnections := 1
|
||||
pool := createConnPool(t, maxConnections)
|
||||
defer pool.Close()
|
||||
|
||||
// Consume all connections ...
|
||||
allConnections := acquireAllConnections(t, pool, maxConnections)
|
||||
|
||||
// ... then try to consume 1 more. It should hang forever.
|
||||
// To unblock it we release the previously taken connection in a goroutine.
|
||||
stopDeadWaitTimeout := 5 * time.Second
|
||||
timer := time.AfterFunc(stopDeadWaitTimeout+100*time.Millisecond, func() {
|
||||
releaseAllConnections(pool, allConnections)
|
||||
})
|
||||
defer timer.Stop()
|
||||
|
||||
conn, timeTaken, err := acquireExWithTimeTaken(pool, context.Background())
|
||||
if err == nil {
|
||||
pool.Release(conn)
|
||||
} else {
|
||||
t.Fatalf("Expected error to be nil, instead it was '%v'", err)
|
||||
}
|
||||
if timeTaken < stopDeadWaitTimeout {
|
||||
t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", stopDeadWaitTimeout, timeTaken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolErrClosedPool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
223
conn_test.go
223
conn_test.go
@ -84,6 +84,105 @@ func TestConnect(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectWithMultiHost(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if multihostConnConfig == nil {
|
||||
t.Skip("Skipping due to undefined multihostConnConfig")
|
||||
}
|
||||
|
||||
conn, err := pgx.Connect(*multihostConnConfig)
|
||||
if err != nil {
|
||||
t.Fatalf("Unable to establish connection: %v", err)
|
||||
}
|
||||
|
||||
if _, present := conn.RuntimeParams["server_version"]; !present {
|
||||
t.Error("Runtime parameters not stored")
|
||||
}
|
||||
|
||||
if conn.PID() == 0 {
|
||||
t.Error("Backend PID not stored")
|
||||
}
|
||||
|
||||
var currentDB string
|
||||
err = conn.QueryRow("select current_database()").Scan(¤tDB)
|
||||
if err != nil {
|
||||
t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
|
||||
}
|
||||
if currentDB != defaultConnConfig.Database {
|
||||
t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database)
|
||||
}
|
||||
|
||||
var user string
|
||||
err = conn.QueryRow("select current_user").Scan(&user)
|
||||
if err != nil {
|
||||
t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
|
||||
}
|
||||
if user != defaultConnConfig.User {
|
||||
t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User)
|
||||
}
|
||||
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatal("Unable to close connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectWithMultiHostWritable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if multihostConnConfig == nil {
|
||||
t.Skip("Skipping due to undefined multihostConnConfig")
|
||||
}
|
||||
|
||||
connConfig := *multihostConnConfig
|
||||
connConfig.TargetSessionAttrs = pgx.ReadWriteTargetSession
|
||||
|
||||
conn := mustConnect(t, connConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
if _, present := conn.RuntimeParams["server_version"]; !present {
|
||||
t.Error("Runtime parameters not stored")
|
||||
}
|
||||
|
||||
if conn.PID() == 0 {
|
||||
t.Error("Backend PID not stored")
|
||||
}
|
||||
|
||||
var currentDB string
|
||||
err := conn.QueryRow("select current_database()").Scan(¤tDB)
|
||||
if err != nil {
|
||||
t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
|
||||
}
|
||||
if currentDB != defaultConnConfig.Database {
|
||||
t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database)
|
||||
}
|
||||
|
||||
var user string
|
||||
err = conn.QueryRow("select current_user").Scan(&user)
|
||||
if err != nil {
|
||||
t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
|
||||
}
|
||||
if user != defaultConnConfig.User {
|
||||
t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User)
|
||||
}
|
||||
|
||||
var st string
|
||||
err = conn.QueryRow("SHOW transaction_read_only").Scan(&st)
|
||||
if err != nil {
|
||||
t.Fatalf("QueryRow Scan unexpectedly failed: %v", err)
|
||||
}
|
||||
|
||||
if st == "on" {
|
||||
t.Error("Connection is not writable")
|
||||
}
|
||||
|
||||
err = conn.Close()
|
||||
if err != nil {
|
||||
t.Fatal("Unable to close connection")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectWithUnixSocketDirectory(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -521,6 +620,38 @@ func TestParseURI(t *testing.T) {
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
url: "postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb",
|
||||
connParams: pgx.ConnConfig{
|
||||
User: "jack",
|
||||
Password: "secret",
|
||||
Host: "foo.example.com:5432,bar.example.com:5432",
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
url: "postgres://jack@localhost,10.10.20.30/mydb?application_name=pgxtest&target_session_attrs=read-write",
|
||||
connParams: pgx.ConnConfig{
|
||||
User: "jack",
|
||||
Host: "localhost,10.10.20.30",
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
RuntimeParams: map[string]string{
|
||||
"application_name": "pgxtest",
|
||||
},
|
||||
TargetSessionAttrs: pgx.ReadWriteTargetSession,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
@ -647,6 +778,50 @@ func TestParseDSN(t *testing.T) {
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
url: "user=jack host=localhost1,localhost2 dbname=mydb connect_timeout=10",
|
||||
connParams: pgx.ConnConfig{
|
||||
User: "jack",
|
||||
Host: "localhost1,localhost2",
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial,
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
url: "user=jack host=100.200.220.50,localhost43 port=5432,5433 dbname=mydb",
|
||||
connParams: pgx.ConnConfig{
|
||||
User: "jack",
|
||||
Host: "100.200.220.50:5432,localhost43:5433",
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
},
|
||||
},
|
||||
{
|
||||
url: "user=jack host=localhost dbname=mydb target_session_attrs=read-write",
|
||||
connParams: pgx.ConnConfig{
|
||||
User: "jack",
|
||||
Host: "localhost",
|
||||
Database: "mydb",
|
||||
TLSConfig: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
},
|
||||
UseFallbackTLS: true,
|
||||
FallbackTLSConfig: nil,
|
||||
RuntimeParams: map[string]string{},
|
||||
TargetSessionAttrs: pgx.ReadWriteTargetSession,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
@ -1195,6 +1370,32 @@ func TestExecFailure(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecDeferredError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table t (
|
||||
id text primary key,
|
||||
n int not null,
|
||||
unique (n) deferrable initially deferred
|
||||
);
|
||||
|
||||
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`)
|
||||
|
||||
_, err := conn.Exec(`update t set n=n+1 where id='b'`)
|
||||
if err == nil {
|
||||
t.Fatal("expected error 23505 but got none")
|
||||
}
|
||||
|
||||
if err, ok := err.(pgx.PgError); !ok || err.Code != "23505" {
|
||||
t.Fatalf("expected error 23505, got %v", err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestExecFailureWithArguments(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -2142,6 +2343,24 @@ func TestSetLogLevel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdentifierSanitizeNullSentToServer(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
ident := pgx.Identifier{"foo" + string([]byte{0}) + "bar"}
|
||||
|
||||
var n int64
|
||||
err := conn.QueryRow(`select 1 as ` + ident.Sanitize()).Scan(&n)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != 1 {
|
||||
t.Fatal("unexpected n")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdentifierSanitize(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@ -2169,6 +2388,10 @@ func TestIdentifierSanitize(t *testing.T) {
|
||||
ident: pgx.Identifier{`you should " not do this`, `please don't`},
|
||||
expected: `"you should "" not do this"."please don't"`,
|
||||
},
|
||||
{
|
||||
ident: pgx.Identifier{`you should ` + string([]byte{0}) + `not do this`},
|
||||
expected: `"you should not do this"`,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
|
2
doc.go
2
doc.go
@ -225,7 +225,7 @@ notification.
|
||||
return nil
|
||||
}
|
||||
|
||||
if notification, err := conn.WaitForNotification(time.Second); err != nil {
|
||||
if notification, err := conn.WaitForNotification(context.TODO()); err != nil {
|
||||
// do something with notification
|
||||
}
|
||||
|
||||
|
@ -210,6 +210,7 @@ func PgxInitSteps() []Step {
|
||||
end
|
||||
from pg_type t
|
||||
left join pg_type base_type on t.typelem=base_type.oid
|
||||
left join pg_class base_cls ON base_type.typrelid = base_cls.oid
|
||||
left join pg_namespace nsp on t.typnamespace=nsp.oid
|
||||
where (
|
||||
t.typtype in('b', 'p', 'r', 'e')
|
||||
|
@ -149,7 +149,7 @@ func underlyingTimeType(val interface{}) (interface{}, bool) {
|
||||
switch refVal.Kind() {
|
||||
case reflect.Ptr:
|
||||
if refVal.IsNil() {
|
||||
return time.Time{}, false
|
||||
return nil, false
|
||||
}
|
||||
convVal := refVal.Elem().Interface()
|
||||
return convVal, true
|
||||
@ -160,7 +160,28 @@ func underlyingTimeType(val interface{}) (interface{}, bool) {
|
||||
return refVal.Convert(timeType).Interface(), true
|
||||
}
|
||||
|
||||
return time.Time{}, false
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// underlyingUUIDType gets the underlying type that can be converted to [16]byte
|
||||
func underlyingUUIDType(val interface{}) (interface{}, bool) {
|
||||
refVal := reflect.ValueOf(val)
|
||||
|
||||
switch refVal.Kind() {
|
||||
case reflect.Ptr:
|
||||
if refVal.IsNil() {
|
||||
return time.Time{}, false
|
||||
}
|
||||
convVal := refVal.Elem().Interface()
|
||||
return convVal, true
|
||||
}
|
||||
|
||||
uuidType := reflect.TypeOf([16]byte{})
|
||||
if refVal.Type().ConvertibleTo(uuidType) {
|
||||
return refVal.Convert(uuidType).Interface(), true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// underlyingSliceType gets the underlying slice type
|
||||
@ -401,6 +422,14 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
if dstVal.Kind() == reflect.Array {
|
||||
if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok {
|
||||
baseArrayType := reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType))
|
||||
nextDst := dstPtr.Convert(baseArrayType)
|
||||
return nextDst.Interface(), dstPtr.Type() != nextDst.Type()
|
||||
}
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
|
@ -39,7 +39,7 @@ func (dst *UUID) Set(src interface{}) error {
|
||||
}
|
||||
*dst = UUID{Bytes: uuid, Status: Present}
|
||||
default:
|
||||
if originalSrc, ok := underlyingPtrType(src); ok {
|
||||
if originalSrc, ok := underlyingUUIDType(src); ok {
|
||||
return dst.Set(originalSrc)
|
||||
}
|
||||
return errors.Errorf("cannot convert %v to UUID", value)
|
||||
|
@ -15,6 +15,8 @@ func TestUUIDTranscode(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
type SomeUUIDType [16]byte
|
||||
|
||||
func TestUUIDSet(t *testing.T) {
|
||||
successfulTests := []struct {
|
||||
source interface{}
|
||||
@ -32,6 +34,10 @@ func TestUUIDSet(t *testing.T) {
|
||||
source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
|
||||
result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
source: SomeUUIDType{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
|
||||
result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present},
|
||||
},
|
||||
{
|
||||
source: ([]byte)(nil),
|
||||
result: pgtype.UUID{Status: pgtype.Null},
|
||||
@ -86,6 +92,21 @@ func TestUUIDAssignTo(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}
|
||||
var dst SomeUUIDType
|
||||
expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}
|
||||
|
||||
err := src.AssignTo(&dst)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
if dst != expected {
|
||||
t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}
|
||||
var dst string
|
||||
|
19
query.go
19
query.go
@ -69,6 +69,25 @@ func (rows *Rows) Close() {
|
||||
return
|
||||
}
|
||||
|
||||
// If there is no error and a batch operation is not in progress read until we get the ReadyForQuery message or the
|
||||
// ErrorResponse. This is necessary to detect a deferred constraint violation where the ErrorResponse is sent after
|
||||
// CommandComplete.
|
||||
if rows.err == nil && rows.batch == nil && rows.conn.pendingReadyForQueryCount == 1 {
|
||||
for rows.conn.pendingReadyForQueryCount > 0 {
|
||||
msg, err := rows.conn.rxMsg()
|
||||
if err != nil {
|
||||
rows.err = err
|
||||
break
|
||||
}
|
||||
|
||||
err = rows.conn.processContextFreeMsg(msg)
|
||||
if err != nil {
|
||||
rows.err = err
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if rows.unlockConn {
|
||||
rows.conn.unlock()
|
||||
rows.unlockConn = false
|
||||
|
@ -14,7 +14,7 @@ import (
|
||||
"github.com/jackc/pgx"
|
||||
"github.com/jackc/pgx/pgtype"
|
||||
satori "github.com/jackc/pgx/pgtype/ext/satori-uuid"
|
||||
"github.com/satori/go.uuid"
|
||||
uuid "github.com/satori/go.uuid"
|
||||
"github.com/shopspring/decimal"
|
||||
)
|
||||
|
||||
@ -424,6 +424,47 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
// https://github.com/jackc/pgx/issues/570
|
||||
func TestConnQueryDeferredError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conn := mustConnect(t, *defaultConnConfig)
|
||||
defer closeConn(t, conn)
|
||||
|
||||
mustExec(t, conn, `create temporary table t (
|
||||
id text primary key,
|
||||
n int not null,
|
||||
unique (n) deferrable initially deferred
|
||||
);
|
||||
|
||||
insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`)
|
||||
|
||||
rows, err := conn.Query(`update t set n=n+1 where id='b' returning *`)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var id string
|
||||
var n int32
|
||||
err = rows.Scan(&id, &n)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Err() == nil {
|
||||
t.Fatal("expected error 23505 but got none")
|
||||
}
|
||||
|
||||
if err, ok := rows.Err().(pgx.PgError); !ok || err.Code != "23505" {
|
||||
t.Fatalf("expected error 23505, got %v", err)
|
||||
}
|
||||
|
||||
ensureConnValid(t, conn)
|
||||
}
|
||||
|
||||
func TestQueryEncodeError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
@ -43,8 +43,8 @@
|
||||
//
|
||||
// AcquireConn and ReleaseConn acquire and release a *pgx.Conn from the standard
|
||||
// database/sql.DB connection pool. This allows operations that must be
|
||||
// performed on a single connection, but should not be run in a transaction or
|
||||
// to use pgx specific functionality.
|
||||
// performed on a single connection without running in a transaction, and it
|
||||
// supports operations that use pgx specific functionality.
|
||||
//
|
||||
// conn, err := stdlib.AcquireConn(db)
|
||||
// if err != nil {
|
||||
@ -277,7 +277,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e
|
||||
pgxOpts.IsoLevel = pgx.ReadUncommitted
|
||||
case sql.LevelReadCommitted:
|
||||
pgxOpts.IsoLevel = pgx.ReadCommitted
|
||||
case sql.LevelSnapshot:
|
||||
case sql.LevelRepeatableRead, sql.LevelSnapshot:
|
||||
pgxOpts.IsoLevel = pgx.RepeatableRead
|
||||
case sql.LevelSerializable:
|
||||
pgxOpts.IsoLevel = pgx.Serializable
|
||||
|
@ -629,6 +629,7 @@ func TestConnBeginTxIsolation(t *testing.T) {
|
||||
{sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel},
|
||||
{sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"},
|
||||
{sqlIso: sql.LevelReadCommitted, pgIso: "read committed"},
|
||||
{sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"},
|
||||
{sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"},
|
||||
{sqlIso: sql.LevelSerializable, pgIso: "serializable"},
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user