mirror of https://github.com/jackc/pgx.git
Extract scan value to pgtype
parent
69946b35d8
commit
f756d9d591
|
@ -1,6 +1,7 @@
|
||||||
package pgtype
|
package pgtype
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
@ -84,6 +85,12 @@ func (im InfinityModifier) String() string {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PostgreSQL format codes
|
||||||
|
const (
|
||||||
|
TextFormatCode = 0
|
||||||
|
BinaryFormatCode = 1
|
||||||
|
)
|
||||||
|
|
||||||
type Value interface {
|
type Value interface {
|
||||||
// Set converts and assigns src to itself.
|
// Set converts and assigns src to itself.
|
||||||
Set(src interface{}) error
|
Set(src interface{}) error
|
||||||
|
@ -207,6 +214,53 @@ func (ci *ConnInfo) DeepCopy() *ConnInfo {
|
||||||
return ci2
|
return ci2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ci *ConnInfo) Scan(oid OID, formatCode int16, buf []byte, dest interface{}) error {
|
||||||
|
if dest, ok := dest.(BinaryDecoder); ok && formatCode == BinaryFormatCode {
|
||||||
|
return dest.DecodeBinary(ci, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dest, ok := dest.(TextDecoder); ok && formatCode == TextFormatCode {
|
||||||
|
return dest.DecodeText(ci, buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dt, ok := ci.DataTypeForOID(oid); ok {
|
||||||
|
value := dt.Value
|
||||||
|
switch formatCode {
|
||||||
|
case TextFormatCode:
|
||||||
|
if textDecoder, ok := value.(TextDecoder); ok {
|
||||||
|
err := textDecoder.DecodeText(ci, buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return errors.Errorf("%T is not a pgtype.TextDecoder", value)
|
||||||
|
}
|
||||||
|
case BinaryFormatCode:
|
||||||
|
if binaryDecoder, ok := value.(BinaryDecoder); ok {
|
||||||
|
err := binaryDecoder.DecodeBinary(ci, buf)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return errors.Errorf("%T is not a pgtype.BinaryDecoder", value)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return errors.Errorf("unknown format code: %v", formatCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if scanner, ok := dest.(sql.Scanner); ok {
|
||||||
|
sqlSrc, err := DatabaseSQLValue(ci, value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return scanner.Scan(sqlSrc)
|
||||||
|
} else {
|
||||||
|
return value.AssignTo(dest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors.Errorf("unknown oid: %v", oid)
|
||||||
|
}
|
||||||
|
|
||||||
var nameValues map[string]Value
|
var nameValues map[string]Value
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
62
query.go
62
query.go
|
@ -2,7 +2,6 @@ package pgx
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
"time"
|
"time"
|
||||||
|
@ -186,9 +185,9 @@ func (rows *connRows) nextColumn() ([]byte, *FieldDescription, bool) {
|
||||||
return buf, fd, true
|
return buf, fd, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rows *connRows) Scan(dest ...interface{}) (err error) {
|
func (rows *connRows) Scan(dest ...interface{}) error {
|
||||||
if len(rows.fields) != len(dest) {
|
if len(rows.fields) != len(dest) {
|
||||||
err = errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields))
|
err := errors.Errorf("Scan received wrong number of arguments, got %d but expected %d", len(dest), len(rows.fields))
|
||||||
rows.fatal(err)
|
rows.fatal(err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -200,63 +199,10 @@ func (rows *connRows) Scan(dest ...interface{}) (err error) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if s, ok := d.(pgtype.BinaryDecoder); ok && fd.FormatCode == BinaryFormatCode {
|
err := rows.conn.ConnInfo.Scan(fd.DataType, fd.FormatCode, buf, d)
|
||||||
err = s.DecodeBinary(rows.conn.ConnInfo, buf)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
rows.fatal(scanArgError{col: i, err: err})
|
rows.fatal(scanArgError{col: i, err: err})
|
||||||
}
|
return err
|
||||||
} else if s, ok := d.(pgtype.TextDecoder); ok && fd.FormatCode == TextFormatCode {
|
|
||||||
err = s.DecodeText(rows.conn.ConnInfo, buf)
|
|
||||||
if err != nil {
|
|
||||||
rows.fatal(scanArgError{col: i, err: err})
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if dt, ok := rows.conn.ConnInfo.DataTypeForOID(fd.DataType); ok {
|
|
||||||
value := dt.Value
|
|
||||||
switch fd.FormatCode {
|
|
||||||
case TextFormatCode:
|
|
||||||
if textDecoder, ok := value.(pgtype.TextDecoder); ok {
|
|
||||||
err = textDecoder.DecodeText(rows.conn.ConnInfo, buf)
|
|
||||||
if err != nil {
|
|
||||||
rows.fatal(scanArgError{col: i, err: err})
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.TextDecoder", value)})
|
|
||||||
}
|
|
||||||
case BinaryFormatCode:
|
|
||||||
if binaryDecoder, ok := value.(pgtype.BinaryDecoder); ok {
|
|
||||||
err = binaryDecoder.DecodeBinary(rows.conn.ConnInfo, buf)
|
|
||||||
if err != nil {
|
|
||||||
rows.fatal(scanArgError{col: i, err: err})
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
rows.fatal(scanArgError{col: i, err: errors.Errorf("%T is not a pgtype.BinaryDecoder", value)})
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown format code: %v", fd.FormatCode)})
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows.Err() == nil {
|
|
||||||
if scanner, ok := d.(sql.Scanner); ok {
|
|
||||||
sqlSrc, err := pgtype.DatabaseSQLValue(rows.conn.ConnInfo, value)
|
|
||||||
if err != nil {
|
|
||||||
rows.fatal(err)
|
|
||||||
}
|
|
||||||
err = scanner.Scan(sqlSrc)
|
|
||||||
if err != nil {
|
|
||||||
rows.fatal(scanArgError{col: i, err: err})
|
|
||||||
}
|
|
||||||
} else if err := value.AssignTo(d); err != nil {
|
|
||||||
rows.fatal(scanArgError{col: i, err: err})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
rows.fatal(scanArgError{col: i, err: errors.Errorf("unknown oid: %v", fd.DataType)})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows.Err() != nil {
|
|
||||||
return rows.Err()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue