From 936cb688667497292ae67dddf013d3a5e067ad9b Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 19 May 2017 08:54:08 -0500 Subject: [PATCH] Add driver.Pinger support to stdlib.Conn --- stdlib/sql.go | 8 ++++++++ stdlib/sql_test.go | 50 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/stdlib/sql.go b/stdlib/sql.go index e70780c1..9e97af90 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -336,6 +336,14 @@ func (c *Conn) queryPreparedContext(ctx context.Context, name string, argsV []dr return &Rows{rows: rows}, nil } +func (c *Conn) Ping(ctx context.Context) error { + if !c.conn.IsAlive() { + return driver.ErrBadConn + } + + return c.conn.Ping(ctx) +} + // Anything that isn't a database/sql compatible type needs to be forced to // text format so that pgx.Rows.Values doesn't decode it into a native type // (e.g. []int32) diff --git a/stdlib/sql_test.go b/stdlib/sql_test.go index 4f2484d8..af2b9fe7 100644 --- a/stdlib/sql_test.go +++ b/stdlib/sql_test.go @@ -6,6 +6,7 @@ import ( "database/sql" "fmt" "testing" + "time" "github.com/jackc/pgx" "github.com/jackc/pgx/pgmock" @@ -827,3 +828,52 @@ func TestAcquireConn(t *testing.T) { ensureConnValid(t, db) } + +func TestConnPingContextSuccess(t *testing.T) { + db := openDB(t) + defer closeDB(t, db) + + if err := db.PingContext(context.Background()); err != nil { + t.Fatalf("db.PingContext failed: %v", err) + } + + ensureConnValid(t, db) +} + +func TestConnPingContextCancel(t *testing.T) { + script := &pgmock.Script{ + Steps: pgmock.AcceptUnauthenticatedConnRequestSteps(), + } + script.Steps = append(script.Steps, pgmock.PgxInitSteps()...) + script.Steps = append(script.Steps, + pgmock.ExpectMessage(&pgproto3.Query{String: ";"}), + ) + + server, err := pgmock.NewServer(script) + if err != nil { + t.Fatal(err) + } + defer server.Close() + + errChan := make(chan error) + go func() { + errChan <- server.ServeOne() + }() + + db, err := sql.Open("pgx", fmt.Sprintf("postgres://pgx_md5:secret@%s/pgx_test?sslmode=disable", server.Addr())) + if err != nil { + t.Fatalf("sql.Open failed: %v", err) + } + // defer closeDB(t, db) // mock DB doesn't close correctly yet + + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + + err = db.PingContext(ctx) + if err != context.DeadlineExceeded { + t.Errorf("err => %v, want %v", err, context.DeadlineExceeded) + } + + if err := <-errChan; err != nil { + t.Errorf("mock server err: %v", err) + } +}