package stdlib import ( "database/sql" "database/sql/driver" "fmt" "github.com/jackc/pgx" "io" ) var openFromConnPoolCount int func init() { d := &Driver{} sql.Register("pgx", d) } type Driver struct { Pool *pgx.ConnPool } func (d *Driver) Open(name string) (driver.Conn, error) { if d.Pool != nil { conn, err := d.Pool.Acquire() if err != nil { return nil, err } return &Conn{conn: conn, pool: d.Pool}, nil } connConfig, err := pgx.ParseURI(name) if err != nil { return nil, err } conn, err := pgx.Connect(connConfig) if err != nil { return nil, err } c := &Conn{conn: conn} return c, nil } // OpenFromConnPool takes the existing *pgx.ConnPool pool and returns a *sql.DB // with pool as the backend. This enables full control over the connection // process and configuration while maintaining compatibility with the // database/sql interface. In addition, by calling Driver() on the returned // *sql.DB and typecasting to *stdlib.Driver a reference to the pgx.ConnPool can // be reaquired later. This allows fast paths targeting pgx to be used while // still maintaining compatibility with other databases and drivers. func OpenFromConnPool(pool *pgx.ConnPool) (*sql.DB, error) { d := &Driver{Pool: pool} name := fmt.Sprintf("pgx-%d", openFromConnPoolCount) openFromConnPoolCount++ sql.Register(name, d) db, err := sql.Open(name, "") if err != nil { return nil, err } // Presumably OpenFromConnPool is being used because the user wants to use // database/sql most of the time, but fast path with pgx some of the time. // Allow database/sql to use all the connections, but release 2 idle ones. // Don't have database/sql immediately release all idle connections because // that would mean that prepared statements would be lost (which kills // performance if the prepared statements constantly have to be reprepared) db.SetMaxIdleConns(pool.MaxConnectionCount() - 2) db.SetMaxOpenConns(pool.MaxConnectionCount()) return db, nil } type Conn struct { conn *pgx.Conn pool *pgx.ConnPool psCount int64 // Counter used for creating unique prepared statement names } func (c *Conn) Prepare(query string) (driver.Stmt, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn } name := fmt.Sprintf("pgx_%d", c.psCount) c.psCount++ ps, err := c.conn.Prepare(name, query) if err != nil { return nil, err } return &Stmt{ps: ps, conn: c}, nil } func (c *Conn) Close() error { if c.pool != nil { c.pool.Release(c.conn) return nil } return c.conn.Close() } func (c *Conn) Begin() (driver.Tx, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn } _, err := c.conn.Execute("begin") if err != nil { return nil, err } return &Tx{conn: c.conn}, nil } func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn } args := valueToInterface(argsV) commandTag, err := c.conn.Execute(query, args...) return driver.RowsAffected(commandTag.RowsAffected()), err } func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn } args := valueToInterface(argsV) rowCount := 0 columnsChan := make(chan []string) errChan := make(chan error) rowChan := make(chan []driver.Value, 8) go func() { err := c.conn.SelectFunc(query, func(r *pgx.DataRowReader) error { if rowCount == 0 { fieldNames := make([]string, len(r.FieldDescriptions)) for i, fd := range r.FieldDescriptions { fieldNames[i] = fd.Name } columnsChan <- fieldNames } rowCount++ values := make([]driver.Value, len(r.FieldDescriptions)) for i, _ := range r.FieldDescriptions { values[i] = r.ReadValue() } rowChan <- values return nil }, args...) close(rowChan) if err != nil { errChan <- err } }() rows := Rows{rowChan: rowChan} select { case rows.columnNames = <-columnsChan: return &rows, nil case err := <-errChan: return nil, err } } type Stmt struct { ps *pgx.PreparedStatement conn *Conn } func (s *Stmt) Close() error { return s.conn.conn.Deallocate(s.ps.Name) } func (s *Stmt) NumInput() int { return len(s.ps.ParameterOids) } func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { return s.conn.Exec(s.ps.Name, argsV) } func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { return s.conn.Query(s.ps.Name, argsV) } type Rows struct { columnNames []string rowChan chan []driver.Value } func (r *Rows) Columns() []string { return r.columnNames } func (r *Rows) Close() error { for _ = range r.rowChan { // Ensure all rows are read } return nil } func (r *Rows) Next(dest []driver.Value) error { row, ok := <-r.rowChan if !ok { return io.EOF } copy(dest, row) return nil } func valueToInterface(argsV []driver.Value) []interface{} { args := make([]interface{}, 0, len(argsV)) for _, v := range argsV { args = append(args, v.(interface{})) } return args } type Tx struct { conn *pgx.Conn } func (t *Tx) Commit() error { _, err := t.conn.Execute("commit") return err } func (t *Tx) Rollback() error { _, err := t.conn.Execute("rollback") return err }