diff --git a/conn.go b/conn.go index 63c2f7f7..0b262a69 100644 --- a/conn.go +++ b/conn.go @@ -59,14 +59,14 @@ type Conn struct { RuntimeParams map[string]string // parameters that have been reported by the server config ConnConfig // config used when establishing this connection TxStatus byte - preparedStatements map[string]*preparedStatement + preparedStatements map[string]*PreparedStatement notifications []*Notification alive bool causeOfDeath error logger log.Logger } -type preparedStatement struct { +type PreparedStatement struct { Name string FieldDescriptions []FieldDescription ParameterOids []Oid @@ -185,7 +185,7 @@ func Connect(config ConnConfig) (c *Conn, err error) { c.bufSize = c.config.MsgBufSize c.buf = bytes.NewBuffer(make([]byte, 0, c.bufSize)) c.RuntimeParams = make(map[string]string) - c.preparedStatements = make(map[string]*preparedStatement) + c.preparedStatements = make(map[string]*PreparedStatement) c.alive = true if config.TLSConfig != nil { @@ -579,7 +579,7 @@ func (c *Conn) SelectValues(sql string, arguments ...interface{}) ([]interface{} // Prepare creates a prepared statement with name and sql. sql can contain placeholders // for bound parameters. These placeholders are referenced positional as $1, $2, etc. -func (c *Conn) Prepare(name, sql string) (err error) { +func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { defer func() { if err != nil { c.logger.Error(fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err)) @@ -595,7 +595,7 @@ func (c *Conn) Prepare(name, sql string) (err error) { binary.Write(buf, binary.BigEndian, int16(0)) err = c.txMsg('P', buf, false) if err != nil { - return err + return nil, err } // describe @@ -605,16 +605,16 @@ func (c *Conn) Prepare(name, sql string) (err error) { buf.WriteByte(0) err = c.txMsg('D', buf, false) if err != nil { - return + return nil, err } // sync err = c.txMsg('S', c.getBuf(), true) if err != nil { - return err + return nil, err } - ps := preparedStatement{Name: name} + ps = &PreparedStatement{Name: name} var softErr error @@ -623,7 +623,7 @@ func (c *Conn) Prepare(name, sql string) (err error) { var r *MessageReader t, r, err := c.rxMsg() if err != nil { - return err + return nil, err } switch t { @@ -641,8 +641,8 @@ func (c *Conn) Prepare(name, sql string) (err error) { case noData: case readyForQuery: c.rxReadyForQuery(r) - c.preparedStatements[name] = &ps - return softErr + c.preparedStatements[name] = ps + return ps, softErr default: if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { softErr = e @@ -761,7 +761,7 @@ func (c *Conn) sendSimpleQuery(sql string, arguments ...interface{}) (err error) return c.txMsg('Q', buf, true) } -func (c *Conn) sendPreparedQuery(ps *preparedStatement, arguments ...interface{}) (err error) { +func (c *Conn) sendPreparedQuery(ps *PreparedStatement, arguments ...interface{}) (err error) { if len(ps.ParameterOids) != len(arguments) { return fmt.Errorf("Prepared statement \"%v\" requires %d parameters, but %d were provided", ps.Name, len(ps.ParameterOids), len(arguments)) } diff --git a/conn_test.go b/conn_test.go index e6627eed..4eed28b1 100644 --- a/conn_test.go +++ b/conn_test.go @@ -520,7 +520,7 @@ func TestPrepare(t *testing.T) { defer conn.Close() testTranscode := func(sql string, value interface{}) { - if err = conn.Prepare("testTranscode", sql); err != nil { + if _, err = conn.Prepare("testTranscode", sql); err != nil { t.Errorf("Unable to prepare statement: %v", err) return } @@ -555,7 +555,7 @@ func TestPrepare(t *testing.T) { // Ensure that unknown types are just treated as strings testTranscode("select $1::point", "(0,0)") - if err = conn.Prepare("testByteSliceTranscode", "select $1::bytea"); err != nil { + if _, err = conn.Prepare("testByteSliceTranscode", "select $1::bytea"); err != nil { t.Errorf("Unable to prepare statement: %v", err) return } @@ -588,7 +588,7 @@ func TestPrepare(t *testing.T) { } mustExecute(t, conn, "create temporary table foo(id serial)") - if err = conn.Prepare("deleteFoo", "delete from foo"); err != nil { + if _, err = conn.Prepare("deleteFoo", "delete from foo"); err != nil { t.Fatalf("Unable to prepare delete: %v", err) } } @@ -600,7 +600,7 @@ func TestPrepareFailure(t *testing.T) { } defer conn.Close() - if err = conn.Prepare("badSQL", "select foo"); err == nil { + if _, err = conn.Prepare("badSQL", "select foo"); err == nil { t.Fatal("Prepare should have failed with syntax error") } diff --git a/helper_test.go b/helper_test.go index 1b43eaae..057c9a9e 100644 --- a/helper_test.go +++ b/helper_test.go @@ -21,7 +21,7 @@ func getSharedConnection(t testing.TB) (c *pgx.Conn) { } func mustPrepare(t testing.TB, conn *pgx.Conn, name, sql string) { - if err := conn.Prepare(name, sql); err != nil { + if _, err := conn.Prepare(name, sql); err != nil { t.Fatalf("Could not prepare %v: %v", name, err) } }