diff --git a/conn.go b/conn.go index fd859965..4c8b59da 100644 --- a/conn.go +++ b/conn.go @@ -1196,6 +1196,30 @@ func (c *Conn) LoadType(ctx context.Context, typeName string) (*pgtype.Type, err return &pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec}, nil case "e": // enum return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}}, nil + case "r": // range + elementOID, err := c.getRangeElementOID(ctx, oid) + if err != nil { + return nil, err + } + + dt, ok := c.TypeMap().TypeForOID(elementOID) + if !ok { + return nil, errors.New("range element OID not registered") + } + + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.RangeCodec{ElementType: dt}}, nil + case "m": // multirange + elementOID, err := c.getMultiRangeElementOID(ctx, oid) + if err != nil { + return nil, err + } + + dt, ok := c.TypeMap().TypeForOID(elementOID) + if !ok { + return nil, errors.New("multirange element OID not registered") + } + + return &pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}, nil default: return &pgtype.Type{}, errors.New("unknown typtype") } @@ -1212,6 +1236,28 @@ func (c *Conn) getArrayElementOID(ctx context.Context, oid uint32) (uint32, erro return typelem, nil } +func (c *Conn) getRangeElementOID(ctx context.Context, oid uint32) (uint32, error) { + var typelem uint32 + + err := c.QueryRow(ctx, "select rngsubtype from pg_range where rngtypid=$1", oid).Scan(&typelem) + if err != nil { + return 0, err + } + + return typelem, nil +} + +func (c *Conn) getMultiRangeElementOID(ctx context.Context, oid uint32) (uint32, error) { + var typelem uint32 + + err := c.QueryRow(ctx, "select rngtypid from pg_range where rngmultitypid=$1", oid).Scan(&typelem) + if err != nil { + return 0, err + } + + return typelem, nil +} + func (c *Conn) getCompositeFields(ctx context.Context, oid uint32) ([]pgtype.CompositeCodecField, error) { var typrelid uint32 diff --git a/conn_test.go b/conn_test.go index 3a23caa3..66e5573d 100644 --- a/conn_test.go +++ b/conn_test.go @@ -903,6 +903,65 @@ create type pgx_b.point as (c text); }) } +func TestLoadRangeType(t *testing.T) { + pgxtest.RunWithQueryExecModes(context.Background(), t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + pgxtest.SkipCockroachDB(t, conn, "Server does support range types") + + tx, err := conn.Begin(ctx) + require.NoError(t, err) + defer tx.Rollback(ctx) + + _, err = tx.Exec(ctx, "create type examplefloatrange as range (subtype=float8, subtype_diff=float8mi, multirange_type_name=examplefloatmultirange)") + require.NoError(t, err) + + // Register types + newRangeType, err := conn.LoadType(ctx, "examplefloatrange") + require.NoError(t, err) + conn.TypeMap().RegisterType(newRangeType) + conn.TypeMap().RegisterDefaultPgType(pgtype.Range[float64]{}, "examplefloatrange") + + newMultiRangeType, err := conn.LoadType(ctx, "examplefloatmultirange") + require.NoError(t, err) + conn.TypeMap().RegisterType(newMultiRangeType) + conn.TypeMap().RegisterDefaultPgType(pgtype.Multirange[pgtype.Range[float64]]{}, "examplefloatmultirange") + + // Test range type + var inputRangeType = pgtype.Range[float64]{ + Lower: 1.0, + Upper: 2.0, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Inclusive, + Valid: true, + } + var outputRangeType pgtype.Range[float64] + err = tx.QueryRow(ctx, "SELECT $1::examplefloatrange", inputRangeType).Scan(&outputRangeType) + require.NoError(t, err) + require.Equal(t, inputRangeType, outputRangeType) + + // Test multi range type + var inputMultiRangeType = pgtype.Multirange[pgtype.Range[float64]]{ + { + Lower: 1.0, + Upper: 2.0, + LowerType: pgtype.Inclusive, + UpperType: pgtype.Inclusive, + Valid: true, + }, + { + Lower: 3.0, + Upper: 4.0, + LowerType: pgtype.Exclusive, + UpperType: pgtype.Exclusive, + Valid: true, + }, + } + var outputMultiRangeType pgtype.Multirange[pgtype.Range[float64]] + err = tx.QueryRow(ctx, "SELECT $1::examplefloatmultirange", inputMultiRangeType).Scan(&outputMultiRangeType) + require.NoError(t, err) + require.Equal(t, inputMultiRangeType, outputMultiRangeType) + }) +} + func TestStmtCacheInvalidationConn(t *testing.T) { ctx := context.Background()