diff --git a/stdlib/sql.go b/stdlib/sql.go index ce79edb6..408bc62a 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -388,6 +388,10 @@ func (s *Stmt) Exec(argsV []driver.Value) (driver.Result, error) { return s.conn.Exec(s.ps.Name, argsV) } +func (s *Stmt) ExecContext(ctx context.Context, argsV []driver.NamedValue) (driver.Result, error) { + return s.conn.ExecContext(ctx, s.ps.Name, argsV) +} + func (s *Stmt) Query(argsV []driver.Value) (driver.Rows, error) { return s.conn.queryPrepared(s.ps.Name, argsV) } diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index e7db03c9..447aa8b6 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -1100,3 +1100,51 @@ func TestRowsColumnTypeDatabaseTypeName(t *testing.T) { ensureConnValid(t, db) } + +func TestStmtExecContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + _, err := db.Exec("create temporary table t(id int primary key)") + if err != nil { + t.Fatalf("db.Exec failed: %v", err) + } + + stmt, err := db.Prepare("insert into t(id) values ($1::int4)") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + _, err = stmt.ExecContext(context.Background(), 42) + if err != nil { + t.Fatal(err) + } + + ensureConnValid(t, db) +} + +func TestStmtExecContextCancel(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + _, err := db.Exec("create temporary table t(id int primary key)") + if err != nil { + t.Fatalf("db.Exec failed: %v", err) + } + + stmt, err := db.Prepare("insert into t(id) select $1::int4 from pg_sleep(5)") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + + _, err = stmt.ExecContext(ctx, 42) + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + ensureConnValid(t, db) +}