package pgx import ( "context" "strings" "time" errors "golang.org/x/xerrors" "github.com/jackc/pgconn" "github.com/jackc/pgproto3/v2" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4/internal/sanitize" ) const ( connStatusUninitialized = iota connStatusClosed connStatusIdle connStatusBusy ) // ConnConfig contains all the options used to establish a connection. type ConnConfig struct { pgconn.Config Logger Logger LogLevel LogLevel // PreferSimpleProtocol disables implicit prepared statement usage. By default pgx automatically uses the extended // protocol. This can improve performance due to being able to use the binary format. It also does not rely on client // side parameter sanitization. However, it does incur two round-trips per query (unless using a prepared statement) // and may be incompatible proxies such as PGBouncer. Setting PreferSimpleProtocol causes the simple protocol to be // used by default. The same functionality can be controlled on a per query basis by setting // QueryExOptions.SimpleProtocol. PreferSimpleProtocol bool } // Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access // to multiple database connections from multiple goroutines. type Conn struct { pgConn *pgconn.PgConn config *ConnConfig // config used when establishing this connection preparedStatements map[string]*PreparedStatement logger Logger logLevel LogLevel causeOfDeath error doneChan chan struct{} closedChan chan error ConnInfo *pgtype.ConnInfo wbuf []byte preallocatedRows []connRows eqb extendedQueryBuilder } // PreparedStatement is a description of a prepared statement type PreparedStatement struct { Name string SQL string FieldDescriptions []pgproto3.FieldDescription ParameterOIDs []pgtype.OID } // PrepareExOptions is an option struct that can be passed to PrepareEx type PrepareExOptions struct { ParameterOIDs []pgtype.OID } // Identifier a PostgreSQL identifier or name. Identifiers can be composed of // multiple parts such as ["schema", "table"] or ["table", "column"]. type Identifier []string // Sanitize returns a sanitized string safe for SQL interpolation. func (ident Identifier) Sanitize() string { parts := make([]string, len(ident)) for i := range ident { s := strings.Replace(ident[i], string([]byte{0}), "", -1) parts[i] = `"` + strings.Replace(s, `"`, `""`, -1) + `"` } return strings.Join(parts, ".") } // ErrNoRows occurs when rows are expected but none are returned. var ErrNoRows = errors.New("no rows in result set") // ErrDeadConn occurs on an attempt to use a dead connection var ErrDeadConn = errors.New("conn is dead") // ErrTLSRefused occurs when the connection attempt requires TLS and the // PostgreSQL server refuses to use TLS var ErrTLSRefused = pgconn.ErrTLSRefused // ErrInvalidLogLevel occurs on attempt to set an invalid log level. var ErrInvalidLogLevel = errors.New("invalid log level") // ProtocolError occurs when unexpected data is received from PostgreSQL type ProtocolError string func (e ProtocolError) Error() string { return string(e) } // Connect establishes a connection with a PostgreSQL server with a connection string. See // pgconn.Connect for details. func Connect(ctx context.Context, connString string) (*Conn, error) { connConfig, err := ParseConfig(connString) if err != nil { return nil, err } return connect(ctx, connConfig) } // Connect establishes a connection with a PostgreSQL server with a configuration struct. func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) { return connect(ctx, connConfig) } func ParseConfig(connString string) (*ConnConfig, error) { config, err := pgconn.ParseConfig(connString) if err != nil { return nil, err } connConfig := &ConnConfig{ Config: *config, } return connConfig, nil } func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) { c = new(Conn) c.config = config c.ConnInfo = pgtype.NewConnInfo() if c.config.LogLevel != 0 { c.logLevel = c.config.LogLevel } else { // Preserve pre-LogLevel behavior by defaulting to LogLevelDebug c.logLevel = LogLevelDebug } c.logger = c.config.Logger if c.shouldLog(LogLevelInfo) { c.log(ctx, LogLevelInfo, "Dialing PostgreSQL server", map[string]interface{}{"host": config.Config.Host}) } c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config) if err != nil { return nil, err } if err != nil { if c.shouldLog(LogLevelError) { c.log(ctx, LogLevelError, "connect failed", map[string]interface{}{"err": err}) } return nil, err } c.preparedStatements = make(map[string]*PreparedStatement) c.doneChan = make(chan struct{}) c.closedChan = make(chan error) c.wbuf = make([]byte, 0, 1024) // Replication connections can't execute the queries to // populate the c.PgTypes and c.pgsqlAfInet if _, ok := c.pgConn.Config.RuntimeParams["replication"]; ok { return c, nil } return c, nil } // Close closes a connection. It is safe to call Close on a already closed // connection. func (c *Conn) Close(ctx context.Context) error { if !c.IsAlive() { return nil } err := c.pgConn.Close(ctx) c.causeOfDeath = errors.New("Closed") if c.shouldLog(LogLevelInfo) { c.log(ctx, LogLevelInfo, "closed connection", nil) } return err } // Prepare creates a prepared statement with name and sql. sql can contain placeholders // for bound parameters. These placeholders are referenced positional as $1, $2, etc. // // Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same // name and sql arguments. This allows a code path to Prepare and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) Prepare(ctx context.Context, name, sql string) (ps *PreparedStatement, err error) { if name != "" { if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql { return ps, nil } } if c.shouldLog(LogLevelError) { defer func() { if err != nil { c.log(ctx, LogLevelError, "Prepare failed", map[string]interface{}{"err": err, "name": name, "sql": sql}) } }() } psd, err := c.pgConn.Prepare(ctx, name, sql, nil) if err != nil { return nil, err } ps = &PreparedStatement{ Name: psd.Name, SQL: psd.SQL, ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)), FieldDescriptions: psd.Fields, } for i := range ps.ParameterOIDs { ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i]) } if name != "" { c.preparedStatements[name] = ps } return ps, nil } // Deallocate released a prepared statement func (c *Conn) Deallocate(ctx context.Context, name string) error { delete(c.preparedStatements, name) _, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll() return err } func (c *Conn) IsAlive() bool { return c.pgConn.IsAlive() } func (c *Conn) CauseOfDeath() error { return c.causeOfDeath } // Processes messages that are not exclusive to one context such as // authentication or query response. The response to these messages is the same // regardless of when they occur. It also ignores messages that are only // meaningful in a given context. These messages can occur due to a context // deadline interrupting message processing. For example, an interrupted query // may have left DataRow messages on the wire. func (c *Conn) processContextFreeMsg(msg pgproto3.BackendMessage) (err error) { switch msg := msg.(type) { case *pgproto3.ErrorResponse: return c.rxErrorResponse(msg) } return nil } func (c *Conn) rxErrorResponse(msg *pgproto3.ErrorResponse) *pgconn.PgError { err := &pgconn.PgError{ Severity: msg.Severity, Code: msg.Code, Message: msg.Message, Detail: msg.Detail, Hint: msg.Hint, Position: msg.Position, InternalPosition: msg.InternalPosition, InternalQuery: msg.InternalQuery, Where: msg.Where, SchemaName: msg.SchemaName, TableName: msg.TableName, ColumnName: msg.ColumnName, DataTypeName: msg.DataTypeName, ConstraintName: msg.ConstraintName, File: msg.File, Line: msg.Line, Routine: msg.Routine, } if err.Severity == "FATAL" { c.die(err) } return err } func (c *Conn) die(err error) { if !c.IsAlive() { return } c.causeOfDeath = err ctx, cancel := context.WithCancel(context.Background()) cancel() // force immediate hard cancel c.pgConn.Close(ctx) } func (c *Conn) shouldLog(lvl LogLevel) bool { return c.logger != nil && c.logLevel >= lvl } func (c *Conn) log(ctx context.Context, lvl LogLevel, msg string, data map[string]interface{}) { if data == nil { data = map[string]interface{}{} } if c.pgConn != nil && c.pgConn.PID() != 0 { data["pid"] = c.pgConn.PID() } c.logger.Log(ctx, lvl, msg, data) } func quoteIdentifier(s string) string { return `"` + strings.Replace(s, `"`, `""`, -1) + `"` } func (c *Conn) Ping(ctx context.Context) error { _, err := c.Exec(ctx, ";") return err } func connInfoFromRows(rows Rows, err error) (map[string]pgtype.OID, error) { if err != nil { return nil, err } defer rows.Close() nameOIDs := make(map[string]pgtype.OID, 256) for rows.Next() { var oid pgtype.OID var name pgtype.Text if err = rows.Scan(&oid, &name); err != nil { return nil, err } nameOIDs[name.String] = oid } if err = rows.Err(); err != nil { return nil, err } return nameOIDs, err } // PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the // PostgreSQL connection than pgx exposes. // // It is strongly recommended that the connection be idle (no in-progress queries) before the underlying *pgconn.PgConn // is used and the connection must be returned to the same state before any *pgx.Conn methods are again used. func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn } // Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced // positionally from the sql string as $1, $2, etc. func (c *Conn) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) { startTime := time.Now() commandTag, err := c.exec(ctx, sql, arguments...) if err != nil { if c.shouldLog(LogLevelError) { c.log(ctx, LogLevelError, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "err": err}) } return commandTag, err } if c.shouldLog(LogLevelInfo) { endTime := time.Now() c.log(ctx, LogLevelInfo, "Exec", map[string]interface{}{"sql": sql, "args": logQueryArgs(arguments), "time": endTime.Sub(startTime), "commandTag": commandTag}) } return commandTag, err } func (c *Conn) exec(ctx context.Context, sql string, arguments ...interface{}) (commandTag pgconn.CommandTag, err error) { simpleProtocol := c.config.PreferSimpleProtocol optionLoop: for len(arguments) > 0 { switch arg := arguments[0].(type) { case QuerySimpleProtocol: simpleProtocol = bool(arg) arguments = arguments[1:] default: break optionLoop } } if simpleProtocol { sql, err = c.sanitizeForSimpleQuery(sql, arguments...) if err != nil { return nil, err } mrr := c.pgConn.Exec(ctx, sql) if mrr.NextResult() { result := mrr.ResultReader().Read() err = mrr.Close() return result.CommandTag, err } else { err = mrr.Close() return nil, err } } c.eqb.Reset() if ps, ok := c.preparedStatements[sql]; ok { args, err := convertDriverValuers(arguments) if err != nil { return nil, err } for i := range args { err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], args[i]) if err != nil { return nil, err } } for i := range ps.FieldDescriptions { if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok { if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { c.eqb.AppendResultFormat(BinaryFormatCode) } else { c.eqb.AppendResultFormat(TextFormatCode) } } } result := c.pgConn.ExecPrepared(ctx, ps.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() return result.CommandTag, result.Err } if len(arguments) == 0 { results, err := c.pgConn.Exec(ctx, sql).ReadAll() if err != nil { return nil, err } if len(results) == 0 { return nil, nil } return results[len(results)-1].CommandTag, nil } else { psd, err := c.pgConn.Prepare(ctx, "", sql, nil) if err != nil { return nil, err } if len(psd.ParamOIDs) != len(arguments) { return nil, errors.Errorf("expected %d arguments, got %d", len(psd.ParamOIDs), len(arguments)) } ps := &PreparedStatement{ Name: psd.Name, SQL: psd.SQL, ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)), FieldDescriptions: psd.Fields, } for i := range ps.ParameterOIDs { ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i]) } arguments, err = convertDriverValuers(arguments) if err != nil { return nil, err } for i := range arguments { err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], arguments[i]) if err != nil { return nil, err } } for i := range ps.FieldDescriptions { if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok { if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { c.eqb.AppendResultFormat(BinaryFormatCode) } else { c.eqb.AppendResultFormat(TextFormatCode) } } } result := c.pgConn.ExecPrepared(ctx, psd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() return result.CommandTag, result.Err } } func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows { if len(c.preallocatedRows) == 0 { c.preallocatedRows = make([]connRows, 64) } r := &c.preallocatedRows[len(c.preallocatedRows)-1] c.preallocatedRows = c.preallocatedRows[0 : len(c.preallocatedRows)-1] r.ctx = ctx r.logger = c r.connInfo = c.ConnInfo r.startTime = time.Now() r.sql = sql r.args = args return r } // QuerySimpleProtocol controls whether the simple or extended protocol is used to send the query. type QuerySimpleProtocol bool // QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position. type QueryResultFormats []int16 // QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID. type QueryResultFormatsByOID map[pgtype.OID]int16 // Query executes sql with args. If there is an error the returned Rows will be returned in an error state. So it is // allowed to ignore the error returned from Query and handle it in Rows. func (c *Conn) Query(ctx context.Context, sql string, args ...interface{}) (Rows, error) { var resultFormats QueryResultFormats var resultFormatsByOID QueryResultFormatsByOID simpleProtocol := c.config.PreferSimpleProtocol optionLoop: for len(args) > 0 { switch arg := args[0].(type) { case QueryResultFormats: resultFormats = arg args = args[1:] case QueryResultFormatsByOID: resultFormatsByOID = arg args = args[1:] case QuerySimpleProtocol: simpleProtocol = bool(arg) args = args[1:] default: break optionLoop } } rows := c.getRows(ctx, sql, args) var err error if simpleProtocol { sql, err = c.sanitizeForSimpleQuery(sql, args...) if err != nil { rows.fatal(err) return rows, err } mrr := c.pgConn.Exec(ctx, sql) if mrr.NextResult() { rows.resultReader = mrr.ResultReader() rows.multiResultReader = mrr } else { err = mrr.Close() rows.fatal(err) return rows, err } return rows, nil } c.eqb.Reset() ps, ok := c.preparedStatements[sql] if !ok { psd, err := c.pgConn.Prepare(ctx, "", sql, nil) if err != nil { rows.fatal(err) return rows, rows.err } if len(psd.ParamOIDs) != len(args) { rows.fatal(errors.Errorf("expected %d arguments, got %d", len(psd.ParamOIDs), len(args))) return rows, rows.err } ps = &PreparedStatement{ Name: psd.Name, SQL: psd.SQL, ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)), FieldDescriptions: psd.Fields, } for i := range ps.ParameterOIDs { ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i]) } } rows.sql = ps.SQL args, err = convertDriverValuers(args) if err != nil { rows.fatal(err) return rows, rows.err } for i := range args { err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], args[i]) if err != nil { rows.fatal(err) return rows, rows.err } } if resultFormatsByOID != nil { resultFormats = make([]int16, len(ps.FieldDescriptions)) for i := range resultFormats { resultFormats[i] = resultFormatsByOID[pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)] } } if resultFormats == nil { for i := range ps.FieldDescriptions { if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok { if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { c.eqb.AppendResultFormat(BinaryFormatCode) } else { c.eqb.AppendResultFormat(TextFormatCode) } } } resultFormats = c.eqb.resultFormats } rows.resultReader = c.pgConn.ExecPrepared(ctx, ps.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) return rows, rows.err } // QueryRow is a convenience wrapper over Query. Any error that occurs while // querying is deferred until calling Scan on the returned Row. That Row will // error with ErrNoRows if no rows are returned. func (c *Conn) QueryRow(ctx context.Context, sql string, args ...interface{}) Row { rows, _ := c.Query(ctx, sql, args...) return (*connRow)(rows.(*connRows)) } // SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless // explicit transaction control statements are executed. func (c *Conn) SendBatch(ctx context.Context, b *Batch) BatchResults { batch := &pgconn.Batch{} for _, bi := range b.items { c.eqb.Reset() var parameterOIDs []pgtype.OID ps := c.preparedStatements[bi.query] if ps != nil { parameterOIDs = ps.ParameterOIDs } else { parameterOIDs = bi.parameterOIDs } args, err := convertDriverValuers(bi.arguments) if err != nil { return &batchResults{err: err} } for i := range args { err = c.eqb.AppendParam(c.ConnInfo, parameterOIDs[i], args[i]) if err != nil { return &batchResults{err: err} } } if ps != nil { resultFormats := bi.resultFormatCodes if resultFormats == nil { for i := range ps.FieldDescriptions { if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok { if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { c.eqb.AppendResultFormat(BinaryFormatCode) } else { c.eqb.AppendResultFormat(TextFormatCode) } } } resultFormats = c.eqb.resultFormats } batch.ExecPrepared(ps.Name, c.eqb.paramValues, c.eqb.paramFormats, resultFormats) } else { oids := make([]uint32, len(parameterOIDs)) for i := 0; i < len(parameterOIDs); i++ { oids[i] = uint32(parameterOIDs[i]) } batch.ExecParams(bi.query, c.eqb.paramValues, oids, c.eqb.paramFormats, bi.resultFormatCodes) } } mrr := c.pgConn.ExecBatch(ctx, batch) return &batchResults{ ctx: ctx, conn: c, mrr: mrr, } } func (c *Conn) sanitizeForSimpleQuery(sql string, args ...interface{}) (string, error) { if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" { return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on") } if c.pgConn.ParameterStatus("client_encoding") != "UTF8" { return "", errors.New("simple protocol queries must be run with client_encoding=UTF8") } var err error valueArgs := make([]interface{}, len(args)) for i, a := range args { valueArgs[i], err = convertSimpleArgument(c.ConnInfo, a) if err != nil { return "", err } } return sanitize.SanitizeSQL(sql, valueArgs...) }