1
0
mirror of https://github.com/jackc/pgx.git synced 2025-05-12 10:29:49 +00:00
James Hartig 6d336eccb1 Added LastStmtSent and use it to retry on errors if statement was not sent
Previously, a failed connection could be put back in a pool and when the
next query was attempted it would fail immediately trying to prepare the
query or reset the deadline. It wasn't clear if the Query or Exec call
could safely be retried since there was no way to know where it failed.

You can now call LastQuerySent and if it returns false then you're
guaranteed that the last call to Query(Ex)/Exec(Ex) didn't get far enough
to attempt to send the query. The call can be retried with a new
connection.

This is used in the stdlib to return a ErrBadConn if a network error
occurred and the statement was not attempted.

Fixes 
2018-11-19 10:44:40 -05:00

649 lines
16 KiB
Go

// Package stdlib is the compatibility layer from pgx to database/sql.
//
// A database/sql connection can be established through sql.Open.
//
// db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable")
// if err != nil {
// return err
// }
//
// Or from a DSN string.
//
// db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable")
// if err != nil {
// return err
// }
//
// A DriverConfig can be used to further configure the connection process. This
// allows configuring TLS configuration, setting a custom dialer, logging, and
// setting an AfterConnect hook.
//
// driverConfig := stdlib.DriverConfig{
// ConnConfig: pgx.ConnConfig{
// Logger: logger,
// },
// AfterConnect: func(c *pgx.Conn) error {
// // Ensure all connections have this temp table available
// _, err := c.Exec("create temporary table foo(...)")
// return err
// },
// }
//
// stdlib.RegisterDriverConfig(&driverConfig)
//
// db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test"))
// if err != nil {
// return err
// }
//
// pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2.
// It does not support named parameters.
//
// db.QueryRow("select * from users where id=$1", userID)
//
// AcquireConn and ReleaseConn acquire and release a *pgx.Conn from the standard
// database/sql.DB connection pool. This allows operations that must be
// performed on a single connection, but should not be run in a transaction or
// to use pgx specific functionality.
//
// conn, err := stdlib.AcquireConn(db)
// if err != nil {
// return err
// }
// defer stdlib.ReleaseConn(db, conn)
//
// // do stuff with pgx.Conn
//
// It also can be used to enable a fast path for pgx while preserving
// compatibility with other drivers and database.
//
// conn, err := stdlib.AcquireConn(db)
// if err == nil {
// // fast path with pgx
// // ...
// // release conn when done
// stdlib.ReleaseConn(db, conn)
// } else {
// // normal path for other drivers and databases
// }
package stdlib
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
"net"
"reflect"
"strings"
"sync"
"github.com/pkg/errors"
"github.com/jackc/pgx"
"github.com/jackc/pgx/pgtype"
)
// oids that map to intrinsic database/sql types. These will be allowed to be
// binary, anything else will be forced to text format
var databaseSqlOIDs map[pgtype.OID]bool
var pgxDriver *Driver
type ctxKey int
var ctxKeyFakeTx ctxKey = 0
var ErrNotPgx = errors.New("not pgx *sql.DB")
func init() {
pgxDriver = &Driver{
configs: make(map[int64]*DriverConfig),
}
fakeTxConns = make(map[*pgx.Conn]*sql.Tx)
sql.Register("pgx", pgxDriver)
databaseSqlOIDs = make(map[pgtype.OID]bool)
databaseSqlOIDs[pgtype.BoolOID] = true
databaseSqlOIDs[pgtype.ByteaOID] = true
databaseSqlOIDs[pgtype.CIDOID] = true
databaseSqlOIDs[pgtype.DateOID] = true
databaseSqlOIDs[pgtype.Float4OID] = true
databaseSqlOIDs[pgtype.Float8OID] = true
databaseSqlOIDs[pgtype.Int2OID] = true
databaseSqlOIDs[pgtype.Int4OID] = true
databaseSqlOIDs[pgtype.Int8OID] = true
databaseSqlOIDs[pgtype.OIDOID] = true
databaseSqlOIDs[pgtype.TimestampOID] = true
databaseSqlOIDs[pgtype.TimestamptzOID] = true
databaseSqlOIDs[pgtype.XIDOID] = true
}
var (
fakeTxMutex sync.Mutex
fakeTxConns map[*pgx.Conn]*sql.Tx
)
type Driver struct {
configMutex sync.Mutex
configCount int64
configs map[int64]*DriverConfig
}
func (d *Driver) Open(name string) (driver.Conn, error) {
var (
connConfig pgx.ConnConfig
afterConnect func(*pgx.Conn) error
)
if len(name) >= 9 && name[0] == 0 {
idBuf := []byte(name)[1:9]
id := int64(binary.BigEndian.Uint64(idBuf))
d.configMutex.Lock()
connConfig = d.configs[id].ConnConfig
afterConnect = d.configs[id].AfterConnect
d.configMutex.Unlock()
name = name[9:]
}
parsedConfig, err := pgx.ParseConnectionString(name)
if err != nil {
return nil, err
}
connConfig = connConfig.Merge(parsedConfig)
conn, err := pgx.Connect(connConfig)
if err != nil {
return nil, err
}
if afterConnect != nil {
err = afterConnect(conn)
if err != nil {
return nil, err
}
}
c := &Conn{conn: conn, driver: d, connConfig: connConfig}
return c, nil
}
type DriverConfig struct {
pgx.ConnConfig
AfterConnect func(*pgx.Conn) error // function to call on every new connection
driver *Driver
id int64
}
// ConnectionString encodes the DriverConfig into the original connection
// string. DriverConfig must be registered before calling ConnectionString.
func (c *DriverConfig) ConnectionString(original string) string {
if c.driver == nil {
panic("DriverConfig must be registered before calling ConnectionString")
}
buf := make([]byte, 9)
binary.BigEndian.PutUint64(buf[1:], uint64(c.id))
buf = append(buf, original...)
return string(buf)
}
func (d *Driver) registerDriverConfig(c *DriverConfig) {
d.configMutex.Lock()
c.driver = d
c.id = d.configCount
d.configs[d.configCount] = c
d.configCount++
d.configMutex.Unlock()
}
func (d *Driver) unregisterDriverConfig(c *DriverConfig) {
d.configMutex.Lock()
delete(d.configs, c.id)
d.configMutex.Unlock()
}
// RegisterDriverConfig registers a DriverConfig for use with Open.
func RegisterDriverConfig(c *DriverConfig) {
pgxDriver.registerDriverConfig(c)
}
// UnregisterDriverConfig removes a DriverConfig registration.
func UnregisterDriverConfig(c *DriverConfig) {
pgxDriver.unregisterDriverConfig(c)
}
type Conn struct {
conn *pgx.Conn
psCount int64 // Counter used for creating unique prepared statement names
driver *Driver
connConfig pgx.ConnConfig
}
func (c *Conn) Prepare(query string) (driver.Stmt, error) {
return c.PrepareContext(context.Background(), query)
}
func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
name := fmt.Sprintf("pgx_%d", c.psCount)
c.psCount++
ps, err := c.conn.PrepareEx(ctx, name, query, nil)
if err != nil {
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return &Stmt{ps: ps, conn: c}, nil
}
func (c *Conn) Close() error {
return c.conn.Close()
}
func (c *Conn) Begin() (driver.Tx, error) {
return c.BeginTx(context.Background(), driver.TxOptions{})
}
func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok {
*pconn = c.conn
return fakeTx{}, nil
}
var pgxOpts pgx.TxOptions
switch sql.IsolationLevel(opts.Isolation) {
case sql.LevelDefault:
case sql.LevelReadUncommitted:
pgxOpts.IsoLevel = pgx.ReadUncommitted
case sql.LevelReadCommitted:
pgxOpts.IsoLevel = pgx.ReadCommitted
case sql.LevelSnapshot:
pgxOpts.IsoLevel = pgx.RepeatableRead
case sql.LevelSerializable:
pgxOpts.IsoLevel = pgx.Serializable
default:
return nil, errors.Errorf("unsupported isolation: %v", opts.Isolation)
}
if opts.ReadOnly {
pgxOpts.AccessMode = pgx.ReadOnly
}
return c.conn.BeginEx(ctx, &pgxOpts)
}
func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
args := valueToInterface(argsV)
commandTag, err := c.conn.Exec(query, args...)
// if we got a network error before we had a chance to send the query, retry
if err != nil && !c.conn.LastStmtSent() {
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
}
return driver.RowsAffected(commandTag.RowsAffected()), err
}
func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
args := namedValueToInterface(argsV)
commandTag, err := c.conn.ExecEx(ctx, query, nil, args...)
// if we got a network error before we had a chance to send the query, retry
if err != nil && !c.conn.LastStmtSent() {
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
}
return driver.RowsAffected(commandTag.RowsAffected()), err
}
func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
if !c.connConfig.PreferSimpleProtocol {
ps, err := c.conn.Prepare("", query)
if err != nil {
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPrepared("", argsV)
}
rows, err := c.conn.Query(query, valueToInterface(argsV)...)
if err != nil {
// if we got a network error before we had a chance to send the query, retry
if !c.conn.LastStmtSent() {
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
}
return nil, err
}
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
more := rows.Next()
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
}
func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
if !c.connConfig.PreferSimpleProtocol {
ps, err := c.conn.PrepareEx(ctx, "", query, nil)
if err != nil {
// since PrepareEx failed, we didn't actually get to send the values, so
// we can safely retry
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
return nil, err
}
restrictBinaryToDatabaseSqlTypes(ps)
return c.queryPreparedContext(ctx, "", argsV)
}
rows, err := c.conn.QueryEx(ctx, query, nil, namedValueToInterface(argsV)...)
if err != nil {
// if we got a network error before we had a chance to send the query, retry
if !c.conn.LastStmtSent() {
if _, is := err.(net.Error); is {
return nil, driver.ErrBadConn
}
}
return nil, err
}
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
more := rows.Next()
return &Rows{rows: rows, skipNext: true, skipNextMore: more}, nil
}
func (c *Conn) queryPrepared(name string, argsV []driver.Value) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
args := valueToInterface(argsV)
rows, err := c.conn.Query(name, args...)
if err != nil {
return nil, err
}
return &Rows{rows: rows}, nil
}
func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []driver.NamedValue) (driver.Rows, error) {
if !c.conn.IsAlive() {
return nil, driver.ErrBadConn
}
args := namedValueToInterface(argsV)
rows, err := c.conn.QueryEx(ctx, name, nil, args...)
if err != nil {
return nil, err
}
return &Rows{rows: rows}, nil
}
func (c *Conn) Ping(ctx context.Context) error {
if !c.conn.IsAlive() {
return driver.ErrBadConn
}
return c.conn.Ping(ctx)
}
// Anything that isn't a database/sql compatible type needs to be forced to
// text format so that pgx.Rows.Values doesn't decode it into a native type
// (e.g. []int32)
func restrictBinaryToDatabaseSqlTypes(ps *pgx.PreparedStatement) {
for i := range ps.FieldDescriptions {
intrinsic, _ := databaseSqlOIDs[ps.FieldDescriptions[i].DataType]
if !intrinsic {
ps.FieldDescriptions[i].FormatCode = pgx.TextFormatCode
}
}
}
type Stmt struct {
ps *pgx.PreparedStatement
conn *Conn
}
func (s *Stmt) Close() error {
return s.conn.conn.Deallocate(s.ps.Name)
}
func (s *Stmt) NumInput() int {
return len(s.ps.ParameterOIDs)
}
func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
return s.conn.Exec(s.ps.Name, argsV)
}
func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) {
return s.conn.ExecContext(ctx, s.ps.Name, argsV)
}
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
return s.conn.queryPrepared(s.ps.Name, argsV)
}
func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) {
return s.conn.queryPreparedContext(ctx, s.ps.Name, argsV)
}
type Rows struct {
rows *pgx.Rows
values []interface{}
skipNext bool
skipNextMore bool
}
func (r *Rows) Columns() []string {
fieldDescriptions := r.rows.FieldDescriptions()
names := make([]string, 0, len(fieldDescriptions))
for _, fd := range fieldDescriptions {
names = append(names, fd.Name)
}
return names
}
// ColumnTypeDatabaseTypeName return the database system type name.
func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
return strings.ToUpper(r.rows.FieldDescriptions()[index].DataTypeName)
}
// ColumnTypeLength returns the length of the column type if the column is a
// variable length type. If the column is not a variable length type ok
// should return false.
func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
return r.rows.FieldDescriptions()[index].Length()
}
// ColumnTypePrecisionScale should return the precision and scale for decimal
// types. If not applicable, ok should be false.
func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
return r.rows.FieldDescriptions()[index].PrecisionScale()
}
// ColumnTypeScanType returns the value type that can be used to scan types into.
func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
return r.rows.FieldDescriptions()[index].Type()
}
func (r *Rows) Close() error {
r.rows.Close()
return nil
}
func (r *Rows) Next(dest []driver.Value) error {
if r.values == nil {
r.values = make([]interface{}, len(r.rows.FieldDescriptions()))
for i, fd := range r.rows.FieldDescriptions() {
switch fd.DataType {
case pgtype.BoolOID:
r.values[i] = &pgtype.Bool{}
case pgtype.ByteaOID:
r.values[i] = &pgtype.Bytea{}
case pgtype.CIDOID:
r.values[i] = &pgtype.CID{}
case pgtype.DateOID:
r.values[i] = &pgtype.Date{}
case pgtype.Float4OID:
r.values[i] = &pgtype.Float4{}
case pgtype.Float8OID:
r.values[i] = &pgtype.Float8{}
case pgtype.Int2OID:
r.values[i] = &pgtype.Int2{}
case pgtype.Int4OID:
r.values[i] = &pgtype.Int4{}
case pgtype.Int8OID:
r.values[i] = &pgtype.Int8{}
case pgtype.JSONOID:
r.values[i] = &pgtype.JSON{}
case pgtype.JSONBOID:
r.values[i] = &pgtype.JSONB{}
case pgtype.OIDOID:
r.values[i] = &pgtype.OIDValue{}
case pgtype.TimestampOID:
r.values[i] = &pgtype.Timestamp{}
case pgtype.TimestamptzOID:
r.values[i] = &pgtype.Timestamptz{}
case pgtype.XIDOID:
r.values[i] = &pgtype.XID{}
default:
r.values[i] = &pgtype.GenericText{}
}
}
}
var more bool
if r.skipNext {
more = r.skipNextMore
r.skipNext = false
} else {
more = r.rows.Next()
}
if !more {
if r.rows.Err() == nil {
return io.EOF
} else {
return r.rows.Err()
}
}
err := r.rows.Scan(r.values...)
if err != nil {
return err
}
for i, v := range r.values {
dest[i], err = v.(driver.Valuer).Value()
if err != nil {
return err
}
}
return nil
}
func valueToInterface(argsV []driver.Value) []interface{} {
args := make([]interface{}, 0, len(argsV))
for _, v := range argsV {
if v != nil {
args = append(args, v.(interface{}))
} else {
args = append(args, nil)
}
}
return args
}
func namedValueToInterface(argsV []driver.NamedValue) []interface{} {
args := make([]interface{}, 0, len(argsV))
for _, v := range argsV {
if v.Value != nil {
args = append(args, v.Value.(interface{}))
} else {
args = append(args, nil)
}
}
return args
}
type fakeTx struct{}
func (fakeTx) Commit() error { return nil }
func (fakeTx) Rollback() error { return nil }
func AcquireConn(db *sql.DB) (*pgx.Conn, error) {
var conn *pgx.Conn
ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn)
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
if conn == nil {
tx.Rollback()
return nil, ErrNotPgx
}
fakeTxMutex.Lock()
fakeTxConns[conn] = tx
fakeTxMutex.Unlock()
return conn, nil
}
func ReleaseConn(db *sql.DB, conn *pgx.Conn) error {
var tx *sql.Tx
var ok bool
fakeTxMutex.Lock()
tx, ok = fakeTxConns[conn]
if ok {
delete(fakeTxConns, conn)
fakeTxMutex.Unlock()
} else {
fakeTxMutex.Unlock()
return errors.Errorf("can't release conn that is not acquired")
}
return tx.Rollback()
}