diff --git a/conn.go b/conn.go index 276a316d..f9b89a78 100644 --- a/conn.go +++ b/conn.go @@ -588,82 +588,7 @@ func configSSL(sslmode string, cc *ConnConfig) error { // name and sql arguments. This allows a code path to Prepare and Query/Exec without // concern for if the statement has already been prepared. func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { - if name != "" { - if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql { - return ps, nil - } - } - - if c.shouldLog(LogLevelError) { - defer func() { - if err != nil { - c.log(LogLevelError, fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err)) - } - }() - } - - // parse - wbuf := newWriteBuf(c, 'P') - wbuf.WriteCString(name) - wbuf.WriteCString(sql) - wbuf.WriteInt16(0) - - // describe - wbuf.startMsg('D') - wbuf.WriteByte('S') - wbuf.WriteCString(name) - - // sync - wbuf.startMsg('S') - wbuf.closeMsg() - - _, err = c.conn.Write(wbuf.buf) - if err != nil { - c.die(err) - return nil, err - } - - ps = &PreparedStatement{Name: name, SQL: sql} - - var softErr error - - for { - var t byte - var r *msgReader - t, r, err := c.rxMsg() - if err != nil { - return nil, err - } - - switch t { - case parseComplete: - case parameterDescription: - ps.ParameterOids = c.rxParameterDescription(r) - if len(ps.ParameterOids) > 65535 && softErr == nil { - softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids)) - } - case rowDescription: - ps.FieldDescriptions = c.rxRowDescription(r) - for i := range ps.FieldDescriptions { - t, _ := c.PgTypes[ps.FieldDescriptions[i].DataType] - ps.FieldDescriptions[i].DataTypeName = t.Name - ps.FieldDescriptions[i].FormatCode = t.DefaultFormat - } - case noData: - case readyForQuery: - c.rxReadyForQuery(r) - - if softErr == nil { - c.preparedStatements[name] = ps - } - - return ps, softErr - default: - if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil { - softErr = e - } - } - } + return c.PrepareEx(name, sql, nil) } // PrepareEx creates a prepared statement with name and sql. sql can contain placeholders @@ -673,7 +598,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) { // PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same // name and sql arguments. This allows a code path to PrepareEx and Query/Exec without // concern for if the statement has already been prepared. -func (c *Conn) PrepareEx(name, sql string, opts PrepareExOptions) (ps *PreparedStatement, err error) { +func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) { if name != "" { if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql { return ps, nil @@ -693,16 +618,14 @@ func (c *Conn) PrepareEx(name, sql string, opts PrepareExOptions) (ps *PreparedS wbuf.WriteCString(name) wbuf.WriteCString(sql) - if len(opts.ParameterOids) > 65535 { - return nil, errors.New(fmt.Sprintf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids))) - } - - if len(opts.ParameterOids) > 0 { + if opts != nil { + if len(opts.ParameterOids) > 65535 { + return nil, errors.New(fmt.Sprintf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids))) + } wbuf.WriteInt16(int16(len(opts.ParameterOids))) for _, oid := range opts.ParameterOids { wbuf.WriteInt32(int32(oid)) } - } else { wbuf.WriteInt16(0) } diff --git a/conn_pool.go b/conn_pool.go index 68850987..b27074cd 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -331,34 +331,7 @@ func (p *ConnPool) Begin() (*Tx, error) { // the same name and sql arguments. This allows a code path to Prepare and // Query/Exec/PrepareEx without concern for if the statement has already been prepared. func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) { - p.cond.L.Lock() - defer p.cond.L.Unlock() - - if ps, ok := p.preparedStatements[name]; ok && ps.SQL == sql { - return ps, nil - } - - c, err := p.acquire(nil) - if err != nil { - return nil, err - } - ps, err := c.Prepare(name, sql) - p.availableConnections = append(p.availableConnections, c) - if err != nil { - return nil, err - } - - for _, c := range p.availableConnections { - _, err := c.Prepare(name, sql) - if err != nil { - return nil, err - } - } - - p.invalidateAcquired() - p.preparedStatements[name] = ps - - return ps, err + return p.PrepareEx(name, sql, nil) } // PrepareEx creates a prepared statement on a connection in the pool to test the @@ -372,7 +345,7 @@ func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) { // PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same // name and sql arguments. This allows a code path to PrepareEx and Query/Exec/Prepare without // concern for if the statement has already been prepared. -func (p *ConnPool) PrepareEx(name, sql string, opts PrepareExOptions) (*PreparedStatement, error) { +func (p *ConnPool) PrepareEx(name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { p.cond.L.Lock() defer p.cond.L.Unlock() diff --git a/conn_test.go b/conn_test.go index 99f222b4..181a3ed2 100644 --- a/conn_test.go +++ b/conn_test.go @@ -981,6 +981,34 @@ func TestPrepareIdempotency(t *testing.T) { } } +func TestPrepareEx(t *testing.T) { + t.Parallel() + + conn := mustConnect(t, *defaultConnConfig) + defer closeConn(t, conn) + + _, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgx.Oid{pgx.TextOid}}) + if err != nil { + t.Errorf("Unable to prepare statement: %v", err) + return + } + + var s string + err = conn.QueryRow("test", "hello").Scan(&s) + if err != nil { + t.Errorf("Executing prepared statement failed: %v", err) + } + + if s != "hello" { + t.Errorf("Prepared statement did not return expected value: %v", s) + } + + err = conn.Deallocate("test") + if err != nil { + t.Errorf("conn.Deallocate failed: %v", err) + } +} + func TestListenNotify(t *testing.T) { t.Parallel() diff --git a/tx.go b/tx.go index 29b3b235..e5c90c23 100644 --- a/tx.go +++ b/tx.go @@ -129,15 +129,11 @@ func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag, // Prepare delegates to the underlying *Conn func (tx *Tx) Prepare(name, sql string) (*PreparedStatement, error) { - if tx.status != TxStatusInProgress { - return nil, ErrTxClosed - } - - return tx.conn.Prepare(name, sql) + return tx.PrepareEx(name, sql, nil) } // PrepareEx delegates to the underlying *Conn -func (tx *Tx) PrepareEx(name, sql string, opts PrepareExOptions) (*PreparedStatement, error) { +func (tx *Tx) PrepareEx(name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) { if tx.status != TxStatusInProgress { return nil, ErrTxClosed }