diff --git a/README.md b/README.md index b7051f65..0a4cacc3 100644 --- a/README.md +++ b/README.md @@ -17,7 +17,8 @@ if err != nil { ## v4 Coming Soon -This is the current stable v3 version. v4 is currently is in prelease status. Consider using [v4](https://github.com/jackc/pgx/tree/v4) for new development or test upgrading existing applications. +This is the current stable v3 version. v4 is currently is in release candidate status. Consider using +[v4](https://github.com/jackc/pgx/tree/v4) for new development or test upgrading existing applications. ## Features diff --git a/batch.go b/batch.go index 4b624387..7f5422dc 100644 --- a/batch.go +++ b/batch.go @@ -135,7 +135,7 @@ func (b *Batch) Send(ctx context.Context, txOptions *TxOptions) error { _, err = b.conn.conn.Write(buf) if err != nil { - b.conn.die(err) + b.die(err) return err } @@ -268,6 +268,23 @@ func (b *Batch) Close() (err error) { } } + for b.conn.pendingReadyForQueryCount > 0 { + msg, err := b.conn.rxMsg() + if err != nil { + return err + } + + switch msg := msg.(type) { + case *pgproto3.ErrorResponse: + return b.conn.rxErrorResponse(msg) + default: + err = b.conn.processContextFreeMsg(msg) + if err != nil { + return err + } + } + } + if err = b.conn.ensureConnectionReadyForQuery(); err != nil { return err } @@ -281,10 +298,13 @@ func (b *Batch) die(err error) { } b.err = err - b.conn.die(err) + if b.conn != nil { + err = b.conn.termContext(err) + b.conn.die(err) - if b.conn != nil && b.connPool != nil { - b.connPool.Release(b.conn) + if b.connPool != nil { + b.connPool.Release(b.conn) + } } } diff --git a/batch_test.go b/batch_test.go index 61bbe357..d0e26875 100644 --- a/batch_test.go +++ b/batch_test.go @@ -701,3 +701,55 @@ func TestTxBeginBatchRollback(t *testing.T) { ensureConnValid(t, conn) } + +func TestConnBeginBatchDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred + ); + + insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + batch := conn.BeginBatch() + batch.Queue(`update t set n=n+1 where id='b' returning *`, + nil, + nil, + []int16{pgx.BinaryFormatCode}, + ) + + err := batch.Send(context.Background(), nil) + if err != nil { + t.Fatal(err) + } + + rows, err := batch.QueryResults() + if err != nil { + t.Error(err) + } + + for rows.Next() { + var id string + var n int32 + err = rows.Scan(&id, &n) + if err != nil { + t.Fatal(err) + } + } + + err = batch.Close() + if err == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := err.(pgx.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} diff --git a/chunkreader/chunkreader.go b/chunkreader/chunkreader.go index f8d437b2..5c36292d 100644 --- a/chunkreader/chunkreader.go +++ b/chunkreader/chunkreader.go @@ -28,7 +28,11 @@ func NewChunkReader(r io.Reader) *ChunkReader { func NewChunkReaderEx(r io.Reader, options Options) (*ChunkReader, error) { if options.MinBufLen == 0 { - options.MinBufLen = 4096 + // By historical reasons Postgres currently has 8KB send buffer inside, + // so here we want to have at least the same size buffer. + // @see https://github.com/postgres/postgres/blob/249d64999615802752940e017ee5166e726bc7cd/src/backend/libpq/pqcomm.c#L134 + // @see https://www.postgresql.org/message-id/0cdc5485-cb3c-5e16-4a46-e3b2f7a41322%40ya.ru + options.MinBufLen = 8192 } return &ChunkReader{ diff --git a/conn.go b/conn.go index b613707e..b4469af9 100644 --- a/conn.go +++ b/conn.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "io/ioutil" + "math" "net" "net/url" "os" @@ -61,10 +62,50 @@ type NoticeHandler func(*Conn, *Notice) // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(network, addr string) (net.Conn, error) +// TargetSessionType represents target session attrs configuration parameter. +type TargetSessionType string + +// Block enumerates available values for TargetSessionType. +const ( + AnyTargetSession = "any" + ReadWriteTargetSession = "read-write" +) + +func (t TargetSessionType) isValid() error { + switch t { + case "", AnyTargetSession, ReadWriteTargetSession: + return nil + } + + return errors.New("invalid value for target_session_attrs, expected \"any\" or \"read-write\"") +} + +func (t TargetSessionType) writableRequired() bool { + return t == ReadWriteTargetSession +} + // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { - Host string // host (e.g. localhost) or path to unix domain socket directory (e.g. /private/tmp) - Port uint16 // default: 5432 + // Name of host to connect to. (e.g. localhost) + // If a host name begins with a slash, it specifies Unix-domain communication + // rather than TCP/IP communication; the value is the name of the directory + // in which the socket file is stored. (e.g. /private/tmp) + // The default behavior when host is not specified, or is empty, is to connect to localhost. + // + // A comma-separated list of host names is also accepted, + // in which case each host name in the list is tried in order; + // an empty item in the list selects the default behavior as explained above. + // @see https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS + Host string + + // Port number to connect to at the server host, + // or socket file name extension for Unix-domain connections. + // An empty or zero value, specifies the default port number — 5432. + // + // If multiple hosts were given in the Host parameter, then + // this parameter may specify a single port number to be used for all hosts, + // or for those that haven't port explicitly defined. + Port uint16 Database string User string // default: OS user name Password string @@ -89,22 +130,94 @@ type ConnConfig struct { // used by default. The same functionality can be controlled on a per query // basis by setting QueryExOptions.SimpleProtocol. PreferSimpleProtocol bool + + // TargetSessionAttr allows to specify which servers are accepted for this connection. + // "any", meaning that any kind of servers can be accepted. This is as well the default value. + // "read-write", to disallow connections to read-only servers, hot standbys for example. + // @see https://www.postgresql.org/message-id/CAD__OuhqPRGpcsfwPHz_PDqAGkoqS1UvnUnOnAB-LBWBW=wu4A@mail.gmail.com + // @see https://paquier.xyz/postgresql-2/postgres-10-libpq-read-write/ + // + // The query SHOW transaction_read_only will be sent upon any successful connection; + // if it returns on, the connection will be closed. + // If multiple hosts were specified in the connection string, + // any remaining servers will be tried just as if the connection attempt had failed. + // The default value of this parameter, any, regards all connections as acceptable. + TargetSessionAttrs TargetSessionType } -func (cc *ConnConfig) networkAddress() (network, address string) { - network = "tcp" - address = fmt.Sprintf("%s:%d", cc.Host, cc.Port) - // See if host is a valid path, if yes connect with a socket - if _, err := os.Stat(cc.Host); err == nil { - // For backward compatibility accept socket file paths -- but directories are now preferred - network = "unix" - address = cc.Host - if !strings.Contains(address, "/.s.PGSQL.") { - address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatInt(int64(cc.Port), 10) - } +// hostAddr represents network end point defined as hostname or IP + port. +type hostAddr struct { + Host string + Port uint16 +} + +// Network returns the address's network name, "tcp". +func (a *hostAddr) Network() string { return "tcp" } + +// String implements net.Addr String method. +func (a *hostAddr) String() string { + if a == nil { + return "" } - return network, address + return net.JoinHostPort(a.Host, strconv.Itoa(int(a.Port))) +} + +func (cc *ConnConfig) networkAddresses() ([]net.Addr, error) { + // See if host is a valid path, if yes connect with a unix socket + if _, err := os.Stat(cc.Host); err == nil { + // For backward compatibility accept socket file paths -- but directories are now preferred + network := "unix" + address := cc.Host + + if !strings.Contains(address, "/.s.PGSQL.") { + address = filepath.Join(address, ".s.PGSQL.") + strconv.FormatUint(uint64(cc.Port), 10) + } + + addrs := []net.Addr{ + &net.UnixAddr{Name: address, Net: network}, + } + + return addrs, nil + } + + if cc.Host == "" { + addrs := []net.Addr{ + &net.TCPAddr{Port: int(cc.Port)}, + } + + return addrs, nil + } + + var addrs []net.Addr + + hostports := strings.Split(cc.Host, ",") + for i, hostport := range hostports { + if hostport == "" { + return nil, fmt.Errorf("multi-host part %d is empty, at least host or port must be defined", i) + } + + // It's not possible to use net.TCPAddr here, cuz host may be hostname. + addr := hostAddr{ + Host: hostport, + Port: cc.Port, + } + + pos := strings.IndexByte(hostport, ':') + if pos != -1 { + p, err := strconv.ParseUint(hostport[pos+1:], 10, 16) + if err != nil { + return nil, fmt.Errorf("multi-host part %d (%s) has invalid port format", i, hostport) + } + + addr.Host = hostport[:pos] + addr.Port = uint16(p) + } + + addrs = append(addrs, &addr) + } + + return addrs, nil } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. @@ -145,6 +258,10 @@ type Conn struct { ConnInfo *pgtype.ConnInfo frontend *pgproto3.Frontend + + // In case of Multiple Hosts we need to know what addr was used to connect. + // This address will be used to send a cancellation request. + addr net.Addr } // PreparedStatement is a description of a prepared statement @@ -190,7 +307,8 @@ type Identifier []string func (ident Identifier) Sanitize() string { parts := make([]string, len(ident)) for i := range ident { - parts[i] = `"` + strings.Replace(ident[i], `"`, `""`, -1) + `"` + s := strings.Replace(ident[i], string([]byte{0}), "", -1) + parts[i] = `"` + strings.Replace(s, `"`, `""`, -1) + `"` } return strings.Join(parts, ".") } @@ -262,33 +380,123 @@ func connect(config ConnConfig, connInfo *pgtype.ConnInfo) (c *Conn, err error) } } + if err := c.config.TargetSessionAttrs.isValid(); err != nil { + return nil, err + } + c.onNotice = config.OnNotice - network, address := c.config.networkAddress() if c.config.Dial == nil { d := defaultDialer() c.config.Dial = d.Dial } - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"network": network, "address": address}) - } - err = c.connect(config, network, address, config.TLSConfig) - if err != nil && config.UseFallbackTLS { - if c.shouldLog(LogLevelInfo) { - c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{"err": err}) - } - err = c.connect(config, network, address, config.FallbackTLSConfig) - } - + addrs, err := c.config.networkAddresses() if err != nil { - if c.shouldLog(LogLevelError) { - c.log(LogLevelError, "connect failed", map[string]interface{}{"err": err}) - } return nil, err } - return c, nil + var errs []error + for _, addr := range addrs { + network, address := addr.Network(), addr.String() + + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{ + "network": network, + "address": address, + }) + } + + err = c.connect(config, network, address, config.TLSConfig) + if err != nil && config.UseFallbackTLS { + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "connect with TLSConfig failed, trying FallbackTLSConfig", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) + } + err = c.connect(config, network, address, config.FallbackTLSConfig) + } + + if err != nil { + if c.shouldLog(LogLevelError) { + c.log(LogLevelError, "connect failed", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) + } + + // On any auth errors return immediately + if pgErr, ok := err.(PgError); ok { + switch pgErr.Code { + // @see: https://www.postgresql.org/docs/current/errcodes-appendix.html + case "28000", "28P01": // Invalid Authorization Specification + return nil, pgErr + } + } + + errs = append(errs, err) + continue + } + + err = c.checkWritable() + if err != nil { + c.die(err) + + if c.shouldLog(LogLevelInfo) { + c.log(LogLevelInfo, "host is not writable", map[string]interface{}{ + "err": err, + "network": network, + "address": address, + }) + } + + errs = append(errs, err) + continue + } + + c.addr = addr + + return c, nil + } + + // To keep backwards compatibility, if specific error type expected. + if len(errs) == 1 { + return nil, errs[0] + } + + errmsgs := make([]string, len(errs)) + for i, err := range errs { + errmsgs[i] = err.Error() + } + + return nil, errors.New(strings.Join(errmsgs, "; ")) +} + +func (c *Conn) checkWritable() error { + if !c.config.TargetSessionAttrs.writableRequired() { + return nil + } + + var st string + err := c.QueryRowEx(context.Background(), "SHOW transaction_read_only", &QueryExOptions{SimpleProtocol: true}). + Scan(&st) + + if err != nil { + return errors.Wrap(err, "failed to fetch \"transaction_read_only\" state") + } + + switch st { + case "on": + return errors.New("writable transactions disabled by server") + case "off": + // If transaction_read_only = off, then connection is writable. + return nil + } + + return errors.New("unexpected \"transaction_read_only\" status") } func (c *Conn) connect(config ConnConfig, network, address string, tlsConfig *tls.Config) (err error) { @@ -403,6 +611,7 @@ func initPostgresql(c *Conn) (*pgtype.ConnInfo, error) { end from pg_type t left join pg_type base_type on t.typelem=base_type.oid +left join pg_class base_cls ON base_type.typrelid = base_cls.oid left join pg_namespace nsp on t.typnamespace=nsp.oid where ( t.typtype in('b', 'p', 'r', 'e') @@ -750,6 +959,10 @@ func (old ConnConfig) Merge(other ConnConfig) ConnConfig { cc.PreferSimpleProtocol = old.PreferSimpleProtocol || other.PreferSimpleProtocol + if other.TargetSessionAttrs != "" { + cc.TargetSessionAttrs = other.TargetSessionAttrs + } + cc.RuntimeParams = make(map[string]string) for k, v := range old.RuntimeParams { cc.RuntimeParams[k] = v @@ -777,16 +990,26 @@ func ParseURI(uri string) (ConnConfig, error) { cp.Password, _ = url.User.Password() } - parts := strings.SplitN(url.Host, ":", 2) - cp.Host = parts[0] - if len(parts) == 2 { - p, err := strconv.ParseUint(parts[1], 10, 16) - if err != nil { - return cp, err + hasMuliHosts := strings.IndexByte(url.Host, ',') != -1 + if !hasMuliHosts { + parts := strings.SplitN(url.Host, ":", 2) + cp.Host = parts[0] + if len(parts) == 2 { + p, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return cp, err + } + cp.Port = uint16(p) } - cp.Port = uint16(p) + } else { + cp.Host = url.Host } + cp.Database = strings.TrimLeft(url.Path, "/") + cp.TargetSessionAttrs = TargetSessionType(url.Query().Get("target_session_attrs")) + if err := cp.TargetSessionAttrs.isValid(); err != nil { + return cp, err + } if pgtimeout := url.Query().Get("connect_timeout"); pgtimeout != "" { timeout, err := strconv.ParseInt(pgtimeout, 10, 64) @@ -810,11 +1033,12 @@ func ParseURI(uri string) (ConnConfig, error) { } ignoreKeys := map[string]struct{}{ - "connect_timeout": {}, - "sslcert": {}, - "sslkey": {}, - "sslmode": {}, - "sslrootcert": {}, + "connect_timeout": {}, + "sslcert": {}, + "sslkey": {}, + "sslmode": {}, + "sslrootcert": {}, + "target_session_attrs": {}, } cp.RuntimeParams = make(map[string]string) @@ -834,6 +1058,7 @@ func ParseURI(uri string) (ConnConfig, error) { if cp.Password == "" { pgpass(&cp) } + return cp, nil } @@ -859,6 +1084,7 @@ func ParseDSN(s string) (ConnConfig, error) { cp.RuntimeParams = make(map[string]string) + var hostval, portval string for _, b := range m { switch b[1] { case "user": @@ -866,13 +1092,9 @@ func ParseDSN(s string) (ConnConfig, error) { case "password": cp.Password = b[2] case "host": - cp.Host = b[2] + hostval = b[2] case "port": - p, err := strconv.ParseUint(b[2], 10, 16) - if err != nil { - return cp, err - } - cp.Port = uint16(p) + portval = b[2] case "dbname": cp.Database = b[2] case "sslmode": @@ -891,23 +1113,93 @@ func ParseDSN(s string) (ConnConfig, error) { d := defaultDialer() d.Timeout = time.Duration(timeout) * time.Second cp.Dial = d.Dial + case "target_session_attrs": + cp.TargetSessionAttrs = TargetSessionType(b[2]) + if err := cp.TargetSessionAttrs.isValid(); err != nil { + return cp, err + } default: cp.RuntimeParams[b[1]] = b[2] } } - err := configTLS(tlsArgs, &cp) + host, port, err := parseHostPortDSN(hostval, portval) if err != nil { return cp, err } + + cp.Host, cp.Port = host, port + + err = configTLS(tlsArgs, &cp) + if err != nil { + return cp, err + } + if cp.Password == "" { pgpass(&cp) } + return cp, nil } -// ParseConnectionString parses either a URI or a DSN connection string. -// see ParseURI and ParseDSN for details. +func parseHostPortDSN(hostval, portval string) (host string, port uint16, err error) { + if portval == "" { + return hostval, 0, nil + } + + hosts := strings.Split(hostval, ",") + ports := strings.Split(portval, ",") + + if len(ports) == 1 { + port, err := parsePort(portval) + if err != nil { + return "", 0, errors.Errorf("invalid port: %v", err) + } + + return hostval, port, nil + } + + if len(hosts) != len(ports) { + return "", 0, errors.New("the number of hosts and ports must be the same") + } + + hostports := make([]string, len(hosts)) + for i, host := range hosts { + hostports[i] = host + ":" + ports[i] + } + + return strings.Join(hostports, ","), 0, nil +} + +func parsePort(s string) (uint16, error) { + port, err := strconv.ParseUint(s, 10, 16) + if err != nil { + return 0, err + } + if port < 1 || port > math.MaxUint16 { + return 0, errors.New("outside range") + } + return uint16(port), nil +} + +// ParseConnectionString parses either a URI or a DSN connection string and builds ConnConfig. +// +// # Example DSN +// user=jack password=secret host=pg.example.com port=5432 dbname=mydb sslmode=verify-ca +// +// # Example URL +// postgres://jack:secret@pg.example.com:5432/mydb?sslmode=verify-ca +// +// ParseConnectionString supports specifying multiple hosts in similar manner to libpq. +// Host and port may include comma separated values that will be tried in order. +// This can be used as part of a high availability system. +// See https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS for more information. +// +// # Example URL +// postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb +// +// # Example DSN +// user=jack password=secret host=host1,host2,host3 port=5432,5433,5434 dbname=mydb sslmode=verify-ca func ParseConnectionString(s string) (ConnConfig, error) { if u, err := url.Parse(s); err == nil && u.Scheme != "" { return ParseURI(s) @@ -932,6 +1224,8 @@ func ParseConnectionString(s string) (ConnConfig, error) { // PGSSLROOTCERT // PGAPPNAME // PGCONNECT_TIMEOUT +// PGTARGETSESSIONATTRS +// @see: https://www.postgresql.org/docs/10/libpq-envars.html // // Important TLS Security Notes: // ParseEnvLibpq tries to match libpq behavior with regard to PGSSLMODE. This @@ -977,6 +1271,11 @@ func ParseEnvLibpq() (ConnConfig, error) { } } + cc.TargetSessionAttrs = TargetSessionType(os.Getenv("PGTARGETSESSIONATTRS")) + if err := cc.TargetSessionAttrs.isValid(); err != nil { + return cc, err + } + tlsArgs := configTLSArgs{ sslMode: os.Getenv("PGSSLMODE"), sslKey: os.Getenv("PGSSLKEY"), @@ -1692,8 +1991,7 @@ func quoteIdentifier(s string) string { } func doCancel(c *Conn) error { - network, address := c.config.networkAddress() - cancelConn, err := c.config.Dial(network, address) + cancelConn, err := c.config.Dial(c.addr.Network(), c.addr.String()) if err != nil { return err } diff --git a/conn_config_test.go.example b/conn_config_test.go.example index 096e1354..2ca84ac3 100644 --- a/conn_config_test.go.example +++ b/conn_config_test.go.example @@ -7,6 +7,8 @@ import ( // "go/build" // "io/ioutil" // "path" + // "net" + // "time" "github.com/jackc/pgx" ) @@ -14,6 +16,7 @@ import ( var defaultConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // To skip tests for specific connection / authentication types set that connection param to nil +var multihostConnConfig *pgx.ConnConfig = nil var tcpConnConfig *pgx.ConnConfig = nil var unixSocketConnConfig *pgx.ConnConfig = nil var md5ConnConfig *pgx.ConnConfig = nil @@ -24,6 +27,7 @@ var customDialerConnConfig *pgx.ConnConfig = nil var replicationConnConfig *pgx.ConnConfig = nil var cratedbConnConfig *pgx.ConnConfig = nil +// var multihostConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "2.2.2.2:1,127.0.0.1,4.2.4.2", User: "pgx_md5", Password: "secret", Database: "pgx_test", Dial: (&net.Dialer{KeepAlive: 5 * time.Minute, Timeout: 100 * time.Millisecond}).Dial} // var tcpConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} // var unixSocketConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "/private/tmp", User: "pgx_none", Database: "pgx_test"} // var md5ConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_config_test.go.travis b/conn_config_test.go.travis index cf29a743..fbfb5252 100644 --- a/conn_config_test.go.travis +++ b/conn_config_test.go.travis @@ -5,9 +5,12 @@ import ( "github.com/jackc/pgx" "os" "strconv" + "net" + "time" ) var defaultConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} +var multihostConnConfig *pgx.ConnConfig = &pgx.ConnConfig{Host: "2.2.2.2:1,127.0.0.1,4.2.4.2", User: "pgx_md5", Password: "secret", Database: "pgx_test", Dial: (&net.Dialer{KeepAlive: 5 * time.Minute, Timeout: 100 * time.Millisecond}).Dial} var tcpConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} var unixSocketConnConfig = &pgx.ConnConfig{Host: "/var/run/postgresql", User: "postgres", Database: "pgx_test"} var md5ConnConfig = &pgx.ConnConfig{Host: "127.0.0.1", User: "pgx_md5", Password: "secret", Database: "pgx_test"} diff --git a/conn_pool.go b/conn_pool.go index 47a0b391..95e1b015 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -110,6 +110,25 @@ func (p *ConnPool) Acquire() (*Conn, error) { return c, err } +func (p *ConnPool) AcquireEx(ctx context.Context) (*Conn, error) { + var deadline *time.Time + + if p.acquireTimeout > 0 { + tmp := time.Now().Add(p.acquireTimeout) + deadline = &tmp + } + + ctxDeadline, ok := ctx.Deadline() + if ok && (deadline == nil || ctxDeadline.Before(*deadline)) { + deadline = &ctxDeadline + } + + p.cond.L.Lock() + c, err := p.acquire(deadline) + p.cond.L.Unlock() + return c, err +} + // deadlinePassed returns true if the given deadline has passed. func (p *ConnPool) deadlinePassed(deadline *time.Time) bool { return deadline != nil && time.Now().After(*deadline) @@ -319,7 +338,7 @@ func (p *ConnPool) createConnection() (*Conn, error) { func (p *ConnPool) createConnectionUnlocked() (*Conn, error) { p.inProgressConnects++ p.cond.L.Unlock() - c, err := Connect(p.config) + c, err := connect(p.config, p.connInfo.DeepCopy()) p.cond.L.Lock() p.inProgressConnects-- @@ -341,7 +360,8 @@ func (p *ConnPool) afterConnectionCreated(c *Conn) (*Conn, error) { } for _, ps := range p.preparedStatements { - if _, err := c.Prepare(ps.Name, ps.SQL); err != nil { + opts := &PrepareExOptions{ParameterOIDs: ps.ParameterOIDs} + if _, err := c.PrepareEx(context.Background(), ps.Name, ps.SQL, opts); err != nil { c.die(err) return nil, err } diff --git a/conn_pool_test.go b/conn_pool_test.go index 84a74aed..83bdf1fd 100644 --- a/conn_pool_test.go +++ b/conn_pool_test.go @@ -45,6 +45,12 @@ func acquireWithTimeTaken(pool *pgx.ConnPool) (*pgx.Conn, time.Duration, error) return c, time.Since(startTime), err } +func acquireExWithTimeTaken(pool *pgx.ConnPool, ctx context.Context) (*pgx.Conn, time.Duration, error) { + startTime := time.Now() + c, err := pool.AcquireEx(ctx) + return c, time.Since(startTime), err +} + func TestNewConnPool(t *testing.T) { t.Parallel() @@ -315,6 +321,144 @@ func TestPoolWithoutAcquireTimeoutSet(t *testing.T) { } } +func TestPoolWithAcquireExContextTimeoutSet(t *testing.T) { + t.Parallel() + + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + ctxTimeout := 2 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < ctxTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", ctxTimeout, timeTaken) + } +} + +func TestPoolWithAcquireExPoolTimeoutLower(t *testing.T) { + t.Parallel() + + connAllocTimeout := 2 * time.Second + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + AcquireTimeout: connAllocTimeout, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + ctxTimeout := 5 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < connAllocTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", connAllocTimeout, timeTaken) + } + if timeTaken > ctxTimeout { + t.Fatalf("Expected connection allocation time to be less than %v, instead it was '%v'", ctxTimeout, timeTaken) + } +} + +func TestPoolWithAcquireExPoolTimeoutHigher(t *testing.T) { + t.Parallel() + + connAllocTimeout := 5 * time.Second + config := pgx.ConnPoolConfig{ + ConnConfig: *defaultConnConfig, + MaxConnections: 1, + AcquireTimeout: connAllocTimeout, + } + + pool, err := pgx.NewConnPool(config) + if err != nil { + t.Fatalf("Unable to create connection pool: %v", err) + } + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, config.MaxConnections) + defer releaseAllConnections(pool, allConnections) + + ctxTimeout := 2 * time.Second + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeout) + defer cancel() + + // ... then try to consume 1 more. It should fail after a short timeout. + _, timeTaken, err := acquireExWithTimeTaken(pool, ctx) + + if err == nil || err != pgx.ErrAcquireTimeout { + t.Fatalf("Expected error to be pgx.ErrAcquireTimeout, instead it was '%v'", err) + } + if timeTaken < ctxTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", ctxTimeout, timeTaken) + } + if timeTaken > connAllocTimeout { + t.Fatalf("Expected connection allocation time to be less than %v, instead it was '%v'", connAllocTimeout, timeTaken) + } +} + +func TestPoolWithoutAcquireExTimeoutSet(t *testing.T) { + t.Parallel() + + maxConnections := 1 + pool := createConnPool(t, maxConnections) + defer pool.Close() + + // Consume all connections ... + allConnections := acquireAllConnections(t, pool, maxConnections) + + // ... then try to consume 1 more. It should hang forever. + // To unblock it we release the previously taken connection in a goroutine. + stopDeadWaitTimeout := 5 * time.Second + timer := time.AfterFunc(stopDeadWaitTimeout+100*time.Millisecond, func() { + releaseAllConnections(pool, allConnections) + }) + defer timer.Stop() + + conn, timeTaken, err := acquireExWithTimeTaken(pool, context.Background()) + if err == nil { + pool.Release(conn) + } else { + t.Fatalf("Expected error to be nil, instead it was '%v'", err) + } + if timeTaken < stopDeadWaitTimeout { + t.Fatalf("Expected connection allocation time to be at least %v, instead it was '%v'", stopDeadWaitTimeout, timeTaken) + } +} + func TestPoolErrClosedPool(t *testing.T) { t.Parallel() diff --git a/conn_test.go b/conn_test.go index 6ca00c6d..c6ce50cc 100644 --- a/conn_test.go +++ b/conn_test.go @@ -84,6 +84,105 @@ func TestConnect(t *testing.T) { } } +func TestConnectWithMultiHost(t *testing.T) { + t.Parallel() + + if multihostConnConfig == nil { + t.Skip("Skipping due to undefined multihostConnConfig") + } + + conn, err := pgx.Connect(*multihostConnConfig) + if err != nil { + t.Fatalf("Unable to establish connection: %v", err) + } + + if _, present := conn.RuntimeParams["server_version"]; !present { + t.Error("Runtime parameters not stored") + } + + if conn.PID() == 0 { + t.Error("Backend PID not stored") + } + + var currentDB string + err = conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + } + + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + +func TestConnectWithMultiHostWritable(t *testing.T) { + t.Parallel() + + if multihostConnConfig == nil { + t.Skip("Skipping due to undefined multihostConnConfig") + } + + connConfig := *multihostConnConfig + connConfig.TargetSessionAttrs = pgx.ReadWriteTargetSession + + conn := mustConnect(t, connConfig) + defer closeConn(t, conn) + + if _, present := conn.RuntimeParams["server_version"]; !present { + t.Error("Runtime parameters not stored") + } + + if conn.PID() == 0 { + t.Error("Backend PID not stored") + } + + var currentDB string + err := conn.QueryRow("select current_database()").Scan(¤tDB) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if currentDB != defaultConnConfig.Database { + t.Errorf("Did not connect to specified database (%v)", defaultConnConfig.Database) + } + + var user string + err = conn.QueryRow("select current_user").Scan(&user) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + if user != defaultConnConfig.User { + t.Errorf("Did not connect as specified user (%v)", defaultConnConfig.User) + } + + var st string + err = conn.QueryRow("SHOW transaction_read_only").Scan(&st) + if err != nil { + t.Fatalf("QueryRow Scan unexpectedly failed: %v", err) + } + + if st == "on" { + t.Error("Connection is not writable") + } + + err = conn.Close() + if err != nil { + t.Fatal("Unable to close connection") + } +} + func TestConnectWithUnixSocketDirectory(t *testing.T) { t.Parallel() @@ -521,6 +620,38 @@ func TestParseURI(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + url: "postgres://jack:secret@foo.example.com:5432,bar.example.com:5432/mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Password: "secret", + Host: "foo.example.com:5432,bar.example.com:5432", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "postgres://jack@localhost,10.10.20.30/mydb?application_name=pgxtest&target_session_attrs=read-write", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost,10.10.20.30", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{ + "application_name": "pgxtest", + }, + TargetSessionAttrs: pgx.ReadWriteTargetSession, + }, + }, } for i, tt := range tests { @@ -647,6 +778,50 @@ func TestParseDSN(t *testing.T) { RuntimeParams: map[string]string{}, }, }, + { + url: "user=jack host=localhost1,localhost2 dbname=mydb connect_timeout=10", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost1,localhost2", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + Dial: (&net.Dialer{Timeout: 10 * time.Second, KeepAlive: 5 * time.Minute}).Dial, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=100.200.220.50,localhost43 port=5432,5433 dbname=mydb", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "100.200.220.50:5432,localhost43:5433", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + }, + }, + { + url: "user=jack host=localhost dbname=mydb target_session_attrs=read-write", + connParams: pgx.ConnConfig{ + User: "jack", + Host: "localhost", + Database: "mydb", + TLSConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + UseFallbackTLS: true, + FallbackTLSConfig: nil, + RuntimeParams: map[string]string{}, + TargetSessionAttrs: pgx.ReadWriteTargetSession, + }, + }, } for i, tt := range tests { @@ -1195,6 +1370,32 @@ func TestExecFailure(t *testing.T) { } } +func TestExecDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred +); + +insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + _, err := conn.Exec(`update t set n=n+1 where id='b'`) + if err == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := err.(pgx.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} + func TestExecFailureWithArguments(t *testing.T) { t.Parallel() @@ -2142,6 +2343,24 @@ func TestSetLogLevel(t *testing.T) { } } +func TestIdentifierSanitizeNullSentToServer(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + ident := pgx.Identifier{"foo" + string([]byte{0}) + "bar"} + + var n int64 + err := conn.QueryRow(`select 1 as ` + ident.Sanitize()).Scan(&n) + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Fatal("unexpected n") + } +} + func TestIdentifierSanitize(t *testing.T) { t.Parallel() @@ -2169,6 +2388,10 @@ func TestIdentifierSanitize(t *testing.T) { ident: pgx.Identifier{`you should " not do this`, `please don't`}, expected: `"you should "" not do this"."please don't"`, }, + { + ident: pgx.Identifier{`you should ` + string([]byte{0}) + `not do this`}, + expected: `"you should not do this"`, + }, } for i, tt := range tests { diff --git a/doc.go b/doc.go index 5808c09d..0c2b35d3 100644 --- a/doc.go +++ b/doc.go @@ -225,7 +225,7 @@ notification. return nil } - if notification, err := conn.WaitForNotification(time.Second); err != nil { + if notification, err := conn.WaitForNotification(context.TODO()); err != nil { // do something with notification } diff --git a/pgmock/pgmock.go b/pgmock/pgmock.go index 4a64b506..97093968 100644 --- a/pgmock/pgmock.go +++ b/pgmock/pgmock.go @@ -210,6 +210,7 @@ func PgxInitSteps() []Step { end from pg_type t left join pg_type base_type on t.typelem=base_type.oid +left join pg_class base_cls ON base_type.typrelid = base_cls.oid left join pg_namespace nsp on t.typnamespace=nsp.oid where ( t.typtype in('b', 'p', 'r', 'e') diff --git a/pgtype/convert.go b/pgtype/convert.go index 5dfb738e..029e3d48 100644 --- a/pgtype/convert.go +++ b/pgtype/convert.go @@ -149,7 +149,7 @@ func underlyingTimeType(val interface{}) (interface{}, bool) { switch refVal.Kind() { case reflect.Ptr: if refVal.IsNil() { - return time.Time{}, false + return nil, false } convVal := refVal.Elem().Interface() return convVal, true @@ -160,7 +160,28 @@ func underlyingTimeType(val interface{}) (interface{}, bool) { return refVal.Convert(timeType).Interface(), true } - return time.Time{}, false + return nil, false +} + +// underlyingUUIDType gets the underlying type that can be converted to [16]byte +func underlyingUUIDType(val interface{}) (interface{}, bool) { + refVal := reflect.ValueOf(val) + + switch refVal.Kind() { + case reflect.Ptr: + if refVal.IsNil() { + return time.Time{}, false + } + convVal := refVal.Elem().Interface() + return convVal, true + } + + uuidType := reflect.TypeOf([16]byte{}) + if refVal.Type().ConvertibleTo(uuidType) { + return refVal.Convert(uuidType).Interface(), true + } + + return nil, false } // underlyingSliceType gets the underlying slice type @@ -401,6 +422,14 @@ func GetAssignToDstType(dst interface{}) (interface{}, bool) { } } + if dstVal.Kind() == reflect.Array { + if baseElemType, ok := kindTypes[dstVal.Type().Elem().Kind()]; ok { + baseArrayType := reflect.PtrTo(reflect.ArrayOf(dstVal.Len(), baseElemType)) + nextDst := dstPtr.Convert(baseArrayType) + return nextDst.Interface(), dstPtr.Type() != nextDst.Type() + } + } + return nil, false } diff --git a/pgtype/uuid.go b/pgtype/uuid.go index 5e1eead5..8d33d8f8 100644 --- a/pgtype/uuid.go +++ b/pgtype/uuid.go @@ -39,7 +39,7 @@ func (dst *UUID) Set(src interface{}) error { } *dst = UUID{Bytes: uuid, Status: Present} default: - if originalSrc, ok := underlyingPtrType(src); ok { + if originalSrc, ok := underlyingUUIDType(src); ok { return dst.Set(originalSrc) } return errors.Errorf("cannot convert %v to UUID", value) diff --git a/pgtype/uuid_test.go b/pgtype/uuid_test.go index 162d999f..1eddeda1 100644 --- a/pgtype/uuid_test.go +++ b/pgtype/uuid_test.go @@ -15,6 +15,8 @@ func TestUUIDTranscode(t *testing.T) { }) } +type SomeUUIDType [16]byte + func TestUUIDSet(t *testing.T) { successfulTests := []struct { source interface{} @@ -32,6 +34,10 @@ func TestUUIDSet(t *testing.T) { source: []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, }, + { + source: SomeUUIDType{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + result: pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present}, + }, { source: ([]byte)(nil), result: pgtype.UUID{Status: pgtype.Null}, @@ -86,6 +92,21 @@ func TestUUIDAssignTo(t *testing.T) { } } + { + src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} + var dst SomeUUIDType + expected := [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + + err := src.AssignTo(&dst) + if err != nil { + t.Error(err) + } + + if dst != expected { + t.Errorf("expected %v to assign %v, but result was %v", src, expected, dst) + } + } + { src := pgtype.UUID{Bytes: [16]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, Status: pgtype.Present} var dst string diff --git a/query.go b/query.go index 5c6cbf7f..bf4ec561 100644 --- a/query.go +++ b/query.go @@ -69,6 +69,25 @@ func (rows *Rows) Close() { return } + // If there is no error and a batch operation is not in progress read until we get the ReadyForQuery message or the + // ErrorResponse. This is necessary to detect a deferred constraint violation where the ErrorResponse is sent after + // CommandComplete. + if rows.err == nil && rows.batch == nil && rows.conn.pendingReadyForQueryCount == 1 { + for rows.conn.pendingReadyForQueryCount > 0 { + msg, err := rows.conn.rxMsg() + if err != nil { + rows.err = err + break + } + + err = rows.conn.processContextFreeMsg(msg) + if err != nil { + rows.err = err + break + } + } + } + if rows.unlockConn { rows.conn.unlock() rows.unlockConn = false diff --git a/query_test.go b/query_test.go index 06b7b8b7..ea1fd66e 100644 --- a/query_test.go +++ b/query_test.go @@ -14,7 +14,7 @@ import ( "github.com/jackc/pgx" "github.com/jackc/pgx/pgtype" satori "github.com/jackc/pgx/pgtype/ext/satori-uuid" - "github.com/satori/go.uuid" + uuid "github.com/satori/go.uuid" "github.com/shopspring/decimal" ) @@ -424,6 +424,47 @@ func TestConnQueryErrorWhileReturningRows(t *testing.T) { } +// https://github.com/jackc/pgx/issues/570 +func TestConnQueryDeferredError(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + mustExec(t, conn, `create temporary table t ( + id text primary key, + n int not null, + unique (n) deferrable initially deferred +); + +insert into t (id, n) values ('a', 1), ('b', 2), ('c', 3);`) + + rows, err := conn.Query(`update t set n=n+1 where id='b' returning *`) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + var id string + var n int32 + err = rows.Scan(&id, &n) + if err != nil { + t.Fatal(err) + } + } + + if rows.Err() == nil { + t.Fatal("expected error 23505 but got none") + } + + if err, ok := rows.Err().(pgx.PgError); !ok || err.Code != "23505" { + t.Fatalf("expected error 23505, got %v", err) + } + + ensureConnValid(t, conn) +} + func TestQueryEncodeError(t *testing.T) { t.Parallel() diff --git a/stdlib/sql.go b/stdlib/sql.go index ec5933f3..3cd2d941 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -43,8 +43,8 @@ // // AcquireConn and ReleaseConn acquire and release a *pgx.Conn from the standard // database/sql.DB connection pool. This allows operations that must be -// performed on a single connection, but should not be run in a transaction or -// to use pgx specific functionality. +// performed on a single connection without running in a transaction, and it +// supports operations that use pgx specific functionality. // // conn, err := stdlib.AcquireConn(db) // if err != nil { @@ -277,7 +277,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, e pgxOpts.IsoLevel = pgx.ReadUncommitted case sql.LevelReadCommitted: pgxOpts.IsoLevel = pgx.ReadCommitted - case sql.LevelSnapshot: + case sql.LevelRepeatableRead, sql.LevelSnapshot: pgxOpts.IsoLevel = pgx.RepeatableRead case sql.LevelSerializable: pgxOpts.IsoLevel = pgx.Serializable diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index cf2b91b1..895ee583 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -629,6 +629,7 @@ func TestConnBeginTxIsolation(t *testing.T) { {sqlIso: sql.LevelDefault, pgIso: defaultIsoLevel}, {sqlIso: sql.LevelReadUncommitted, pgIso: "read uncommitted"}, {sqlIso: sql.LevelReadCommitted, pgIso: "read committed"}, + {sqlIso: sql.LevelRepeatableRead, pgIso: "repeatable read"}, {sqlIso: sql.LevelSnapshot, pgIso: "repeatable read"}, {sqlIso: sql.LevelSerializable, pgIso: "serializable"}, }