Register more default types and handle unknown types better

query-exec-mode
Jack Christensen 2022-03-05 21:19:58 -06:00
parent 2831eedef3
commit 0905d1f452
2 changed files with 49 additions and 9 deletions

View File

@ -344,10 +344,12 @@ func NewMap() *Map {
registerDefaultPgTypeVariants("int8", "_int8", int64(0))
// Integer types that do not have a direct match to a PostgreSQL type
registerDefaultPgTypeVariants("int8", "_int8", int8(0))
registerDefaultPgTypeVariants("int8", "_int8", int(0))
registerDefaultPgTypeVariants("int8", "_int8", uint8(0))
registerDefaultPgTypeVariants("int8", "_int8", uint16(0))
registerDefaultPgTypeVariants("int8", "_int8", uint32(0))
registerDefaultPgTypeVariants("int8", "_int8", uint64(0))
registerDefaultPgTypeVariants("int8", "_int8", int(0))
registerDefaultPgTypeVariants("int8", "_int8", uint(0))
registerDefaultPgTypeVariants("float4", "_float4", float32(0))
@ -355,12 +357,46 @@ func NewMap() *Map {
registerDefaultPgTypeVariants("bool", "_bool", false)
registerDefaultPgTypeVariants("timestamptz", "_timestamptz", time.Time{})
registerDefaultPgTypeVariants("interval", "_interval", time.Duration(0))
registerDefaultPgTypeVariants("text", "_text", "")
registerDefaultPgTypeVariants("bytea", "_bytea", []byte(nil))
registerDefaultPgTypeVariants("inet", "_inet", net.IP{})
registerDefaultPgTypeVariants("cidr", "_cidr", net.IPNet{})
// pgtype provided structs
registerDefaultPgTypeVariants("varbit", "_varbit", Bits{})
registerDefaultPgTypeVariants("bool", "_bool", Bool{})
registerDefaultPgTypeVariants("box", "_box", Box{})
registerDefaultPgTypeVariants("circle", "_circle", Circle{})
registerDefaultPgTypeVariants("date", "_date", Date{})
registerDefaultPgTypeVariants("daterange", "_daterange", Daterange{})
registerDefaultPgTypeVariants("float4", "_float4", Float4{})
registerDefaultPgTypeVariants("float8", "_float8", Float8{})
registerDefaultPgTypeVariants("float8range", "_float8range", Float8range{})
registerDefaultPgTypeVariants("inet", "_inet", Inet{})
registerDefaultPgTypeVariants("int2", "_int2", Int2{})
registerDefaultPgTypeVariants("int4", "_int4", Int4{})
registerDefaultPgTypeVariants("int4range", "_int4range", Int4range{})
registerDefaultPgTypeVariants("int8", "_int8", Int8{})
registerDefaultPgTypeVariants("int8range", "_int8range", Int8range{})
registerDefaultPgTypeVariants("interval", "_interval", Interval{})
registerDefaultPgTypeVariants("line", "_line", Line{})
registerDefaultPgTypeVariants("lseg", "_lseg", Lseg{})
registerDefaultPgTypeVariants("numeric", "_numeric", Numeric{})
registerDefaultPgTypeVariants("numrange", "_numrange", Numrange{})
registerDefaultPgTypeVariants("path", "_path", Path{})
registerDefaultPgTypeVariants("point", "_point", Point{})
registerDefaultPgTypeVariants("polygon", "_polygon", Polygon{})
registerDefaultPgTypeVariants("tid", "_tid", TID{})
registerDefaultPgTypeVariants("text", "_text", Text{})
registerDefaultPgTypeVariants("time", "_time", Time{})
registerDefaultPgTypeVariants("timestamp", "_timestamp", Timestamp{})
registerDefaultPgTypeVariants("timestamptz", "_timestamptz", Timestamptz{})
registerDefaultPgTypeVariants("tsrange", "_tsrange", Tsrange{})
registerDefaultPgTypeVariants("tstzrange", "_tstzrange", Tstzrange{})
registerDefaultPgTypeVariants("uuid", "_uuid", UUID{})
return m
}
@ -1181,13 +1217,13 @@ func (m *Map) PlanEncode(oid uint32, format int16, value interface{}) EncodePlan
if plan := dt.Codec.PlanEncode(m, oid, format, value); plan != nil {
return plan
}
}
for _, f := range m.TryWrapEncodePlanFuncs {
if wrapperPlan, nextValue, ok := f(value); ok {
if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil {
wrapperPlan.SetNext(nextPlan)
return wrapperPlan
}
for _, f := range m.TryWrapEncodePlanFuncs {
if wrapperPlan, nextValue, ok := f(value); ok {
if nextPlan := m.PlanEncode(oid, format, nextValue); nextPlan != nil {
wrapperPlan.SetNext(nextPlan)
return wrapperPlan
}
}
}

View File

@ -30,9 +30,13 @@ func convertSimpleArgument(m *pgtype.Map, arg interface{}) (interface{}, error)
return nil, nil
}
if dv, ok := arg.(driver.Valuer); ok {
return dv.Value()
}
// All these could be handled by m.Encode below. However, that transforms the argument to a string. That could change
// the type of the argument. e.g. '42' instead of 42. So standard types are special cased.
switch arg := arg.(type) {
case driver.Valuer:
return arg.Value()
case float32:
return float64(arg), nil
case float64: