From 984eace2b5943bbec1c549efe580dc11e02ac557 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Fri, 16 Oct 2015 14:17:07 -0500 Subject: [PATCH] Make *Conn.Prepare idempotent. fixes #94 --- conn.go | 13 +++++++++++-- conn_test.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index 447e5171..4038ba6e 100644 --- a/conn.go +++ b/conn.go @@ -68,6 +68,7 @@ type Conn struct { type PreparedStatement struct { Name string + SQL string FieldDescriptions []FieldDescription ParameterOids []Oid } @@ -489,7 +490,15 @@ func configSSL(sslmode string, cc *ConnConfig) error { // Prepare creates a prepared statement with name and sql. sql can contain placeholders // for bound parameters. These placeholders are referenced positional as $1, $2, etc. +// +// Prepare is idempotent; i.e. it is safe to call Prepare multiple times with the same +// name and sql arguments. This allows a code path to Prepare and Query/Exec without +// concern for if the statement has already been prepared. func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { + if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql { + return ps, nil + } + if c.logLevel >= LogLevelError { defer func() { if err != nil { @@ -519,7 +528,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { return nil, err } - ps = &PreparedStatement{Name: name} + ps = &PreparedStatement{Name: name, SQL: sql} var softErr error @@ -549,7 +558,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { case readyForQuery: c.rxReadyForQuery(r) - if softErr == nil { + if softErr == nil && name != "" { c.preparedStatements[name] = ps } diff --git a/conn_test.go b/conn_test.go index 6da1ed7c..a8711eb4 100644 --- a/conn_test.go +++ b/conn_test.go @@ -848,6 +848,36 @@ func TestPrepareQueryManyParameters(t *testing.T) { ensureConnValid(t, conn) } +func TestPrepareIdempotency(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + for i := 0; i < 2; i++ { + _, err := conn.Prepare("test", "select 42::integer") + if err != nil { + t.Fatalf("%d. Unable to prepare statement: %v", i, err) + } + + var n int32 + err = conn.QueryRow("test").Scan(&n) + if err != nil { + t.Errorf("%d. Executing prepared statement failed: %v", i, err) + } + + if n != int32(42) { + t.Errorf("%d. Prepared statement did not return expected value: %v", i, n) + } + } + + _, err := conn.Prepare("test", "select 'fail'::varchar") + if err == nil { + t.Fatalf("Prepare statement with same name but different SQL should have failed but it didn't") + return + } +} + func TestListenNotify(t *testing.T) { t.Parallel()