// Package stdlib is the compatibility layer from pgx to database/sql. // // A database/sql connection can be established through sql.Open. // // db, err := sql.Open("pgx", "postgres://pgx_md5:secret@localhost:5432/pgx_test?sslmode=disable") // if err != nil { // return err // } // // Or from a DSN string. // // db, err := sql.Open("pgx", "user=postgres password=secret host=localhost port=5432 database=pgx_test sslmode=disable") // if err != nil { // return err // } // // A DriverConfig can be used to further configure the connection process. This // allows configuring TLS configuration, setting a custom dialer, logging, and // setting an AfterConnect hook. // // driverConfig := stdlib.DriverConfig{ // ConnConfig: pgx.ConnConfig{ // Logger: logger, // }, // AfterConnect: func(c *pgx.Conn) error { // // Ensure all connections have this temp table available // _, err := c.Exec("create temporary table foo(...)") // return err // }, // } // // stdlib.RegisterDriverConfig(&driverConfig) // // db, err := sql.Open("pgx", driverConfig.ConnectionString("postgres://pgx_md5:secret@127.0.0.1:5432/pgx_test")) // if err != nil { // return err // } // // pgx uses standard PostgreSQL positional parameters in queries. e.g. $1, $2. // It does not support named parameters. // // db.QueryRow("select * from users where id=$1", userID) // // AcquireConn and ReleaseConn acquire and release a *pgx.Conn from the standard // database/sql.DB connection pool. This allows operations that must be // performed on a single connection without running in a transaction, and it // supports operations that use pgx specific functionality. // // conn, err := stdlib.AcquireConn(db) // if err != nil { // return err // } // defer stdlib.ReleaseConn(db, conn) // // // do stuff with pgx.Conn // // It also can be used to enable a fast path for pgx while preserving // compatibility with other drivers and database. // // conn, err := stdlib.AcquireConn(db) // if err == nil { // // fast path with pgx // // ... // // release conn when done // stdlib.ReleaseConn(db, conn) // } else { // // normal path for other drivers and databases // } package stdlib import ( "context" "database/sql" "database/sql/driver" "fmt" "io" "math" "net" "reflect" "strings" "sync" "time" errors "golang.org/x/xerrors" "github.com/jackc/pgconn" "github.com/jackc/pgtype" "github.com/jackc/pgx/v4" ) // Only intrinsic types should be binary format with database/sql. var databaseSQLResultFormats pgx.QueryResultFormatsByOID var pgxDriver *Driver type ctxKey int var ctxKeyFakeTx ctxKey = 0 var ErrNotPgx = errors.New("not pgx *sql.DB") func init() { pgxDriver = &Driver{} fakeTxConns = make(map[*pgx.Conn]*sql.Tx) sql.Register("pgx", pgxDriver) databaseSQLResultFormats = pgx.QueryResultFormatsByOID{ pgtype.BoolOID: 1, pgtype.ByteaOID: 1, pgtype.CIDOID: 1, pgtype.DateOID: 1, pgtype.Float4OID: 1, pgtype.Float8OID: 1, pgtype.Int2OID: 1, pgtype.Int4OID: 1, pgtype.Int8OID: 1, pgtype.OIDOID: 1, pgtype.TimestampOID: 1, pgtype.TimestamptzOID: 1, pgtype.XIDOID: 1, } } var ( fakeTxMutex sync.Mutex fakeTxConns map[*pgx.Conn]*sql.Tx ) type Driver struct{} func (d *Driver) Open(name string) (driver.Conn, error) { connConfig, err := pgx.ParseConfig(name) if err != nil { return nil, err } ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) // Ensure eventual timeout defer cancel() conn, err := pgx.ConnectConfig(ctx, connConfig) if err != nil { return nil, err } c := &Conn{conn: conn, driver: d, connConfig: *connConfig} return c, nil } type Conn struct { conn *pgx.Conn psCount int64 // Counter used for creating unique prepared statement names driver *Driver connConfig pgx.ConnConfig } func (c *Conn) Prepare(query string) (driver.Stmt, error) { return c.PrepareContext(context.Background(), query) } func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn } name := fmt.Sprintf("pgx_%d", c.psCount) c.psCount++ ps, err := c.conn.Prepare(ctx, name, query) if err != nil { return nil, err } return &Stmt{ps: ps, conn: c}, nil } func (c *Conn) Close() error { return c.conn.Close(context.Background()) } func (c *Conn) Begin() (driver.Tx, error) { return c.BeginTx(context.Background(), driver.TxOptions{}) } func (c *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn } if pconn, ok := ctx.Value(ctxKeyFakeTx).(**pgx.Conn); ok { *pconn = c.conn return fakeTx{}, nil } var pgxOpts pgx.TxOptions switch sql.IsolationLevel(opts.Isolation) { case sql.LevelDefault: case sql.LevelReadUncommitted: pgxOpts.IsoLevel = pgx.ReadUncommitted case sql.LevelReadCommitted: pgxOpts.IsoLevel = pgx.ReadCommitted case sql.LevelRepeatableRead, sql.LevelSnapshot: pgxOpts.IsoLevel = pgx.RepeatableRead case sql.LevelSerializable: pgxOpts.IsoLevel = pgx.Serializable default: return nil, errors.Errorf("unsupported isolation: %v", opts.Isolation) } if opts.ReadOnly { pgxOpts.AccessMode = pgx.ReadOnly } tx, err := c.conn.BeginEx(ctx, pgxOpts) if err != nil { return nil, err } return wrapTx{tx: tx}, nil } func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn } args := namedValueToInterface(argsV) commandTag, err := c.conn.Exec(ctx, query, args...) // if we got a network error before we had a chance to send the query, retry if err != nil { var netErr net.Error if is := errors.As(err, &netErr); is && errors.Is(err, pgconn.ErrNoBytesSent) { return nil, driver.ErrBadConn } } return driver.RowsAffected(commandTag.RowsAffected()), err } func (c *Conn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn } args := []interface{}{databaseSQLResultFormats} args = append(args, namedValueToInterface(argsV)...) rows, err := c.conn.Query(ctx, query, args...) if err != nil { if errors.Is(err, pgconn.ErrNoBytesSent) { return nil, driver.ErrBadConn } return nil, err } // Preload first row because otherwise we won't know what columns are available when database/sql asks. more := rows.Next() return &Rows{conn: c, rows: rows, skipNext: true, skipNextMore: more}, nil } func (c *Conn) Ping(ctx context.Context) error { if !c.conn.IsAlive() { return driver.ErrBadConn } return c.conn.Ping(ctx) } type Stmt struct { ps *pgx.PreparedStatement conn *Conn } func (s *Stmt) Close() error { return s.conn.conn.Deallocate(context.Background(), s.ps.Name) } func (s *Stmt) NumInput() int { return len(s.ps.ParameterOIDs) } func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { return nil, errors.New("Stmt.Exec deprecated and not implemented") } func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) { return s.conn.ExecContext(ctx, s.ps.Name, argsV) } func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { return nil, errors.New("Stmt.Query deprecated and not implemented") } func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) { return s.conn.QueryContext(ctx, s.ps.Name, argsV) } type Rows struct { conn *Conn rows pgx.Rows values []interface{} skipNext bool skipNextMore bool } func (r *Rows) Columns() []string { fieldDescriptions := r.rows.FieldDescriptions() names := make([]string, 0, len(fieldDescriptions)) for _, fd := range fieldDescriptions { names = append(names, string(fd.Name)) } return names } // ColumnTypeDatabaseTypeName return the database system type name. func (r *Rows) ColumnTypeDatabaseTypeName(index int) string { if dt, ok := r.conn.conn.ConnInfo.DataTypeForOID(pgtype.OID(r.rows.FieldDescriptions()[index].DataTypeOID)); ok { return strings.ToUpper(dt.Name) } return "" } const varHeaderSize = 4 // ColumnTypeLength returns the length of the column type if the column is a // variable length type. If the column is not a variable length type ok // should return false. func (r *Rows) ColumnTypeLength(index int) (int64, bool) { fd := r.rows.FieldDescriptions()[index] switch fd.DataTypeOID { case pgtype.TextOID, pgtype.ByteaOID: return math.MaxInt64, true case pgtype.VarcharOID, pgtype.BPCharArrayOID: return int64(fd.TypeModifier - varHeaderSize), true default: return 0, false } } // ColumnTypePrecisionScale should return the precision and scale for decimal // types. If not applicable, ok should be false. func (r *Rows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { fd := r.rows.FieldDescriptions()[index] switch fd.DataTypeOID { case pgtype.NumericOID: mod := fd.TypeModifier - varHeaderSize precision = int64((mod >> 16) & 0xffff) scale = int64(mod & 0xffff) return precision, scale, true default: return 0, 0, false } } // ColumnTypeScanType returns the value type that can be used to scan types into. func (r *Rows) ColumnTypeScanType(index int) reflect.Type { fd := r.rows.FieldDescriptions()[index] switch fd.DataTypeOID { case pgtype.Float8OID: return reflect.TypeOf(float64(0)) case pgtype.Float4OID: return reflect.TypeOf(float32(0)) case pgtype.Int8OID: return reflect.TypeOf(int64(0)) case pgtype.Int4OID: return reflect.TypeOf(int32(0)) case pgtype.Int2OID: return reflect.TypeOf(int16(0)) case pgtype.VarcharOID, pgtype.BPCharArrayOID, pgtype.TextOID: return reflect.TypeOf("") case pgtype.BoolOID: return reflect.TypeOf(false) case pgtype.NumericOID: return reflect.TypeOf(float64(0)) case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID: return reflect.TypeOf(time.Time{}) case pgtype.ByteaOID: return reflect.TypeOf([]byte(nil)) default: return reflect.TypeOf(new(interface{})).Elem() } } func (r *Rows) Close() error { r.rows.Close() return nil } func (r *Rows) Next(dest []driver.Value) error { if r.values == nil { r.values = make([]interface{}, len(r.rows.FieldDescriptions())) for i, fd := range r.rows.FieldDescriptions() { switch fd.DataTypeOID { case pgtype.BoolOID: r.values[i] = &pgtype.Bool{} case pgtype.ByteaOID: r.values[i] = &pgtype.Bytea{} case pgtype.CIDOID: r.values[i] = &pgtype.CID{} case pgtype.DateOID: r.values[i] = &pgtype.Date{} case pgtype.Float4OID: r.values[i] = &pgtype.Float4{} case pgtype.Float8OID: r.values[i] = &pgtype.Float8{} case pgtype.Int2OID: r.values[i] = &pgtype.Int2{} case pgtype.Int4OID: r.values[i] = &pgtype.Int4{} case pgtype.Int8OID: r.values[i] = &pgtype.Int8{} case pgtype.JSONOID: r.values[i] = &pgtype.JSON{} case pgtype.JSONBOID: r.values[i] = &pgtype.JSONB{} case pgtype.OIDOID: r.values[i] = &pgtype.OIDValue{} case pgtype.TimestampOID: r.values[i] = &pgtype.Timestamp{} case pgtype.TimestamptzOID: r.values[i] = &pgtype.Timestamptz{} case pgtype.XIDOID: r.values[i] = &pgtype.XID{} default: r.values[i] = &pgtype.GenericText{} } } } var more bool if r.skipNext { more = r.skipNextMore r.skipNext = false } else { more = r.rows.Next() } if !more { if r.rows.Err() == nil { return io.EOF } else { return r.rows.Err() } } err := r.rows.Scan(r.values...) if err != nil { return err } for i, v := range r.values { dest[i], err = v.(driver.Valuer).Value() if err != nil { return err } } return nil } func valueToInterface(argsV []driver.Value) []interface{} { args := make([]interface{}, 0, len(argsV)) for _, v := range argsV { if v != nil { args = append(args, v.(interface{})) } else { args = append(args, nil) } } return args } func namedValueToInterface(argsV []driver.NamedValue) []interface{} { args := make([]interface{}, 0, len(argsV)) for _, v := range argsV { if v.Value != nil { args = append(args, v.Value.(interface{})) } else { args = append(args, nil) } } return args } type wrapTx struct{ tx *pgx.Tx } func (wtx wrapTx) Commit() error { return wtx.tx.Commit(context.Background()) } func (wtx wrapTx) Rollback() error { return wtx.tx.Rollback(context.Background()) } type fakeTx struct{} func (fakeTx) Commit() error { return nil } func (fakeTx) Rollback() error { return nil } func AcquireConn(db *sql.DB) (*pgx.Conn, error) { var conn *pgx.Conn ctx := context.WithValue(context.Background(), ctxKeyFakeTx, &conn) tx, err := db.BeginTx(ctx, nil) if err != nil { return nil, err } if conn == nil { tx.Rollback() return nil, ErrNotPgx } fakeTxMutex.Lock() fakeTxConns[conn] = tx fakeTxMutex.Unlock() return conn, nil } func ReleaseConn(db *sql.DB, conn *pgx.Conn) error { var tx *sql.Tx var ok bool fakeTxMutex.Lock() tx, ok = fakeTxConns[conn] if ok { delete(fakeTxConns, conn) fakeTxMutex.Unlock() } else { fakeTxMutex.Unlock() return errors.Errorf("can't release conn that is not acquired") } return tx.Rollback() }