pgx/conn.go

777 lines
22 KiB
Go

package pgx
import (
"context"
"strconv"
"strings"
"time"
errors "golang.org/x/xerrors"
"github.com/jackc/pgconn"
"github.com/jackc/pgconn/stmtcache"
"github.com/jackc/pgproto3/v2"
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v4/internal/sanitize"
)
const (
connStatusUninitialized = iota
connStatusClosed
connStatusIdle
connStatusBusy
)
// ConnConfig contains all the options used to establish a connection. It must be created by ParseConfig and
// then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic.
type ConnConfig struct {
pgconn.Config
Logger Logger
LogLevel LogLevel
// BuildStatementCache creates the stmtcache.Cache implementation for connections created with this config. Set
// to nil to disable automatic prepared statements.
BuildStatementCache BuildStatementCacheFunc
// PreferSimpleProtocol disables implicit prepared statement usage. By default pgx automatically uses the extended
// protocol. This can improve performance due to being able to use the binary format. It also does not rely on client
// side parameter sanitization. However, it does incur two round-trips per query (unless using a prepared statement)
// and may be incompatible proxies such as PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be
// used by default. The same functionality can be controlled on a per query basis by setting
// QueryExOptions.SimpleProtocol.
PreferSimpleProtocol bool
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
}
// BuildStatementCacheFunc is a function that can be used to create a stmtcache.Cache implementation for connection.
type BuildStatementCacheFunc func(conn *pgconn.PgConn) stmtcache.Cache
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access
// to multiple database connections from multiple goroutines.
type Conn struct {
pgConn *pgconn.PgConn
config *ConnConfig // config used when establishing this connection
preparedStatements map[string]*pgconn.StatementDescription
stmtcache stmtcache.Cache
logger Logger
logLevel LogLevel
notifications []*pgconn.Notification
doneChan chan struct{}
closedChan chan error
ConnInfo *pgtype.ConnInfo
wbuf []byte
preallocatedRows []connRows
eqb extendedQueryBuilder
}
// Identifier a PostgreSQL identifier or name. Identifiers can be composed of
// multiple parts such as ["schema", "table"] or ["table", "column"].
type Identifier []string
// Sanitize returns a sanitized string safe for SQL interpolation.
func (ident Identifier) Sanitize() string {
parts := make([]string, len(ident))
for i := range ident {
s := strings.Replace(ident[i], string([]byte{0}), "", -1)
parts[i] = `"` + strings.Replace(s, `"`, `""`, -1) + `"`
}
return strings.Join(parts, ".")
}
// ErrNoRows occurs when rows are expected but none are returned.
var ErrNoRows = errors.New("no rows in result set")
// ErrInvalidLogLevel occurs on attempt to set an invalid log level.
var ErrInvalidLogLevel = errors.New("invalid log level")
// Connect establishes a connection with a PostgreSQL server with a connection string. See
// pgconn.Connect for details.
func Connect(ctx context.Context, connString string) (*Conn, error) {
connConfig, err := ParseConfig(connString)
if err != nil {
return nil, err
}
return connect(ctx, connConfig)
}
// Connect establishes a connection with a PostgreSQL server with a configuration struct. connConfig must have been
// created by ParseConfig.
func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
return connect(ctx, connConfig)
}
// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig
// does. In addition, it accepts the following options:
//
// statement_cache_capacity
// The maximum size of the automatic statement cache. Set to 0 to disable automatic statement caching. Default: 512.
//
// statement_cache_mode
// Possible values: "prepare" and "describe". "prepare" will create prepared statements on the PostgreSQL server.
// "describe" will use the anonymous prepared statement to describe a statement without creating a statement on the
// server. "describe" is primarily useful when the environment does not allow prepared statements such as when
// running a connection pooler like PgBouncer. Default: "prepare"
func ParseConfig(connString string) (*ConnConfig, error) {
config, err := pgconn.ParseConfig(connString)
if err != nil {
return nil, err
}
var buildStatementCache BuildStatementCacheFunc
statementCacheCapacity := 512
statementCacheMode := stmtcache.ModePrepare
if s, ok := config.RuntimeParams["statement_cache_capacity"]; ok {
delete(config.RuntimeParams, "statement_cache_capacity")
n, err := strconv.ParseInt(s, 10, 32)
if err != nil {
return nil, errors.Errorf("cannot parse statement_cache_capacity: %w", err)
}
statementCacheCapacity = int(n)
}
if s, ok := config.RuntimeParams["statement_cache_mode"]; ok {
delete(config.RuntimeParams, "statement_cache_mode")
switch s {
case "prepare":
statementCacheMode = stmtcache.ModePrepare
case "describe":
statementCacheMode = stmtcache.ModeDescribe
default:
return nil, errors.Errorf("invalid statement_cache_mod: %s", s)
}
}
if statementCacheCapacity > 0 {
buildStatementCache = func(conn *pgconn.PgConn) stmtcache.Cache {
return stmtcache.New(conn, statementCacheMode, statementCacheCapacity)
}
}
connConfig := &ConnConfig{
Config: *config,
createdByParseConfig: true,
LogLevel: LogLevelInfo,
BuildStatementCache: buildStatementCache,
}
return connConfig, nil
}
func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
// zero values.
if !config.createdByParseConfig {
panic("config must be created by ParseConfig")
}
c = &Conn{
config: config,
ConnInfo: pgtype.NewConnInfo(),
logLevel: config.LogLevel,
logger: config.Logger,
}
// Only install pgx notification system if no other callback handler is present.
if config.Config.OnNotification == nil {
config.Config.OnNotification = c.bufferNotifications
} else {
if c.shouldLog(LogLevelDebug) {
c.log(ctx, LogLevelDebug, "pgx notification handler disabled by application supplied OnNotification", map[string]interface{}{"host": config.Config.Host})
}
}
if c.shouldLog(LogLevelInfo) {
c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host})
}
c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config)
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(ctx, LogLevelError, "connect failed", map[string]interface{}{"err": err})
}
return nil, err
}
c.preparedStatements = make(map[string]*pgconn.StatementDescription)
c.doneChan = make(chan struct{})
c.closedChan = make(chan error)
c.wbuf = make([]byte, 0, 1024)
if c.config.BuildStatementCache != nil {
c.stmtcache = c.config.BuildStatementCache(c.pgConn)
}
// Replication connections can't execute the queries to
// populate the c.PgTypes and c.pgsqlAfInet
if _, ok := config.Config.RuntimeParams["replication"]; ok {
return c, nil
}
return c, nil
}
// Close closes a connection. It is safe to call Close on a already closed
// connection.
func (c *Conn) Close(ctx context.Context) error {
if c.IsClosed() {
return nil
}
err := c.pgConn.Close(ctx)
if c.shouldLog(LogLevelInfo) {
c.log(ctx, LogLevelInfo, "closed connection", nil)
}
return err
}
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
//
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same
// name and sql arguments. This allows a code path to Prepare and Query/Exec without
// concern for if the statement has already been prepared.
func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) {
if name != "" {
var ok bool
if sd, ok = c.preparedStatements[name]; ok && sd.SQL == sql {
return sd, nil
}
}
if c.shouldLog(LogLevelError) {
defer func() {
if err != nil {
c.log(ctx, LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql})
}
}()
}
sd, err = c.pgConn.Prepare(ctx, name, sql, nil)
if err != nil {
return nil, err
}
if name != "" {
c.preparedStatements[name] = sd
}
return sd, nil
}
// Deallocate released a prepared statement
func (c *Conn) Deallocate(ctx context.Context, name string) error {
delete(c.preparedStatements, name)
_, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll()
return err
}
func (c *Conn) bufferNotifications(_ *pgconn.PgConn, n *pgconn.Notification) {
c.notifications = append(c.notifications, n)
}
// WaitForNotification waits for a PostgreSQL notification. It wraps the underlying pgconn notification system in a
// slightly more convenient form.
func (c *Conn) WaitForNotification(ctx context.Context) (*pgconn.Notification, error) {
var n *pgconn.Notification
// Return already received notification immediately
if len(c.notifications) > 0 {
n = c.notifications[0]
c.notifications = c.notifications[1:]
return n, nil
}
err := c.pgConn.WaitForNotification(ctx)
if len(c.notifications) > 0 {
n = c.notifications[0]
c.notifications = c.notifications[1:]
}
return n, err
}
func (c *Conn) IsClosed() bool {
return c.pgConn.IsClosed()
}
// Processes messages that are not exclusive to one context such as
// authentication or query response. The response to these messages is the same
// regardless of when they occur. It also ignores messages that are only
// meaningful in a given context. These messages can occur due to a context
// deadline interrupting message processing. For example, an interrupted query
// may have left DataRow messages on the wire.
func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) {
switch msg := msg.(type) {
case *pgproto3.ErrorResponse:
return c.rxErrorResponse(msg)
}
return nil
}
func (c *Conn) rxErrorResponse(msg *pgproto3.ErrorResponse) *pgconn.PgError {
err := &pgconn.PgError{
Severity: msg.Severity,
Code: msg.Code,
Message: msg.Message,
Detail: msg.Detail,
Hint: msg.Hint,
Position: msg.Position,
InternalPosition: msg.InternalPosition,
InternalQuery: msg.InternalQuery,
Where: msg.Where,
SchemaName: msg.SchemaName,
TableName: msg.TableName,
ColumnName: msg.ColumnName,
DataTypeName: msg.DataTypeName,
ConstraintName: msg.ConstraintName,
File: msg.File,
Line: msg.Line,
Routine: msg.Routine,
}
if err.Severity == "FATAL" {
c.die(err)
}
return err
}
func (c *Conn) die(err error) {
if c.IsClosed() {
return
}
ctx, cancel := context.WithCancel(context.Background())
cancel() // force immediate hard cancel
c.pgConn.Close(ctx)
}
func (c *Conn) shouldLog(lvl LogLevel) bool {
return c.logger != nil && c.logLevel >= lvl
}
func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) {
if data == nil {
data = map[string]interface{}{}
}
if c.pgConn != nil && c.pgConn.PID() != 0 {
data["pid"] = c.pgConn.PID()
}
c.logger.Log(ctx, lvl, msg, data)
}
func quoteIdentifier(s string) string {
return `"` + strings.Replace(s, `"`, `""`, -1) + `"`
}
func (c *Conn) Ping(ctx context.Context) error {
_, err := c.Exec(ctx, ";")
return err
}
func connInfoFromRows(rows Rows, err error) (map[string]uint32, error) {
if err != nil {
return nil, err
}
defer rows.Close()
nameOIDs := make(map[string]uint32, 256)
for rows.Next() {
var oid uint32
var name pgtype.Text
if err = rows.Scan(&oid, &name); err != nil {
return nil, err
}
nameOIDs[name.String] = oid
}
if err = rows.Err(); err != nil {
return nil, err
}
return nameOIDs, err
}
// PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the
// PostgreSQL connection than pgx exposes.
//
// It is strongly recommended that the connection be idle (no in-progress queries) before the underlying *pgconn.PgConn
// is used and the connection must be returned to the same state before any *pgx.Conn methods are again used.
func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn }
// StatementCache returns the statement cache used for this connection.
func (c *Conn) StatementCache() stmtcache.Cache { return c.stmtcache }
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
// positionally from the sql string as $1, $2, etc.
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) {
startTime := time.Now()
commandTag, err := c.exec(ctx, sql, arguments...)
if err != nil {
if c.shouldLog(LogLevelError) {
c.log(ctx, LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err})
}
return commandTag, err
}
if c.shouldLog(LogLevelInfo) {
endTime := time.Now()
c.log(ctx, LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag})
}
return commandTag, err
}
func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) {
simpleProtocol := c.config.PreferSimpleProtocol
optionLoop:
for len(arguments) > 0 {
switch arg := arguments[0].(type) {
case QuerySimpleProtocol:
simpleProtocol = bool(arg)
arguments = arguments[1:]
default:
break optionLoop
}
}
if simpleProtocol {
return c.execSimpleProtocol(ctx, sql, arguments)
}
if sd, ok := c.preparedStatements[sql]; ok {
return c.execPrepared(ctx, sd, arguments)
}
if len(arguments) == 0 {
return c.execSimpleProtocol(ctx, sql, arguments)
}
if c.stmtcache != nil {
sd, err := c.stmtcache.Get(ctx, sql)
if err != nil {
return nil, err
}
if c.stmtcache.Mode() == stmtcache.ModeDescribe {
return c.execParams(ctx, sd, arguments)
}
return c.execPrepared(ctx, sd, arguments)
}
sd, err := c.Prepare(ctx, "", sql)
if err != nil {
return nil, err
}
return c.execPrepared(ctx, sd, arguments)
}
func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []interface{}) (commandTag pgconn.CommandTag, err error) {
if len(arguments) > 0 {
sql, err = c.sanitizeForSimpleQuery(sql, arguments...)
if err != nil {
return nil, err
}
}
mrr := c.pgConn.Exec(ctx, sql)
for mrr.NextResult() {
commandTag, err = mrr.ResultReader().Close()
}
err = mrr.Close()
return commandTag, err
}
func (c *Conn) execParamsAndPreparedPrefix(sd *pgconn.StatementDescription, arguments []interface{}) error {
c.eqb.Reset()
args, err := convertDriverValuers(arguments)
if err != nil {
return err
}
for i := range args {
err = c.eqb.AppendParam(c.ConnInfo, sd.ParamOIDs[i], args[i])
if err != nil {
return err
}
}
for i := range sd.Fields {
c.eqb.AppendResultFormat(c.ConnInfo.ResultFormatCodeForOID(sd.Fields[i].DataTypeOID))
}
return nil
}
func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) {
err := c.execParamsAndPreparedPrefix(sd, arguments)
if err != nil {
return nil, err
}
result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats).Read()
return result.CommandTag, result.Err
}
func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []interface{}) (pgconn.CommandTag, error) {
err := c.execParamsAndPreparedPrefix(sd, arguments)
if err != nil {
return nil, err
}
result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read()
return result.CommandTag, result.Err
}
func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows {
if len(c.preallocatedRows) == 0 {
c.preallocatedRows = make([]connRows, 64)
}
r := &c.preallocatedRows[len(c.preallocatedRows)-1]
c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1]
r.ctx = ctx
r.logger = c
r.connInfo = c.ConnInfo
r.startTime = time.Now()
r.sql = sql
r.args = args
return r
}
// QuerySimpleProtocol controls whether the simple or extended protocol is used to send the query.
type QuerySimpleProtocol bool
// QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position.
type QueryResultFormats []int16
// QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID.
type QueryResultFormatsByOID map[uint32]int16
// Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is
// allowed to ignore the error returned from Query and handle it in Rows.
func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) {
var resultFormats QueryResultFormats
var resultFormatsByOID QueryResultFormatsByOID
simpleProtocol := c.config.PreferSimpleProtocol
optionLoop:
for len(args) > 0 {
switch arg := args[0].(type) {
case QueryResultFormats:
resultFormats = arg
args = args[1:]
case QueryResultFormatsByOID:
resultFormatsByOID = arg
args = args[1:]
case QuerySimpleProtocol:
simpleProtocol = bool(arg)
args = args[1:]
default:
break optionLoop
}
}
rows := c.getRows(ctx, sql, args)
var err error
if simpleProtocol {
sql, err = c.sanitizeForSimpleQuery(sql, args...)
if err != nil {
rows.fatal(err)
return rows, err
}
mrr := c.pgConn.Exec(ctx, sql)
if mrr.NextResult() {
rows.resultReader = mrr.ResultReader()
rows.multiResultReader = mrr
} else {
err = mrr.Close()
rows.fatal(err)
return rows, err
}
return rows, nil
}
c.eqb.Reset()
sd, ok := c.preparedStatements[sql]
if !ok {
if c.stmtcache != nil {
sd, err = c.stmtcache.Get(ctx, sql)
if err != nil {
rows.fatal(err)
return rows, rows.err
}
} else {
sd, err = c.pgConn.Prepare(ctx, "", sql, nil)
if err != nil {
rows.fatal(err)
return rows, rows.err
}
}
}
if len(sd.ParamOIDs) != len(args) {
rows.fatal(errors.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args)))
return rows, rows.err
}
rows.sql = sd.SQL
args, err = convertDriverValuers(args)
if err != nil {
rows.fatal(err)
return rows, rows.err
}
for i := range args {
err = c.eqb.AppendParam(c.ConnInfo, sd.ParamOIDs[i], args[i])
if err != nil {
rows.fatal(err)
return rows, rows.err
}
}
if resultFormatsByOID != nil {
resultFormats = make([]int16, len(sd.Fields))
for i := range resultFormats {
resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)]
}
}
if resultFormats == nil {
for i := range sd.Fields {
c.eqb.AppendResultFormat(c.ConnInfo.ResultFormatCodeForOID(sd.Fields[i].DataTypeOID))
}
resultFormats = c.eqb.resultFormats
}
if c.stmtcache != nil && c.stmtcache.Mode() == stmtcache.ModeDescribe {
rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, resultFormats)
} else {
rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats)
}
return rows, rows.err
}
// QueryRow is a convenience wrapper over Query. Any error that occurs while
// querying is deferred until calling Scan on the returned Row. That Row will
// error with ErrNoRows if no rows are returned.
func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row {
rows, _ := c.Query(ctx, sql, args...)
return (*connRow)(rows.(*connRows))
}
// SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless
// explicit transaction control statements are executed.
func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults {
distinctUnpreparedQueries := map[string]struct{}{}
for _, bi := range b.items {
if _, ok := c.preparedStatements[bi.query]; ok {
continue
}
distinctUnpreparedQueries[bi.query] = struct{}{}
}
var stmtCache stmtcache.Cache
if c.stmtcache != nil && c.stmtcache.Cap() >= len(distinctUnpreparedQueries) {
stmtCache = c.stmtcache
} else {
stmtCache = stmtcache.New(c.pgConn, stmtcache.ModeDescribe, len(distinctUnpreparedQueries))
}
for sql, _ := range distinctUnpreparedQueries {
_, err := stmtCache.Get(ctx, sql)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
batch := &pgconn.Batch{}
for _, bi := range b.items {
c.eqb.Reset()
sd := c.preparedStatements[bi.query]
if sd == nil {
var err error
sd, err = stmtCache.Get(ctx, bi.query)
if err != nil {
// the stmtCache was prefilled from distinctUnpreparedQueries above so we are guaranteed no errors
panic("BUG: unexpected error from stmtCache")
}
}
if len(sd.ParamOIDs) != len(bi.arguments) {
return &batchResults{ctx: ctx, conn: c, err: errors.Errorf("mismatched param and argument count")}
}
args, err := convertDriverValuers(bi.arguments)
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
for i := range args {
err = c.eqb.AppendParam(c.ConnInfo, sd.ParamOIDs[i], args[i])
if err != nil {
return &batchResults{ctx: ctx, conn: c, err: err}
}
}
for i := range sd.Fields {
c.eqb.AppendResultFormat(c.ConnInfo.ResultFormatCodeForOID(sd.Fields[i].DataTypeOID))
}
if sd.Name == "" {
batch.ExecParams(bi.query, c.eqb.paramValues, sd.ParamOIDs, c.eqb.paramFormats, c.eqb.resultFormats)
} else {
batch.ExecPrepared(sd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats)
}
}
mrr := c.pgConn.ExecBatch(ctx, batch)
return &batchResults{
ctx: ctx,
conn: c,
mrr: mrr,
}
}
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) {
if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
}
if c.pgConn.ParameterStatus("client_encoding") != "UTF8" {
return "", errors.New("simple protocol queries must be run with client_encoding=UTF8")
}
var err error
valueArgs := make([]interface{}, len(args))
for i, a := range args {
valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a)
if err != nil {
return "", err
}
}
return sanitize.SanitizeSQL(sql, valueArgs...)
}