From 3d4540aa1b825d8ce245a4fe6f1f474b4bd9c029 Mon Sep 17 00:00:00 2001 From: Lev Zakharov Date: Fri, 18 Aug 2023 19:01:26 +0300 Subject: [PATCH] add *sql.DB construction from *pgxpool.Pool --- stdlib/pgxpool.go | 531 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 531 insertions(+) create mode 100644 stdlib/pgxpool.go diff --git a/stdlib/pgxpool.go b/stdlib/pgxpool.go new file mode 100644 index 00000000..53f5e8c6 --- /dev/null +++ b/stdlib/pgxpool.go @@ -0,0 +1,531 @@ +package stdlib + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "io" + "math" + "reflect" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" +) + +// OptionOpenDBFromPool options for configuring the driver connector when +// opening a new *sql.DB from *pgxpool.Pool. +type OptionOpenDBFromPool func(*poolConnector) + +// OptionBeforeConnect provides a callback for before acquire. +func OptionBeforeAcquire(f func(context.Context, *pgxpool.Pool) error) OptionOpenDBFromPool { + return func(c *poolConnector) { + c.BeforeAcquire = f + } +} + +// OptionAfterAcquire provides a callback for after acquire. +func OptionAfterAcquire(f func(context.Context, *pgxpool.Conn) error) OptionOpenDBFromPool { + return func(c *poolConnector) { + c.AfterAcquire = f + } +} + +// OptionResetSession provides a callback that can be used to add custom logic +// prior to executing a query on the connection if the connection has been used +// before. If ResetSessionFunc returns ErrBadConn error the connection will be +// discarded. +func OptionPoolConnResetSession(f func(context.Context, *pgxpool.Conn) error) OptionOpenDBFromPool { + return func(c *poolConnector) { + c.ResetSession = f + } +} + +// GetPoolConnector creates a new driver connector to open a new *sql.DB from +// the provided pgx pool. +func GetPoolConnector(pool *pgxpool.Pool, opts ...OptionOpenDBFromPool) driver.Connector { + c := &poolConnector{pool: pool, driver: pgxDriver} + for _, opt := range opts { + opt(c) + } + return c +} + +// OpenDBFromPool opens a database using the provided pgx pool. It sets the +// maximum number of connections in the idle connection pool to zero, since +// those connections are managed in this pgx pool. +func OpenDBFromPool(pool *pgxpool.Pool, opts ...OptionOpenDBFromPool) *sql.DB { + c := GetPoolConnector(pool, opts...) + db := sql.OpenDB(c) + db.SetMaxIdleConns(0) + return db +} + +type poolConnector struct { + pool *pgxpool.Pool + BeforeAcquire func(context.Context, *pgxpool.Pool) error // function to call before acquiring of every new connection + AfterAcquire func(context.Context, *pgxpool.Conn) error // function to call after acquiring of every new connection + ResetSession func(context.Context, *pgxpool.Conn) error // function is called before a connection is reused + driver driver.Driver +} + +func (c *poolConnector) Connect(ctx context.Context) (driver.Conn, error) { + if err := c.BeforeAcquire(ctx, c.pool); err != nil { + return nil, err + } + + conn, err := c.pool.Acquire(ctx) + if err != nil { + return nil, err + } + + if err := c.AfterAcquire(ctx, conn); err != nil { + return nil, err + } + + return &PoolConn{conn: conn}, nil +} + +func (c *poolConnector) Driver() driver.Driver { + return c.driver +} + +type PoolConn struct { + conn *pgxpool.Conn + psCount int64 // Counter used for creating unique prepared statement names + resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused + lastResetSessionTime time.Time +} + +// Conn returns the underlying *pgx.Conn +func (c *PoolConn) Conn() *pgx.Conn { + return c.conn.Conn() +} + +func (c *PoolConn) Prepare(query string) (driver.Stmt, error) { + return c.PrepareContext(context.Background(), query) +} + +func (c *PoolConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + conn := c.Conn() + + if conn.IsClosed() { + return nil, driver.ErrBadConn + } + + name := fmt.Sprintf("pgx_%d", c.psCount) + c.psCount++ + + sd, err := conn.Prepare(ctx, name, query) + if err != nil { + return nil, err + } + + return &PoolStmt{sd: sd, conn: c}, nil +} + +func (c *PoolConn) Close() error { + c.conn.Release() + return nil +} + +func (c *PoolConn) Begin() (driver.Tx, error) { + return c.BeginTx(context.Background(), driver.TxOptions{}) +} + +func (c *PoolConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if c.Conn().IsClosed() { + return nil, driver.ErrBadConn + } + + var pgxOpts pgx.TxOptions + switch sql.IsolationLevel(opts.Isolation) { + case sql.LevelDefault: + case sql.LevelReadUncommitted: + pgxOpts.IsoLevel = pgx.ReadUncommitted + case sql.LevelReadCommitted: + pgxOpts.IsoLevel = pgx.ReadCommitted + case sql.LevelRepeatableRead, sql.LevelSnapshot: + pgxOpts.IsoLevel = pgx.RepeatableRead + case sql.LevelSerializable: + pgxOpts.IsoLevel = pgx.Serializable + default: + return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation) + } + + if opts.ReadOnly { + pgxOpts.AccessMode = pgx.ReadOnly + } + + tx, err := c.conn.BeginTx(ctx, pgxOpts) + if err != nil { + return nil, err + } + + return wrapTx{ctx: ctx, tx: tx}, nil +} + +func (c *PoolConn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) { + if c.Conn().IsClosed() { + return nil, driver.ErrBadConn + } + + args := namedValueToInterface(argsV) + + commandTag, err := c.conn.Exec(ctx, query, args...) + // if we got a network error before we had a chance to send the query, retry + if err != nil { + if pgconn.SafeToRetry(err) { + return nil, driver.ErrBadConn + } + } + return driver.RowsAffected(commandTag.RowsAffected()), err +} + +func (c *PoolConn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) { + if c.Conn().IsClosed() { + return nil, driver.ErrBadConn + } + + args := []any{databaseSQLResultFormats} + args = append(args, namedValueToInterface(argsV)...) + + rows, err := c.conn.Query(ctx, query, args...) + if err != nil { + if pgconn.SafeToRetry(err) { + return nil, driver.ErrBadConn + } + return nil, err + } + + // Preload first row because otherwise we won't know what columns are available when database/sql asks. + more := rows.Next() + if err = rows.Err(); err != nil { + rows.Close() + return nil, err + } + return &PoolRows{conn: c, rows: rows, skipNext: true, skipNextMore: more}, nil +} + +func (c *PoolConn) Ping(ctx context.Context) error { + if c.Conn().IsClosed() { + return driver.ErrBadConn + } + + err := c.conn.Ping(ctx) + if err != nil { + // A Ping failure implies some sort of fatal state. The connection is almost certainly already closed by the + // failure, but manually close it just to be sure. + c.Close() + return driver.ErrBadConn + } + + return nil +} + +func (c *PoolConn) CheckNamedValue(*driver.NamedValue) error { + // Underlying pgx supports sql.Scanner and driver.Valuer interfaces natively. So everything can be passed through directly. + return nil +} + +func (c *PoolConn) ResetSession(ctx context.Context) error { + if c.Conn().IsClosed() { + return driver.ErrBadConn + } + + now := time.Now() + if now.Sub(c.lastResetSessionTime) > time.Second { + if err := c.Conn().PgConn().CheckConn(); err != nil { + return driver.ErrBadConn + } + } + c.lastResetSessionTime = now + + return c.resetSessionFunc(ctx, c.Conn()) +} + +type PoolStmt struct { + sd *pgconn.StatementDescription + conn *PoolConn +} + +func (s *PoolStmt) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + return s.conn.Conn().Deallocate(ctx, s.sd.Name) +} + +func (s *PoolStmt) NumInput() int { + return len(s.sd.ParamOIDs) +} + +func (s *PoolStmt) Exec(argsV []driver.Value) (driver.Result, error) { + return nil, errors.New("Stmt.Exec deprecated and not implemented") +} + +func (s *PoolStmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) { + return s.conn.ExecContext(ctx, s.sd.Name, argsV) +} + +func (s *PoolStmt) Query(argsV []driver.Value) (driver.Rows, error) { + return nil, errors.New("Stmt.Query deprecated and not implemented") +} + +func (s *PoolStmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) { + return s.conn.QueryContext(ctx, s.sd.Name, argsV) +} + +type PoolRows struct { + conn *PoolConn + rows pgx.Rows + valueFuncs []rowValueFunc + skipNext bool + skipNextMore bool + + columnNames []string +} + +func (r *PoolRows) Columns() []string { + if r.columnNames == nil { + fields := r.rows.FieldDescriptions() + r.columnNames = make([]string, len(fields)) + for i, fd := range fields { + r.columnNames[i] = string(fd.Name) + } + } + + return r.columnNames +} + +// ColumnTypeDatabaseTypeName returns the database system type name. If the name is unknown the OID is returned. +func (r *PoolRows) ColumnTypeDatabaseTypeName(index int) string { + if dt, ok := r.conn.Conn().TypeMap().TypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok { + return strings.ToUpper(dt.Name) + } + + return strconv.FormatInt(int64(r.rows.FieldDescriptions()[index].DataTypeOID), 10) +} + +// ColumnTypeLength returns the length of the column type if the column is a +// variable length type. If the column is not a variable length type ok +// should return false. +func (r *PoolRows) ColumnTypeLength(index int) (int64, bool) { + fd := r.rows.FieldDescriptions()[index] + + switch fd.DataTypeOID { + case pgtype.TextOID, pgtype.ByteaOID: + return math.MaxInt64, true + case pgtype.VarcharOID, pgtype.BPCharArrayOID: + return int64(fd.TypeModifier - varHeaderSize), true + default: + return 0, false + } +} + +// ColumnTypePrecisionScale should return the precision and scale for decimal +// types. If not applicable, ok should be false. +func (r *PoolRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { + fd := r.rows.FieldDescriptions()[index] + + switch fd.DataTypeOID { + case pgtype.NumericOID: + mod := fd.TypeModifier - varHeaderSize + precision = int64((mod >> 16) & 0xffff) + scale = int64(mod & 0xffff) + return precision, scale, true + default: + return 0, 0, false + } +} + +// ColumnTypeScanType returns the value type that can be used to scan types into. +func (r *PoolRows) ColumnTypeScanType(index int) reflect.Type { + fd := r.rows.FieldDescriptions()[index] + + switch fd.DataTypeOID { + case pgtype.Float8OID: + return reflect.TypeOf(float64(0)) + case pgtype.Float4OID: + return reflect.TypeOf(float32(0)) + case pgtype.Int8OID: + return reflect.TypeOf(int64(0)) + case pgtype.Int4OID: + return reflect.TypeOf(int32(0)) + case pgtype.Int2OID: + return reflect.TypeOf(int16(0)) + case pgtype.BoolOID: + return reflect.TypeOf(false) + case pgtype.NumericOID: + return reflect.TypeOf(float64(0)) + case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID: + return reflect.TypeOf(time.Time{}) + case pgtype.ByteaOID: + return reflect.TypeOf([]byte(nil)) + default: + return reflect.TypeOf("") + } +} + +func (r *PoolRows) Close() error { + r.rows.Close() + return r.rows.Err() +} + +func (r *PoolRows) Next(dest []driver.Value) error { + m := r.conn.Conn().TypeMap() + fieldDescriptions := r.rows.FieldDescriptions() + + if r.valueFuncs == nil { + r.valueFuncs = make([]rowValueFunc, len(fieldDescriptions)) + + for i, fd := range fieldDescriptions { + dataTypeOID := fd.DataTypeOID + format := fd.Format + + switch fd.DataTypeOID { + case pgtype.BoolOID: + var d bool + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return d, err + } + case pgtype.ByteaOID: + var d []byte + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return d, err + } + case pgtype.CIDOID, pgtype.OIDOID, pgtype.XIDOID: + var d pgtype.Uint32 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + if err != nil { + return nil, err + } + return d.Value() + } + case pgtype.DateOID: + var d pgtype.Date + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + if err != nil { + return nil, err + } + return d.Value() + } + case pgtype.Float4OID: + var d float32 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return float64(d), err + } + case pgtype.Float8OID: + var d float64 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return d, err + } + case pgtype.Int2OID: + var d int16 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return int64(d), err + } + case pgtype.Int4OID: + var d int32 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return int64(d), err + } + case pgtype.Int8OID: + var d int64 + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return d, err + } + case pgtype.JSONOID, pgtype.JSONBOID: + var d []byte + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + if err != nil { + return nil, err + } + return d, nil + } + case pgtype.TimestampOID: + var d pgtype.Timestamp + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + if err != nil { + return nil, err + } + return d.Value() + } + case pgtype.TimestamptzOID: + var d pgtype.Timestamptz + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + if err != nil { + return nil, err + } + return d.Value() + } + default: + var d string + scanPlan := m.PlanScan(dataTypeOID, format, &d) + r.valueFuncs[i] = func(src []byte) (driver.Value, error) { + err := scanPlan.Scan(src, &d) + return d, err + } + } + } + } + + var more bool + if r.skipNext { + more = r.skipNextMore + r.skipNext = false + } else { + more = r.rows.Next() + } + + if !more { + if r.rows.Err() == nil { + return io.EOF + } else { + return r.rows.Err() + } + } + + for i, rv := range r.rows.RawValues() { + if rv != nil { + var err error + dest[i], err = r.valueFuncs[i](rv) + if err != nil { + return fmt.Errorf("convert field %d failed: %v", i, err) + } + } else { + dest[i] = nil + } + } + + return nil +}