From e5820baebe59c24cd088f6bb27d7398ed175963b Mon Sep 17 00:00:00 2001 From: Jack Christensen <jack@jackchristensen.com> Date: Fri, 19 May 2017 17:31:56 -0500 Subject: [PATCH] Add driver.StmtQueryContext support to stdlib.Stmt --- stdlib/sql.go | 4 ++ stdlib/sql_test.go | 110 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) diff --git a/stdlib/sql.go b/stdlib/sql.go index 408bc62a..088095ab 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -396,6 +396,10 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { return s.conn.queryPrepared(s.ps.Name, argsV) } +func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) { + return s.conn.queryPreparedContext(ctx, s.ps.Name, argsV) +} + type Rows struct { rows *pgx.Rows values []interface{} diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 447aa8b6..b26c815d 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -1148,3 +1148,113 @@ func TestStmtExecContextCancel(t *testing.T) { ensureConnValid(t, db) } + +func TestStmtQueryContextSuccess(t *testing.T) { + // db := openDB(t) + // defer closeDB(t, db) + + db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:15432/pgx_test?sslmode=disable") + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + + stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + rows, err := stmt.QueryContext(context.Background(), 5) + if err != nil { + t.Fatalf("stmt.QueryContext failed: %v", err) + } + + for rows.Next() { + var n int64 + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + ensureConnValid(t, db) +} + +func TestStmtQueryContextCancel(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select * from generate_series(1, $1::int4) n"}), + pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}), + pgmock.ExpectMessage(&pgproto3.Sync{}), + + pgmock.SendMessage(&pgproto3.ParseComplete{}), + pgmock.SendMessage(&pgproto3.ParameterDescription{ParameterOIDs: []uint32{23}}), + pgmock.SendMessage(&pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + { + Name: "n", + DataTypeOID: 23, + DataTypeSize: 4, + TypeModifier: 4294967295, + }, + }, + }), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + + pgmock.ExpectMessage(&pgproto3.Bind{PreparedStatement: "pgx_0", ParameterFormatCodes: []int16{1}, Parameters: [][]uint8{[]uint8{0x0, 0x0, 0x0, 0x2a}}, ResultFormatCodes: []int16{1}}), + pgmock.ExpectMessage(&pgproto3.Execute{}), + pgmock.ExpectMessage(&pgproto3.Sync{}), + + pgmock.SendMessage(&pgproto3.BindComplete{}), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + // defer closeDB(t, db) // mock DB doesn't close correctly yet + + stmt, err := db.Prepare("select * from generate_series(1, $1::int4) n") + if err != nil { + t.Fatal(err) + } + // defer stmt.Close() + + ctx, cancelFn := context.WithCancel(context.Background()) + + rows, err := stmt.QueryContext(ctx, 42) + if err != nil { + t.Fatalf("stmt.QueryContext failed: %v", err) + } + + cancelFn() + + for rows.Next() { + t.Fatalf("no rows should ever be received") + } + + if rows.Err() != context.Canceled { + t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled) + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +}