Pass PrepareEx opts as pointer and DRY implementation

pull/149/head
Jack Christensen 2016-05-20 08:30:10 -05:00
parent 7954a36b2d
commit 40f00f4a82
4 changed files with 38 additions and 118 deletions

89
conn.go
View File

@ -588,82 +588,7 @@ func configSSL(sslmode string, cc *ConnConfig) error {
// name and sql arguments. This allows a code path to Prepare and Query/Exec without
// concern for if the statement has already been prepared.
func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
if name != "" {
if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
return ps, nil
}
}
if c.shouldLog(LogLevelError) {
defer func() {
if err != nil {
c.log(LogLevelError, fmt.Sprintf("Prepare `%s` as `%s` failed: %v", name, sql, err))
}
}()
}
// parse
wbuf := newWriteBuf(c, 'P')
wbuf.WriteCString(name)
wbuf.WriteCString(sql)
wbuf.WriteInt16(0)
// describe
wbuf.startMsg('D')
wbuf.WriteByte('S')
wbuf.WriteCString(name)
// sync
wbuf.startMsg('S')
wbuf.closeMsg()
_, err = c.conn.Write(wbuf.buf)
if err != nil {
c.die(err)
return nil, err
}
ps = &PreparedStatement{Name: name, SQL: sql}
var softErr error
for {
var t byte
var r *msgReader
t, r, err := c.rxMsg()
if err != nil {
return nil, err
}
switch t {
case parseComplete:
case parameterDescription:
ps.ParameterOids = c.rxParameterDescription(r)
if len(ps.ParameterOids) > 65535 && softErr == nil {
softErr = fmt.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOids))
}
case rowDescription:
ps.FieldDescriptions = c.rxRowDescription(r)
for i := range ps.FieldDescriptions {
t, _ := c.PgTypes[ps.FieldDescriptions[i].DataType]
ps.FieldDescriptions[i].DataTypeName = t.Name
ps.FieldDescriptions[i].FormatCode = t.DefaultFormat
}
case noData:
case readyForQuery:
c.rxReadyForQuery(r)
if softErr == nil {
c.preparedStatements[name] = ps
}
return ps, softErr
default:
if e := c.processContextFreeMsg(t, r); e != nil && softErr == nil {
softErr = e
}
}
}
return c.PrepareEx(name, sql, nil)
}
// PrepareEx creates a prepared statement with name and sql. sql can contain placeholders
@ -673,7 +598,7 @@ func (c *Conn) Prepare(name, sql string) (ps *PreparedStatement, err error) {
// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec without
// concern for if the statement has already been prepared.
func (c *Conn) PrepareEx(name, sql string, opts PrepareExOptions) (ps *PreparedStatement, err error) {
func (c *Conn) PrepareEx(name, sql string, opts *PrepareExOptions) (ps *PreparedStatement, err error) {
if name != "" {
if ps, ok := c.preparedStatements[name]; ok && ps.SQL == sql {
return ps, nil
@ -693,16 +618,14 @@ func (c *Conn) PrepareEx(name, sql string, opts PrepareExOptions) (ps *PreparedS
wbuf.WriteCString(name)
wbuf.WriteCString(sql)
if len(opts.ParameterOids) > 65535 {
return nil, errors.New(fmt.Sprintf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)))
}
if len(opts.ParameterOids) > 0 {
if opts != nil {
if len(opts.ParameterOids) > 65535 {
return nil, errors.New(fmt.Sprintf("Number of PrepareExOptions ParameterOids must be between 0 and 65535, received %d", len(opts.ParameterOids)))
}
wbuf.WriteInt16(int16(len(opts.ParameterOids)))
for _, oid := range opts.ParameterOids {
wbuf.WriteInt32(int32(oid))
}
} else {
wbuf.WriteInt16(0)
}

View File

@ -331,34 +331,7 @@ func (p *ConnPool) Begin() (*Tx, error) {
// the same name and sql arguments. This allows a code path to Prepare and
// Query/Exec/PrepareEx without concern for if the statement has already been prepared.
func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) {
p.cond.L.Lock()
defer p.cond.L.Unlock()
if ps, ok := p.preparedStatements[name]; ok && ps.SQL == sql {
return ps, nil
}
c, err := p.acquire(nil)
if err != nil {
return nil, err
}
ps, err := c.Prepare(name, sql)
p.availableConnections = append(p.availableConnections, c)
if err != nil {
return nil, err
}
for _, c := range p.availableConnections {
_, err := c.Prepare(name, sql)
if err != nil {
return nil, err
}
}
p.invalidateAcquired()
p.preparedStatements[name] = ps
return ps, err
return p.PrepareEx(name, sql, nil)
}
// PrepareEx creates a prepared statement on a connection in the pool to test the
@ -372,7 +345,7 @@ func (p *ConnPool) Prepare(name, sql string) (*PreparedStatement, error) {
// PrepareEx is idempotent; i.e. it is safe to call PrepareEx multiple times with the same
// name and sql arguments. This allows a code path to PrepareEx and Query/Exec/Prepare without
// concern for if the statement has already been prepared.
func (p *ConnPool) PrepareEx(name, sql string, opts PrepareExOptions) (*PreparedStatement, error) {
func (p *ConnPool) PrepareEx(name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) {
p.cond.L.Lock()
defer p.cond.L.Unlock()

View File

@ -981,6 +981,34 @@ func TestPrepareIdempotency(t *testing.T) {
}
}
func TestPrepareEx(t *testing.T) {
t.Parallel()
conn := mustConnect(t, *defaultConnConfig)
defer closeConn(t, conn)
_, err := conn.PrepareEx("test", "select $1", &pgx.PrepareExOptions{ParameterOids: []pgx.Oid{pgx.TextOid}})
if err != nil {
t.Errorf("Unable to prepare statement: %v", err)
return
}
var s string
err = conn.QueryRow("test", "hello").Scan(&s)
if err != nil {
t.Errorf("Executing prepared statement failed: %v", err)
}
if s != "hello" {
t.Errorf("Prepared statement did not return expected value: %v", s)
}
err = conn.Deallocate("test")
if err != nil {
t.Errorf("conn.Deallocate failed: %v", err)
}
}
func TestListenNotify(t *testing.T) {
t.Parallel()

8
tx.go
View File

@ -129,15 +129,11 @@ func (tx *Tx) Exec(sql string, arguments ...interface{}) (commandTag CommandTag,
// Prepare delegates to the underlying *Conn
func (tx *Tx) Prepare(name, sql string) (*PreparedStatement, error) {
if tx.status != TxStatusInProgress {
return nil, ErrTxClosed
}
return tx.conn.Prepare(name, sql)
return tx.PrepareEx(name, sql, nil)
}
// PrepareEx delegates to the underlying *Conn
func (tx *Tx) PrepareEx(name, sql string, opts PrepareExOptions) (*PreparedStatement, error) {
func (tx *Tx) PrepareEx(name, sql string, opts *PrepareExOptions) (*PreparedStatement, error) {
if tx.status != TxStatusInProgress {
return nil, ErrTxClosed
}