mirror of
https://github.com/jackc/pgx.git
synced 2025-05-02 13:40:00 +00:00
CopyFrom had to create a prepared statement to get the OIDs of the data types that were going to be copied into the table. Every COPY operation required an extra round trips to retrieve the type information. There was no way to customize this behavior. By leveraging the QueryExecMode feature, like in `Conn.Query`, users can specify if they want to cache the prepared statements, execute them on every request (like the old behavior), or bypass the prepared statement relying on the pgtype.Map to get the type information. The `QueryExecMode` behave exactly like in `Conn.Query` in the way the data type OIDs are fetched, meaning that: - `QueryExecModeCacheStatement`: caches the statement. - `QueryExecModeCacheDescribe`: caches the statement and assumes they do not change. - `QueryExecModeDescribeExec`: gets the statement description on every execution. This is like to the old behavior of `CopyFrom`. - `QueryExecModeExec` and `QueryExecModeSimpleProtocol`: maintain the same behavior as before, which is the same as `QueryExecModeDescribeExec`. It will keep getting the statement description on every execution The `QueryExecMode` can only be set via `ConnConfig.DefaultQueryExecMode`, unlike `Conn.Query` there's no support for specifying the `QueryExecMode` via optional arguments in the function signature.
1341 lines
42 KiB
Go
1341 lines
42 KiB
Go
package pgx
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/internal/anynil"
|
|
"github.com/jackc/pgx/v5/internal/sanitize"
|
|
"github.com/jackc/pgx/v5/internal/stmtcache"
|
|
"github.com/jackc/pgx/v5/pgconn"
|
|
"github.com/jackc/pgx/v5/pgtype"
|
|
)
|
|
|
|
// ConnConfig contains all the options used to establish a connection. It must be created by ParseConfig and
|
|
// then it can be modified. A manually initialized ConnConfig will cause ConnectConfig to panic.
|
|
type ConnConfig struct {
|
|
pgconn.Config
|
|
|
|
Tracer QueryTracer
|
|
|
|
// Original connection string that was parsed into config.
|
|
connString string
|
|
|
|
// StatementCacheCapacity is maximum size of the statement cache used when executing a query with "cache_statement"
|
|
// query exec mode.
|
|
StatementCacheCapacity int
|
|
|
|
// DescriptionCacheCapacity is the maximum size of the description cache used when executing a query with
|
|
// "cache_describe" query exec mode.
|
|
DescriptionCacheCapacity int
|
|
|
|
// DefaultQueryExecMode controls the default mode for executing queries. By default pgx uses the extended protocol
|
|
// and automatically prepares and caches prepared statements. However, this may be incompatible with proxies such as
|
|
// PGBouncer. In this case it may be preferrable to use QueryExecModeExec or QueryExecModeSimpleProtocol. The same
|
|
// functionality can be controlled on a per query basis by passing a QueryExecMode as the first query argument.
|
|
DefaultQueryExecMode QueryExecMode
|
|
|
|
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
|
|
}
|
|
|
|
// ParseConfigOptions contains options that control how a config is built such as getsslpassword.
|
|
type ParseConfigOptions struct {
|
|
pgconn.ParseConfigOptions
|
|
}
|
|
|
|
// Copy returns a deep copy of the config that is safe to use and modify.
|
|
// The only exception is the tls.Config:
|
|
// according to the tls.Config docs it must not be modified after creation.
|
|
func (cc *ConnConfig) Copy() *ConnConfig {
|
|
newConfig := new(ConnConfig)
|
|
*newConfig = *cc
|
|
newConfig.Config = *newConfig.Config.Copy()
|
|
return newConfig
|
|
}
|
|
|
|
// ConnString returns the connection string as parsed by pgx.ParseConfig into pgx.ConnConfig.
|
|
func (cc *ConnConfig) ConnString() string { return cc.connString }
|
|
|
|
// Conn is a PostgreSQL connection handle. It is not safe for concurrent usage. Use a connection pool to manage access
|
|
// to multiple database connections from multiple goroutines.
|
|
type Conn struct {
|
|
pgConn *pgconn.PgConn
|
|
config *ConnConfig // config used when establishing this connection
|
|
preparedStatements map[string]*pgconn.StatementDescription
|
|
statementCache stmtcache.Cache
|
|
descriptionCache stmtcache.Cache
|
|
|
|
queryTracer QueryTracer
|
|
batchTracer BatchTracer
|
|
copyFromTracer CopyFromTracer
|
|
prepareTracer PrepareTracer
|
|
|
|
notifications []*pgconn.Notification
|
|
|
|
doneChan chan struct{}
|
|
closedChan chan error
|
|
|
|
typeMap *pgtype.Map
|
|
|
|
wbuf []byte
|
|
eqb ExtendedQueryBuilder
|
|
}
|
|
|
|
// Identifier a PostgreSQL identifier or name. Identifiers can be composed of
|
|
// multiple parts such as ["schema", "table"] or ["table", "column"].
|
|
type Identifier []string
|
|
|
|
// Sanitize returns a sanitized string safe for SQL interpolation.
|
|
func (ident Identifier) Sanitize() string {
|
|
parts := make([]string, len(ident))
|
|
for i := range ident {
|
|
s := strings.ReplaceAll(ident[i], string([]byte{0}), "")
|
|
parts[i] = `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
|
|
}
|
|
return strings.Join(parts, ".")
|
|
}
|
|
|
|
// ErrNoRows occurs when rows are expected but none are returned.
|
|
var ErrNoRows = errors.New("no rows in result set")
|
|
|
|
var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
|
|
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
|
|
|
|
// Connect establishes a connection with a PostgreSQL server with a connection string. See
|
|
// pgconn.Connect for details.
|
|
func Connect(ctx context.Context, connString string) (*Conn, error) {
|
|
connConfig, err := ParseConfig(connString)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return connect(ctx, connConfig)
|
|
}
|
|
|
|
// ConnectWithOptions behaves exactly like Connect with the addition of options. At the present options is only used to
|
|
// provide a GetSSLPassword function.
|
|
func ConnectWithOptions(ctx context.Context, connString string, options ParseConfigOptions) (*Conn, error) {
|
|
connConfig, err := ParseConfigWithOptions(connString, options)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return connect(ctx, connConfig)
|
|
}
|
|
|
|
// ConnectConfig establishes a connection with a PostgreSQL server with a configuration struct.
|
|
// connConfig must have been created by ParseConfig.
|
|
func ConnectConfig(ctx context.Context, connConfig *ConnConfig) (*Conn, error) {
|
|
// In general this improves safety. In particular avoid the config.Config.OnNotification mutation from affecting other
|
|
// connections with the same config. See https://github.com/jackc/pgx/issues/618.
|
|
connConfig = connConfig.Copy()
|
|
|
|
return connect(ctx, connConfig)
|
|
}
|
|
|
|
// ParseConfigWithOptions behaves exactly as ParseConfig does with the addition of options. At the present options is
|
|
// only used to provide a GetSSLPassword function.
|
|
func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*ConnConfig, error) {
|
|
config, err := pgconn.ParseConfigWithOptions(connString, options.ParseConfigOptions)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
statementCacheCapacity := 512
|
|
if s, ok := config.RuntimeParams["statement_cache_capacity"]; ok {
|
|
delete(config.RuntimeParams, "statement_cache_capacity")
|
|
n, err := strconv.ParseInt(s, 10, 32)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cannot parse statement_cache_capacity: %w", err)
|
|
}
|
|
statementCacheCapacity = int(n)
|
|
}
|
|
|
|
descriptionCacheCapacity := 512
|
|
if s, ok := config.RuntimeParams["description_cache_capacity"]; ok {
|
|
delete(config.RuntimeParams, "description_cache_capacity")
|
|
n, err := strconv.ParseInt(s, 10, 32)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("cannot parse description_cache_capacity: %w", err)
|
|
}
|
|
descriptionCacheCapacity = int(n)
|
|
}
|
|
|
|
defaultQueryExecMode := QueryExecModeCacheStatement
|
|
if s, ok := config.RuntimeParams["default_query_exec_mode"]; ok {
|
|
delete(config.RuntimeParams, "default_query_exec_mode")
|
|
switch s {
|
|
case "cache_statement":
|
|
defaultQueryExecMode = QueryExecModeCacheStatement
|
|
case "cache_describe":
|
|
defaultQueryExecMode = QueryExecModeCacheDescribe
|
|
case "describe_exec":
|
|
defaultQueryExecMode = QueryExecModeDescribeExec
|
|
case "exec":
|
|
defaultQueryExecMode = QueryExecModeExec
|
|
case "simple_protocol":
|
|
defaultQueryExecMode = QueryExecModeSimpleProtocol
|
|
default:
|
|
return nil, fmt.Errorf("invalid default_query_exec_mode: %v", err)
|
|
}
|
|
}
|
|
|
|
connConfig := &ConnConfig{
|
|
Config: *config,
|
|
createdByParseConfig: true,
|
|
StatementCacheCapacity: statementCacheCapacity,
|
|
DescriptionCacheCapacity: descriptionCacheCapacity,
|
|
DefaultQueryExecMode: defaultQueryExecMode,
|
|
connString: connString,
|
|
}
|
|
|
|
return connConfig, nil
|
|
}
|
|
|
|
// ParseConfig creates a ConnConfig from a connection string. ParseConfig handles all options that pgconn.ParseConfig
|
|
// does. In addition, it accepts the following options:
|
|
//
|
|
// default_query_exec_mode
|
|
// Possible values: "cache_statement", "cache_describe", "describe_exec", "exec", and "simple_protocol". See
|
|
// QueryExecMode constant documentation for the meaning of these values. Default: "cache_statement".
|
|
//
|
|
// statement_cache_capacity
|
|
// The maximum size of the statement cache used when executing a query with "cache_statement" query exec mode.
|
|
// Default: 512.
|
|
//
|
|
// description_cache_capacity
|
|
// The maximum size of the description cache used when executing a query with "cache_describe" query exec mode.
|
|
// Default: 512.
|
|
func ParseConfig(connString string) (*ConnConfig, error) {
|
|
return ParseConfigWithOptions(connString, ParseConfigOptions{})
|
|
}
|
|
|
|
// connect connects to a database. connect takes ownership of config. The caller must not use or access it again.
|
|
func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
|
|
if connectTracer, ok := config.Tracer.(ConnectTracer); ok {
|
|
ctx = connectTracer.TraceConnectStart(ctx, TraceConnectStartData{ConnConfig: config})
|
|
defer func() {
|
|
connectTracer.TraceConnectEnd(ctx, TraceConnectEndData{Conn: c, Err: err})
|
|
}()
|
|
}
|
|
|
|
// Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
|
|
// zero values.
|
|
if !config.createdByParseConfig {
|
|
panic("config must be created by ParseConfig")
|
|
}
|
|
|
|
c = &Conn{
|
|
config: config,
|
|
typeMap: pgtype.NewMap(),
|
|
queryTracer: config.Tracer,
|
|
}
|
|
|
|
if t, ok := c.queryTracer.(BatchTracer); ok {
|
|
c.batchTracer = t
|
|
}
|
|
if t, ok := c.queryTracer.(CopyFromTracer); ok {
|
|
c.copyFromTracer = t
|
|
}
|
|
if t, ok := c.queryTracer.(PrepareTracer); ok {
|
|
c.prepareTracer = t
|
|
}
|
|
|
|
// Only install pgx notification system if no other callback handler is present.
|
|
if config.Config.OnNotification == nil {
|
|
config.Config.OnNotification = c.bufferNotifications
|
|
}
|
|
|
|
c.pgConn, err = pgconn.ConnectConfig(ctx, &config.Config)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
c.preparedStatements = make(map[string]*pgconn.StatementDescription)
|
|
c.doneChan = make(chan struct{})
|
|
c.closedChan = make(chan error)
|
|
c.wbuf = make([]byte, 0, 1024)
|
|
|
|
if c.config.StatementCacheCapacity > 0 {
|
|
c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity)
|
|
}
|
|
|
|
if c.config.DescriptionCacheCapacity > 0 {
|
|
c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
// Close closes a connection. It is safe to call Close on a already closed
|
|
// connection.
|
|
func (c *Conn) Close(ctx context.Context) error {
|
|
if c.IsClosed() {
|
|
return nil
|
|
}
|
|
|
|
err := c.pgConn.Close(ctx)
|
|
return err
|
|
}
|
|
|
|
// Prepare creates a prepared statement with name and sql. sql can contain placeholders
|
|
// for bound parameters. These placeholders are referenced positional as $1, $2, etc.
|
|
//
|
|
// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same
|
|
// name and sql arguments. This allows a code path to Prepare and Query/Exec without
|
|
// concern for if the statement has already been prepared.
|
|
func (c *Conn) Prepare(ctx context.Context, name, sql string) (sd *pgconn.StatementDescription, err error) {
|
|
if c.prepareTracer != nil {
|
|
ctx = c.prepareTracer.TracePrepareStart(ctx, c, TracePrepareStartData{Name: name, SQL: sql})
|
|
}
|
|
|
|
if name != "" {
|
|
var ok bool
|
|
if sd, ok = c.preparedStatements[name]; ok && sd.SQL == sql {
|
|
if c.prepareTracer != nil {
|
|
c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{AlreadyPrepared: true})
|
|
}
|
|
return sd, nil
|
|
}
|
|
}
|
|
|
|
if c.prepareTracer != nil {
|
|
defer func() {
|
|
c.prepareTracer.TracePrepareEnd(ctx, c, TracePrepareEndData{Err: err})
|
|
}()
|
|
}
|
|
|
|
sd, err = c.pgConn.Prepare(ctx, name, sql, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if name != "" {
|
|
c.preparedStatements[name] = sd
|
|
}
|
|
|
|
return sd, nil
|
|
}
|
|
|
|
// Deallocate released a prepared statement
|
|
func (c *Conn) Deallocate(ctx context.Context, name string) error {
|
|
delete(c.preparedStatements, name)
|
|
_, err := c.pgConn.Exec(ctx, "deallocate "+quoteIdentifier(name)).ReadAll()
|
|
return err
|
|
}
|
|
|
|
// DeallocateAll releases all previously prepared statements from the server and client, where it also resets the statement and description cache.
|
|
func (c *Conn) DeallocateAll(ctx context.Context) error {
|
|
c.preparedStatements = map[string]*pgconn.StatementDescription{}
|
|
if c.config.StatementCacheCapacity > 0 {
|
|
c.statementCache = stmtcache.NewLRUCache(c.config.StatementCacheCapacity)
|
|
}
|
|
if c.config.DescriptionCacheCapacity > 0 {
|
|
c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
|
|
}
|
|
_, err := c.pgConn.Exec(ctx, "deallocate all").ReadAll()
|
|
return err
|
|
}
|
|
|
|
func (c *Conn) bufferNotifications(_ *pgconn.PgConn, n *pgconn.Notification) {
|
|
c.notifications = append(c.notifications, n)
|
|
}
|
|
|
|
// WaitForNotification waits for a PostgreSQL notification. It wraps the underlying pgconn notification system in a
|
|
// slightly more convenient form.
|
|
func (c *Conn) WaitForNotification(ctx context.Context) (*pgconn.Notification, error) {
|
|
var n *pgconn.Notification
|
|
|
|
// Return already received notification immediately
|
|
if len(c.notifications) > 0 {
|
|
n = c.notifications[0]
|
|
c.notifications = c.notifications[1:]
|
|
return n, nil
|
|
}
|
|
|
|
err := c.pgConn.WaitForNotification(ctx)
|
|
if len(c.notifications) > 0 {
|
|
n = c.notifications[0]
|
|
c.notifications = c.notifications[1:]
|
|
}
|
|
return n, err
|
|
}
|
|
|
|
// IsClosed reports if the connection has been closed.
|
|
func (c *Conn) IsClosed() bool {
|
|
return c.pgConn.IsClosed()
|
|
}
|
|
|
|
func (c *Conn) die(err error) {
|
|
if c.IsClosed() {
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
cancel() // force immediate hard cancel
|
|
c.pgConn.Close(ctx)
|
|
}
|
|
|
|
func quoteIdentifier(s string) string {
|
|
return `"` + strings.ReplaceAll(s, `"`, `""`) + `"`
|
|
}
|
|
|
|
// Ping executes an empty sql statement against the *Conn
|
|
// If the sql returns without error, the database Ping is considered successful, otherwise, the error is returned.
|
|
func (c *Conn) Ping(ctx context.Context) error {
|
|
_, err := c.Exec(ctx, ";")
|
|
return err
|
|
}
|
|
|
|
// PgConn returns the underlying *pgconn.PgConn. This is an escape hatch method that allows lower level access to the
|
|
// PostgreSQL connection than pgx exposes.
|
|
//
|
|
// It is strongly recommended that the connection be idle (no in-progress queries) before the underlying *pgconn.PgConn
|
|
// is used and the connection must be returned to the same state before any *pgx.Conn methods are again used.
|
|
func (c *Conn) PgConn() *pgconn.PgConn { return c.pgConn }
|
|
|
|
// TypeMap returns the connection info used for this connection.
|
|
func (c *Conn) TypeMap() *pgtype.Map { return c.typeMap }
|
|
|
|
// Config returns a copy of config that was used to establish this connection.
|
|
func (c *Conn) Config() *ConnConfig { return c.config.Copy() }
|
|
|
|
// Exec executes sql. sql can be either a prepared statement name or an SQL string. arguments should be referenced
|
|
// positionally from the sql string as $1, $2, etc.
|
|
func (c *Conn) Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error) {
|
|
if c.queryTracer != nil {
|
|
ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: arguments})
|
|
}
|
|
|
|
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
|
|
return pgconn.CommandTag{}, err
|
|
}
|
|
|
|
commandTag, err := c.exec(ctx, sql, arguments...)
|
|
|
|
if c.queryTracer != nil {
|
|
c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{CommandTag: commandTag, Err: err})
|
|
}
|
|
|
|
return commandTag, err
|
|
}
|
|
|
|
func (c *Conn) exec(ctx context.Context, sql string, arguments ...any) (commandTag pgconn.CommandTag, err error) {
|
|
mode := c.config.DefaultQueryExecMode
|
|
var queryRewriter QueryRewriter
|
|
|
|
optionLoop:
|
|
for len(arguments) > 0 {
|
|
switch arg := arguments[0].(type) {
|
|
case QueryExecMode:
|
|
mode = arg
|
|
arguments = arguments[1:]
|
|
case QueryRewriter:
|
|
queryRewriter = arg
|
|
arguments = arguments[1:]
|
|
default:
|
|
break optionLoop
|
|
}
|
|
}
|
|
|
|
if queryRewriter != nil {
|
|
sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
|
if err != nil {
|
|
return pgconn.CommandTag{}, fmt.Errorf("rewrite query failed: %v", err)
|
|
}
|
|
}
|
|
|
|
// Always use simple protocol when there are no arguments.
|
|
if len(arguments) == 0 {
|
|
mode = QueryExecModeSimpleProtocol
|
|
}
|
|
|
|
if sd, ok := c.preparedStatements[sql]; ok {
|
|
return c.execPrepared(ctx, sd, arguments)
|
|
}
|
|
|
|
switch mode {
|
|
case QueryExecModeCacheStatement:
|
|
if c.statementCache == nil {
|
|
return pgconn.CommandTag{}, errDisabledStatementCache
|
|
}
|
|
sd := c.statementCache.Get(sql)
|
|
if sd == nil {
|
|
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
|
|
if err != nil {
|
|
return pgconn.CommandTag{}, err
|
|
}
|
|
c.statementCache.Put(sd)
|
|
}
|
|
|
|
return c.execPrepared(ctx, sd, arguments)
|
|
case QueryExecModeCacheDescribe:
|
|
if c.descriptionCache == nil {
|
|
return pgconn.CommandTag{}, errDisabledDescriptionCache
|
|
}
|
|
sd := c.descriptionCache.Get(sql)
|
|
if sd == nil {
|
|
sd, err = c.Prepare(ctx, "", sql)
|
|
if err != nil {
|
|
return pgconn.CommandTag{}, err
|
|
}
|
|
}
|
|
|
|
return c.execParams(ctx, sd, arguments)
|
|
case QueryExecModeDescribeExec:
|
|
sd, err := c.Prepare(ctx, "", sql)
|
|
if err != nil {
|
|
return pgconn.CommandTag{}, err
|
|
}
|
|
return c.execPrepared(ctx, sd, arguments)
|
|
case QueryExecModeExec:
|
|
return c.execSQLParams(ctx, sql, arguments)
|
|
case QueryExecModeSimpleProtocol:
|
|
return c.execSimpleProtocol(ctx, sql, arguments)
|
|
default:
|
|
return pgconn.CommandTag{}, fmt.Errorf("unknown QueryExecMode: %v", mode)
|
|
}
|
|
}
|
|
|
|
func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []any) (commandTag pgconn.CommandTag, err error) {
|
|
if len(arguments) > 0 {
|
|
sql, err = c.sanitizeForSimpleQuery(sql, arguments...)
|
|
if err != nil {
|
|
return pgconn.CommandTag{}, err
|
|
}
|
|
}
|
|
|
|
mrr := c.pgConn.Exec(ctx, sql)
|
|
for mrr.NextResult() {
|
|
commandTag, err = mrr.ResultReader().Close()
|
|
}
|
|
err = mrr.Close()
|
|
return commandTag, err
|
|
}
|
|
|
|
func (c *Conn) execParams(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) {
|
|
err := c.eqb.Build(c.typeMap, sd, arguments)
|
|
if err != nil {
|
|
return pgconn.CommandTag{}, err
|
|
}
|
|
|
|
result := c.pgConn.ExecParams(ctx, sd.SQL, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats).Read()
|
|
c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
|
|
return result.CommandTag, result.Err
|
|
}
|
|
|
|
func (c *Conn) execPrepared(ctx context.Context, sd *pgconn.StatementDescription, arguments []any) (pgconn.CommandTag, error) {
|
|
err := c.eqb.Build(c.typeMap, sd, arguments)
|
|
if err != nil {
|
|
return pgconn.CommandTag{}, err
|
|
}
|
|
|
|
result := c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats).Read()
|
|
c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
|
|
return result.CommandTag, result.Err
|
|
}
|
|
|
|
type unknownArgumentTypeQueryExecModeExecError struct {
|
|
arg any
|
|
}
|
|
|
|
func (e *unknownArgumentTypeQueryExecModeExecError) Error() string {
|
|
return fmt.Sprintf("cannot use unregistered type %T as query argument in QueryExecModeExec", e.arg)
|
|
}
|
|
|
|
func (c *Conn) execSQLParams(ctx context.Context, sql string, args []any) (pgconn.CommandTag, error) {
|
|
err := c.eqb.Build(c.typeMap, nil, args)
|
|
if err != nil {
|
|
return pgconn.CommandTag{}, err
|
|
}
|
|
|
|
result := c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats).Read()
|
|
c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
|
|
return result.CommandTag, result.Err
|
|
}
|
|
|
|
func (c *Conn) getRows(ctx context.Context, sql string, args []any) *baseRows {
|
|
r := &baseRows{}
|
|
|
|
r.ctx = ctx
|
|
r.queryTracer = c.queryTracer
|
|
r.typeMap = c.typeMap
|
|
r.startTime = time.Now()
|
|
r.sql = sql
|
|
r.args = args
|
|
r.conn = c
|
|
|
|
return r
|
|
}
|
|
|
|
type QueryExecMode int32
|
|
|
|
const (
|
|
_ QueryExecMode = iota
|
|
|
|
// Automatically prepare and cache statements. This uses the extended protocol. Queries are executed in a single
|
|
// round trip after the statement is cached. This is the default.
|
|
QueryExecModeCacheStatement
|
|
|
|
// Cache statement descriptions (i.e. argument and result types) and assume they do not change. This uses the
|
|
// extended protocol. Queries are executed in a single round trip after the description is cached. If the database
|
|
// schema is modified or the search_path is changed this may result in undetected result decoding errors.
|
|
QueryExecModeCacheDescribe
|
|
|
|
// Get the statement description on every execution. This uses the extended protocol. Queries require two round trips
|
|
// to execute. It does not use prepared statements (allowing usage with most connection poolers) and is safe even
|
|
// when the the database schema is modified concurrently.
|
|
QueryExecModeDescribeExec
|
|
|
|
// Assume the PostgreSQL query parameter types based on the Go type of the arguments. This uses the extended protocol
|
|
// with text formatted parameters and results. Queries are executed in a single round trip. Type mappings can be
|
|
// registered with pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are
|
|
// unregistered or ambigious. e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know
|
|
// the PostgreSQL type can use a map[string]string directly as an argument. This mode cannot.
|
|
QueryExecModeExec
|
|
|
|
// Use the simple protocol. Assume the PostgreSQL query parameter types based on the Go type of the arguments.
|
|
// Queries are executed in a single round trip. Type mappings can be registered with
|
|
// pgtype.Map.RegisterDefaultPgType. Queries will be rejected that have arguments that are unregistered or ambigious.
|
|
// e.g. A map[string]string may have the PostgreSQL type json or hstore. Modes that know the PostgreSQL type can use
|
|
// a map[string]string directly as an argument. This mode cannot.
|
|
//
|
|
// QueryExecModeSimpleProtocol should have the user application visible behavior as QueryExecModeExec with minor
|
|
// exceptions such as behavior when multiple result returning queries are erroneously sent in a single string.
|
|
//
|
|
// QueryExecModeSimpleProtocol uses client side parameter interpolation. All values are quoted and escaped. Prefer
|
|
// QueryExecModeExec over QueryExecModeSimpleProtocol whenever possible. In general QueryExecModeSimpleProtocol
|
|
// should only be used if connecting to a proxy server, connection pool server, or non-PostgreSQL server that does
|
|
// not support the extended protocol.
|
|
QueryExecModeSimpleProtocol
|
|
)
|
|
|
|
func (m QueryExecMode) String() string {
|
|
switch m {
|
|
case QueryExecModeCacheStatement:
|
|
return "cache statement"
|
|
case QueryExecModeCacheDescribe:
|
|
return "cache describe"
|
|
case QueryExecModeDescribeExec:
|
|
return "describe exec"
|
|
case QueryExecModeExec:
|
|
return "exec"
|
|
case QueryExecModeSimpleProtocol:
|
|
return "simple protocol"
|
|
default:
|
|
return "invalid"
|
|
}
|
|
}
|
|
|
|
// QueryResultFormats controls the result format (text=0, binary=1) of a query by result column position.
|
|
type QueryResultFormats []int16
|
|
|
|
// QueryResultFormatsByOID controls the result format (text=0, binary=1) of a query by the result column OID.
|
|
type QueryResultFormatsByOID map[uint32]int16
|
|
|
|
// QueryRewriter rewrites a query when used as the first arguments to a query method.
|
|
type QueryRewriter interface {
|
|
RewriteQuery(ctx context.Context, conn *Conn, sql string, args []any) (newSQL string, newArgs []any, err error)
|
|
}
|
|
|
|
// Query sends a query to the server and returns a Rows to read the results. Only errors encountered sending the query
|
|
// and initializing Rows will be returned. Err() on the returned Rows must be checked after the Rows is closed to
|
|
// determine if the query executed successfully.
|
|
//
|
|
// The returned Rows must be closed before the connection can be used again. It is safe to attempt to read from the
|
|
// returned Rows even if an error is returned. The error will be the available in rows.Err() after rows are closed. It
|
|
// is allowed to ignore the error returned from Query and handle it in Rows.
|
|
//
|
|
// It is possible for a query to return one or more rows before encountering an error. In most cases the rows should be
|
|
// collected before processing rather than processed while receiving each row. This avoids the possibility of the
|
|
// application processing rows from a query that the server rejected. The CollectRows function is useful here.
|
|
//
|
|
// An implementor of QueryRewriter may be passed as the first element of args. It can rewrite the sql and change or
|
|
// replace args. For example, NamedArgs is QueryRewriter that implements named arguments.
|
|
//
|
|
// For extra control over how the query is executed, the types QueryExecMode, QueryResultFormats, and
|
|
// QueryResultFormatsByOID may be used as the first args to control exactly how the query is executed. This is rarely
|
|
// needed. See the documentation for those types for details.
|
|
func (c *Conn) Query(ctx context.Context, sql string, args ...any) (Rows, error) {
|
|
if c.queryTracer != nil {
|
|
ctx = c.queryTracer.TraceQueryStart(ctx, c, TraceQueryStartData{SQL: sql, Args: args})
|
|
}
|
|
|
|
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
|
|
if c.queryTracer != nil {
|
|
c.queryTracer.TraceQueryEnd(ctx, c, TraceQueryEndData{Err: err})
|
|
}
|
|
return &baseRows{err: err, closed: true}, err
|
|
}
|
|
|
|
var resultFormats QueryResultFormats
|
|
var resultFormatsByOID QueryResultFormatsByOID
|
|
mode := c.config.DefaultQueryExecMode
|
|
var queryRewriter QueryRewriter
|
|
|
|
optionLoop:
|
|
for len(args) > 0 {
|
|
switch arg := args[0].(type) {
|
|
case QueryResultFormats:
|
|
resultFormats = arg
|
|
args = args[1:]
|
|
case QueryResultFormatsByOID:
|
|
resultFormatsByOID = arg
|
|
args = args[1:]
|
|
case QueryExecMode:
|
|
mode = arg
|
|
args = args[1:]
|
|
case QueryRewriter:
|
|
queryRewriter = arg
|
|
args = args[1:]
|
|
default:
|
|
break optionLoop
|
|
}
|
|
}
|
|
|
|
if queryRewriter != nil {
|
|
var err error
|
|
originalSQL := sql
|
|
originalArgs := args
|
|
sql, args, err = queryRewriter.RewriteQuery(ctx, c, sql, args)
|
|
if err != nil {
|
|
rows := c.getRows(ctx, originalSQL, originalArgs)
|
|
err = fmt.Errorf("rewrite query failed: %v", err)
|
|
rows.fatal(err)
|
|
return rows, err
|
|
}
|
|
}
|
|
|
|
// Bypass any statement caching.
|
|
if sql == "" {
|
|
mode = QueryExecModeSimpleProtocol
|
|
}
|
|
|
|
c.eqb.reset()
|
|
anynil.NormalizeSlice(args)
|
|
rows := c.getRows(ctx, sql, args)
|
|
|
|
var err error
|
|
sd, explicitPreparedStatement := c.preparedStatements[sql]
|
|
if sd != nil || mode == QueryExecModeCacheStatement || mode == QueryExecModeCacheDescribe || mode == QueryExecModeDescribeExec {
|
|
if sd == nil {
|
|
sd, err = c.getStatementDescription(ctx, mode, sql)
|
|
if err != nil {
|
|
rows.fatal(err)
|
|
return rows, err
|
|
}
|
|
}
|
|
|
|
if len(sd.ParamOIDs) != len(args) {
|
|
rows.fatal(fmt.Errorf("expected %d arguments, got %d", len(sd.ParamOIDs), len(args)))
|
|
return rows, rows.err
|
|
}
|
|
|
|
rows.sql = sd.SQL
|
|
|
|
err = c.eqb.Build(c.typeMap, sd, args)
|
|
if err != nil {
|
|
rows.fatal(err)
|
|
return rows, rows.err
|
|
}
|
|
|
|
if resultFormatsByOID != nil {
|
|
resultFormats = make([]int16, len(sd.Fields))
|
|
for i := range resultFormats {
|
|
resultFormats[i] = resultFormatsByOID[uint32(sd.Fields[i].DataTypeOID)]
|
|
}
|
|
}
|
|
|
|
if resultFormats == nil {
|
|
resultFormats = c.eqb.ResultFormats
|
|
}
|
|
|
|
if !explicitPreparedStatement && mode == QueryExecModeCacheDescribe {
|
|
rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, sd.ParamOIDs, c.eqb.ParamFormats, resultFormats)
|
|
} else {
|
|
rows.resultReader = c.pgConn.ExecPrepared(ctx, sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, resultFormats)
|
|
}
|
|
} else if mode == QueryExecModeExec {
|
|
err := c.eqb.Build(c.typeMap, nil, args)
|
|
if err != nil {
|
|
rows.fatal(err)
|
|
return rows, rows.err
|
|
}
|
|
|
|
rows.resultReader = c.pgConn.ExecParams(ctx, sql, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
|
} else if mode == QueryExecModeSimpleProtocol {
|
|
sql, err = c.sanitizeForSimpleQuery(sql, args...)
|
|
if err != nil {
|
|
rows.fatal(err)
|
|
return rows, err
|
|
}
|
|
|
|
mrr := c.pgConn.Exec(ctx, sql)
|
|
if mrr.NextResult() {
|
|
rows.resultReader = mrr.ResultReader()
|
|
rows.multiResultReader = mrr
|
|
} else {
|
|
err = mrr.Close()
|
|
rows.fatal(err)
|
|
return rows, err
|
|
}
|
|
|
|
return rows, nil
|
|
} else {
|
|
err = fmt.Errorf("unknown QueryExecMode: %v", mode)
|
|
rows.fatal(err)
|
|
return rows, rows.err
|
|
}
|
|
|
|
c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
|
|
|
|
return rows, rows.err
|
|
}
|
|
|
|
// getStatementDescription returns the statement description of the sql query
|
|
// according to the given mode.
|
|
//
|
|
// If the mode is one that doesn't require to know the param and result OIDs
|
|
// then nil is returned without error.
|
|
func (c *Conn) getStatementDescription(
|
|
ctx context.Context,
|
|
mode QueryExecMode,
|
|
sql string,
|
|
) (sd *pgconn.StatementDescription, err error) {
|
|
|
|
switch mode {
|
|
case QueryExecModeCacheStatement:
|
|
if c.statementCache == nil {
|
|
return nil, errDisabledStatementCache
|
|
}
|
|
sd = c.statementCache.Get(sql)
|
|
if sd == nil {
|
|
sd, err = c.Prepare(ctx, stmtcache.NextStatementName(), sql)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.statementCache.Put(sd)
|
|
}
|
|
case QueryExecModeCacheDescribe:
|
|
if c.descriptionCache == nil {
|
|
return nil, errDisabledDescriptionCache
|
|
}
|
|
sd = c.descriptionCache.Get(sql)
|
|
if sd == nil {
|
|
sd, err = c.Prepare(ctx, "", sql)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.descriptionCache.Put(sd)
|
|
}
|
|
case QueryExecModeDescribeExec:
|
|
return c.Prepare(ctx, "", sql)
|
|
}
|
|
return sd, err
|
|
}
|
|
|
|
// QueryRow is a convenience wrapper over Query. Any error that occurs while
|
|
// querying is deferred until calling Scan on the returned Row. That Row will
|
|
// error with ErrNoRows if no rows are returned.
|
|
func (c *Conn) QueryRow(ctx context.Context, sql string, args ...any) Row {
|
|
rows, _ := c.Query(ctx, sql, args...)
|
|
return (*connRow)(rows.(*baseRows))
|
|
}
|
|
|
|
// SendBatch sends all queued queries to the server at once. All queries are run in an implicit transaction unless
|
|
// explicit transaction control statements are executed. The returned BatchResults must be closed before the connection
|
|
// is used again.
|
|
func (c *Conn) SendBatch(ctx context.Context, b *Batch) (br BatchResults) {
|
|
if c.batchTracer != nil {
|
|
ctx = c.batchTracer.TraceBatchStart(ctx, c, TraceBatchStartData{Batch: b})
|
|
defer func() {
|
|
err := br.(interface{ earlyError() error }).earlyError()
|
|
if err != nil {
|
|
c.batchTracer.TraceBatchEnd(ctx, c, TraceBatchEndData{Err: err})
|
|
}
|
|
}()
|
|
}
|
|
|
|
if err := c.deallocateInvalidatedCachedStatements(ctx); err != nil {
|
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
|
}
|
|
|
|
mode := c.config.DefaultQueryExecMode
|
|
|
|
for _, bi := range b.queuedQueries {
|
|
var queryRewriter QueryRewriter
|
|
sql := bi.query
|
|
arguments := bi.arguments
|
|
|
|
optionLoop:
|
|
for len(arguments) > 0 {
|
|
switch arg := arguments[0].(type) {
|
|
case QueryRewriter:
|
|
queryRewriter = arg
|
|
arguments = arguments[1:]
|
|
default:
|
|
break optionLoop
|
|
}
|
|
}
|
|
|
|
if queryRewriter != nil {
|
|
var err error
|
|
sql, arguments, err = queryRewriter.RewriteQuery(ctx, c, sql, arguments)
|
|
if err != nil {
|
|
return &batchResults{ctx: ctx, conn: c, err: fmt.Errorf("rewrite query failed: %v", err)}
|
|
}
|
|
}
|
|
|
|
bi.query = sql
|
|
bi.arguments = arguments
|
|
}
|
|
|
|
if mode == QueryExecModeSimpleProtocol {
|
|
return c.sendBatchQueryExecModeSimpleProtocol(ctx, b)
|
|
}
|
|
|
|
// All other modes use extended protocol and thus can use prepared statements.
|
|
for _, bi := range b.queuedQueries {
|
|
if sd, ok := c.preparedStatements[bi.query]; ok {
|
|
bi.sd = sd
|
|
}
|
|
}
|
|
|
|
switch mode {
|
|
case QueryExecModeExec:
|
|
return c.sendBatchQueryExecModeExec(ctx, b)
|
|
case QueryExecModeCacheStatement:
|
|
return c.sendBatchQueryExecModeCacheStatement(ctx, b)
|
|
case QueryExecModeCacheDescribe:
|
|
return c.sendBatchQueryExecModeCacheDescribe(ctx, b)
|
|
case QueryExecModeDescribeExec:
|
|
return c.sendBatchQueryExecModeDescribeExec(ctx, b)
|
|
default:
|
|
panic("unknown QueryExecMode")
|
|
}
|
|
}
|
|
|
|
func (c *Conn) sendBatchQueryExecModeSimpleProtocol(ctx context.Context, b *Batch) *batchResults {
|
|
var sb strings.Builder
|
|
for i, bi := range b.queuedQueries {
|
|
if i > 0 {
|
|
sb.WriteByte(';')
|
|
}
|
|
sql, err := c.sanitizeForSimpleQuery(bi.query, bi.arguments...)
|
|
if err != nil {
|
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
|
}
|
|
sb.WriteString(sql)
|
|
}
|
|
mrr := c.pgConn.Exec(ctx, sb.String())
|
|
return &batchResults{
|
|
ctx: ctx,
|
|
conn: c,
|
|
mrr: mrr,
|
|
b: b,
|
|
qqIdx: 0,
|
|
}
|
|
}
|
|
|
|
func (c *Conn) sendBatchQueryExecModeExec(ctx context.Context, b *Batch) *batchResults {
|
|
batch := &pgconn.Batch{}
|
|
|
|
for _, bi := range b.queuedQueries {
|
|
sd := bi.sd
|
|
if sd != nil {
|
|
err := c.eqb.Build(c.typeMap, sd, bi.arguments)
|
|
if err != nil {
|
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
|
}
|
|
|
|
batch.ExecPrepared(sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
|
} else {
|
|
err := c.eqb.Build(c.typeMap, nil, bi.arguments)
|
|
if err != nil {
|
|
return &batchResults{ctx: ctx, conn: c, err: err}
|
|
}
|
|
batch.ExecParams(bi.query, c.eqb.ParamValues, nil, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
|
}
|
|
}
|
|
|
|
c.eqb.reset() // Allow c.eqb internal memory to be GC'ed as soon as possible.
|
|
|
|
mrr := c.pgConn.ExecBatch(ctx, batch)
|
|
|
|
return &batchResults{
|
|
ctx: ctx,
|
|
conn: c,
|
|
mrr: mrr,
|
|
b: b,
|
|
qqIdx: 0,
|
|
}
|
|
}
|
|
|
|
func (c *Conn) sendBatchQueryExecModeCacheStatement(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
|
if c.statementCache == nil {
|
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledStatementCache}
|
|
}
|
|
|
|
distinctNewQueries := []*pgconn.StatementDescription{}
|
|
distinctNewQueriesIdxMap := make(map[string]int)
|
|
|
|
for _, bi := range b.queuedQueries {
|
|
if bi.sd == nil {
|
|
sd := c.statementCache.Get(bi.query)
|
|
if sd != nil {
|
|
bi.sd = sd
|
|
} else {
|
|
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
|
bi.sd = distinctNewQueries[idx]
|
|
} else {
|
|
sd = &pgconn.StatementDescription{
|
|
Name: stmtcache.NextStatementName(),
|
|
SQL: bi.query,
|
|
}
|
|
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
|
distinctNewQueries = append(distinctNewQueries, sd)
|
|
bi.sd = sd
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.statementCache)
|
|
}
|
|
|
|
func (c *Conn) sendBatchQueryExecModeCacheDescribe(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
|
if c.descriptionCache == nil {
|
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: errDisabledDescriptionCache}
|
|
}
|
|
|
|
distinctNewQueries := []*pgconn.StatementDescription{}
|
|
distinctNewQueriesIdxMap := make(map[string]int)
|
|
|
|
for _, bi := range b.queuedQueries {
|
|
if bi.sd == nil {
|
|
sd := c.descriptionCache.Get(bi.query)
|
|
if sd != nil {
|
|
bi.sd = sd
|
|
} else {
|
|
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
|
bi.sd = distinctNewQueries[idx]
|
|
} else {
|
|
sd = &pgconn.StatementDescription{
|
|
SQL: bi.query,
|
|
}
|
|
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
|
distinctNewQueries = append(distinctNewQueries, sd)
|
|
bi.sd = sd
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, c.descriptionCache)
|
|
}
|
|
|
|
func (c *Conn) sendBatchQueryExecModeDescribeExec(ctx context.Context, b *Batch) (pbr *pipelineBatchResults) {
|
|
distinctNewQueries := []*pgconn.StatementDescription{}
|
|
distinctNewQueriesIdxMap := make(map[string]int)
|
|
|
|
for _, bi := range b.queuedQueries {
|
|
if bi.sd == nil {
|
|
if idx, present := distinctNewQueriesIdxMap[bi.query]; present {
|
|
bi.sd = distinctNewQueries[idx]
|
|
} else {
|
|
sd := &pgconn.StatementDescription{
|
|
SQL: bi.query,
|
|
}
|
|
distinctNewQueriesIdxMap[sd.SQL] = len(distinctNewQueries)
|
|
distinctNewQueries = append(distinctNewQueries, sd)
|
|
bi.sd = sd
|
|
}
|
|
}
|
|
}
|
|
|
|
return c.sendBatchExtendedWithDescription(ctx, b, distinctNewQueries, nil)
|
|
}
|
|
|
|
func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, distinctNewQueries []*pgconn.StatementDescription, sdCache stmtcache.Cache) (pbr *pipelineBatchResults) {
|
|
pipeline := c.pgConn.StartPipeline(context.Background())
|
|
defer func() {
|
|
if pbr.err != nil {
|
|
pipeline.Close()
|
|
}
|
|
}()
|
|
|
|
// Prepare any needed queries
|
|
if len(distinctNewQueries) > 0 {
|
|
for _, sd := range distinctNewQueries {
|
|
pipeline.SendPrepare(sd.Name, sd.SQL, nil)
|
|
}
|
|
|
|
err := pipeline.Sync()
|
|
if err != nil {
|
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
|
}
|
|
|
|
for _, sd := range distinctNewQueries {
|
|
results, err := pipeline.GetResults()
|
|
if err != nil {
|
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
|
}
|
|
|
|
resultSD, ok := results.(*pgconn.StatementDescription)
|
|
if !ok {
|
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected statement description, got %T", results)}
|
|
}
|
|
|
|
// Fill in the previously empty / pending statement descriptions.
|
|
sd.ParamOIDs = resultSD.ParamOIDs
|
|
sd.Fields = resultSD.Fields
|
|
}
|
|
|
|
results, err := pipeline.GetResults()
|
|
if err != nil {
|
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
|
}
|
|
|
|
_, ok := results.(*pgconn.PipelineSync)
|
|
if !ok {
|
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: fmt.Errorf("expected sync, got %T", results)}
|
|
}
|
|
}
|
|
|
|
// Put all statements into the cache. It's fine if it overflows because HandleInvalidated will clean them up later.
|
|
if sdCache != nil {
|
|
for _, sd := range distinctNewQueries {
|
|
sdCache.Put(sd)
|
|
}
|
|
}
|
|
|
|
// Queue the queries.
|
|
for _, bi := range b.queuedQueries {
|
|
err := c.eqb.Build(c.typeMap, bi.sd, bi.arguments)
|
|
if err != nil {
|
|
// we wrap the error so we the user can understand which query failed inside the batch
|
|
err = fmt.Errorf("error building query %s: %w", bi.query, err)
|
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
|
}
|
|
|
|
if bi.sd.Name == "" {
|
|
pipeline.SendQueryParams(bi.sd.SQL, c.eqb.ParamValues, bi.sd.ParamOIDs, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
|
} else {
|
|
pipeline.SendQueryPrepared(bi.sd.Name, c.eqb.ParamValues, c.eqb.ParamFormats, c.eqb.ResultFormats)
|
|
}
|
|
}
|
|
|
|
err := pipeline.Sync()
|
|
if err != nil {
|
|
return &pipelineBatchResults{ctx: ctx, conn: c, err: err}
|
|
}
|
|
|
|
return &pipelineBatchResults{
|
|
ctx: ctx,
|
|
conn: c,
|
|
pipeline: pipeline,
|
|
b: b,
|
|
}
|
|
}
|
|
|
|
func (c *Conn) sanitizeForSimpleQuery(sql string, args ...any) (string, error) {
|
|
if c.pgConn.ParameterStatus("standard_conforming_strings") != "on" {
|
|
return "", errors.New("simple protocol queries must be run with standard_conforming_strings=on")
|
|
}
|
|
|
|
if c.pgConn.ParameterStatus("client_encoding") != "UTF8" {
|
|
return "", errors.New("simple protocol queries must be run with client_encoding=UTF8")
|
|
}
|
|
|
|
var err error
|
|
valueArgs := make([]any, len(args))
|
|
for i, a := range args {
|
|
valueArgs[i], err = convertSimpleArgument(c.typeMap, a)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
return sanitize.SanitizeSQL(sql, valueArgs...)
|
|
}
|
|
|
|
// LoadType inspects the database for typeName and produces a pgtype.Type suitable for registration.
|
|
func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, error) {
|
|
var oid uint32
|
|
|
|
err := c.QueryRow(ctx, "select $1::text::regtype::oid;", typeName).Scan(&oid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var typtype string
|
|
var typbasetype uint32
|
|
|
|
err = c.QueryRow(ctx, "select typtype::text, typbasetype from pg_type where oid=$1", oid).Scan(&typtype, &typbasetype)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
switch typtype {
|
|
case "b": // array
|
|
elementOID, err := c.getArrayElementOID(ctx, oid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
dt, ok := c.TypeMap().TypeForOID(elementOID)
|
|
if !ok {
|
|
return nil, errors.New("array element OID not registered")
|
|
}
|
|
|
|
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}, nil
|
|
case "c": // composite
|
|
fields, err := c.getCompositeFields(ctx, oid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}}, nil
|
|
case "d": // domain
|
|
dt, ok := c.TypeMap().TypeForOID(typbasetype)
|
|
if !ok {
|
|
return nil, errors.New("domain base type OID not registered")
|
|
}
|
|
|
|
return &pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec}, nil
|
|
case "e": // enum
|
|
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil
|
|
case "r": // range
|
|
elementOID, err := c.getRangeElementOID(ctx, oid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
dt, ok := c.TypeMap().TypeForOID(elementOID)
|
|
if !ok {
|
|
return nil, errors.New("range element OID not registered")
|
|
}
|
|
|
|
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.RangeCodec{ElementType: dt}}, nil
|
|
case "m": // multirange
|
|
elementOID, err := c.getMultiRangeElementOID(ctx, oid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
dt, ok := c.TypeMap().TypeForOID(elementOID)
|
|
if !ok {
|
|
return nil, errors.New("multirange element OID not registered")
|
|
}
|
|
|
|
return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}, nil
|
|
default:
|
|
return &pgtype.Type{}, errors.New("unknown typtype")
|
|
}
|
|
}
|
|
|
|
func (c *Conn) getArrayElementOID(ctx context.Context, oid uint32) (uint32, error) {
|
|
var typelem uint32
|
|
|
|
err := c.QueryRow(ctx, "select typelem from pg_type where oid=$1", oid).Scan(&typelem)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return typelem, nil
|
|
}
|
|
|
|
func (c *Conn) getRangeElementOID(ctx context.Context, oid uint32) (uint32, error) {
|
|
var typelem uint32
|
|
|
|
err := c.QueryRow(ctx, "select rngsubtype from pg_range where rngtypid=$1", oid).Scan(&typelem)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return typelem, nil
|
|
}
|
|
|
|
func (c *Conn) getMultiRangeElementOID(ctx context.Context, oid uint32) (uint32, error) {
|
|
var typelem uint32
|
|
|
|
err := c.QueryRow(ctx, "select rngtypid from pg_range where rngmultitypid=$1", oid).Scan(&typelem)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return typelem, nil
|
|
}
|
|
|
|
func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) {
|
|
var typrelid uint32
|
|
|
|
err := c.QueryRow(ctx, "select typrelid from pg_type where oid=$1", oid).Scan(&typrelid)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var fields []pgtype.CompositeCodecField
|
|
var fieldName string
|
|
var fieldOID uint32
|
|
rows, _ := c.Query(ctx, `select attname, atttypid
|
|
from pg_attribute
|
|
where attrelid=$1
|
|
order by attnum`,
|
|
typrelid,
|
|
)
|
|
_, err = ForEachRow(rows, []any{&fieldName, &fieldOID}, func() error {
|
|
dt, ok := c.TypeMap().TypeForOID(fieldOID)
|
|
if !ok {
|
|
return fmt.Errorf("unknown composite type field OID: %v", fieldOID)
|
|
}
|
|
fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt})
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return fields, nil
|
|
}
|
|
|
|
func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error {
|
|
if c.pgConn.TxStatus() != 'I' {
|
|
return nil
|
|
}
|
|
|
|
if c.descriptionCache != nil {
|
|
c.descriptionCache.HandleInvalidated()
|
|
}
|
|
|
|
var invalidatedStatements []*pgconn.StatementDescription
|
|
if c.statementCache != nil {
|
|
invalidatedStatements = c.statementCache.HandleInvalidated()
|
|
}
|
|
|
|
if len(invalidatedStatements) == 0 {
|
|
return nil
|
|
}
|
|
|
|
pipeline := c.pgConn.StartPipeline(ctx)
|
|
defer pipeline.Close()
|
|
|
|
for _, sd := range invalidatedStatements {
|
|
pipeline.SendDeallocate(sd.Name)
|
|
}
|
|
|
|
err := pipeline.Sync()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
|
|
}
|
|
|
|
err = pipeline.Close()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to deallocate cached statement(s): %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|