mirror of
https://github.com/jackc/pgx.git
synced 2025-04-27 13:14:32 +00:00
refactor to use the same connection implementation
This commit is contained in:
parent
3d4540aa1b
commit
51ade172e5
@ -1,531 +0,0 @@
|
|||||||
package stdlib
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"database/sql/driver"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math"
|
|
||||||
"reflect"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
|
||||||
"github.com/jackc/pgx/v5/pgxpool"
|
|
||||||
)
|
|
||||||
|
|
||||||
// OptionOpenDBFromPool options for configuring the driver connector when
|
|
||||||
// opening a new *sql.DB from *pgxpool.Pool.
|
|
||||||
type OptionOpenDBFromPool func(*poolConnector)
|
|
||||||
|
|
||||||
// OptionBeforeConnect provides a callback for before acquire.
|
|
||||||
func OptionBeforeAcquire(f func(context.Context, *pgxpool.Pool) error) OptionOpenDBFromPool {
|
|
||||||
return func(c *poolConnector) {
|
|
||||||
c.BeforeAcquire = f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OptionAfterAcquire provides a callback for after acquire.
|
|
||||||
func OptionAfterAcquire(f func(context.Context, *pgxpool.Conn) error) OptionOpenDBFromPool {
|
|
||||||
return func(c *poolConnector) {
|
|
||||||
c.AfterAcquire = f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OptionResetSession provides a callback that can be used to add custom logic
|
|
||||||
// prior to executing a query on the connection if the connection has been used
|
|
||||||
// before. If ResetSessionFunc returns ErrBadConn error the connection will be
|
|
||||||
// discarded.
|
|
||||||
func OptionPoolConnResetSession(f func(context.Context, *pgxpool.Conn) error) OptionOpenDBFromPool {
|
|
||||||
return func(c *poolConnector) {
|
|
||||||
c.ResetSession = f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetPoolConnector creates a new driver connector to open a new *sql.DB from
|
|
||||||
// the provided pgx pool.
|
|
||||||
func GetPoolConnector(pool *pgxpool.Pool, opts ...OptionOpenDBFromPool) driver.Connector {
|
|
||||||
c := &poolConnector{pool: pool, driver: pgxDriver}
|
|
||||||
for _, opt := range opts {
|
|
||||||
opt(c)
|
|
||||||
}
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// OpenDBFromPool opens a database using the provided pgx pool. It sets the
|
|
||||||
// maximum number of connections in the idle connection pool to zero, since
|
|
||||||
// those connections are managed in this pgx pool.
|
|
||||||
func OpenDBFromPool(pool *pgxpool.Pool, opts ...OptionOpenDBFromPool) *sql.DB {
|
|
||||||
c := GetPoolConnector(pool, opts...)
|
|
||||||
db := sql.OpenDB(c)
|
|
||||||
db.SetMaxIdleConns(0)
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
|
|
||||||
type poolConnector struct {
|
|
||||||
pool *pgxpool.Pool
|
|
||||||
BeforeAcquire func(context.Context, *pgxpool.Pool) error // function to call before acquiring of every new connection
|
|
||||||
AfterAcquire func(context.Context, *pgxpool.Conn) error // function to call after acquiring of every new connection
|
|
||||||
ResetSession func(context.Context, *pgxpool.Conn) error // function is called before a connection is reused
|
|
||||||
driver driver.Driver
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *poolConnector) Connect(ctx context.Context) (driver.Conn, error) {
|
|
||||||
if err := c.BeforeAcquire(ctx, c.pool); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := c.pool.Acquire(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.AfterAcquire(ctx, conn); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &PoolConn{conn: conn}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *poolConnector) Driver() driver.Driver {
|
|
||||||
return c.driver
|
|
||||||
}
|
|
||||||
|
|
||||||
type PoolConn struct {
|
|
||||||
conn *pgxpool.Conn
|
|
||||||
psCount int64 // Counter used for creating unique prepared statement names
|
|
||||||
resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused
|
|
||||||
lastResetSessionTime time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// Conn returns the underlying *pgx.Conn
|
|
||||||
func (c *PoolConn) Conn() *pgx.Conn {
|
|
||||||
return c.conn.Conn()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PoolConn) Prepare(query string) (driver.Stmt, error) {
|
|
||||||
return c.PrepareContext(context.Background(), query)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PoolConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
|
||||||
conn := c.Conn()
|
|
||||||
|
|
||||||
if conn.IsClosed() {
|
|
||||||
return nil, driver.ErrBadConn
|
|
||||||
}
|
|
||||||
|
|
||||||
name := fmt.Sprintf("pgx_%d", c.psCount)
|
|
||||||
c.psCount++
|
|
||||||
|
|
||||||
sd, err := conn.Prepare(ctx, name, query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &PoolStmt{sd: sd, conn: c}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PoolConn) Close() error {
|
|
||||||
c.conn.Release()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PoolConn) Begin() (driver.Tx, error) {
|
|
||||||
return c.BeginTx(context.Background(), driver.TxOptions{})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PoolConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
|
||||||
if c.Conn().IsClosed() {
|
|
||||||
return nil, driver.ErrBadConn
|
|
||||||
}
|
|
||||||
|
|
||||||
var pgxOpts pgx.TxOptions
|
|
||||||
switch sql.IsolationLevel(opts.Isolation) {
|
|
||||||
case sql.LevelDefault:
|
|
||||||
case sql.LevelReadUncommitted:
|
|
||||||
pgxOpts.IsoLevel = pgx.ReadUncommitted
|
|
||||||
case sql.LevelReadCommitted:
|
|
||||||
pgxOpts.IsoLevel = pgx.ReadCommitted
|
|
||||||
case sql.LevelRepeatableRead, sql.LevelSnapshot:
|
|
||||||
pgxOpts.IsoLevel = pgx.RepeatableRead
|
|
||||||
case sql.LevelSerializable:
|
|
||||||
pgxOpts.IsoLevel = pgx.Serializable
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation)
|
|
||||||
}
|
|
||||||
|
|
||||||
if opts.ReadOnly {
|
|
||||||
pgxOpts.AccessMode = pgx.ReadOnly
|
|
||||||
}
|
|
||||||
|
|
||||||
tx, err := c.conn.BeginTx(ctx, pgxOpts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return wrapTx{ctx: ctx, tx: tx}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PoolConn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) {
|
|
||||||
if c.Conn().IsClosed() {
|
|
||||||
return nil, driver.ErrBadConn
|
|
||||||
}
|
|
||||||
|
|
||||||
args := namedValueToInterface(argsV)
|
|
||||||
|
|
||||||
commandTag, err := c.conn.Exec(ctx, query, args...)
|
|
||||||
// if we got a network error before we had a chance to send the query, retry
|
|
||||||
if err != nil {
|
|
||||||
if pgconn.SafeToRetry(err) {
|
|
||||||
return nil, driver.ErrBadConn
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return driver.RowsAffected(commandTag.RowsAffected()), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PoolConn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) {
|
|
||||||
if c.Conn().IsClosed() {
|
|
||||||
return nil, driver.ErrBadConn
|
|
||||||
}
|
|
||||||
|
|
||||||
args := []any{databaseSQLResultFormats}
|
|
||||||
args = append(args, namedValueToInterface(argsV)...)
|
|
||||||
|
|
||||||
rows, err := c.conn.Query(ctx, query, args...)
|
|
||||||
if err != nil {
|
|
||||||
if pgconn.SafeToRetry(err) {
|
|
||||||
return nil, driver.ErrBadConn
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Preload first row because otherwise we won't know what columns are available when database/sql asks.
|
|
||||||
more := rows.Next()
|
|
||||||
if err = rows.Err(); err != nil {
|
|
||||||
rows.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &PoolRows{conn: c, rows: rows, skipNext: true, skipNextMore: more}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PoolConn) Ping(ctx context.Context) error {
|
|
||||||
if c.Conn().IsClosed() {
|
|
||||||
return driver.ErrBadConn
|
|
||||||
}
|
|
||||||
|
|
||||||
err := c.conn.Ping(ctx)
|
|
||||||
if err != nil {
|
|
||||||
// A Ping failure implies some sort of fatal state. The connection is almost certainly already closed by the
|
|
||||||
// failure, but manually close it just to be sure.
|
|
||||||
c.Close()
|
|
||||||
return driver.ErrBadConn
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PoolConn) CheckNamedValue(*driver.NamedValue) error {
|
|
||||||
// Underlying pgx supports sql.Scanner and driver.Valuer interfaces natively. So everything can be passed through directly.
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *PoolConn) ResetSession(ctx context.Context) error {
|
|
||||||
if c.Conn().IsClosed() {
|
|
||||||
return driver.ErrBadConn
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
if now.Sub(c.lastResetSessionTime) > time.Second {
|
|
||||||
if err := c.Conn().PgConn().CheckConn(); err != nil {
|
|
||||||
return driver.ErrBadConn
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.lastResetSessionTime = now
|
|
||||||
|
|
||||||
return c.resetSessionFunc(ctx, c.Conn())
|
|
||||||
}
|
|
||||||
|
|
||||||
type PoolStmt struct {
|
|
||||||
sd *pgconn.StatementDescription
|
|
||||||
conn *PoolConn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *PoolStmt) Close() error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
||||||
defer cancel()
|
|
||||||
return s.conn.Conn().Deallocate(ctx, s.sd.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *PoolStmt) NumInput() int {
|
|
||||||
return len(s.sd.ParamOIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *PoolStmt) Exec(argsV []driver.Value) (driver.Result, error) {
|
|
||||||
return nil, errors.New("Stmt.Exec deprecated and not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *PoolStmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) {
|
|
||||||
return s.conn.ExecContext(ctx, s.sd.Name, argsV)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *PoolStmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
|
||||||
return nil, errors.New("Stmt.Query deprecated and not implemented")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *PoolStmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) {
|
|
||||||
return s.conn.QueryContext(ctx, s.sd.Name, argsV)
|
|
||||||
}
|
|
||||||
|
|
||||||
type PoolRows struct {
|
|
||||||
conn *PoolConn
|
|
||||||
rows pgx.Rows
|
|
||||||
valueFuncs []rowValueFunc
|
|
||||||
skipNext bool
|
|
||||||
skipNextMore bool
|
|
||||||
|
|
||||||
columnNames []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *PoolRows) Columns() []string {
|
|
||||||
if r.columnNames == nil {
|
|
||||||
fields := r.rows.FieldDescriptions()
|
|
||||||
r.columnNames = make([]string, len(fields))
|
|
||||||
for i, fd := range fields {
|
|
||||||
r.columnNames[i] = string(fd.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return r.columnNames
|
|
||||||
}
|
|
||||||
|
|
||||||
// ColumnTypeDatabaseTypeName returns the database system type name. If the name is unknown the OID is returned.
|
|
||||||
func (r *PoolRows) ColumnTypeDatabaseTypeName(index int) string {
|
|
||||||
if dt, ok := r.conn.Conn().TypeMap().TypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok {
|
|
||||||
return strings.ToUpper(dt.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
return strconv.FormatInt(int64(r.rows.FieldDescriptions()[index].DataTypeOID), 10)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ColumnTypeLength returns the length of the column type if the column is a
|
|
||||||
// variable length type. If the column is not a variable length type ok
|
|
||||||
// should return false.
|
|
||||||
func (r *PoolRows) ColumnTypeLength(index int) (int64, bool) {
|
|
||||||
fd := r.rows.FieldDescriptions()[index]
|
|
||||||
|
|
||||||
switch fd.DataTypeOID {
|
|
||||||
case pgtype.TextOID, pgtype.ByteaOID:
|
|
||||||
return math.MaxInt64, true
|
|
||||||
case pgtype.VarcharOID, pgtype.BPCharArrayOID:
|
|
||||||
return int64(fd.TypeModifier - varHeaderSize), true
|
|
||||||
default:
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ColumnTypePrecisionScale should return the precision and scale for decimal
|
|
||||||
// types. If not applicable, ok should be false.
|
|
||||||
func (r *PoolRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) {
|
|
||||||
fd := r.rows.FieldDescriptions()[index]
|
|
||||||
|
|
||||||
switch fd.DataTypeOID {
|
|
||||||
case pgtype.NumericOID:
|
|
||||||
mod := fd.TypeModifier - varHeaderSize
|
|
||||||
precision = int64((mod >> 16) & 0xffff)
|
|
||||||
scale = int64(mod & 0xffff)
|
|
||||||
return precision, scale, true
|
|
||||||
default:
|
|
||||||
return 0, 0, false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ColumnTypeScanType returns the value type that can be used to scan types into.
|
|
||||||
func (r *PoolRows) ColumnTypeScanType(index int) reflect.Type {
|
|
||||||
fd := r.rows.FieldDescriptions()[index]
|
|
||||||
|
|
||||||
switch fd.DataTypeOID {
|
|
||||||
case pgtype.Float8OID:
|
|
||||||
return reflect.TypeOf(float64(0))
|
|
||||||
case pgtype.Float4OID:
|
|
||||||
return reflect.TypeOf(float32(0))
|
|
||||||
case pgtype.Int8OID:
|
|
||||||
return reflect.TypeOf(int64(0))
|
|
||||||
case pgtype.Int4OID:
|
|
||||||
return reflect.TypeOf(int32(0))
|
|
||||||
case pgtype.Int2OID:
|
|
||||||
return reflect.TypeOf(int16(0))
|
|
||||||
case pgtype.BoolOID:
|
|
||||||
return reflect.TypeOf(false)
|
|
||||||
case pgtype.NumericOID:
|
|
||||||
return reflect.TypeOf(float64(0))
|
|
||||||
case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID:
|
|
||||||
return reflect.TypeOf(time.Time{})
|
|
||||||
case pgtype.ByteaOID:
|
|
||||||
return reflect.TypeOf([]byte(nil))
|
|
||||||
default:
|
|
||||||
return reflect.TypeOf("")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *PoolRows) Close() error {
|
|
||||||
r.rows.Close()
|
|
||||||
return r.rows.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *PoolRows) Next(dest []driver.Value) error {
|
|
||||||
m := r.conn.Conn().TypeMap()
|
|
||||||
fieldDescriptions := r.rows.FieldDescriptions()
|
|
||||||
|
|
||||||
if r.valueFuncs == nil {
|
|
||||||
r.valueFuncs = make([]rowValueFunc, len(fieldDescriptions))
|
|
||||||
|
|
||||||
for i, fd := range fieldDescriptions {
|
|
||||||
dataTypeOID := fd.DataTypeOID
|
|
||||||
format := fd.Format
|
|
||||||
|
|
||||||
switch fd.DataTypeOID {
|
|
||||||
case pgtype.BoolOID:
|
|
||||||
var d bool
|
|
||||||
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
|
||||||
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
||||||
err := scanPlan.Scan(src, &d)
|
|
||||||
return d, err
|
|
||||||
}
|
|
||||||
case pgtype.ByteaOID:
|
|
||||||
var d []byte
|
|
||||||
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
|
||||||
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
||||||
err := scanPlan.Scan(src, &d)
|
|
||||||
return d, err
|
|
||||||
}
|
|
||||||
case pgtype.CIDOID, pgtype.OIDOID, pgtype.XIDOID:
|
|
||||||
var d pgtype.Uint32
|
|
||||||
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
|
||||||
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
||||||
err := scanPlan.Scan(src, &d)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return d.Value()
|
|
||||||
}
|
|
||||||
case pgtype.DateOID:
|
|
||||||
var d pgtype.Date
|
|
||||||
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
|
||||||
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
||||||
err := scanPlan.Scan(src, &d)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return d.Value()
|
|
||||||
}
|
|
||||||
case pgtype.Float4OID:
|
|
||||||
var d float32
|
|
||||||
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
|
||||||
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
||||||
err := scanPlan.Scan(src, &d)
|
|
||||||
return float64(d), err
|
|
||||||
}
|
|
||||||
case pgtype.Float8OID:
|
|
||||||
var d float64
|
|
||||||
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
|
||||||
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
||||||
err := scanPlan.Scan(src, &d)
|
|
||||||
return d, err
|
|
||||||
}
|
|
||||||
case pgtype.Int2OID:
|
|
||||||
var d int16
|
|
||||||
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
|
||||||
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
||||||
err := scanPlan.Scan(src, &d)
|
|
||||||
return int64(d), err
|
|
||||||
}
|
|
||||||
case pgtype.Int4OID:
|
|
||||||
var d int32
|
|
||||||
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
|
||||||
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
||||||
err := scanPlan.Scan(src, &d)
|
|
||||||
return int64(d), err
|
|
||||||
}
|
|
||||||
case pgtype.Int8OID:
|
|
||||||
var d int64
|
|
||||||
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
|
||||||
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
||||||
err := scanPlan.Scan(src, &d)
|
|
||||||
return d, err
|
|
||||||
}
|
|
||||||
case pgtype.JSONOID, pgtype.JSONBOID:
|
|
||||||
var d []byte
|
|
||||||
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
|
||||||
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
||||||
err := scanPlan.Scan(src, &d)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return d, nil
|
|
||||||
}
|
|
||||||
case pgtype.TimestampOID:
|
|
||||||
var d pgtype.Timestamp
|
|
||||||
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
|
||||||
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
||||||
err := scanPlan.Scan(src, &d)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return d.Value()
|
|
||||||
}
|
|
||||||
case pgtype.TimestamptzOID:
|
|
||||||
var d pgtype.Timestamptz
|
|
||||||
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
|
||||||
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
||||||
err := scanPlan.Scan(src, &d)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return d.Value()
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
var d string
|
|
||||||
scanPlan := m.PlanScan(dataTypeOID, format, &d)
|
|
||||||
r.valueFuncs[i] = func(src []byte) (driver.Value, error) {
|
|
||||||
err := scanPlan.Scan(src, &d)
|
|
||||||
return d, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var more bool
|
|
||||||
if r.skipNext {
|
|
||||||
more = r.skipNextMore
|
|
||||||
r.skipNext = false
|
|
||||||
} else {
|
|
||||||
more = r.rows.Next()
|
|
||||||
}
|
|
||||||
|
|
||||||
if !more {
|
|
||||||
if r.rows.Err() == nil {
|
|
||||||
return io.EOF
|
|
||||||
} else {
|
|
||||||
return r.rows.Err()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, rv := range r.rows.RawValues() {
|
|
||||||
if rv != nil {
|
|
||||||
var err error
|
|
||||||
dest[i], err = r.valueFuncs[i](rv)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("convert field %d failed: %v", i, err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
dest[i] = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
@ -74,6 +74,7 @@ import (
|
|||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
"github.com/jackc/pgx/v5/pgconn"
|
"github.com/jackc/pgx/v5/pgconn"
|
||||||
"github.com/jackc/pgx/v5/pgtype"
|
"github.com/jackc/pgx/v5/pgtype"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Only intrinsic types should be binary format with database/sql.
|
// Only intrinsic types should be binary format with database/sql.
|
||||||
@ -125,7 +126,7 @@ func contains(list []string, y string) bool {
|
|||||||
type OptionOpenDB func(*connector)
|
type OptionOpenDB func(*connector)
|
||||||
|
|
||||||
// OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will
|
// OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will
|
||||||
// be used to connect, so only its immediate members should be modified.
|
// be used to connect, so only its immediate members should be modified. Used only if db is opened with *pgx.ConnConfig.
|
||||||
func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB {
|
func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB {
|
||||||
return func(dc *connector) {
|
return func(dc *connector) {
|
||||||
dc.BeforeConnect = bc
|
dc.BeforeConnect = bc
|
||||||
@ -139,6 +140,20 @@ func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OptionBeforeConnect provides a callback for before acquire. Used only if db is opened with *pgxpool.Pool.
|
||||||
|
func OptionBeforeAcquire(ba func(context.Context, *pgxpool.Pool) error) OptionOpenDB {
|
||||||
|
return func(c *connector) {
|
||||||
|
c.BeforeAcquire = ba
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OptionAfterAcquire provides a callback for after acquire. Used only if db is opened with *pgxpool.Pool.
|
||||||
|
func OptionAfterAcquire(aa func(context.Context, *pgxpool.Conn) error) OptionOpenDB {
|
||||||
|
return func(c *connector) {
|
||||||
|
c.AfterAcquire = aa
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// OptionResetSession provides a callback that can be used to add custom logic prior to executing a query on the
|
// OptionResetSession provides a callback that can be used to add custom logic prior to executing a query on the
|
||||||
// connection if the connection has been used before.
|
// connection if the connection has been used before.
|
||||||
// If ResetSessionFunc returns ErrBadConn error the connection will be discarded.
|
// If ResetSessionFunc returns ErrBadConn error the connection will be discarded.
|
||||||
@ -191,15 +206,41 @@ func GetConnector(config pgx.ConnConfig, opts ...OptionOpenDB) driver.Connector
|
|||||||
return c
|
return c
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetPoolConnector(pool *pgxpool.Pool, opts ...OptionOpenDB) driver.Connector {
|
||||||
|
c := connector{
|
||||||
|
pool: pool,
|
||||||
|
BeforeAcquire: func(context.Context, *pgxpool.Pool) error { return nil }, // noop before acquire by default
|
||||||
|
AfterAcquire: func(context.Context, *pgxpool.Conn) error { return nil }, // noop after acquire by default
|
||||||
|
ResetSession: func(context.Context, *pgx.Conn) error { return nil }, // noop reset session by default
|
||||||
|
driver: pgxDriver,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(&c)
|
||||||
|
}
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
|
func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB {
|
||||||
c := GetConnector(config, opts...)
|
c := GetConnector(config, opts...)
|
||||||
return sql.OpenDB(c)
|
return sql.OpenDB(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func OpenDBFromPool(pool *pgxpool.Pool, opts ...OptionOpenDB) *sql.DB {
|
||||||
|
c := GetPoolConnector(pool, opts...)
|
||||||
|
db := sql.OpenDB(c)
|
||||||
|
db.SetMaxIdleConns(0)
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
|
||||||
type connector struct {
|
type connector struct {
|
||||||
pgx.ConnConfig
|
pgx.ConnConfig
|
||||||
|
pool *pgxpool.Pool
|
||||||
BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection
|
BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection
|
||||||
AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection
|
AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection
|
||||||
|
BeforeAcquire func(context.Context, *pgxpool.Pool) error // function to call before acquiring of every new connection
|
||||||
|
AfterAcquire func(context.Context, *pgxpool.Conn) error // function to call after acquiring of every new connection
|
||||||
ResetSession func(context.Context, *pgx.Conn) error // function is called before a connection is reused
|
ResetSession func(context.Context, *pgx.Conn) error // function is called before a connection is reused
|
||||||
driver *Driver
|
driver *Driver
|
||||||
}
|
}
|
||||||
@ -207,12 +248,16 @@ type connector struct {
|
|||||||
// Connect implement driver.Connector interface
|
// Connect implement driver.Connector interface
|
||||||
func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
|
func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
|
||||||
var (
|
var (
|
||||||
err error
|
connConfig pgx.ConnConfig
|
||||||
conn *pgx.Conn
|
conn *pgx.Conn
|
||||||
|
close func(context.Context) error
|
||||||
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if c.pool == nil {
|
||||||
// Create a shallow copy of the config, so that BeforeConnect can safely modify it
|
// Create a shallow copy of the config, so that BeforeConnect can safely modify it
|
||||||
connConfig := c.ConnConfig
|
connConfig = c.ConnConfig
|
||||||
|
|
||||||
if err = c.BeforeConnect(ctx, &connConfig); err != nil {
|
if err = c.BeforeConnect(ctx, &connConfig); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -225,7 +270,38 @@ func (c connector) Connect(ctx context.Context) (driver.Conn, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Conn{conn: conn, driver: c.driver, connConfig: connConfig, resetSessionFunc: c.ResetSession}, nil
|
close = conn.Close
|
||||||
|
} else {
|
||||||
|
var pconn *pgxpool.Conn
|
||||||
|
|
||||||
|
if err = c.BeforeAcquire(ctx, c.pool); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pconn, err = c.pool.Acquire(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = c.AfterAcquire(ctx, pconn); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn = pconn.Conn()
|
||||||
|
|
||||||
|
close = func(_ context.Context) error {
|
||||||
|
pconn.Release()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Conn{
|
||||||
|
conn: conn,
|
||||||
|
close: close,
|
||||||
|
driver: c.driver,
|
||||||
|
connConfig: connConfig,
|
||||||
|
resetSessionFunc: c.ResetSession,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Driver implement driver.Connector interface
|
// Driver implement driver.Connector interface
|
||||||
@ -302,6 +378,7 @@ func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) {
|
|||||||
|
|
||||||
c := &Conn{
|
c := &Conn{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
|
close: conn.Close,
|
||||||
driver: dc.driver,
|
driver: dc.driver,
|
||||||
connConfig: *connConfig,
|
connConfig: *connConfig,
|
||||||
resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil },
|
resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil },
|
||||||
@ -326,6 +403,7 @@ func UnregisterConnConfig(connStr string) {
|
|||||||
|
|
||||||
type Conn struct {
|
type Conn struct {
|
||||||
conn *pgx.Conn
|
conn *pgx.Conn
|
||||||
|
close func(context.Context) error
|
||||||
psCount int64 // Counter used for creating unique prepared statement names
|
psCount int64 // Counter used for creating unique prepared statement names
|
||||||
driver *Driver
|
driver *Driver
|
||||||
connConfig pgx.ConnConfig
|
connConfig pgx.ConnConfig
|
||||||
@ -361,7 +439,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e
|
|||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
return c.conn.Close(ctx)
|
return c.close(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Begin() (driver.Tx, error) {
|
func (c *Conn) Begin() (driver.Tx, error) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user