diff --git a/Rakefile b/Rakefile index 275755bd..7076d2a0 100644 --- a/Rakefile +++ b/Rakefile @@ -7,4 +7,4 @@ rule '.go' => '.go.erb' do |task| end desc "Generate code" -task generate: ["pgtype/int.go"] +task generate: ["pgtype/int.go", "pgtype/int_test.go"] diff --git a/pgtype/int2_test.go b/pgtype/int_test.go similarity index 56% rename from pgtype/int2_test.go rename to pgtype/int_test.go index f5bdac89..3ba1306b 100644 --- a/pgtype/int2_test.go +++ b/pgtype/int_test.go @@ -1,65 +1,13 @@ +// Do not edit. Generated from pgtype/int_test.go.erb package pgtype_test import ( - "context" - "fmt" "math" - "reflect" "testing" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" - "github.com/jackc/pgx/v5/pgtype/testutil" ) -type PgxTranscodeTestCase struct { - src interface{} - dst interface{} - test func(interface{}) bool -} - -func isExpectedEq(a interface{}) func(interface{}) bool { - return func(v interface{}) bool { - return a == v - } -} - -func testPgxCodec(t testing.TB, pgTypeName string, tests []PgxTranscodeTestCase) { - conn := testutil.MustConnectPgx(t) - defer testutil.MustCloseContext(t, conn) - - _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName)) - if err != nil { - t.Fatal(err) - } - - formats := []struct { - name string - code int16 - }{ - {name: "TextFormat", code: pgx.TextFormatCode}, - {name: "BinaryFormat", code: pgx.BinaryFormatCode}, - } - - for i, tt := range tests { - for _, format := range formats { - err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{format.code}, tt.src).Scan(tt.dst) - if err != nil { - t.Errorf("%s %d: %v", format.name, i, err) - } - - dst := reflect.ValueOf(tt.dst) - if dst.Kind() == reflect.Ptr { - dst = dst.Elem() - } - - if !tt.test(dst.Interface()) { - t.Errorf("%s %d: unexpected result for %v: %v", format.name, i, tt.src, dst.Interface()) - } - } - } -} - func TestInt2Codec(t *testing.T) { testPgxCodec(t, "int2", []PgxTranscodeTestCase{ {int8(1), new(int16), isExpectedEq(int16(1))}, diff --git a/pgtype/int_test.go.erb b/pgtype/int_test.go.erb new file mode 100644 index 00000000..be1f5358 --- /dev/null +++ b/pgtype/int_test.go.erb @@ -0,0 +1,45 @@ +package pgtype_test + +import ( + "math" + "testing" + + "github.com/jackc/pgx/v5/pgtype" +) + +<% [2].each do |pg_byte_size| %> +<% pg_bit_size = pg_byte_size * 8 %> +func TestInt<%= pg_byte_size %>Codec(t *testing.T) { + testPgxCodec(t, "int<%= pg_byte_size %>", []PgxTranscodeTestCase{ + {int8(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int16(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int32(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int64(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint8(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint16(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint32(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint64(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {int(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {uint(1), new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {pgtype.Int<%= pg_byte_size %>{Int: 1, Valid: true}, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {1, new(int8), isExpectedEq(int8(1))}, + {1, new(int16), isExpectedEq(int16(1))}, + {1, new(int32), isExpectedEq(int32(1))}, + {1, new(int64), isExpectedEq(int64(1))}, + {1, new(uint8), isExpectedEq(uint8(1))}, + {1, new(uint16), isExpectedEq(uint16(1))}, + {1, new(uint32), isExpectedEq(uint32(1))}, + {1, new(uint64), isExpectedEq(uint64(1))}, + {1, new(int), isExpectedEq(int(1))}, + {1, new(uint), isExpectedEq(uint(1))}, + {math.MinInt<%= pg_bit_size %>, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(math.MinInt<%= pg_bit_size %>))}, + {-1, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(-1))}, + {0, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(0))}, + {1, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(1))}, + {math.MaxInt<%= pg_bit_size %>, new(int<%= pg_bit_size %>), isExpectedEq(int<%= pg_bit_size %>(math.MaxInt<%= pg_bit_size %>))}, + {1, new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{Int: 1, Valid: true})}, + {pgtype.Int<%= pg_byte_size %>{}, new(pgtype.Int<%= pg_byte_size %>), isExpectedEq(pgtype.Int<%= pg_byte_size %>{})}, + {nil, new(*int<%= pg_bit_size %>), isExpectedEq((*int<%= pg_bit_size %>)(nil))}, + }) +} +<% end %> diff --git a/pgtype/pgtype_test.go b/pgtype/pgtype_test.go index 17b8afe1..43c6c24b 100644 --- a/pgtype/pgtype_test.go +++ b/pgtype/pgtype_test.go @@ -2,13 +2,17 @@ package pgtype_test import ( "bytes" + "context" "database/sql" "errors" + "fmt" "net" + "reflect" "testing" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgtype" + "github.com/jackc/pgx/v5/pgtype/testutil" _ "github.com/jackc/pgx/v5/stdlib" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -299,3 +303,51 @@ func BenchmarkScanPlanScanInt4IntoGoInt32(b *testing.B) { } } } + +type PgxTranscodeTestCase struct { + src interface{} + dst interface{} + test func(interface{}) bool +} + +func isExpectedEq(a interface{}) func(interface{}) bool { + return func(v interface{}) bool { + return a == v + } +} + +func testPgxCodec(t testing.TB, pgTypeName string, tests []PgxTranscodeTestCase) { + conn := testutil.MustConnectPgx(t) + defer testutil.MustCloseContext(t, conn) + + _, err := conn.Prepare(context.Background(), "test", fmt.Sprintf("select $1::%s", pgTypeName)) + if err != nil { + t.Fatal(err) + } + + formats := []struct { + name string + code int16 + }{ + {name: "TextFormat", code: pgx.TextFormatCode}, + {name: "BinaryFormat", code: pgx.BinaryFormatCode}, + } + + for i, tt := range tests { + for _, format := range formats { + err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{format.code}, tt.src).Scan(tt.dst) + if err != nil { + t.Errorf("%s %d: %v", format.name, i, err) + } + + dst := reflect.ValueOf(tt.dst) + if dst.Kind() == reflect.Ptr { + dst = dst.Elem() + } + + if !tt.test(dst.Interface()) { + t.Errorf("%s %d: unexpected result for %v: %v", format.name, i, tt.src, dst.Interface()) + } + } + } +}