From eb0a4c96264854d3e603eb90536c25104244a6ad Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 22 Jan 2022 11:21:12 -0600 Subject: [PATCH] Replace some old database/sql compatibility --- pgtype/builtin_wrappers.go | 2 ++ pgtype/database_sql.go | 41 ---------------------------- pgtype/pgtype.go | 56 ++++++++++---------------------------- 3 files changed, 16 insertions(+), 83 deletions(-) delete mode 100644 pgtype/database_sql.go diff --git a/pgtype/builtin_wrappers.go b/pgtype/builtin_wrappers.go index abf21b82..9df28f55 100644 --- a/pgtype/builtin_wrappers.go +++ b/pgtype/builtin_wrappers.go @@ -461,6 +461,8 @@ func (w timeWrapper) TimeValue() (Time, error) { type durationWrapper time.Duration +func (w durationWrapper) SkipUnderlyingTypePlan() {} + func (w *durationWrapper) ScanInterval(v Interval) error { if !v.Valid { return fmt.Errorf("cannot scan NULL into *time.Interval") diff --git a/pgtype/database_sql.go b/pgtype/database_sql.go deleted file mode 100644 index 9d1cf822..00000000 --- a/pgtype/database_sql.go +++ /dev/null @@ -1,41 +0,0 @@ -package pgtype - -import ( - "database/sql/driver" - "errors" -) - -func DatabaseSQLValue(ci *ConnInfo, src Value) (interface{}, error) { - if valuer, ok := src.(driver.Valuer); ok { - return valuer.Value() - } - - if textEncoder, ok := src.(TextEncoder); ok { - buf, err := textEncoder.EncodeText(ci, nil) - if err != nil { - return nil, err - } - return string(buf), nil - } - - if binaryEncoder, ok := src.(BinaryEncoder); ok { - buf, err := binaryEncoder.EncodeBinary(ci, nil) - if err != nil { - return nil, err - } - return buf, nil - } - - return nil, errors.New("cannot convert to database/sql compatible value") -} - -func EncodeValueText(src TextEncoder) (interface{}, error) { - buf, err := src.EncodeText(nil, make([]byte, 0, 32)) - if err != nil { - return nil, err - } - if buf == nil { - return nil, nil - } - return string(buf), err -} diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index 481c58be..ee2730ee 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -520,41 +520,6 @@ func (plan scanPlanDstTextDecoder) Scan(ci *ConnInfo, oid uint32, formatCode int return newPlan.Scan(ci, oid, formatCode, src, dst) } -type scanPlanDataTypeSQLScanner DataType - -func (plan *scanPlanDataTypeSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { - scanner, ok := dst.(sql.Scanner) - if !ok { - newPlan := ci.PlanScan(oid, formatCode, dst) - return newPlan.Scan(ci, oid, formatCode, src, dst) - } - - dt := (*DataType)(plan) - if dt.Codec != nil { - sqlValue, err := dt.Codec.DecodeDatabaseSQLValue(ci, oid, formatCode, src) - if err != nil { - return err - } - return scanner.Scan(sqlValue) - } - var err error - switch formatCode { - case BinaryFormatCode: - err = dt.binaryDecoder.DecodeBinary(ci, src) - case TextFormatCode: - err = dt.textDecoder.DecodeText(ci, src) - } - if err != nil { - return err - } - - sqlSrc, err := DatabaseSQLValue(ci, dt.Value) - if err != nil { - return err - } - return scanner.Scan(sqlSrc) -} - type scanPlanDataTypeAssignTo DataType func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { @@ -596,6 +561,18 @@ func (plan *scanPlanDataTypeAssignTo) Scan(ci *ConnInfo, oid uint32, formatCode return assignToErr } +type scanPlanCodecSQLScanner struct{ c Codec } + +func (plan *scanPlanCodecSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { + value, err := plan.c.DecodeDatabaseSQLValue(ci, oid, formatCode, src) + if err != nil { + return err + } + + scanner := dst.(sql.Scanner) + return scanner.Scan(value) +} + type scanPlanSQLScanner struct{} func (scanPlanSQLScanner) Scan(ci *ConnInfo, oid uint32, formatCode int16, src []byte, dst interface{}) error { @@ -1176,7 +1153,7 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan for _, f := range tryWrappers { if wrapperPlan, nextDst, ok := f(dst); ok { if nextPlan := ci.PlanScan(oid, formatCode, nextDst); nextPlan != nil { - if _, ok := nextPlan.(*scanPlanDataTypeAssignTo); !ok { // avoid fallthrough -- this will go away when old system removed. + if _, ok := nextPlan.(scanPlanReflection); !ok { // avoid fallthrough -- this will go away when old system removed. wrapperPlan.SetNext(nextPlan) return wrapperPlan } @@ -1187,15 +1164,10 @@ func (ci *ConnInfo) PlanScan(oid uint32, formatCode int16, dst interface{}) Scan if _, ok := dst.(*interface{}); ok { return &pointerEmptyInterfaceScanPlan{codec: dt.Codec} } - } - if dt != nil { if _, ok := dst.(sql.Scanner); ok { - if _, found := ci.preferAssignToOverSQLScannerTypes[reflect.TypeOf(dst)]; !found { - return (*scanPlanDataTypeSQLScanner)(dt) - } + return &scanPlanCodecSQLScanner{c: dt.Codec} } - return (*scanPlanDataTypeAssignTo)(dt) } if _, ok := dst.(sql.Scanner); ok {