From e5820baebe59c24cd088f6bb27d7398ed175963b Mon Sep 17 00:00:00 2001
From: Jack Christensen <jack@jackchristensen.com>
Date: Fri, 19 May 2017 17:31:56 -0500
Subject: [PATCH] Add driver.StmtQueryContext support to stdlib.Stmt

---
 stdlib/sql.go      |   4 ++
 stdlib/sql_test.go | 110 +++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 114 insertions(+)

diff --git a/stdlib/sql.go b/stdlib/sql.go
index 408bc62a..088095ab 100644
--- a/stdlib/sql.go
+++ b/stdlib/sql.go
@@ -396,6 +396,10 @@ func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) {
 	return s.conn.queryPrepared(s.ps.Name, argsV)
 }
 
+func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (driver.Rows, error) {
+	return s.conn.queryPreparedContext(ctx, s.ps.Name, argsV)
+}
+
 type Rows struct {
 	rows   *pgx.Rows
 	values []interface{}
diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go
index 447aa8b6..b26c815d 100644
--- a/stdlib/sql_test.go
+++ b/stdlib/sql_test.go
@@ -1148,3 +1148,113 @@ func TestStmtExecContextCancel(t *testing.T) {
 
 	ensureConnValid(t, db)
 }
+
+func TestStmtQueryContextSuccess(t *testing.T) {
+	// db := openDB(t)
+	// defer closeDB(t, db)
+
+	db, err := sql.Open("pgx", "postgres://pgx_md5:secret@127.0.0.1:15432/pgx_test?sslmode=disable")
+	if err != nil {
+		t.Fatalf("sql.Open failed: %v", err)
+	}
+
+	stmt, err := db.Prepare("select * from generate_series(1,$1::int4) n")
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer stmt.Close()
+
+	rows, err := stmt.QueryContext(context.Background(), 5)
+	if err != nil {
+		t.Fatalf("stmt.QueryContext failed: %v", err)
+	}
+
+	for rows.Next() {
+		var n int64
+		if err := rows.Scan(&n); err != nil {
+			t.Error(err)
+		}
+	}
+
+	if rows.Err() != nil {
+		t.Error(rows.Err())
+	}
+
+	ensureConnValid(t, db)
+}
+
+func TestStmtQueryContextCancel(t *testing.T) {
+	script := &pgmock.Script{
+		Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(),
+	}
+	script.Steps = append(script.Steps, pgmock.PgxInitSteps()...)
+	script.Steps = append(script.Steps,
+		pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select * from generate_series(1, $1::int4) n"}),
+		pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}),
+		pgmock.ExpectMessage(&pgproto3.Sync{}),
+
+		pgmock.SendMessage(&pgproto3.ParseComplete{}),
+		pgmock.SendMessage(&pgproto3.ParameterDescription{ParameterOIDs: []uint32{23}}),
+		pgmock.SendMessage(&pgproto3.RowDescription{
+			Fields: []pgproto3.FieldDescription{
+				{
+					Name:         "n",
+					DataTypeOID:  23,
+					DataTypeSize: 4,
+					TypeModifier: 4294967295,
+				},
+			},
+		}),
+		pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}),
+
+		pgmock.ExpectMessage(&pgproto3.Bind{PreparedStatement: "pgx_0", ParameterFormatCodes: []int16{1}, Parameters: [][]uint8{[]uint8{0x0, 0x0, 0x0, 0x2a}}, ResultFormatCodes: []int16{1}}),
+		pgmock.ExpectMessage(&pgproto3.Execute{}),
+		pgmock.ExpectMessage(&pgproto3.Sync{}),
+
+		pgmock.SendMessage(&pgproto3.BindComplete{}),
+	)
+
+	server, err := pgmock.NewServer(script)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer server.Close()
+
+	errChan := make(chan error)
+	go func() {
+		errChan <- server.ServeOne()
+	}()
+
+	db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr()))
+	if err != nil {
+		t.Fatalf("sql.Open failed: %v", err)
+	}
+	// defer closeDB(t, db) // mock DB doesn't close correctly yet
+
+	stmt, err := db.Prepare("select * from generate_series(1, $1::int4) n")
+	if err != nil {
+		t.Fatal(err)
+	}
+	// defer stmt.Close()
+
+	ctx, cancelFn := context.WithCancel(context.Background())
+
+	rows, err := stmt.QueryContext(ctx, 42)
+	if err != nil {
+		t.Fatalf("stmt.QueryContext failed: %v", err)
+	}
+
+	cancelFn()
+
+	for rows.Next() {
+		t.Fatalf("no rows should ever be received")
+	}
+
+	if rows.Err() != context.Canceled {
+		t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled)
+	}
+
+	if err := <-errChan; err != nil {
+		t.Errorf("mock server err: %v", err)
+	}
+}