Extract scan value to pgtype

pull/483/head
Jack Christensen 2019-04-12 21:31:59 -05:00
parent 69946b35d8
commit f756d9d591
2 changed files with 60 additions and 60 deletions

View File

@ -1,6 +1,7 @@
package pgtype
import (
"database/sql"
"reflect"
"github.com/pkg/errors"
@ -84,6 +85,12 @@ func (im InfinityModifier) String() string {
}
}
// PostgreSQL format codes
const (
TextFormatCode = 0
BinaryFormatCode = 1
)
type Value interface {
// Set converts and assigns src to itself.
Set(src interface{}) error
@ -207,6 +214,53 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo {
return ci2
}
func (ci *ConnInfo) Scan(oid OID, formatCode int16, buf []byte, dest interface{}) error {
if dest, ok := dest.(BinaryDecoder); ok && formatCode == BinaryFormatCode {
return dest.DecodeBinary(ci, buf)
}
if dest, ok := dest.(TextDecoder); ok && formatCode == TextFormatCode {
return dest.DecodeText(ci, buf)
}
if dt, ok := ci.DataTypeForOID(oid); ok {
value := dt.Value
switch formatCode {
case TextFormatCode:
if textDecoder, ok := value.(TextDecoder); ok {
err := textDecoder.DecodeText(ci, buf)
if err != nil {
return err
}
} else {
return errors.Errorf("%T is not a pgtype.TextDecoder", value)
}
case BinaryFormatCode:
if binaryDecoder, ok := value.(BinaryDecoder); ok {
err := binaryDecoder.DecodeBinary(ci, buf)
if err != nil {
return err
}
} else {
return errors.Errorf("%T is not a pgtype.BinaryDecoder", value)
}
default:
return errors.Errorf("unknown format code: %v", formatCode)
}
if scanner, ok := dest.(sql.Scanner); ok {
sqlSrc, err := DatabaseSQLValue(ci, value)
if err != nil {
return err
}
return scanner.Scan(sqlSrc)
} else {
return value.AssignTo(dest)
}
}
return errors.Errorf("unknown oid: %v", oid)
}
var nameValues map[string]Value
func init() {

View File

@ -2,7 +2,6 @@ package pgx
import (
"context"
"database/sql"
"fmt"
"reflect"
"time"
@ -186,9 +185,9 @@ func (rows *connRows) nextColumn() ([]byte, *FieldDescription, bool) {
return buf, fd, true
}
func (rows *connRows) Scan(dest ...interface{}) (err error) {
func (rows *connRows) Scan(dest ...interface{}) error {
if len(rows.fields) != len(dest) {
err = errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields))
err := errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields))
rows.fatal(err)
return err
}
@ -200,63 +199,10 @@ func (rows *connRows) Scan(dest ...interface{}) (err error) {
continue
}
if s, ok := d.(pgtype.BinaryDecoder); ok && fd.FormatCode == BinaryFormatCode {
err = s.DecodeBinary(rows.conn.ConnInfo, buf)
if err != nil {
rows.fatal(scanArgError{col: i, err: err})
}
} else if s, ok := d.(pgtype.TextDecoder); ok && fd.FormatCode == TextFormatCode {
err = s.DecodeText(rows.conn.ConnInfo, buf)
if err != nil {
rows.fatal(scanArgError{col: i, err: err})
}
} else {
if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok {
value := dt.Value
switch fd.FormatCode {
case TextFormatCode:
if textDecoder, ok := value.(pgtype.TextDecoder); ok {
err = textDecoder.DecodeText(rows.conn.ConnInfo, buf)
if err != nil {
rows.fatal(scanArgError{col: i, err: err})
}
} else {
rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.TextDecoder", value)})
}
case BinaryFormatCode:
if binaryDecoder, ok := value.(pgtype.BinaryDecoder); ok {
err = binaryDecoder.DecodeBinary(rows.conn.ConnInfo, buf)
if err != nil {
rows.fatal(scanArgError{col: i, err: err})
}
} else {
rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.BinaryDecoder", value)})
}
default:
rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown format code: %v", fd.FormatCode)})
}
if rows.Err() == nil {
if scanner, ok := d.(sql.Scanner); ok {
sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value)
if err != nil {
rows.fatal(err)
}
err = scanner.Scan(sqlSrc)
if err != nil {
rows.fatal(scanArgError{col: i, err: err})
}
} else if err := value.AssignTo(d); err != nil {
rows.fatal(scanArgError{col: i, err: err})
}
}
} else {
rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown oid: %v", fd.DataType)})
}
}
if rows.Err() != nil {
return rows.Err()
err := rows.conn.ConnInfo.Scan(fd.DataType, fd.FormatCode, buf, d)
if err != nil {
rows.fatal(scanArgError{col: i, err: err})
return err
}
}