mirror of https://github.com/jackc/pgx.git
Extract scan value to pgtype
parent
69946b35d8
commit
f756d9d591
|
@ -1,6 +1,7 @@
|
|||
package pgtype
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
@ -84,6 +85,12 @@ func (im InfinityModifier) String() string {
|
|||
}
|
||||
}
|
||||
|
||||
// PostgreSQL format codes
|
||||
const (
|
||||
TextFormatCode = 0
|
||||
BinaryFormatCode = 1
|
||||
)
|
||||
|
||||
type Value interface {
|
||||
// Set converts and assigns src to itself.
|
||||
Set(src interface{}) error
|
||||
|
@ -207,6 +214,53 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo {
|
|||
return ci2
|
||||
}
|
||||
|
||||
func (ci *ConnInfo) Scan(oid OID, formatCode int16, buf []byte, dest interface{}) error {
|
||||
if dest, ok := dest.(BinaryDecoder); ok && formatCode == BinaryFormatCode {
|
||||
return dest.DecodeBinary(ci, buf)
|
||||
}
|
||||
|
||||
if dest, ok := dest.(TextDecoder); ok && formatCode == TextFormatCode {
|
||||
return dest.DecodeText(ci, buf)
|
||||
}
|
||||
|
||||
if dt, ok := ci.DataTypeForOID(oid); ok {
|
||||
value := dt.Value
|
||||
switch formatCode {
|
||||
case TextFormatCode:
|
||||
if textDecoder, ok := value.(TextDecoder); ok {
|
||||
err := textDecoder.DecodeText(ci, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return errors.Errorf("%T is not a pgtype.TextDecoder", value)
|
||||
}
|
||||
case BinaryFormatCode:
|
||||
if binaryDecoder, ok := value.(BinaryDecoder); ok {
|
||||
err := binaryDecoder.DecodeBinary(ci, buf)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
return errors.Errorf("%T is not a pgtype.BinaryDecoder", value)
|
||||
}
|
||||
default:
|
||||
return errors.Errorf("unknown format code: %v", formatCode)
|
||||
}
|
||||
|
||||
if scanner, ok := dest.(sql.Scanner); ok {
|
||||
sqlSrc, err := DatabaseSQLValue(ci, value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return scanner.Scan(sqlSrc)
|
||||
} else {
|
||||
return value.AssignTo(dest)
|
||||
}
|
||||
}
|
||||
return errors.Errorf("unknown oid: %v", oid)
|
||||
}
|
||||
|
||||
var nameValues map[string]Value
|
||||
|
||||
func init() {
|
||||
|
|
66
query.go
66
query.go
|
@ -2,7 +2,6 @@ package pgx
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
@ -186,9 +185,9 @@ func (rows *connRows) nextColumn() ([]byte, *FieldDescription, bool) {
|
|||
return buf, fd, true
|
||||
}
|
||||
|
||||
func (rows *connRows) Scan(dest ...interface{}) (err error) {
|
||||
func (rows *connRows) Scan(dest ...interface{}) error {
|
||||
if len(rows.fields) != len(dest) {
|
||||
err = errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields))
|
||||
err := errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields))
|
||||
rows.fatal(err)
|
||||
return err
|
||||
}
|
||||
|
@ -200,63 +199,10 @@ func (rows *connRows) Scan(dest ...interface{}) (err error) {
|
|||
continue
|
||||
}
|
||||
|
||||
if s, ok := d.(pgtype.BinaryDecoder); ok && fd.FormatCode == BinaryFormatCode {
|
||||
err = s.DecodeBinary(rows.conn.ConnInfo, buf)
|
||||
if err != nil {
|
||||
rows.fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
} else if s, ok := d.(pgtype.TextDecoder); ok && fd.FormatCode == TextFormatCode {
|
||||
err = s.DecodeText(rows.conn.ConnInfo, buf)
|
||||
if err != nil {
|
||||
rows.fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
} else {
|
||||
if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok {
|
||||
value := dt.Value
|
||||
switch fd.FormatCode {
|
||||
case TextFormatCode:
|
||||
if textDecoder, ok := value.(pgtype.TextDecoder); ok {
|
||||
err = textDecoder.DecodeText(rows.conn.ConnInfo, buf)
|
||||
if err != nil {
|
||||
rows.fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
} else {
|
||||
rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.TextDecoder", value)})
|
||||
}
|
||||
case BinaryFormatCode:
|
||||
if binaryDecoder, ok := value.(pgtype.BinaryDecoder); ok {
|
||||
err = binaryDecoder.DecodeBinary(rows.conn.ConnInfo, buf)
|
||||
if err != nil {
|
||||
rows.fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
} else {
|
||||
rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.BinaryDecoder", value)})
|
||||
}
|
||||
default:
|
||||
rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown format code: %v", fd.FormatCode)})
|
||||
}
|
||||
|
||||
if rows.Err() == nil {
|
||||
if scanner, ok := d.(sql.Scanner); ok {
|
||||
sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value)
|
||||
if err != nil {
|
||||
rows.fatal(err)
|
||||
}
|
||||
err = scanner.Scan(sqlSrc)
|
||||
if err != nil {
|
||||
rows.fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
} else if err := value.AssignTo(d); err != nil {
|
||||
rows.fatal(scanArgError{col: i, err: err})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown oid: %v", fd.DataType)})
|
||||
}
|
||||
}
|
||||
|
||||
if rows.Err() != nil {
|
||||
return rows.Err()
|
||||
err := rows.conn.ConnInfo.Scan(fd.DataType, fd.FormatCode, buf, d)
|
||||
if err != nil {
|
||||
rows.fatal(scanArgError{col: i, err: err})
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue