From c4050134cc04b79fc24100eda98ea7c71d9b28c0 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 12 Jan 2019 12:19:12 -0600 Subject: [PATCH] Begin delegation of Prepare to pgconn --- conn.go | 73 +++++++++++++++++---------------------------------------- 1 file changed, 21 insertions(+), 52 deletions(-) diff --git a/conn.go b/conn.go index 8db567c2..c19dfef8 100644 --- a/conn.go +++ b/conn.go @@ -538,64 +538,33 @@ func (c *Conn) prepareEx(name, sql string, opts *PrepareExOptions) (ps *Prepared return nil, errors.Errorf("Number of PrepareExOptions ParameterOIDs must be between 0 and 65535, received %d", len(opts.ParameterOIDs)) } - buf := appendParse(c.wbuf, name, sql, opts.ParameterOIDs) - buf = appendDescribe(buf, 'S', name) - buf = appendSync(buf) + var paramOIDs []uint32 + for _, oid := range opts.ParameterOIDs { + paramOIDs = append(paramOIDs, uint32(oid)) + } - n, err := c.pgConn.Conn().Write(buf) + psd, err := c.pgConn.Prepare(context.TODO(), name, sql, paramOIDs) if err != nil { - if fatalWriteErr(n, err) { - c.die(err) - } return nil, err } - c.pendingReadyForQueryCount++ - ps = &PreparedStatement{Name: name, SQL: sql} - - var softErr error - - for { - msg, err := c.rxMsg() - if err != nil { - return nil, err - } - - switch msg := msg.(type) { - case *pgproto3.ParameterDescription: - ps.ParameterOIDs = c.rxParameterDescription(msg) - - if len(ps.ParameterOIDs) > 65535 && softErr == nil { - softErr = errors.Errorf("PostgreSQL supports maximum of 65535 parameters, received %d", len(ps.ParameterOIDs)) - } - case *pgproto3.RowDescription: - ps.FieldDescriptions = c.rxRowDescription(msg) - for i := range ps.FieldDescriptions { - if dt, ok := c.ConnInfo.DataTypeForOID(ps.FieldDescriptions[i].DataType); ok { - ps.FieldDescriptions[i].DataTypeName = dt.Name - if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { - ps.FieldDescriptions[i].FormatCode = BinaryFormatCode - } else { - ps.FieldDescriptions[i].FormatCode = TextFormatCode - } - } else { - return nil, errors.Errorf("unknown oid: %d", ps.FieldDescriptions[i].DataType) - } - } - case *pgproto3.ReadyForQuery: - c.rxReadyForQuery(msg) - - if softErr == nil { - c.preparedStatements[name] = ps - } - - return ps, softErr - default: - if e := c.processContextFreeMsg(msg); e != nil && softErr == nil { - softErr = e - } - } + ps = &PreparedStatement{ + Name: psd.Name, + SQL: psd.SQL, + ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)), + FieldDescriptions: make([]FieldDescription, len(psd.Fields)), } + + for i := range ps.ParameterOIDs { + ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i]) + } + for i := range ps.FieldDescriptions { + c.pgproto3FieldDescriptionToPgxFieldDescription(&psd.Fields[i], &ps.FieldDescriptions[i]) + } + + c.preparedStatements[name] = ps + + return ps, nil } // Deallocate released a prepared statement