From 8aaf235595222e7aa973a1b166e2b92df409fef4 Mon Sep 17 00:00:00 2001 From: Jack Christensen Date: Sat, 8 Jan 2022 21:41:08 -0600 Subject: [PATCH] Standardize scanner and valuer for int types --- pgtype/int.go | 76 ++++++++++++++++++++++++++++------------------- pgtype/int.go.erb | 32 ++++++++++++-------- 2 files changed, 66 insertions(+), 42 deletions(-) diff --git a/pgtype/int.go b/pgtype/int.go index 5fee64a6..54898420 100644 --- a/pgtype/int.go +++ b/pgtype/int.go @@ -13,7 +13,11 @@ import ( ) type Int64Scanner interface { - ScanInt64(v int64, valid bool) error + ScanInt64(Int8) error +} + +type Int64Valuer interface { + Int64Value() (Int8, error) } type Int2 struct { @@ -22,23 +26,27 @@ type Int2 struct { } // ScanInt64 implements the Int64Scanner interface. -func (dst *Int2) ScanInt64(n int64, valid bool) error { - if !valid { +func (dst *Int2) ScanInt64(n Int8) error { + if !n.Valid { *dst = Int2{} return nil } - if n < math.MinInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n) + if n.Int < math.MinInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n.Int) } - if n > math.MaxInt16 { - return fmt.Errorf("%d is greater than maximum value for Int2", n) + if n.Int > math.MaxInt16 { + return fmt.Errorf("%d is greater than maximum value for Int2", n.Int) } - *dst = Int2{Int: int16(n), Valid: true} + *dst = Int2{Int: int16(n.Int), Valid: true} return nil } +func (n Int2) Int64Value() (Int8, error) { + return Int8{Int: int64(n.Int), Valid: n.Valid}, nil +} + // Scan implements the database/sql Scanner interface. func (dst *Int2) Scan(src interface{}) error { if src == nil { @@ -511,7 +519,7 @@ func (scanPlanBinaryInt2ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCod } if src == nil { - return s.ScanInt64(0, false) + return s.ScanInt64(Int8{}) } if len(src) != 2 { @@ -520,7 +528,7 @@ func (scanPlanBinaryInt2ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCod n := int64(binary.BigEndian.Uint16(src)) - return s.ScanInt64(n, true) + return s.ScanInt64(Int8{Int: n, Valid: true}) } type Int4 struct { @@ -529,23 +537,27 @@ type Int4 struct { } // ScanInt64 implements the Int64Scanner interface. -func (dst *Int4) ScanInt64(n int64, valid bool) error { - if !valid { +func (dst *Int4) ScanInt64(n Int8) error { + if !n.Valid { *dst = Int4{} return nil } - if n < math.MinInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", n) + if n.Int < math.MinInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n.Int) } - if n > math.MaxInt32 { - return fmt.Errorf("%d is greater than maximum value for Int4", n) + if n.Int > math.MaxInt32 { + return fmt.Errorf("%d is greater than maximum value for Int4", n.Int) } - *dst = Int4{Int: int32(n), Valid: true} + *dst = Int4{Int: int32(n.Int), Valid: true} return nil } +func (n Int4) Int64Value() (Int8, error) { + return Int8{Int: int64(n.Int), Valid: n.Valid}, nil +} + // Scan implements the database/sql Scanner interface. func (dst *Int4) Scan(src interface{}) error { if src == nil { @@ -1029,7 +1041,7 @@ func (scanPlanBinaryInt4ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCod } if src == nil { - return s.ScanInt64(0, false) + return s.ScanInt64(Int8{}) } if len(src) != 4 { @@ -1038,7 +1050,7 @@ func (scanPlanBinaryInt4ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCod n := int64(binary.BigEndian.Uint32(src)) - return s.ScanInt64(n, true) + return s.ScanInt64(Int8{Int: n, Valid: true}) } type Int8 struct { @@ -1047,23 +1059,27 @@ type Int8 struct { } // ScanInt64 implements the Int64Scanner interface. -func (dst *Int8) ScanInt64(n int64, valid bool) error { - if !valid { +func (dst *Int8) ScanInt64(n Int8) error { + if !n.Valid { *dst = Int8{} return nil } - if n < math.MinInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", n) + if n.Int < math.MinInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n.Int) } - if n > math.MaxInt64 { - return fmt.Errorf("%d is greater than maximum value for Int8", n) + if n.Int > math.MaxInt64 { + return fmt.Errorf("%d is greater than maximum value for Int8", n.Int) } - *dst = Int8{Int: int64(n), Valid: true} + *dst = Int8{Int: int64(n.Int), Valid: true} return nil } +func (n Int8) Int64Value() (Int8, error) { + return Int8{Int: int64(n.Int), Valid: n.Valid}, nil +} + // Scan implements the database/sql Scanner interface. func (dst *Int8) Scan(src interface{}) error { if src == nil { @@ -1569,7 +1585,7 @@ func (scanPlanBinaryInt8ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCod } if src == nil { - return s.ScanInt64(0, false) + return s.ScanInt64(Int8{}) } if len(src) != 8 { @@ -1578,7 +1594,7 @@ func (scanPlanBinaryInt8ToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCod n := int64(binary.BigEndian.Uint64(src)) - return s.ScanInt64(n, true) + return s.ScanInt64(Int8{Int: n, Valid: true}) } type scanPlanTextAnyToInt8 struct{} @@ -1800,7 +1816,7 @@ func (scanPlanTextAnyToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode i } if src == nil { - return s.ScanInt64(0, false) + return s.ScanInt64(Int8{}) } n, err := strconv.ParseInt(string(src), 10, 64) @@ -1808,7 +1824,7 @@ func (scanPlanTextAnyToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode i return err } - err = s.ScanInt64(n, true) + err = s.ScanInt64(Int8{Int: n, Valid: true}) if err != nil { return err } diff --git a/pgtype/int.go.erb b/pgtype/int.go.erb index 419dddd2..0d88dd42 100644 --- a/pgtype/int.go.erb +++ b/pgtype/int.go.erb @@ -11,7 +11,11 @@ import ( ) type Int64Scanner interface { - ScanInt64(v int64, valid bool) error + ScanInt64(Int8) error +} + +type Int64Valuer interface { + Int64Value() (Int8, error) } @@ -23,23 +27,27 @@ type Int<%= pg_byte_size %> struct { } // ScanInt64 implements the Int64Scanner interface. -func (dst *Int<%= pg_byte_size %>) ScanInt64(n int64, valid bool) error { - if !valid { +func (dst *Int<%= pg_byte_size %>) ScanInt64(n Int8) error { + if !n.Valid { *dst = Int<%= pg_byte_size %>{} return nil } - if n < math.MinInt<%= pg_bit_size %> { - return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + if n.Int < math.MinInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int) } - if n > math.MaxInt<%= pg_bit_size %> { - return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n) + if n.Int > math.MaxInt<%= pg_bit_size %> { + return fmt.Errorf("%d is greater than maximum value for Int<%= pg_byte_size %>", n.Int) } - *dst = Int<%= pg_byte_size %>{Int: int<%= pg_bit_size %>(n), Valid: true} + *dst = Int<%= pg_byte_size %>{Int: int<%= pg_bit_size %>(n.Int), Valid: true} return nil } +func (n Int<%= pg_byte_size %>) Int64Value() (Int8, error) { + return Int8{Int: int64(n.Int), Valid: n.Valid}, nil +} + // Scan implements the database/sql Scanner interface. func (dst *Int<%= pg_byte_size %>) Scan(src interface{}) error { if src == nil { @@ -397,7 +405,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(ci *ConnInfo, oid } if src == nil { - return s.ScanInt64(0, false) + return s.ScanInt64(Int8{}) } if len(src) != <%= pg_byte_size %> { @@ -407,7 +415,7 @@ func (scanPlanBinaryInt<%= pg_byte_size %>ToInt64Scanner) Scan(ci *ConnInfo, oid n := int64(binary.BigEndian.Uint<%= pg_bit_size %>(src)) - return s.ScanInt64(n, true) + return s.ScanInt64(Int8{Int: n, Valid: true}) } <% end %> @@ -471,7 +479,7 @@ func (scanPlanTextAnyToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode i } if src == nil { - return s.ScanInt64(0, false) + return s.ScanInt64(Int8{}) } n, err := strconv.ParseInt(string(src), 10, 64) @@ -479,7 +487,7 @@ func (scanPlanTextAnyToInt64Scanner) Scan(ci *ConnInfo, oid uint32, formatCode i return err } - err = s.ScanInt64(n, true) + err = s.ScanInt64(Int8{Int: n, Valid: true}) if err != nil { return err }