mirror of
https://github.com/jackc/pgx.git
synced 2025-05-31 11:42:24 +00:00
Conn implements driver.Execer and driver.Queryer
This commit is contained in:
parent
b2c1a14fcc
commit
eb85aad21f
@ -48,7 +48,7 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Stmt{ps: ps, conn: c.conn}, nil
|
return &Stmt{ps: ps, conn: c}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
@ -68,31 +68,18 @@ func (c *Conn) Begin() (driver.Tx, error) {
|
|||||||
return &Tx{conn: c.conn}, nil
|
return &Tx{conn: c.conn}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Stmt struct {
|
func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) {
|
||||||
ps *pgx.PreparedStatement
|
if !c.conn.IsAlive() {
|
||||||
conn *pgx.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Stmt) Close() error {
|
|
||||||
return s.conn.Deallocate(s.ps.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Stmt) NumInput() int {
|
|
||||||
return len(s.ps.ParameterOids)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
|
|
||||||
if !s.conn.IsAlive() {
|
|
||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
|
|
||||||
args := valueToInterface(argsV)
|
args := valueToInterface(argsV)
|
||||||
commandTag, err := s.conn.Execute(s.ps.Name, args...)
|
commandTag, err := c.conn.Execute(query, args...)
|
||||||
return driver.RowsAffected(commandTag.RowsAffected()), err
|
return driver.RowsAffected(commandTag.RowsAffected()), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) {
|
||||||
if !s.conn.IsAlive() {
|
if !c.conn.IsAlive() {
|
||||||
return nil, driver.ErrBadConn
|
return nil, driver.ErrBadConn
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,7 +91,7 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
|||||||
rowChan := make(chan []driver.Value)
|
rowChan := make(chan []driver.Value)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err := s.conn.SelectFunc(s.ps.Name, func(r *pgx.DataRowReader) error {
|
err := c.conn.SelectFunc(query, func(r *pgx.DataRowReader) error {
|
||||||
if rowCount == 0 {
|
if rowCount == 0 {
|
||||||
fieldNames := make([]string, len(r.FieldDescriptions))
|
fieldNames := make([]string, len(r.FieldDescriptions))
|
||||||
for i, fd := range r.FieldDescriptions {
|
for i, fd := range r.FieldDescriptions {
|
||||||
@ -138,6 +125,27 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Stmt struct {
|
||||||
|
ps *pgx.PreparedStatement
|
||||||
|
conn *Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) Close() error {
|
||||||
|
return s.conn.conn.Deallocate(s.ps.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) NumInput() int {
|
||||||
|
return len(s.ps.ParameterOids)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) {
|
||||||
|
return s.conn.Exec(s.ps.Name, argsV)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
|
||||||
|
return s.conn.Query(s.ps.Name, argsV)
|
||||||
|
}
|
||||||
|
|
||||||
type Rows struct {
|
type Rows struct {
|
||||||
columnNames []string
|
columnNames []string
|
||||||
rowChan chan []driver.Value
|
rowChan chan []driver.Value
|
||||||
|
@ -22,20 +22,32 @@ func closeDB(t *testing.T, db *sql.DB) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type preparer interface {
|
||||||
|
Prepare(query string) (*sql.Stmt, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt {
|
||||||
|
stmt, err := p.Prepare(sql)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%v Prepare unexpectedly failed: %v", p, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeStmt(t *testing.T, stmt *sql.Stmt) {
|
||||||
|
err := stmt.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stmt.Close unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestNormalLifeCycle(t *testing.T) {
|
func TestNormalLifeCycle(t *testing.T) {
|
||||||
db := openDB(t)
|
db := openDB(t)
|
||||||
defer closeDB(t, db)
|
defer closeDB(t, db)
|
||||||
|
|
||||||
stmt, err := db.Prepare("select 'foo', n from generate_series($1::int, $2::int) n")
|
stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
|
||||||
if err != nil {
|
defer closeStmt(t, stmt)
|
||||||
t.Fatalf("db.Prepare unexpectedly failed: %v", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = stmt.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("stmt.Close unexpectedly failed: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
rows, err := stmt.Query(int32(1), int32(10))
|
rows, err := stmt.Query(int32(1), int32(10))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -73,20 +85,48 @@ func TestNormalLifeCycle(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStmtExec(t *testing.T) {
|
||||||
|
db := openDB(t)
|
||||||
|
defer closeDB(t, db)
|
||||||
|
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Begin unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)")
|
||||||
|
_, err = createStmt.Exec()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stmt.Exec unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
closeStmt(t, createStmt)
|
||||||
|
|
||||||
|
insertStmt := prepareStmt(t, tx, "insert into t values($1::text)")
|
||||||
|
result, err := insertStmt.Exec("foo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("stmt.Exec unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("result.RowsAffected unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
if n != 1 {
|
||||||
|
t.Fatalf("Expected 1, received %d", n)
|
||||||
|
}
|
||||||
|
closeStmt(t, insertStmt)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("tx.Commit unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestQueryCloseRowsEarly(t *testing.T) {
|
func TestQueryCloseRowsEarly(t *testing.T) {
|
||||||
db := openDB(t)
|
db := openDB(t)
|
||||||
defer closeDB(t, db)
|
defer closeDB(t, db)
|
||||||
|
|
||||||
stmt, err := db.Prepare("select 'foo', n from generate_series($1::int, $2::int) n")
|
stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n")
|
||||||
if err != nil {
|
defer closeStmt(t, stmt)
|
||||||
t.Fatalf("db.Prepare unexpectedly failed: %v", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err = stmt.Close()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("stmt.Close unexpectedly failed: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
rows, err := stmt.Query(int32(1), int32(10))
|
rows, err := stmt.Query(int32(1), int32(10))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -136,7 +176,7 @@ func TestQueryCloseRowsEarly(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExec(t *testing.T) {
|
func TestConnExec(t *testing.T) {
|
||||||
db := openDB(t)
|
db := openDB(t)
|
||||||
defer closeDB(t, db)
|
defer closeDB(t, db)
|
||||||
|
|
||||||
@ -159,6 +199,46 @@ func TestExec(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestConnQuery(t *testing.T) {
|
||||||
|
db := openDB(t)
|
||||||
|
defer closeDB(t, db)
|
||||||
|
|
||||||
|
rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("db.Query unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rowCount := int64(0)
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
rowCount++
|
||||||
|
|
||||||
|
var s string
|
||||||
|
var n int64
|
||||||
|
if err := rows.Scan(&s, &n); err != nil {
|
||||||
|
t.Fatalf("rows.Scan unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
if s != "foo" {
|
||||||
|
t.Errorf(`Expected "foo", received "%v"`, s)
|
||||||
|
}
|
||||||
|
if n != rowCount {
|
||||||
|
t.Errorf("Expected %d, received %d", rowCount, n)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = rows.Err()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("rows.Err unexpectedly is: %v", err)
|
||||||
|
}
|
||||||
|
if rowCount != 10 {
|
||||||
|
t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rows.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("rows.Close unexpectedly failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestTransactionLifeCycle(t *testing.T) {
|
func TestTransactionLifeCycle(t *testing.T) {
|
||||||
db := openDB(t)
|
db := openDB(t)
|
||||||
defer closeDB(t, db)
|
defer closeDB(t, db)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user