diff --git a/conn.go b/conn.go index c857234c..d89fe575 100644 --- a/conn.go +++ b/conn.go @@ -430,108 +430,67 @@ optionLoop: } if simpleProtocol { + return c.execSimpleProtocol(ctx, sql, arguments) + } + + if ps, ok := c.preparedStatements[sql]; ok { + return c.execPrepared(ctx, ps, arguments) + } + + if len(arguments) == 0 { + return c.execSimpleProtocol(ctx, sql, arguments) + } + + ps, err := c.Prepare(ctx, "", sql) + if err != nil { + return nil, err + } + return c.execPrepared(ctx, ps, arguments) +} + +func (c *Conn) execSimpleProtocol(ctx context.Context, sql string, arguments []interface{}) (commandTag pgconn.CommandTag, err error) { + if len(arguments) > 0 { sql, err = c.sanitizeForSimpleQuery(sql, arguments...) if err != nil { return nil, err } - - mrr := c.pgConn.Exec(ctx, sql) - if mrr.NextResult() { - result := mrr.ResultReader().Read() - err = mrr.Close() - return result.CommandTag, err - } else { - err = mrr.Close() - return nil, err - } } + mrr := c.pgConn.Exec(ctx, sql) + for mrr.NextResult() { + commandTag, err = mrr.ResultReader().Close() + } + err = mrr.Close() + return commandTag, err +} + +func (c *Conn) execPrepared(ctx context.Context, ps *PreparedStatement, arguments []interface{}) (commandTag pgconn.CommandTag, err error) { c.eqb.Reset() - if ps, ok := c.preparedStatements[sql]; ok { - args, err := convertDriverValuers(arguments) - if err != nil { - return nil, err - } - - for i := range args { - err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], args[i]) - if err != nil { - return nil, err - } - } - - for i := range ps.FieldDescriptions { - if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok { - if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { - c.eqb.AppendResultFormat(BinaryFormatCode) - } else { - c.eqb.AppendResultFormat(TextFormatCode) - } - } - } - - result := c.pgConn.ExecPrepared(ctx, ps.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() - return result.CommandTag, result.Err + args, err := convertDriverValuers(arguments) + if err != nil { + return nil, err } - if len(arguments) == 0 { - results, err := c.pgConn.Exec(ctx, sql).ReadAll() + for i := range args { + err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], args[i]) if err != nil { return nil, err } - if len(results) == 0 { - return nil, nil - } - - return results[len(results)-1].CommandTag, nil - } else { - psd, err := c.pgConn.Prepare(ctx, "", sql, nil) - if err != nil { - return nil, err - } - - if len(psd.ParamOIDs) != len(arguments) { - return nil, errors.Errorf("expected %d arguments, got %d", len(psd.ParamOIDs), len(arguments)) - } - - ps := &PreparedStatement{ - Name: psd.Name, - SQL: psd.SQL, - ParameterOIDs: make([]pgtype.OID, len(psd.ParamOIDs)), - FieldDescriptions: psd.Fields, - } - - for i := range ps.ParameterOIDs { - ps.ParameterOIDs[i] = pgtype.OID(psd.ParamOIDs[i]) - } - - arguments, err = convertDriverValuers(arguments) - if err != nil { - return nil, err - } - - for i := range arguments { - err = c.eqb.AppendParam(c.ConnInfo, ps.ParameterOIDs[i], arguments[i]) - if err != nil { - return nil, err - } - } - - for i := range ps.FieldDescriptions { - if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok { - if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { - c.eqb.AppendResultFormat(BinaryFormatCode) - } else { - c.eqb.AppendResultFormat(TextFormatCode) - } - } - } - - result := c.pgConn.ExecPrepared(ctx, psd.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() - return result.CommandTag, result.Err } + for i := range ps.FieldDescriptions { + if dt, ok := c.ConnInfo.DataTypeForOID(pgtype.OID(ps.FieldDescriptions[i].DataTypeOID)); ok { + if _, ok := dt.Value.(pgtype.BinaryDecoder); ok { + c.eqb.AppendResultFormat(BinaryFormatCode) + } else { + c.eqb.AppendResultFormat(TextFormatCode) + } + } + } + + result := c.pgConn.ExecPrepared(ctx, ps.Name, c.eqb.paramValues, c.eqb.paramFormats, c.eqb.resultFormats).Read() + return result.CommandTag, result.Err } func (c *Conn) getRows(ctx context.Context, sql string, args []interface{}) *connRows {