diff --git a/pgtype/testutil/testutil.go b/pgtype/testutil/testutil.go index 462549a7..3711381c 100644 --- a/pgtype/testutil/testutil.go +++ b/pgtype/testutil/testutil.go @@ -107,7 +107,7 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] conn := MustConnectPgx(t) defer MustCloseContext(t, conn) - ps, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) + _, err := conn.Prepare("test", fmt.Sprintf("select $1::%s", pgTypeName)) if err != nil { t.Fatal(err) } @@ -121,29 +121,43 @@ func TestPgxSuccessfulTranscodeEqFunc(t testing.TB, pgTypeName string, values [] } for i, v := range values { - for _, fc := range formats { - ps.FieldDescriptions[0].FormatCode = fc.formatCode - vEncoder := ForceEncoder(v, fc.formatCode) - if vEncoder == nil { - t.Logf("Skipping: %#v does not implement %v", v, fc.name) - continue - } - // Derefence value if it is a pointer - derefV := v - refVal := reflect.ValueOf(v) - if refVal.Kind() == reflect.Ptr { - derefV = refVal.Elem().Interface() - } + for _, paramFormat := range formats { + for _, resultFormat := range formats { + vEncoder := ForceEncoder(v, paramFormat.formatCode) + if vEncoder == nil { + t.Logf("Skipping Param %s Result %s: %#v does not implement %v for encoding", paramFormat.name, resultFormat.name, v, paramFormat.name) + continue + } + switch resultFormat.formatCode { + case pgx.TextFormatCode: + if _, ok := v.(pgtype.TextEncoder); !ok { + t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name) + continue + } + case pgx.BinaryFormatCode: + if _, ok := v.(pgtype.BinaryEncoder); !ok { + t.Logf("Skipping Param %s Result %s: %#v does not implement %v for decoding", paramFormat.name, resultFormat.name, v, resultFormat.name) + continue + } + } - result := reflect.New(reflect.TypeOf(derefV)) + // Derefence value if it is a pointer + derefV := v + refVal := reflect.ValueOf(v) + if refVal.Kind() == reflect.Ptr { + derefV = refVal.Elem().Interface() + } - err := conn.QueryRow(context.Background(), "test", vEncoder).Scan(result.Interface()) - if err != nil { - t.Errorf("%v %d: %v", fc.name, i, err) - } + result := reflect.New(reflect.TypeOf(derefV)) - if !eqFunc(result.Elem().Interface(), derefV) { - t.Errorf("%v %d: expected %v, got %v", fc.name, i, derefV, result.Elem().Interface()) + err := conn.QueryRow(context.Background(), "test", pgx.QueryResultFormats{resultFormat.formatCode}, vEncoder).Scan(result.Interface()) + if err != nil { + t.Errorf("Param %s Result %s %d: %v", paramFormat.name, resultFormat.name, i, err) + } + + if !eqFunc(result.Elem().Interface(), derefV) { + t.Errorf("Param %s Result %s %d: expected %v, got %v", paramFormat.name, resultFormat.name, i, derefV, result.Elem().Interface()) + } } } }