From 51ade172e5c5f69c4ffafaccbb9c6fec5b3d6969 Mon Sep 17 00:00:00 2001 From: Lev Zakharov Date: Sat, 19 Aug 2023 21:42:28 +0300 Subject: [PATCH] refactor to use the same connection implementation --- stdlib/pgxpool.go | 531 ---------------------------------------------- stdlib/sql.go | 112 ++++++++-- 2 files changed, 95 insertions(+), 548 deletions(-) delete mode 100644 stdlib/pgxpool.go diff --git a/stdlib/pgxpool.go b/stdlib/pgxpool.go deleted file mode 100644 index 53f5e8c6..00000000 --- a/stdlib/pgxpool.go +++ /dev/null @@ -1,531 +0,0 @@ -package stdlib - -import ( - "context" - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "io" - "math" - "reflect" - "strconv" - "strings" - "time" - - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" - "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgxpool" -) - -// OptionOpenDBFromPool options for configuring the driver connector when -// opening a new *sql.DB from *pgxpool.Pool. -type OptionOpenDBFromPool func(*poolConnector) - -// OptionBeforeConnect provides a callback for before acquire. -func OptionBeforeAcquire(f func(context.Context, *pgxpool.Pool) error) OptionOpenDBFromPool { - return func(c *poolConnector) { - c.BeforeAcquire = f - } -} - -// OptionAfterAcquire provides a callback for after acquire. -func OptionAfterAcquire(f func(context.Context, *pgxpool.Conn) error) OptionOpenDBFromPool { - return func(c *poolConnector) { - c.AfterAcquire = f - } -} - -// OptionResetSession provides a callback that can be used to add custom logic -// prior to executing a query on the connection if the connection has been used -// before. If ResetSessionFunc returns ErrBadConn error the connection will be -// discarded. -func OptionPoolConnResetSession(f func(context.Context, *pgxpool.Conn) error) OptionOpenDBFromPool { - return func(c *poolConnector) { - c.ResetSession = f - } -} - -// GetPoolConnector creates a new driver connector to open a new *sql.DB from -// the provided pgx pool. -func GetPoolConnector(pool *pgxpool.Pool, opts ...OptionOpenDBFromPool) driver.Connector { - c := &poolConnector{pool: pool, driver: pgxDriver} - for _, opt := range opts { - opt(c) - } - return c -} - -// OpenDBFromPool opens a database using the provided pgx pool. It sets the -// maximum number of connections in the idle connection pool to zero, since -// those connections are managed in this pgx pool. -func OpenDBFromPool(pool *pgxpool.Pool, opts ...OptionOpenDBFromPool) *sql.DB { - c := GetPoolConnector(pool, opts...) - db := sql.OpenDB(c) - db.SetMaxIdleConns(0) - return db -} - -type poolConnector struct { - pool *pgxpool.Pool - BeforeAcquire func(context.Context, *pgxpool.Pool) error // function to call before acquiring of every new connection - AfterAcquire func(context.Context, *pgxpool.Conn) error // function to call after acquiring of every new connection - ResetSession func(context.Context, *pgxpool.Conn) error // function is called before a connection is reused - driver driver.Driver -} - -func (c *poolConnector) Connect(ctx context.Context) (driver.Conn, error) { - if err := c.BeforeAcquire(ctx, c.pool); err != nil { - return nil, err - } - - conn, err := c.pool.Acquire(ctx) - if err != nil { - return nil, err - } - - if err := c.AfterAcquire(ctx, conn); err != nil { - return nil, err - } - - return &PoolConn{conn: conn}, nil -} - -func (c *poolConnector) Driver() driver.Driver { - return c.driver -} - -type PoolConn struct { - conn *pgxpool.Conn - psCount int64 // Counter used for creating unique prepared statement names - resetSessionFunc func(context.Context, *pgx.Conn) error // Function is called before a connection is reused - lastResetSessionTime time.Time -} - -// Conn returns the underlying *pgx.Conn -func (c *PoolConn) Conn() *pgx.Conn { - return c.conn.Conn() -} - -func (c *PoolConn) Prepare(query string) (driver.Stmt, error) { - return c.PrepareContext(context.Background(), query) -} - -func (c *PoolConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - conn := c.Conn() - - if conn.IsClosed() { - return nil, driver.ErrBadConn - } - - name := fmt.Sprintf("pgx_%d", c.psCount) - c.psCount++ - - sd, err := conn.Prepare(ctx, name, query) - if err != nil { - return nil, err - } - - return &PoolStmt{sd: sd, conn: c}, nil -} - -func (c *PoolConn) Close() error { - c.conn.Release() - return nil -} - -func (c *PoolConn) Begin() (driver.Tx, error) { - return c.BeginTx(context.Background(), driver.TxOptions{}) -} - -func (c *PoolConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - if c.Conn().IsClosed() { - return nil, driver.ErrBadConn - } - - var pgxOpts pgx.TxOptions - switch sql.IsolationLevel(opts.Isolation) { - case sql.LevelDefault: - case sql.LevelReadUncommitted: - pgxOpts.IsoLevel = pgx.ReadUncommitted - case sql.LevelReadCommitted: - pgxOpts.IsoLevel = pgx.ReadCommitted - case sql.LevelRepeatableRead, sql.LevelSnapshot: - pgxOpts.IsoLevel = pgx.RepeatableRead - case sql.LevelSerializable: - pgxOpts.IsoLevel = pgx.Serializable - default: - return nil, fmt.Errorf("unsupported isolation: %v", opts.Isolation) - } - - if opts.ReadOnly { - pgxOpts.AccessMode = pgx.ReadOnly - } - - tx, err := c.conn.BeginTx(ctx, pgxOpts) - if err != nil { - return nil, err - } - - return wrapTx{ctx: ctx, tx: tx}, nil -} - -func (c *PoolConn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) { - if c.Conn().IsClosed() { - return nil, driver.ErrBadConn - } - - args := namedValueToInterface(argsV) - - commandTag, err := c.conn.Exec(ctx, query, args...) - // if we got a network error before we had a chance to send the query, retry - if err != nil { - if pgconn.SafeToRetry(err) { - return nil, driver.ErrBadConn - } - } - return driver.RowsAffected(commandTag.RowsAffected()), err -} - -func (c *PoolConn) QueryContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Rows, error) { - if c.Conn().IsClosed() { - return nil, driver.ErrBadConn - } - - args := []any{databaseSQLResultFormats} - args = append(args, namedValueToInterface(argsV)...) - - rows, err := c.conn.Query(ctx, query, args...) - if err != nil { - if pgconn.SafeToRetry(err) { - return nil, driver.ErrBadConn - } - return nil, err - } - - // Preload first row because otherwise we won't know what columns are available when database/sql asks. - more := rows.Next() - if err = rows.Err(); err != nil { - rows.Close() - return nil, err - } - return &PoolRows{conn: c, rows: rows, skipNext: true, skipNextMore: more}, nil -} - -func (c *PoolConn) Ping(ctx context.Context) error { - if c.Conn().IsClosed() { - return driver.ErrBadConn - } - - err := c.conn.Ping(ctx) - if err != nil { - // A Ping failure implies some sort of fatal state. The connection is almost certainly already closed by the - // failure, but manually close it just to be sure. - c.Close() - return driver.ErrBadConn - } - - return nil -} - -func (c *PoolConn) CheckNamedValue(*driver.NamedValue) error { - // Underlying pgx supports sql.Scanner and driver.Valuer interfaces natively. So everything can be passed through directly. - return nil -} - -func (c *PoolConn) ResetSession(ctx context.Context) error { - if c.Conn().IsClosed() { - return driver.ErrBadConn - } - - now := time.Now() - if now.Sub(c.lastResetSessionTime) > time.Second { - if err := c.Conn().PgConn().CheckConn(); err != nil { - return driver.ErrBadConn - } - } - c.lastResetSessionTime = now - - return c.resetSessionFunc(ctx, c.Conn()) -} - -type PoolStmt struct { - sd *pgconn.StatementDescription - conn *PoolConn -} - -func (s *PoolStmt) Close() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - return s.conn.Conn().Deallocate(ctx, s.sd.Name) -} - -func (s *PoolStmt) NumInput() int { - return len(s.sd.ParamOIDs) -} - -func (s *PoolStmt) Exec(argsV []driver.Value) (driver.Result, error) { - return nil, errors.New("Stmt.Exec deprecated and not implemented") -} - -func (s *PoolStmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) { - return s.conn.ExecContext(ctx, s.sd.Name, argsV) -} - -func (s *PoolStmt) Query(argsV []driver.Value) (driver.Rows, error) { - return nil, errors.New("Stmt.Query deprecated and not implemented") -} - -func (s *PoolStmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) { - return s.conn.QueryContext(ctx, s.sd.Name, argsV) -} - -type PoolRows struct { - conn *PoolConn - rows pgx.Rows - valueFuncs []rowValueFunc - skipNext bool - skipNextMore bool - - columnNames []string -} - -func (r *PoolRows) Columns() []string { - if r.columnNames == nil { - fields := r.rows.FieldDescriptions() - r.columnNames = make([]string, len(fields)) - for i, fd := range fields { - r.columnNames[i] = string(fd.Name) - } - } - - return r.columnNames -} - -// ColumnTypeDatabaseTypeName returns the database system type name. If the name is unknown the OID is returned. -func (r *PoolRows) ColumnTypeDatabaseTypeName(index int) string { - if dt, ok := r.conn.Conn().TypeMap().TypeForOID(r.rows.FieldDescriptions()[index].DataTypeOID); ok { - return strings.ToUpper(dt.Name) - } - - return strconv.FormatInt(int64(r.rows.FieldDescriptions()[index].DataTypeOID), 10) -} - -// ColumnTypeLength returns the length of the column type if the column is a -// variable length type. If the column is not a variable length type ok -// should return false. -func (r *PoolRows) ColumnTypeLength(index int) (int64, bool) { - fd := r.rows.FieldDescriptions()[index] - - switch fd.DataTypeOID { - case pgtype.TextOID, pgtype.ByteaOID: - return math.MaxInt64, true - case pgtype.VarcharOID, pgtype.BPCharArrayOID: - return int64(fd.TypeModifier - varHeaderSize), true - default: - return 0, false - } -} - -// ColumnTypePrecisionScale should return the precision and scale for decimal -// types. If not applicable, ok should be false. -func (r *PoolRows) ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool) { - fd := r.rows.FieldDescriptions()[index] - - switch fd.DataTypeOID { - case pgtype.NumericOID: - mod := fd.TypeModifier - varHeaderSize - precision = int64((mod >> 16) & 0xffff) - scale = int64(mod & 0xffff) - return precision, scale, true - default: - return 0, 0, false - } -} - -// ColumnTypeScanType returns the value type that can be used to scan types into. -func (r *PoolRows) ColumnTypeScanType(index int) reflect.Type { - fd := r.rows.FieldDescriptions()[index] - - switch fd.DataTypeOID { - case pgtype.Float8OID: - return reflect.TypeOf(float64(0)) - case pgtype.Float4OID: - return reflect.TypeOf(float32(0)) - case pgtype.Int8OID: - return reflect.TypeOf(int64(0)) - case pgtype.Int4OID: - return reflect.TypeOf(int32(0)) - case pgtype.Int2OID: - return reflect.TypeOf(int16(0)) - case pgtype.BoolOID: - return reflect.TypeOf(false) - case pgtype.NumericOID: - return reflect.TypeOf(float64(0)) - case pgtype.DateOID, pgtype.TimestampOID, pgtype.TimestamptzOID: - return reflect.TypeOf(time.Time{}) - case pgtype.ByteaOID: - return reflect.TypeOf([]byte(nil)) - default: - return reflect.TypeOf("") - } -} - -func (r *PoolRows) Close() error { - r.rows.Close() - return r.rows.Err() -} - -func (r *PoolRows) Next(dest []driver.Value) error { - m := r.conn.Conn().TypeMap() - fieldDescriptions := r.rows.FieldDescriptions() - - if r.valueFuncs == nil { - r.valueFuncs = make([]rowValueFunc, len(fieldDescriptions)) - - for i, fd := range fieldDescriptions { - dataTypeOID := fd.DataTypeOID - format := fd.Format - - switch fd.DataTypeOID { - case pgtype.BoolOID: - var d bool - scanPlan := m.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(src, &d) - return d, err - } - case pgtype.ByteaOID: - var d []byte - scanPlan := m.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(src, &d) - return d, err - } - case pgtype.CIDOID, pgtype.OIDOID, pgtype.XIDOID: - var d pgtype.Uint32 - scanPlan := m.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(src, &d) - if err != nil { - return nil, err - } - return d.Value() - } - case pgtype.DateOID: - var d pgtype.Date - scanPlan := m.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(src, &d) - if err != nil { - return nil, err - } - return d.Value() - } - case pgtype.Float4OID: - var d float32 - scanPlan := m.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(src, &d) - return float64(d), err - } - case pgtype.Float8OID: - var d float64 - scanPlan := m.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(src, &d) - return d, err - } - case pgtype.Int2OID: - var d int16 - scanPlan := m.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(src, &d) - return int64(d), err - } - case pgtype.Int4OID: - var d int32 - scanPlan := m.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(src, &d) - return int64(d), err - } - case pgtype.Int8OID: - var d int64 - scanPlan := m.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(src, &d) - return d, err - } - case pgtype.JSONOID, pgtype.JSONBOID: - var d []byte - scanPlan := m.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(src, &d) - if err != nil { - return nil, err - } - return d, nil - } - case pgtype.TimestampOID: - var d pgtype.Timestamp - scanPlan := m.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(src, &d) - if err != nil { - return nil, err - } - return d.Value() - } - case pgtype.TimestamptzOID: - var d pgtype.Timestamptz - scanPlan := m.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(src, &d) - if err != nil { - return nil, err - } - return d.Value() - } - default: - var d string - scanPlan := m.PlanScan(dataTypeOID, format, &d) - r.valueFuncs[i] = func(src []byte) (driver.Value, error) { - err := scanPlan.Scan(src, &d) - return d, err - } - } - } - } - - var more bool - if r.skipNext { - more = r.skipNextMore - r.skipNext = false - } else { - more = r.rows.Next() - } - - if !more { - if r.rows.Err() == nil { - return io.EOF - } else { - return r.rows.Err() - } - } - - for i, rv := range r.rows.RawValues() { - if rv != nil { - var err error - dest[i], err = r.valueFuncs[i](rv) - if err != nil { - return fmt.Errorf("convert field %d failed: %v", i, err) - } - } else { - dest[i] = nil - } - } - - return nil -} diff --git a/stdlib/sql.go b/stdlib/sql.go index 97ecc9b2..1bb19d35 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -74,6 +74,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgxpool" ) // Only intrinsic types should be binary format with database/sql. @@ -125,7 +126,7 @@ func contains(list []string, y string) bool { type OptionOpenDB func(*connector) // OptionBeforeConnect provides a callback for before connect. It is passed a shallow copy of the ConnConfig that will -// be used to connect, so only its immediate members should be modified. +// be used to connect, so only its immediate members should be modified. Used only if db is opened with *pgx.ConnConfig. func OptionBeforeConnect(bc func(context.Context, *pgx.ConnConfig) error) OptionOpenDB { return func(dc *connector) { dc.BeforeConnect = bc @@ -139,6 +140,20 @@ func OptionAfterConnect(ac func(context.Context, *pgx.Conn) error) OptionOpenDB } } +// OptionBeforeConnect provides a callback for before acquire. Used only if db is opened with *pgxpool.Pool. +func OptionBeforeAcquire(ba func(context.Context, *pgxpool.Pool) error) OptionOpenDB { + return func(c *connector) { + c.BeforeAcquire = ba + } +} + +// OptionAfterAcquire provides a callback for after acquire. Used only if db is opened with *pgxpool.Pool. +func OptionAfterAcquire(aa func(context.Context, *pgxpool.Conn) error) OptionOpenDB { + return func(c *connector) { + c.AfterAcquire = aa + } +} + // OptionResetSession provides a callback that can be used to add custom logic prior to executing a query on the // connection if the connection has been used before. // If ResetSessionFunc returns ErrBadConn error the connection will be discarded. @@ -191,15 +206,41 @@ func GetConnector(config pgx.ConnConfig, opts ...OptionOpenDB) driver.Connector return c } +func GetPoolConnector(pool *pgxpool.Pool, opts ...OptionOpenDB) driver.Connector { + c := connector{ + pool: pool, + BeforeAcquire: func(context.Context, *pgxpool.Pool) error { return nil }, // noop before acquire by default + AfterAcquire: func(context.Context, *pgxpool.Conn) error { return nil }, // noop after acquire by default + ResetSession: func(context.Context, *pgx.Conn) error { return nil }, // noop reset session by default + driver: pgxDriver, + } + + for _, opt := range opts { + opt(&c) + } + + return c +} + func OpenDB(config pgx.ConnConfig, opts ...OptionOpenDB) *sql.DB { c := GetConnector(config, opts...) return sql.OpenDB(c) } +func OpenDBFromPool(pool *pgxpool.Pool, opts ...OptionOpenDB) *sql.DB { + c := GetPoolConnector(pool, opts...) + db := sql.OpenDB(c) + db.SetMaxIdleConns(0) + return db +} + type connector struct { pgx.ConnConfig + pool *pgxpool.Pool BeforeConnect func(context.Context, *pgx.ConnConfig) error // function to call before creation of every new connection AfterConnect func(context.Context, *pgx.Conn) error // function to call after creation of every new connection + BeforeAcquire func(context.Context, *pgxpool.Pool) error // function to call before acquiring of every new connection + AfterAcquire func(context.Context, *pgxpool.Conn) error // function to call after acquiring of every new connection ResetSession func(context.Context, *pgx.Conn) error // function is called before a connection is reused driver *Driver } @@ -207,25 +248,60 @@ type connector struct { // Connect implement driver.Connector interface func (c connector) Connect(ctx context.Context) (driver.Conn, error) { var ( - err error - conn *pgx.Conn + connConfig pgx.ConnConfig + conn *pgx.Conn + close func(context.Context) error + err error ) - // Create a shallow copy of the config, so that BeforeConnect can safely modify it - connConfig := c.ConnConfig - if err = c.BeforeConnect(ctx, &connConfig); err != nil { - return nil, err + if c.pool == nil { + // Create a shallow copy of the config, so that BeforeConnect can safely modify it + connConfig = c.ConnConfig + + if err = c.BeforeConnect(ctx, &connConfig); err != nil { + return nil, err + } + + if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil { + return nil, err + } + + if err = c.AfterConnect(ctx, conn); err != nil { + return nil, err + } + + close = conn.Close + } else { + var pconn *pgxpool.Conn + + if err = c.BeforeAcquire(ctx, c.pool); err != nil { + return nil, err + } + + pconn, err = c.pool.Acquire(ctx) + if err != nil { + return nil, err + } + + if err = c.AfterAcquire(ctx, pconn); err != nil { + return nil, err + } + + conn = pconn.Conn() + + close = func(_ context.Context) error { + pconn.Release() + return nil + } } - if conn, err = pgx.ConnectConfig(ctx, &connConfig); err != nil { - return nil, err - } - - if err = c.AfterConnect(ctx, conn); err != nil { - return nil, err - } - - return &Conn{conn: conn, driver: c.driver, connConfig: connConfig, resetSessionFunc: c.ResetSession}, nil + return &Conn{ + conn: conn, + close: close, + driver: c.driver, + connConfig: connConfig, + resetSessionFunc: c.ResetSession, + }, nil } // Driver implement driver.Connector interface @@ -302,6 +378,7 @@ func (dc *driverConnector) Connect(ctx context.Context) (driver.Conn, error) { c := &Conn{ conn: conn, + close: conn.Close, driver: dc.driver, connConfig: *connConfig, resetSessionFunc: func(context.Context, *pgx.Conn) error { return nil }, @@ -326,6 +403,7 @@ func UnregisterConnConfig(connStr string) { type Conn struct { conn *pgx.Conn + close func(context.Context) error psCount int64 // Counter used for creating unique prepared statement names driver *Driver connConfig pgx.ConnConfig @@ -361,7 +439,7 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, e func (c *Conn) Close() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() - return c.conn.Close(ctx) + return c.close(ctx) } func (c *Conn) Begin() (driver.Tx, error) {