diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index f12c43d5..83f32ea8 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -980,3 +980,96 @@ func TestConnExecContextCancel(t *testing.T) { t.Errorf("mock server err: %v", err) } } + +func TestConnQueryContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + rows, err := db.QueryContext(context.Background(), "select * from generate_series(1,10) n") + if err != nil { + t.Fatalf("db.QueryContext failed: %v", err) + } + + for rows.Next() { + var n int64 + if err := rows.Scan(&n); err != nil { + t.Error(err) + } + } + + if rows.Err() != nil { + t.Error(rows.Err()) + } + + ensureConnValid(t, db) +} + +func TestConnQueryContextCancel(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Parse{Query: "select * from generate_series(1,10) n"}), + pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S'}), + pgmock.ExpectMessage(&pgproto3.Sync{}), + + pgmock.SendMessage(&pgproto3.ParseComplete{}), + pgmock.SendMessage(&pgproto3.ParameterDescription{}), + pgmock.SendMessage(&pgproto3.RowDescription{ + Fields: []pgproto3.FieldDescription{ + { + Name: "n", + DataTypeOID: 23, + DataTypeSize: 4, + TypeModifier: 4294967295, + }, + }, + }), + pgmock.SendMessage(&pgproto3.ReadyForQuery{TxStatus: 'I'}), + + pgmock.ExpectMessage(&pgproto3.Bind{ResultFormatCodes: []int16{1}}), + pgmock.ExpectMessage(&pgproto3.Execute{}), + pgmock.ExpectMessage(&pgproto3.Sync{}), + + pgmock.SendMessage(&pgproto3.BindComplete{}), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + // defer closeDB(t, db) // mock DB doesn't close correctly yet + + ctx, cancelFn := context.WithCancel(context.Background()) + + rows, err := db.QueryContext(ctx, "select * from generate_series(1,10) n") + if err != nil { + t.Fatalf("db.QueryContext failed: %v", err) + } + + cancelFn() + + for rows.Next() { + t.Fatalf("no rows should ever be received") + } + + if rows.Err() != context.Canceled { + t.Errorf("rows.Err() => %v, want %v", rows.Err(), context.Canceled) + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +}