From dbcfa46d8e485185a610cdda1b9c2d7b03250fa0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 May 2017 14:57:49 -0500 Subject: [PATCH] Add driver.ExecerContext support to stdlib.Conn --- stdlib/sql.go | 11 ++++++++++ stdlib/sql_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/stdlib/sql.go b/stdlib/sql.go index bc2849c2..400f8311 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -279,6 +279,17 @@ func (c *Conn) Exec(query string, argsV []driver.Value) (driver.Result, error) { return driver.RowsAffected(commandTag.RowsAffected()), err } +func (c *Conn) ExecContext(ctx context.Context, query string, argsV []driver.NamedValue) (driver.Result, error) { + if !c.conn.IsAlive() { + return nil, driver.ErrBadConn + } + + args := namedValueToInterface(argsV) + + commandTag, err := c.conn.ExecEx(ctx, query, nil, args...) + return driver.RowsAffected(commandTag.RowsAffected()), err +} + func (c *Conn) Query(query string, argsV []driver.Value) (driver.Rows, error) { if !c.conn.IsAlive() { return nil, driver.ErrBadConn diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 105de3d8..f12c43d5 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -930,3 +930,53 @@ func TestConnPrepareContextCancel(t *testing.T) { t.Errorf("mock server err: %v", err) } } + +func TestConnExecContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + _, err := db.ExecContext(context.Background(), "create temporary table exec_context_test(id serial primary key)") + if err != nil { + t.Fatalf("db.ExecContext failed: %v", err) + } + + ensureConnValid(t, db) +} + +func TestConnExecContextCancel(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Query{String: "create temporary table exec_context_test(id serial primary key)"}), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + // defer closeDB(t, db) // mock DB doesn't close correctly yet + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + + _, err = db.ExecContext(ctx, "create temporary table exec_context_test(id serial primary key)") + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +}