diff --git a/stdlib/sql.go b/stdlib/sql.go index 8ff3fd49..db8eefc0 100644 --- a/stdlib/sql.go +++ b/stdlib/sql.go @@ -349,11 +349,14 @@ func (s *Stmt) QueryContext(ctx context.Context, argsV []driver.NamedValue) (dri } type Rows struct { - conn *Conn - rows pgx.Rows - values []interface{} - skipNext bool - skipNextMore bool + conn *Conn + rows pgx.Rows + values []interface{} + driverValuers []driver.Valuer + textDecoders []pgtype.TextDecoder + binaryDecoders []pgtype.BinaryDecoder + skipNext bool + skipNextMore bool } func (r *Rows) Columns() []string { @@ -444,42 +447,112 @@ func (r *Rows) Close() error { } func (r *Rows) Next(dest []driver.Value) error { + ci := r.conn.conn.ConnInfo() + fieldDescriptions := r.rows.FieldDescriptions() + if r.values == nil { - r.values = make([]interface{}, len(r.rows.FieldDescriptions())) - for i, fd := range r.rows.FieldDescriptions() { + r.values = make([]interface{}, len(fieldDescriptions)) + r.driverValuers = make([]driver.Valuer, len(fieldDescriptions)) + r.textDecoders = make([]pgtype.TextDecoder, len(fieldDescriptions)) + r.binaryDecoders = make([]pgtype.BinaryDecoder, len(fieldDescriptions)) + + for i, fd := range fieldDescriptions { switch fd.DataTypeOID { case pgtype.BoolOID: - r.values[i] = &pgtype.Bool{} + v := &pgtype.Bool{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.ByteaOID: - r.values[i] = &pgtype.Bytea{} + v := &pgtype.Bytea{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.CIDOID: - r.values[i] = &pgtype.CID{} + v := &pgtype.CID{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.DateOID: - r.values[i] = &pgtype.Date{} + v := &pgtype.Date{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.Float4OID: - r.values[i] = &pgtype.Float4{} + v := &pgtype.Float4{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.Float8OID: - r.values[i] = &pgtype.Float8{} + v := &pgtype.Float8{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.Int2OID: - r.values[i] = &pgtype.Int2{} + v := &pgtype.Int2{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.Int4OID: - r.values[i] = &pgtype.Int4{} + v := &pgtype.Int4{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.Int8OID: - r.values[i] = &pgtype.Int8{} + v := &pgtype.Int8{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.JSONOID: - r.values[i] = &pgtype.JSON{} + v := &pgtype.JSON{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.JSONBOID: - r.values[i] = &pgtype.JSONB{} + v := &pgtype.JSONB{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.OIDOID: - r.values[i] = &pgtype.OIDValue{} + v := &pgtype.OIDValue{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.TimestampOID: - r.values[i] = &pgtype.Timestamp{} + v := &pgtype.Timestamp{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.TimestamptzOID: - r.values[i] = &pgtype.Timestamptz{} + v := &pgtype.Timestamptz{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v case pgtype.XIDOID: - r.values[i] = &pgtype.XID{} + v := &pgtype.XID{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v + r.binaryDecoders[i] = v default: - r.values[i] = &pgtype.GenericText{} + v := &pgtype.GenericText{} + r.values[i] = v + r.driverValuers[i] = v + r.textDecoders[i] = v } } } @@ -500,15 +573,24 @@ func (r *Rows) Next(dest []driver.Value) error { } } - err := r.rows.Scan(r.values...) - if err != nil { - return err - } + for i, rv := range r.rows.RawValues() { + fd := fieldDescriptions[i] + if fd.Format == pgx.BinaryFormatCode { + err := r.binaryDecoders[i].DecodeBinary(ci, rv) + if err != nil { + return fmt.Errorf("scan field %d failed: %v", i, err) + } + } else { + err := r.textDecoders[i].DecodeText(ci, rv) + if err != nil { + return fmt.Errorf("scan field %d failed: %v", i, err) + } + } - for i, v := range r.values { - dest[i], err = v.(driver.Valuer).Value() + var err error + dest[i], err = r.driverValuers[i].Value() if err != nil { - return err + return fmt.Errorf("convert field %d failed: %v", i, err) } }