diff --git a/stdlib/sql.go b/stdlib/sql.go index 9e97af90..bc2849c2 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -208,6 +208,10 @@ type Conn struct { } func (c *Conn) Prepare(query string) (driver.Stmt, error) { + return c.PrepareContext(context.Background(), query) +} + +func (c *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn } @@ -215,7 +219,7 @@ func (c *Conn) Prepare(query string) (driver.Stmt, error) { name := fmt.Sprintf("pgx_%d", c.psCount) c.psCount++ - ps, err := c.conn.Prepare(name, query) + ps, err := c.conn.PrepareExContext(ctx, name, query, nil) if err != nil { return nil, err } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index af2b9fe7..105de3d8 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -877,3 +877,56 @@ func TestConnPingContextCancel(t *testing.T) { t.Errorf("mock server err: %v", err) } } + +func TestConnPrepareContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + stmt, err := db.PrepareContext(context.Background(), "select now()") + if err != nil { + t.Fatalf("db.PrepareContext failed: %v", err) + } + stmt.Close() + + ensureConnValid(t, db) +} + +func TestConnPrepareContextCancel(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Parse{Name: "pgx_0", Query: "select now()"}), + pgmock.ExpectMessage(&pgproto3.Describe{ObjectType: 'S', Name: "pgx_0"}), + pgmock.ExpectMessage(&pgproto3.Sync{}), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + // defer closeDB(t, db) // mock DB doesn't close correctly yet + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + + _, err = db.PrepareContext(ctx, "select now()") + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +}