diff --git a/stdlib/sql.go b/stdlib/sql.go index dfb168e9..c43450f6 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -48,7 +48,7 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) { return nil, err } - return &Stmt{ps: ps, conn: c.conn}, nil + return &Stmt{ps: ps, conn: c}, nil } func (c *Conn) Close() error { @@ -68,31 +68,18 @@ func (c *Conn) Begin() (driver.Tx, error) { return &Tx{conn: c.conn}, nil } -type Stmt struct { - ps *pgx.PreparedStatement - conn *pgx.Conn -} - -func (s *Stmt) Close() error { - return s.conn.Deallocate(s.ps.Name) -} - -func (s *Stmt) NumInput() int { - return len(s.ps.ParameterOids) -} - -func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { - if !s.conn.IsAlive() { +func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) { + if !c.conn.IsAlive() { return nil, driver.ErrBadConn } args := valueToInterface(argsV) - commandTag, err := s.conn.Execute(s.ps.Name, args...) + commandTag, err := c.conn.Execute(query, args...) return driver.RowsAffected(commandTag.RowsAffected()), err } -func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { - if !s.conn.IsAlive() { +func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { + if !c.conn.IsAlive() { return nil, driver.ErrBadConn } @@ -104,7 +91,7 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { rowChan := make(chan []driver.Value) go func() { - err := s.conn.SelectFunc(s.ps.Name, func(r *pgx.DataRowReader) error { + err := c.conn.SelectFunc(query, func(r *pgx.DataRowReader) error { if rowCount == 0 { fieldNames := make([]string, len(r.FieldDescriptions)) for i, fd := range r.FieldDescriptions { @@ -138,6 +125,27 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { } } +type Stmt struct { + ps *pgx.PreparedStatement + conn *Conn +} + +func (s *Stmt) Close() error { + return s.conn.conn.Deallocate(s.ps.Name) +} + +func (s *Stmt) NumInput() int { + return len(s.ps.ParameterOids) +} + +func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { + return s.conn.Exec(s.ps.Name, argsV) +} + +func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { + return s.conn.Query(s.ps.Name, argsV) +} + type Rows struct { columnNames []string rowChan chan []driver.Value diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 76848df9..9ae4b808 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -22,20 +22,32 @@ func closeDB(t *testing.T, db *sql.DB) { } } +type preparer interface { + Prepare(query string) (*sql.Stmt, error) +} + +func prepareStmt(t *testing.T, p preparer, sql string) *sql.Stmt { + stmt, err := p.Prepare(sql) + if err != nil { + t.Fatalf("%v Prepare unexpectedly failed: %v", p, err) + } + + return stmt +} + +func closeStmt(t *testing.T, stmt *sql.Stmt) { + err := stmt.Close() + if err != nil { + t.Fatalf("stmt.Close unexpectedly failed: %v", err) + } +} + func TestNormalLifeCycle(t *testing.T) { db := openDB(t) defer closeDB(t, db) - stmt, err := db.Prepare("select 'foo', n from generate_series($1::int, $2::int) n") - if err != nil { - t.Fatalf("db.Prepare unexpectedly failed: %v", err) - } - defer func() { - err = stmt.Close() - if err != nil { - t.Fatalf("stmt.Close unexpectedly failed: %v", err) - } - }() + stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n") + defer closeStmt(t, stmt) rows, err := stmt.Query(int32(1), int32(10)) if err != nil { @@ -73,20 +85,48 @@ func TestNormalLifeCycle(t *testing.T) { } } +func TestStmtExec(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + tx, err := db.Begin() + if err != nil { + t.Fatalf("db.Begin unexpectedly failed: %v", err) + } + + createStmt := prepareStmt(t, tx, "create temporary table t(a varchar not null)") + _, err = createStmt.Exec() + if err != nil { + t.Fatalf("stmt.Exec unexpectedly failed: %v", err) + } + closeStmt(t, createStmt) + + insertStmt := prepareStmt(t, tx, "insert into t values($1::text)") + result, err := insertStmt.Exec("foo") + if err != nil { + t.Fatalf("stmt.Exec unexpectedly failed: %v", err) + } + + n, err := result.RowsAffected() + if err != nil { + t.Fatalf("result.RowsAffected unexpectedly failed: %v", err) + } + if n != 1 { + t.Fatalf("Expected 1, received %d", n) + } + closeStmt(t, insertStmt) + + if err != nil { + t.Fatalf("tx.Commit unexpectedly failed: %v", err) + } +} + func TestQueryCloseRowsEarly(t *testing.T) { db := openDB(t) defer closeDB(t, db) - stmt, err := db.Prepare("select 'foo', n from generate_series($1::int, $2::int) n") - if err != nil { - t.Fatalf("db.Prepare unexpectedly failed: %v", err) - } - defer func() { - err = stmt.Close() - if err != nil { - t.Fatalf("stmt.Close unexpectedly failed: %v", err) - } - }() + stmt := prepareStmt(t, db, "select 'foo', n from generate_series($1::int, $2::int) n") + defer closeStmt(t, stmt) rows, err := stmt.Query(int32(1), int32(10)) if err != nil { @@ -136,7 +176,7 @@ func TestQueryCloseRowsEarly(t *testing.T) { } } -func TestExec(t *testing.T) { +func TestConnExec(t *testing.T) { db := openDB(t) defer closeDB(t, db) @@ -159,6 +199,46 @@ func TestExec(t *testing.T) { } } +func TestConnQuery(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + rows, err := db.Query("select 'foo', n from generate_series($1::int, $2::int) n", int32(1), int32(10)) + if err != nil { + t.Fatalf("db.Query unexpectedly failed: %v", err) + } + + rowCount := int64(0) + + for rows.Next() { + rowCount++ + + var s string + var n int64 + if err := rows.Scan(&s, &n); err != nil { + t.Fatalf("rows.Scan unexpectedly failed: %v", err) + } + if s != "foo" { + t.Errorf(`Expected "foo", received "%v"`, s) + } + if n != rowCount { + t.Errorf("Expected %d, received %d", rowCount, n) + } + } + err = rows.Err() + if err != nil { + t.Fatalf("rows.Err unexpectedly is: %v", err) + } + if rowCount != 10 { + t.Fatalf("Expected to receive 10 rows, instead received %d", rowCount) + } + + err = rows.Close() + if err != nil { + t.Fatalf("rows.Close unexpectedly failed: %v", err) + } +} + func TestTransactionLifeCycle(t *testing.T) { db := openDB(t) defer closeDB(t, db)