package pgconn import ( "bytes" "context" "crypto/md5" "crypto/tls" "encoding/binary" "encoding/hex" "io" "math" "net" "strconv" "strings" "sync" "time" "github.com/jackc/pgconn/internal/ctxwatch" "github.com/jackc/pgio" "github.com/jackc/pgproto3/v2" errors "golang.org/x/xerrors" ) const ( connStatusUninitialized = iota connStatusConnecting connStatusClosed connStatusIdle connStatusBusy ) // Notice represents a notice response message reported by the PostgreSQL server. Be aware that this is distinct from // LISTEN/NOTIFY notification. type Notice PgError // Notification is a message received from the PostgreSQL LISTEN/NOTIFY system type Notification struct { PID uint32 // backend pid that sent the notification Channel string // channel from which notification was received Payload string } // DialFunc is a function that can be used to connect to a PostgreSQL server type DialFunc func(ctx context.Context, network, addr string) (net.Conn, error) // NoticeHandler is a function that can handle notices received from the PostgreSQL server. Notices can be received at // any time, usually during handling of a query response. The *PgConn is provided so the handler is aware of the origin // of the notice, but it must not invoke any query method. Be aware that this is distinct from LISTEN/NOTIFY // notification. type NoticeHandler func(*PgConn, *Notice) // NotificationHandler is a function that can handle notifications received from the PostgreSQL server. Notifications // can be received at any time, usually during handling of a query response. The *PgConn is provided so the handler is // aware of the origin of the notice, but it must not invoke any query method. Be aware that this is distinct from a // notice event. type NotificationHandler func(*PgConn, *Notification) // PgConn is a low-level PostgreSQL connection handle. It is not safe for concurrent usage. type PgConn struct { conn net.Conn // the underlying TCP or unix domain socket connection pid uint32 // backend pid secretKey uint32 // key to use to send a cancel query message to the server parameterStatuses map[string]string // parameters that have been reported by the server TxStatus byte Frontend *pgproto3.Frontend Config *Config status byte // One of connStatus* constants bufferingReceive bool bufferingReceiveMux sync.Mutex bufferingReceiveMsg pgproto3.BackendMessage bufferingReceiveErr error // Reusable / preallocated resources wbuf []byte // write buffer resultReader ResultReader multiResultReader MultiResultReader contextWatcher *ctxwatch.ContextWatcher } // Connect establishes a connection to a PostgreSQL server using the environment and connString (in URL or DSN format) // to provide configuration. See documention for ParseConfig for details. ctx can be used to cancel a connect attempt. func Connect(ctx context.Context, connString string) (*PgConn, error) { config, err := ParseConfig(connString) if err != nil { return nil, err } return ConnectConfig(ctx, config) } // Connect establishes a connection to a PostgreSQL server using config. ctx can be used to cancel a connect attempt. // // If config.Fallbacks are present they will sequentially be tried in case of error establishing network connection. An // authentication error will terminate the chain of attempts (like libpq: // https://www.postgresql.org/docs/11/libpq-connect.html#LIBPQ-MULTIPLE-HOSTS) and be returned as the error. Otherwise, // if all attempts fail the last error is returned. func ConnectConfig(ctx context.Context, config *Config) (pgConn *PgConn, err error) { // For convenience set a few defaults if not already set. This makes it simpler to directly construct a config. if config.Port == 0 { config.Port = 5432 } if config.DialFunc == nil { config.DialFunc = makeDefaultDialer().DialContext } if config.RuntimeParams == nil { config.RuntimeParams = make(map[string]string) } // Simplify usage by treating primary config and fallbacks the same. fallbackConfigs := []*FallbackConfig{ { Host: config.Host, Port: config.Port, TLSConfig: config.TLSConfig, }, } fallbackConfigs = append(fallbackConfigs, config.Fallbacks...) for _, fc := range fallbackConfigs { pgConn, err = connect(ctx, config, fc) if err == nil { break } else if err, ok := err.(*PgError); ok { return nil, err } } if err != nil { return nil, err } if config.AfterConnect != nil { err := config.AfterConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() return nil, errors.Errorf("AfterConnect: %v", err) } } return pgConn, nil } func connect(ctx context.Context, config *Config, fallbackConfig *FallbackConfig) (*PgConn, error) { pgConn := new(PgConn) pgConn.Config = config pgConn.wbuf = make([]byte, 0, 1024) var err error network, address := NetworkAddress(fallbackConfig.Host, fallbackConfig.Port) pgConn.conn, err = config.DialFunc(ctx, network, address) if err != nil { return nil, err } pgConn.parameterStatuses = make(map[string]string) if fallbackConfig.TLSConfig != nil { if err := pgConn.startTLS(fallbackConfig.TLSConfig); err != nil { pgConn.conn.Close() return nil, err } } pgConn.status = connStatusConnecting pgConn.contextWatcher = ctxwatch.NewContextWatcher( func() { pgConn.conn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, func() { pgConn.conn.SetDeadline(time.Time{}) }, ) pgConn.Frontend, err = pgproto3.NewFrontend(pgproto3.NewChunkReader(pgConn.conn), pgConn.conn) if err != nil { return nil, err } startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: make(map[string]string), } // Copy default run-time params for k, v := range config.RuntimeParams { startupMsg.Parameters[k] = v } startupMsg.Parameters["user"] = config.User if config.Database != "" { startupMsg.Parameters["database"] = config.Database } if _, err := pgConn.conn.Write(startupMsg.Encode(pgConn.wbuf)); err != nil { pgConn.conn.Close() return nil, err } for { msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.conn.Close() return nil, err } switch msg := msg.(type) { case *pgproto3.BackendKeyData: pgConn.pid = msg.ProcessID pgConn.secretKey = msg.SecretKey case *pgproto3.Authentication: if err = pgConn.rxAuthenticationX(msg); err != nil { pgConn.conn.Close() return nil, err } case *pgproto3.ReadyForQuery: pgConn.status = connStatusIdle if config.ValidateConnect != nil { err := config.ValidateConnect(ctx, pgConn) if err != nil { pgConn.conn.Close() return nil, errors.Errorf("ValidateConnect: %v", err) } } return pgConn, nil case *pgproto3.ParameterStatus: // handled by ReceiveMessage case *pgproto3.ErrorResponse: pgConn.conn.Close() return nil, errorResponseToPgError(msg) default: pgConn.conn.Close() return nil, errors.New("unexpected message") } } } func (pgConn *PgConn) startTLS(tlsConfig *tls.Config) (err error) { err = binary.Write(pgConn.conn, binary.BigEndian, []int32{8, 80877103}) if err != nil { return } response := make([]byte, 1) if _, err = io.ReadFull(pgConn.conn, response); err != nil { return } if response[0] != 'S' { return ErrTLSRefused } pgConn.conn = tls.Client(pgConn.conn, tlsConfig) return nil } func (pgConn *PgConn) rxAuthenticationX(msg *pgproto3.Authentication) (err error) { switch msg.Type { case pgproto3.AuthTypeOk: case pgproto3.AuthTypeCleartextPassword: err = pgConn.txPasswordMessage(pgConn.Config.Password) case pgproto3.AuthTypeMD5Password: digestedPassword := "md5" + hexMD5(hexMD5(pgConn.Config.Password+pgConn.Config.User)+string(msg.Salt[:])) err = pgConn.txPasswordMessage(digestedPassword) case pgproto3.AuthTypeSASL: err = pgConn.scramAuth(msg.SASLAuthMechanisms) default: err = errors.New("Received unknown authentication message") } return } func (pgConn *PgConn) txPasswordMessage(password string) (err error) { msg := &pgproto3.PasswordMessage{Password: password} _, err = pgConn.conn.Write(msg.Encode(pgConn.wbuf)) return err } func hexMD5(s string) string { hash := md5.New() io.WriteString(hash, s) return hex.EncodeToString(hash.Sum(nil)) } func (pgConn *PgConn) signalMessage() chan struct{} { if pgConn.bufferingReceive { panic("BUG: signalMessage when already in progress") } pgConn.bufferingReceive = true pgConn.bufferingReceiveMux.Lock() ch := make(chan struct{}) go func() { pgConn.bufferingReceiveMsg, pgConn.bufferingReceiveErr = pgConn.Frontend.Receive() pgConn.bufferingReceiveMux.Unlock() close(ch) }() return ch } func (pgConn *PgConn) ReceiveMessage() (pgproto3.BackendMessage, error) { var msg pgproto3.BackendMessage var err error if pgConn.bufferingReceive { pgConn.bufferingReceiveMux.Lock() msg = pgConn.bufferingReceiveMsg err = pgConn.bufferingReceiveErr pgConn.bufferingReceiveMux.Unlock() pgConn.bufferingReceive = false // If a timeout error happened in the background try the read again. if netErr, ok := err.(net.Error); ok && netErr.Timeout() { msg, err = pgConn.Frontend.Receive() } } else { msg, err = pgConn.Frontend.Receive() } if err != nil { // Close on anything other than timeout error - everything else is fatal if err, ok := err.(net.Error); !(ok && err.Timeout()) { pgConn.hardClose() } return nil, err } switch msg := msg.(type) { case *pgproto3.ReadyForQuery: pgConn.TxStatus = msg.TxStatus case *pgproto3.ParameterStatus: pgConn.parameterStatuses[msg.Name] = msg.Value case *pgproto3.ErrorResponse: if msg.Severity == "FATAL" { pgConn.hardClose() return nil, errorResponseToPgError(msg) } case *pgproto3.NoticeResponse: if pgConn.Config.OnNotice != nil { pgConn.Config.OnNotice(pgConn, noticeResponseToNotice(msg)) } case *pgproto3.NotificationResponse: if pgConn.Config.OnNotification != nil { pgConn.Config.OnNotification(pgConn, &Notification{PID: msg.PID, Channel: msg.Channel, Payload: msg.Payload}) } } return msg, nil } // Conn returns the underlying net.Conn. func (pgConn *PgConn) Conn() net.Conn { return pgConn.conn } // PID returns the backend PID. func (pgConn *PgConn) PID() uint32 { return pgConn.pid } // SecretKey returns the backend secret key used to send a cancel query message to the server. func (pgConn *PgConn) SecretKey() uint32 { return pgConn.secretKey } // Close closes a connection. It is safe to call Close on a already closed connection. Close attempts a clean close by // sending the exit message to PostgreSQL. However, this could block so ctx is available to limit the time to wait. The // underlying net.Conn.Close() will always be called regardless of any other errors. func (pgConn *PgConn) Close(ctx context.Context) error { if pgConn.status == connStatusClosed { return nil } pgConn.status = connStatusClosed defer pgConn.conn.Close() pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() _, err := pgConn.conn.Write([]byte{'X', 0, 0, 0, 4}) if err != nil { return linkErrors(ctx.Err(), err) } _, err = pgConn.conn.Read(make([]byte, 1)) if err != io.EOF { return linkErrors(ctx.Err(), err) } return pgConn.conn.Close() } // hardClose closes the underlying connection without sending the exit message. func (pgConn *PgConn) hardClose() error { if pgConn.status == connStatusClosed { return nil } pgConn.status = connStatusClosed return pgConn.conn.Close() } // TODO - rethink how to report status. At the moment this is just a temporary measure so pgx.Conn can detect death of // underlying connection. func (pgConn *PgConn) IsAlive() bool { return pgConn.status >= connStatusIdle } // lock locks the connection. It panics if the connection is already locked or is closed. func (pgConn *PgConn) lock() error { switch pgConn.status { case connStatusBusy: return ErrConnBusy // This only should be possible in case of an application bug. case connStatusClosed: return errors.New("conn closed") case connStatusUninitialized: return errors.New("conn uninitialized") } pgConn.status = connStatusBusy return nil } func (pgConn *PgConn) unlock() { switch pgConn.status { case connStatusBusy: pgConn.status = connStatusIdle case connStatusClosed: default: panic("BUG: cannot unlock unlocked connection") // This should only be possible if there is a bug in this package. } } // ParameterStatus returns the value of a parameter reported by the server (e.g. // server_version). Returns an empty string for unknown parameters. func (pgConn *PgConn) ParameterStatus(key string) string { return pgConn.parameterStatuses[key] } // CommandTag is the result of an Exec function type CommandTag []byte // RowsAffected returns the number of rows affected. If the CommandTag was not // for a row affecting command (e.g. "CREATE TABLE") then it returns 0. func (ct CommandTag) RowsAffected() int64 { idx := bytes.LastIndexByte([]byte(ct), ' ') if idx == -1 { return 0 } n, _ := strconv.ParseInt(string([]byte(ct)[idx+1:]), 10, 64) return n } func (ct CommandTag) String() string { return string(ct) } type PreparedStatementDescription struct { Name string SQL string ParamOIDs []uint32 Fields []pgproto3.FieldDescription } // Prepare creates a prepared statement. func (pgConn *PgConn) Prepare(ctx context.Context, name, sql string, paramOIDs []uint32) (*PreparedStatementDescription, error) { if err := pgConn.lock(); err != nil { return nil, linkErrors(err, ErrNoBytesSent) } defer pgConn.unlock() select { case <-ctx.Done(): return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() buf := pgConn.wbuf buf = (&pgproto3.Parse{Name: name, Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Describe{ObjectType: 'S', Name: name}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() if n == 0 { err = linkErrors(err, ErrNoBytesSent) } return nil, linkErrors(ctx.Err(), err) } psd := &PreparedStatementDescription{Name: name, SQL: sql} var parseErr error readloop: for { msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { case *pgproto3.ParameterDescription: psd.ParamOIDs = make([]uint32, len(msg.ParameterOIDs)) copy(psd.ParamOIDs, msg.ParameterOIDs) case *pgproto3.RowDescription: psd.Fields = make([]pgproto3.FieldDescription, len(msg.Fields)) copy(psd.Fields, msg.Fields) case *pgproto3.ErrorResponse: parseErr = errorResponseToPgError(msg) case *pgproto3.ReadyForQuery: break readloop } } if parseErr != nil { return nil, parseErr } return psd, nil } func errorResponseToPgError(msg *pgproto3.ErrorResponse) *PgError { return &PgError{ Severity: msg.Severity, Code: string(msg.Code), Message: string(msg.Message), Detail: string(msg.Detail), Hint: msg.Hint, Position: msg.Position, InternalPosition: msg.InternalPosition, InternalQuery: string(msg.InternalQuery), Where: string(msg.Where), SchemaName: string(msg.SchemaName), TableName: string(msg.TableName), ColumnName: string(msg.ColumnName), DataTypeName: string(msg.DataTypeName), ConstraintName: msg.ConstraintName, File: string(msg.File), Line: msg.Line, Routine: string(msg.Routine), } } func noticeResponseToNotice(msg *pgproto3.NoticeResponse) *Notice { pgerr := errorResponseToPgError((*pgproto3.ErrorResponse)(msg)) return (*Notice)(pgerr) } // CancelRequest sends a cancel request to the PostgreSQL server. It returns an error if unable to deliver the cancel // request, but lack of an error does not ensure that the query was canceled. As specified in the documentation, there // is no way to be sure a query was canceled. See https://www.postgresql.org/docs/11/protocol-flow.html#id-1.10.5.7.9 func (pgConn *PgConn) CancelRequest(ctx context.Context) error { // Open a cancellation request to the same server. The address is taken from the net.Conn directly instead of reusing // the connection config. This is important in high availability configurations where fallback connections may be // specified or DNS may be used to load balance. serverAddr := pgConn.conn.RemoteAddr() cancelConn, err := pgConn.Config.DialFunc(ctx, serverAddr.Network(), serverAddr.String()) if err != nil { return err } defer cancelConn.Close() contextWatcher := ctxwatch.NewContextWatcher( func() { cancelConn.SetDeadline(time.Date(1, 1, 1, 1, 1, 1, 1, time.UTC)) }, func() { cancelConn.SetDeadline(time.Time{}) }, ) contextWatcher.Watch(ctx) defer contextWatcher.Unwatch() buf := make([]byte, 16) binary.BigEndian.PutUint32(buf[0:4], 16) binary.BigEndian.PutUint32(buf[4:8], 80877102) binary.BigEndian.PutUint32(buf[8:12], uint32(pgConn.pid)) binary.BigEndian.PutUint32(buf[12:16], uint32(pgConn.secretKey)) _, err = cancelConn.Write(buf) if err != nil { return linkErrors(ctx.Err(), err) } _, err = cancelConn.Read(buf) if err != io.EOF { return errors.Errorf("Server failed to close connection after cancel query request: %w", linkErrors(ctx.Err(), err)) } return nil } // WaitForNotification waits for a LISTON/NOTIFY message to be received. It returns an error if a notification was not // received. func (pgConn *PgConn) WaitForNotification(ctx context.Context) error { if err := pgConn.lock(); err != nil { return err } defer pgConn.unlock() select { case <-ctx.Done(): return ctx.Err() default: } pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() for { msg, err := pgConn.ReceiveMessage() if err != nil { return linkErrors(ctx.Err(), err) } switch msg.(type) { case *pgproto3.NotificationResponse: return nil } } } // Exec executes SQL via the PostgreSQL simple query protocol. SQL may contain multiple queries. Execution is // implicitly wrapped in a transaction unless a transaction is already in progress or SQL contains transaction control // statements. // // Prefer ExecParams unless executing arbitrary SQL that may contain multiple queries. func (pgConn *PgConn) Exec(ctx context.Context, sql string) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, err: linkErrors(err, ErrNoBytesSent), } } pgConn.multiResultReader = MultiResultReader{ pgConn: pgConn, ctx: ctx, } multiResult := &pgConn.multiResultReader select { case <-ctx.Done(): multiResult.closed = true multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) pgConn.unlock() return multiResult default: } pgConn.contextWatcher.Watch(ctx) buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() pgConn.contextWatcher.Unwatch() multiResult.closed = true if n == 0 { err = linkErrors(err, ErrNoBytesSent) } multiResult.err = linkErrors(ctx.Err(), err) pgConn.unlock() return multiResult } return multiResult } // ExecParams executes a command via the PostgreSQL extended query protocol. // // sql is a SQL command string. It may only contain one query. Parameter substitution is positional using $1, $2, $3, // etc. // // paramValues are the parameter values. It must be encoded in the format given by paramFormats. // // paramOIDs is a slice of data type OIDs for paramValues. If paramOIDs is nil, the server will infer the data type for // all parameters. Any paramOID element that is 0 that will cause the server to infer the data type for that parameter. // ExecParams will panic if len(paramOIDs) is not 0, 1, or len(paramValues). // // paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or // binary format. If paramFormats is nil all results will be in text protocol. ExecParams will panic if // len(paramFormats) is not 0, 1, or len(paramValues). // // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or // binary format. If resultFormats is nil all results will be in text protocol. // // ResultReader must be closed before PgConn can be used again. func (pgConn *PgConn) ExecParams(ctx context.Context, sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) *ResultReader { result := pgConn.execExtendedPrefix(ctx, paramValues) if result.closed { return result } buf := pgConn.wbuf buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(buf) buf = (&pgproto3.Bind{ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) pgConn.execExtendedSuffix(ctx, buf, result) return result } // ExecPrepared enqueues the execution of a prepared statement via the PostgreSQL extended query protocol. // // paramValues are the parameter values. It must be encoded in the format given by paramFormats. // // paramFormats is a slice of format codes determining for each paramValue column whether it is encoded in text or // binary format. If paramFormats is nil all results will be in text protocol. ExecPrepared will panic if // len(paramFormats) is not 0, 1, or len(paramValues). // // resultFormats is a slice of format codes determining for each result column whether it is encoded in text or // binary format. If resultFormats is nil all results will be in text protocol. // // ResultReader must be closed before PgConn can be used again. func (pgConn *PgConn) ExecPrepared(ctx context.Context, stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) *ResultReader { result := pgConn.execExtendedPrefix(ctx, paramValues) if result.closed { return result } buf := pgConn.wbuf buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(buf) pgConn.execExtendedSuffix(ctx, buf, result) return result } func (pgConn *PgConn) execExtendedPrefix(ctx context.Context, paramValues [][]byte) *ResultReader { if err := pgConn.lock(); err != nil { return &ResultReader{ closed: true, err: linkErrors(err, ErrNoBytesSent), } } pgConn.resultReader = ResultReader{ pgConn: pgConn, ctx: ctx, } result := &pgConn.resultReader if len(paramValues) > math.MaxUint16 { result.concludeCommand(nil, errors.Errorf("extended protocol limited to %v parameters", math.MaxUint16)) result.closed = true pgConn.unlock() return result } select { case <-ctx.Done(): result.concludeCommand(nil, linkErrors(ctx.Err(), ErrNoBytesSent)) result.closed = true pgConn.unlock() return result default: } pgConn.contextWatcher.Watch(ctx) return result } func (pgConn *PgConn) execExtendedSuffix(ctx context.Context, buf []byte, result *ResultReader) { buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(buf) buf = (&pgproto3.Execute{}).Encode(buf) buf = (&pgproto3.Sync{}).Encode(buf) n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() if n == 0 { err = linkErrors(err, ErrNoBytesSent) } result.concludeCommand(nil, linkErrors(ctx.Err(), err)) pgConn.contextWatcher.Unwatch() result.closed = true pgConn.unlock() } } // CopyTo executes the copy command sql and copies the results to w. func (pgConn *PgConn) CopyTo(ctx context.Context, w io.Writer, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { return nil, linkErrors(err, ErrNoBytesSent) } select { case <-ctx.Done(): pgConn.unlock() return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() pgConn.unlock() if n == 0 { err = linkErrors(err, ErrNoBytesSent) } return nil, linkErrors(ctx.Err(), err) } // Read results var commandTag CommandTag var pgErr error for { msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { case *pgproto3.CopyDone: case *pgproto3.CopyData: _, err := w.Write(msg.Data) if err != nil { pgConn.hardClose() return nil, err } case *pgproto3.ReadyForQuery: pgConn.unlock() return commandTag, pgErr case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: pgErr = errorResponseToPgError(msg) } } } // CopyFrom executes the copy command sql and copies all of r to the PostgreSQL server. // // Note: context cancellation will only interrupt operations on the underlying PostgreSQL network connection. Reads on r // could still block. func (pgConn *PgConn) CopyFrom(ctx context.Context, r io.Reader, sql string) (CommandTag, error) { if err := pgConn.lock(); err != nil { return nil, linkErrors(err, ErrNoBytesSent) } defer pgConn.unlock() select { case <-ctx.Done(): return nil, linkErrors(ctx.Err(), ErrNoBytesSent) default: } pgConn.contextWatcher.Watch(ctx) defer pgConn.contextWatcher.Unwatch() // Send copy to command buf := pgConn.wbuf buf = (&pgproto3.Query{String: sql}).Encode(buf) n, err := pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() if n == 0 { err = linkErrors(err, ErrNoBytesSent) } return nil, linkErrors(ctx.Err(), err) } // Read until copy in response or error. var commandTag CommandTag var pgErr error pendingCopyInResponse := true for pendingCopyInResponse { msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { case *pgproto3.CopyInResponse: pendingCopyInResponse = false case *pgproto3.ErrorResponse: pgErr = errorResponseToPgError(msg) case *pgproto3.ReadyForQuery: return commandTag, pgErr } } // Send copy data buf = make([]byte, 0, 65536) buf = append(buf, 'd') sp := len(buf) var readErr error signalMessageChan := pgConn.signalMessage() for readErr == nil && pgErr == nil { var n int n, readErr = r.Read(buf[5:cap(buf)]) if n > 0 { buf = buf[0 : n+5] pgio.SetInt32(buf[sp:], int32(n+4)) _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) } } select { case <-signalMessageChan: msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { case *pgproto3.ErrorResponse: pgErr = errorResponseToPgError(msg) } default: } } buf = buf[:0] if readErr == io.EOF || pgErr != nil { copyDone := &pgproto3.CopyDone{} buf = copyDone.Encode(buf) } else { copyFail := &pgproto3.CopyFail{Message: readErr.Error()} buf = copyFail.Encode(buf) } _, err = pgConn.conn.Write(buf) if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) } // Read results for { msg, err := pgConn.ReceiveMessage() if err != nil { pgConn.hardClose() return nil, linkErrors(ctx.Err(), err) } switch msg := msg.(type) { case *pgproto3.ReadyForQuery: return commandTag, pgErr case *pgproto3.CommandComplete: commandTag = CommandTag(msg.CommandTag) case *pgproto3.ErrorResponse: pgErr = errorResponseToPgError(msg) } } } // MultiResultReader is a reader for a command that could return multiple results such as Exec or ExecBatch. type MultiResultReader struct { pgConn *PgConn ctx context.Context rr *ResultReader closed bool err error } // ReadAll reads all available results. Calling ReadAll is mutually exclusive with all other MultiResultReader methods. func (mrr *MultiResultReader) ReadAll() ([]*Result, error) { var results []*Result for mrr.NextResult() { results = append(results, mrr.ResultReader().Read()) } err := mrr.Close() return results, err } func (mrr *MultiResultReader) receiveMessage() (pgproto3.BackendMessage, error) { msg, err := mrr.pgConn.ReceiveMessage() if err != nil { mrr.pgConn.contextWatcher.Unwatch() mrr.err = preferContextOverNetTimeoutError(mrr.ctx, err) mrr.closed = true mrr.pgConn.hardClose() return nil, mrr.err } switch msg := msg.(type) { case *pgproto3.ReadyForQuery: mrr.pgConn.contextWatcher.Unwatch() mrr.closed = true mrr.pgConn.unlock() case *pgproto3.ErrorResponse: mrr.err = errorResponseToPgError(msg) } return msg, nil } // NextResult returns advances the MultiResultReader to the next result and returns true if a result is available. func (mrr *MultiResultReader) NextResult() bool { for !mrr.closed && mrr.err == nil { msg, err := mrr.receiveMessage() if err != nil { return false } switch msg := msg.(type) { case *pgproto3.RowDescription: mrr.pgConn.resultReader = ResultReader{ pgConn: mrr.pgConn, multiResultReader: mrr, ctx: mrr.ctx, fieldDescriptions: msg.Fields, } mrr.rr = &mrr.pgConn.resultReader return true case *pgproto3.CommandComplete: mrr.pgConn.resultReader = ResultReader{ commandTag: CommandTag(msg.CommandTag), commandConcluded: true, closed: true, } mrr.rr = &mrr.pgConn.resultReader return true case *pgproto3.EmptyQueryResponse: return false } } return false } // ResultReader returns the current ResultReader. func (mrr *MultiResultReader) ResultReader() *ResultReader { return mrr.rr } // Close closes the MultiResultReader and returns the first error that occurred during the MultiResultReader's use. func (mrr *MultiResultReader) Close() error { for !mrr.closed { _, err := mrr.receiveMessage() if err != nil { return mrr.err } } return mrr.err } // ResultReader is a reader for the result of a single query. type ResultReader struct { pgConn *PgConn multiResultReader *MultiResultReader ctx context.Context fieldDescriptions []pgproto3.FieldDescription rowValues [][]byte commandTag CommandTag commandConcluded bool closed bool err error } // Result is the saved query response that is returned by calling Read on a ResultReader. type Result struct { FieldDescriptions []pgproto3.FieldDescription Rows [][][]byte CommandTag CommandTag Err error } // Read saves the query response to a Result. func (rr *ResultReader) Read() *Result { br := &Result{} for rr.NextRow() { if br.FieldDescriptions == nil { br.FieldDescriptions = make([]pgproto3.FieldDescription, len(rr.FieldDescriptions())) copy(br.FieldDescriptions, rr.FieldDescriptions()) } row := make([][]byte, len(rr.Values())) copy(row, rr.Values()) br.Rows = append(br.Rows, row) } br.CommandTag, br.Err = rr.Close() return br } // NextRow advances the ResultReader to the next row and returns true if a row is available. func (rr *ResultReader) NextRow() bool { for !rr.commandConcluded { msg, err := rr.receiveMessage() if err != nil { return false } switch msg := msg.(type) { case *pgproto3.DataRow: rr.rowValues = msg.Values return true } } return false } // FieldDescriptions returns the field descriptions for the current result set. The returned slice is only valid until // the ResultReader is closed. func (rr *ResultReader) FieldDescriptions() []pgproto3.FieldDescription { return rr.fieldDescriptions } // Values returns the current row data. NextRow must have been previously been called. The returned [][]byte is only // valid until the next NextRow call or the ResultReader is closed. However, the underlying byte data is safe to // retain a reference to and mutate. func (rr *ResultReader) Values() [][]byte { return rr.rowValues } // Close consumes any remaining result data and returns the command tag or // error. func (rr *ResultReader) Close() (CommandTag, error) { if rr.closed { return rr.commandTag, rr.err } rr.closed = true for !rr.commandConcluded { _, err := rr.receiveMessage() if err != nil { return nil, rr.err } } if rr.multiResultReader == nil { for { msg, err := rr.receiveMessage() if err != nil { return nil, rr.err } switch msg := msg.(type) { // Detect a deferred constraint violation where the ErrorResponse is sent after CommandComplete. case *pgproto3.ErrorResponse: rr.err = errorResponseToPgError(msg) case *pgproto3.ReadyForQuery: rr.pgConn.contextWatcher.Unwatch() rr.pgConn.unlock() return rr.commandTag, rr.err } } } return rr.commandTag, rr.err } func (rr *ResultReader) receiveMessage() (msg pgproto3.BackendMessage, err error) { if rr.multiResultReader == nil { msg, err = rr.pgConn.ReceiveMessage() } else { msg, err = rr.multiResultReader.receiveMessage() } if err != nil { rr.concludeCommand(nil, err) rr.pgConn.contextWatcher.Unwatch() rr.closed = true if rr.multiResultReader == nil { rr.pgConn.hardClose() } return nil, rr.err } switch msg := msg.(type) { case *pgproto3.RowDescription: rr.fieldDescriptions = msg.Fields case *pgproto3.CommandComplete: rr.concludeCommand(CommandTag(msg.CommandTag), nil) case *pgproto3.ErrorResponse: rr.concludeCommand(nil, errorResponseToPgError(msg)) } return msg, nil } func (rr *ResultReader) concludeCommand(commandTag CommandTag, err error) { if rr.commandConcluded { return } rr.commandTag = commandTag rr.err = preferContextOverNetTimeoutError(rr.ctx, err) rr.fieldDescriptions = nil rr.rowValues = nil rr.commandConcluded = true } // Batch is a collection of queries that can be sent to the PostgreSQL server in a single round-trip. type Batch struct { buf []byte } // ExecParams appends an ExecParams command to the batch. See PgConn.ExecParams for parameter descriptions. func (batch *Batch) ExecParams(sql string, paramValues [][]byte, paramOIDs []uint32, paramFormats []int16, resultFormats []int16) { batch.buf = (&pgproto3.Parse{Query: sql, ParameterOIDs: paramOIDs}).Encode(batch.buf) batch.ExecPrepared("", paramValues, paramFormats, resultFormats) } // ExecPrepared appends an ExecPrepared e command to the batch. See PgConn.ExecPrepared for parameter descriptions. func (batch *Batch) ExecPrepared(stmtName string, paramValues [][]byte, paramFormats []int16, resultFormats []int16) { batch.buf = (&pgproto3.Bind{PreparedStatement: stmtName, ParameterFormatCodes: paramFormats, Parameters: paramValues, ResultFormatCodes: resultFormats}).Encode(batch.buf) batch.buf = (&pgproto3.Describe{ObjectType: 'P'}).Encode(batch.buf) batch.buf = (&pgproto3.Execute{}).Encode(batch.buf) } // ExecBatch executes all the queries in batch in a single round-trip. Execution is implicitly transactional unless a // transaction is already in progress or SQL contains transaction control statements. func (pgConn *PgConn) ExecBatch(ctx context.Context, batch *Batch) *MultiResultReader { if err := pgConn.lock(); err != nil { return &MultiResultReader{ closed: true, err: linkErrors(err, ErrNoBytesSent), } } pgConn.multiResultReader = MultiResultReader{ pgConn: pgConn, ctx: ctx, } multiResult := &pgConn.multiResultReader select { case <-ctx.Done(): multiResult.closed = true multiResult.err = linkErrors(ctx.Err(), ErrNoBytesSent) pgConn.unlock() return multiResult default: } pgConn.contextWatcher.Watch(ctx) batch.buf = (&pgproto3.Sync{}).Encode(batch.buf) // A large batch can deadlock without concurrent reading and writing. If the Write fails the underlying net.Conn is // closed. This is all that can be done without introducing a race condition or adding a concurrent safe communication // channel to relay the error back. The practical effect of this is that the underlying Write error is not reported. // The error the code reading the batch results receives will be a closed connection error. // // See https://github.com/jackc/pgx/issues/374. go func() { _, err := pgConn.conn.Write(batch.buf) if err != nil { pgConn.conn.Close() } }() return multiResult } // EscapeString escapes a string such that it can safely be interpolated into a SQL command string. It does not include // the surrounding single quotes. // // The current implementation requires that standard_conforming_strings=on and client_encoding="UTF8". If these // conditions are not met an error will be returned. It is possible these restrictions will be lifted in the future. func (pgConn *PgConn) EscapeString(s string) (string, error) { if pgConn.ParameterStatus("standard_conforming_strings") != "on" { return "", errors.New("EscapeString must be run with standard_conforming_strings=on") } if pgConn.ParameterStatus("client_encoding") != "UTF8" { return "", errors.New("EscapeString must be run with client_encoding=UTF8") } return strings.Replace(s, "'", "''", -1), nil }