Improve error messages for scan errors on all adapters

pull/32/head
Vinícius Garcia 2022-11-12 15:20:06 -03:00
parent d2c90f4e42
commit 5bfb5cd92a
11 changed files with 361 additions and 64 deletions

View File

@ -3,6 +3,9 @@ package kmysql
import (
"context"
"database/sql"
"strconv"
"strings"
"unicode"
"github.com/vingarcia/ksql"
)
@ -29,7 +32,8 @@ func (s SQLAdapter) ExecContext(ctx context.Context, query string, args ...inter
// QueryContext implements the DBAdapter interface
func (s SQLAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) {
return s.DB.QueryContext(ctx, query, args...)
rows, err := s.DB.QueryContext(ctx, query, args...)
return SQLRows{rows}, err
}
// BeginTx implements the Tx interface
@ -56,7 +60,8 @@ func (s SQLTx) ExecContext(ctx context.Context, query string, args ...interface{
// QueryContext implements the Tx interface
func (s SQLTx) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) {
return s.Tx.QueryContext(ctx, query, args...)
rows, err := s.Tx.QueryContext(ctx, query, args...)
return SQLRows{rows}, err
}
// Rollback implements the Tx interface
@ -70,3 +75,40 @@ func (s SQLTx) Commit(ctx context.Context) error {
}
var _ ksql.Tx = SQLTx{}
// SQLRows implements the ksql.Rows interface and is used to help
// the SQLAdapter to implement the ksql.DBAdapter interface.
type SQLRows struct {
*sql.Rows
}
var _ ksql.Rows = SQLRows{}
// Scan implements the ksql.Rows interface
func (p SQLRows) Scan(args ...interface{}) error {
err := p.Rows.Scan(args...)
if err != nil {
// Since this is the error flow we decided it would be ok
// to spend a little bit more time parsing this error in order
// to produce better error messages.
//
// If the parsing fails we just return the error unchanged.
const scanErrPrefix = "sql: Scan error on column index "
var errMsg = err.Error()
if strings.HasPrefix(errMsg, scanErrPrefix) {
i := len(scanErrPrefix)
for unicode.IsDigit(rune(errMsg[i])) {
i++
}
colIndex, convErr := strconv.Atoi(errMsg[len(scanErrPrefix):i])
if convErr == nil {
return ksql.ScanArgError{
ColumnIndex: colIndex,
Err: err,
}
}
}
}
return err
}

View File

@ -96,14 +96,27 @@ func (p PGXTx) Commit(ctx context.Context) error {
var _ ksql.Tx = PGXTx{}
// PGXRows implements the Rows interface and is used to help
// the PGXAdapter to implement the DBAdapter interface.
// PGXRows implements the ksql.Rows interface and is used to help
// the PGXAdapter to implement the ksql.DBAdapter interface.
type PGXRows struct {
pgx.Rows
}
var _ ksql.Rows = PGXRows{}
// Scan implements the ksql.Rows interface
func (p PGXRows) Scan(args ...interface{}) error {
err := p.Rows.Scan(args...)
if scanErr, ok := err.(pgx.ScanArgError); ok {
return ksql.ScanArgError{
Err: scanErr.Err,
ColumnIndex: scanErr.ColumnIndex,
}
}
return err
}
// Columns implements the Rows interface
func (p PGXRows) Columns() ([]string, error) {
var names []string

View File

@ -3,6 +3,9 @@ package ksqlite3
import (
"context"
"database/sql"
"strconv"
"strings"
"unicode"
"github.com/vingarcia/ksql"
)
@ -29,7 +32,8 @@ func (s SQLAdapter) ExecContext(ctx context.Context, query string, args ...inter
// QueryContext implements the DBAdapter interface
func (s SQLAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) {
return s.DB.QueryContext(ctx, query, args...)
rows, err := s.DB.QueryContext(ctx, query, args...)
return SQLRows{rows}, err
}
// BeginTx implements the Tx interface
@ -56,7 +60,8 @@ func (s SQLTx) ExecContext(ctx context.Context, query string, args ...interface{
// QueryContext implements the Tx interface
func (s SQLTx) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) {
return s.Tx.QueryContext(ctx, query, args...)
rows, err := s.Tx.QueryContext(ctx, query, args...)
return SQLRows{rows}, err
}
// Rollback implements the Tx interface
@ -70,3 +75,40 @@ func (s SQLTx) Commit(ctx context.Context) error {
}
var _ ksql.Tx = SQLTx{}
// SQLRows implements the ksql.Rows interface and is used to help
// the SQLAdapter to implement the ksql.DBAdapter interface.
type SQLRows struct {
*sql.Rows
}
var _ ksql.Rows = SQLRows{}
// Scan implements the ksql.Rows interface
func (p SQLRows) Scan(args ...interface{}) error {
err := p.Rows.Scan(args...)
if err != nil {
// Since this is the error flow we decided it would be ok
// to spend a little bit more time parsing this error in order
// to produce better error messages.
//
// If the parsing fails we just return the error unchanged.
const scanErrPrefix = "sql: Scan error on column index "
var errMsg = err.Error()
if strings.HasPrefix(errMsg, scanErrPrefix) {
i := len(scanErrPrefix)
for unicode.IsDigit(rune(errMsg[i])) {
i++
}
colIndex, convErr := strconv.Atoi(errMsg[len(scanErrPrefix):i])
if convErr == nil {
return ksql.ScanArgError{
ColumnIndex: colIndex,
Err: err,
}
}
}
}
return err
}

View File

@ -3,6 +3,9 @@ package ksqlserver
import (
"context"
"database/sql"
"strconv"
"strings"
"unicode"
"github.com/vingarcia/ksql"
)
@ -29,7 +32,8 @@ func (s SQLAdapter) ExecContext(ctx context.Context, query string, args ...inter
// QueryContext implements the DBAdapter interface
func (s SQLAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) {
return s.DB.QueryContext(ctx, query, args...)
rows, err := s.DB.QueryContext(ctx, query, args...)
return SQLRows{rows}, err
}
// BeginTx implements the Tx interface
@ -56,7 +60,8 @@ func (s SQLTx) ExecContext(ctx context.Context, query string, args ...interface{
// QueryContext implements the Tx interface
func (s SQLTx) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) {
return s.Tx.QueryContext(ctx, query, args...)
rows, err := s.Tx.QueryContext(ctx, query, args...)
return SQLRows{rows}, err
}
// Rollback implements the Tx interface
@ -70,3 +75,40 @@ func (s SQLTx) Commit(ctx context.Context) error {
}
var _ ksql.Tx = SQLTx{}
// SQLRows implements the ksql.Rows interface and is used to help
// the SQLAdapter to implement the ksql.DBAdapter interface.
type SQLRows struct {
*sql.Rows
}
var _ ksql.Rows = SQLRows{}
// Scan implements the ksql.Rows interface
func (p SQLRows) Scan(args ...interface{}) error {
err := p.Rows.Scan(args...)
if err != nil {
// Since this is the error flow we decided it would be ok
// to spend a little bit more time parsing this error in order
// to produce better error messages.
//
// If the parsing fails we just return the error unchanged.
const scanErrPrefix = "sql: Scan error on column index "
var errMsg = err.Error()
if strings.HasPrefix(errMsg, scanErrPrefix) {
i := len(scanErrPrefix)
for unicode.IsDigit(rune(errMsg[i])) {
i++
}
colIndex, convErr := strconv.Atoi(errMsg[len(scanErrPrefix):i])
if convErr == nil {
return ksql.ScanArgError{
ColumnIndex: colIndex,
Err: err,
}
}
}
}
return err
}

View File

@ -12,7 +12,7 @@ replace (
require (
github.com/golang/mock v1.6.0
github.com/stretchr/testify v1.7.0
github.com/stretchr/testify v1.8.0
github.com/vingarcia/ksql v1.4.7
github.com/vingarcia/ksql/adapters/kmysql v0.0.0-00010101000000-000000000000
github.com/vingarcia/ksql/adapters/kpgx v0.0.0-00010101000000-000000000000

View File

@ -230,12 +230,15 @@ github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/y
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
github.com/tv42/httpunix v0.0.0-20191220191345-2ba4b9c3382c/go.mod h1:hzIxponao9Kjc7aWznkXaL4U4TWaDSs8zcsY4Ka08nM=
@ -371,8 +374,9 @@ gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo=
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View File

@ -23,10 +23,11 @@ type StructInfo struct {
// information regarding a specific field
// of a struct.
type FieldInfo struct {
Name string
Index int
Valid bool
Modifier ksqlmodifiers.AttrModifier
AttrName string
ColumnName string
Index int
Valid bool
Modifier ksqlmodifiers.AttrModifier
}
// ByIndex returns either the *FieldInfo of a valid
@ -52,12 +53,12 @@ func (s StructInfo) ByName(name string) *FieldInfo {
func (s StructInfo) add(field FieldInfo) {
field.Valid = true
s.byIndex[field.Index] = &field
s.byName[field.Name] = &field
s.byName[field.ColumnName] = &field
// Make sure to save a lowercased version because
// some databases will set these keys to lowercase.
if _, found := s.byName[strings.ToLower(field.Name)]; !found {
s.byName[strings.ToLower(field.Name)] = &field
if _, found := s.byName[strings.ToLower(field.ColumnName)]; !found {
s.byName[strings.ToLower(field.ColumnName)] = &field
}
}
@ -143,7 +144,7 @@ func StructToMap(obj interface{}) (map[string]interface{}, error) {
field = field.Elem()
}
m[fieldInfo.Name] = field.Interface()
m[fieldInfo.ColumnName] = field.Interface()
}
return m, nil
@ -246,6 +247,7 @@ func getTagNames(t reflect.Type) (_ StructInfo, err error) {
return StructInfo{}, fmt.Errorf("all fields using the ksql tags must be exported, but %v is unexported", t)
}
attrName := t.Field(i).Name
name := t.Field(i).Tag.Get("ksql")
if name == "" {
continue
@ -269,9 +271,10 @@ func getTagNames(t reflect.Type) (_ StructInfo, err error) {
}
info.add(FieldInfo{
Name: name,
Index: i,
Modifier: modifier,
AttrName: attrName,
ColumnName: name,
Index: i,
Modifier: modifier,
})
}
@ -289,8 +292,9 @@ func getTagNames(t reflect.Type) (_ StructInfo, err error) {
}
info.add(FieldInfo{
Name: name,
Index: i,
AttrName: t.Field(i).Name,
ColumnName: name,
Index: i,
})
}

View File

@ -81,7 +81,7 @@ func (i Insert) BuildQuery(dialect ksql.Dialect) (sqlQuery string, params []inte
b.WriteString(" (")
var escapedNames []string
for i := 0; i < info.NumFields(); i++ {
name := info.ByIndex(i).Name
name := info.ByIndex(i).ColumnName
escapedNames = append(escapedNames, dialect.Escape(name))
}
b.WriteString(strings.Join(escapedNames, ", "))

View File

@ -215,7 +215,7 @@ func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) {
var escapedNames []string
for i := 0; i < info.NumFields(); i++ {
name := info.ByIndex(i).Name
name := info.ByIndex(i).ColumnName
escapedNames = append(escapedNames, dialect.Escape(name))
}

95
ksql.go
View File

@ -67,6 +67,34 @@ type Rows interface {
Columns() ([]string, error)
}
// ScanArgError is a type of error that is expected to be returned
// from the Scan() method of the Rows interface.
//
// It should be returned when there is an error scanning one of the input
// values.
//
// This is necessary in order to allow KSQL to produce a better and more
// readable error message when this type of error occur.
type ScanArgError struct {
ColumnIndex int
Err error
}
// Error implements the error interface.
func (s ScanArgError) Error() string {
return fmt.Sprintf(
"error scanning input attribute with index %d: %s",
s.ColumnIndex, s.Err.Error(),
)
}
func (s ScanArgError) ErrorWithStructNames(structName string, colName string) error {
return fmt.Errorf(
"error scanning %s.%s: %s",
structName, colName, s.Err.Error(),
)
}
// Tx represents a transaction and is expected to be returned by the DBAdapter.BeginTx function
type Tx interface {
DBAdapter
@ -1002,27 +1030,34 @@ func scanRowsFromType(
return err
}
var attrNames []string
var scanArgs []interface{}
if info.IsNestedStruct {
// This version is positional meaning that it expect the arguments
// to follow an specific order. It's ok because we don't allow the
// user to type the "SELECT" part of the query for nested structs.
scanArgs, err = getScanArgsForNestedStructs(ctx, dialect, rows, t, v, info)
attrNames, scanArgs, err = getScanArgsForNestedStructs(ctx, dialect, rows, t, v, info)
if err != nil {
return err
}
} else {
names, err := rows.Columns()
colNames, err := rows.Columns()
if err != nil {
return fmt.Errorf("KSQL: unable to read columns from returned rows: %w", err)
}
// Since this version uses the names of the columns it works
// with any order of attributes/columns.
scanArgs = getScanArgsFromNames(ctx, dialect, names, v, info)
attrNames, scanArgs = getScanArgsFromNames(ctx, dialect, colNames, v, info)
}
err = rows.Scan(scanArgs...)
if err != nil {
if scanErr, ok := err.(ScanArgError); ok {
return fmt.Errorf(
"KSQL: scan error: %w",
scanErr.ErrorWithStructNames(t.Name(), attrNames[scanErr.ColumnIndex]),
)
}
return fmt.Errorf("KSQL: scan error: %w", err)
}
return nil
@ -1035,8 +1070,7 @@ func getScanArgsForNestedStructs(
t reflect.Type,
v reflect.Value,
info structs.StructInfo,
) ([]interface{}, error) {
scanArgs := []interface{}{}
) (attrNames []string, scanArgs []interface{}, _ error) {
for i := 0; i < v.NumField(); i++ {
if !info.ByIndex(i).Valid {
continue
@ -1045,7 +1079,7 @@ func getScanArgsForNestedStructs(
// TODO(vingarcia00): Handle case where type is pointer
nestedStructInfo, err := structs.GetTagInfo(t.Field(i).Type)
if err != nil {
return nil, err
return nil, nil, err
}
nestedStructValue := v.Field(i)
@ -1055,34 +1089,36 @@ func getScanArgsForNestedStructs(
continue
}
valueScanner := nopScannerValue
if fieldInfo.Valid {
valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface()
if fieldInfo.Modifier.Scan != nil {
valueScanner = &modifiers.AttrScanWrapper{
Ctx: ctx,
AttrPtr: valueScanner,
ScanFn: fieldInfo.Modifier.Scan,
OpInfo: ksqlmodifiers.OpInfo{
DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks
// if we did this could lead users to make very strange modifiers
Method: "Query",
},
}
valueScanner := nestedStructValue.Field(fieldInfo.Index).Addr().Interface()
if fieldInfo.Modifier.Scan != nil {
valueScanner = &modifiers.AttrScanWrapper{
Ctx: ctx,
AttrPtr: valueScanner,
ScanFn: fieldInfo.Modifier.Scan,
OpInfo: ksqlmodifiers.OpInfo{
DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks
// if we did this could lead users to make very strange modifiers
Method: "Query",
},
}
}
scanArgs = append(scanArgs, valueScanner)
attrNames = append(attrNames, info.ByIndex(i).AttrName+"."+fieldInfo.AttrName)
}
}
return scanArgs, nil
return attrNames, scanArgs, nil
}
func getScanArgsFromNames(ctx context.Context, dialect Dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} {
scanArgs := []interface{}{}
func getScanArgsFromNames(
ctx context.Context,
dialect Dialect,
names []string,
v reflect.Value,
info structs.StructInfo,
) (attrNames []string, scanArgs []interface{}) {
for _, name := range names {
fieldInfo := info.ByName(name)
@ -1105,9 +1141,10 @@ func getScanArgsFromNames(ctx context.Context, dialect Dialect, names []string,
}
scanArgs = append(scanArgs, valueScanner)
attrNames = append(attrNames, fieldInfo.AttrName)
}
return scanArgs
return attrNames, scanArgs
}
func buildDeleteQuery(
@ -1185,7 +1222,7 @@ func buildSelectQueryForPlainStructs(
continue
}
fields = append(fields, dialect.Escape(fieldInfo.Name))
fields = append(fields, dialect.Escape(fieldInfo.ColumnName))
}
return "SELECT " + strings.Join(fields, ", ") + " "
@ -1203,7 +1240,7 @@ func buildSelectQueryForNestedStructs(
continue
}
nestedStructName := nestedStructInfo.Name
nestedStructName := nestedStructInfo.ColumnName
nestedStructType := structType.Field(i).Type
if nestedStructType.Kind() != reflect.Struct {
return "", fmt.Errorf(
@ -1225,7 +1262,7 @@ func buildSelectQueryForNestedStructs(
fields = append(
fields,
dialect.Escape(nestedStructName)+"."+dialect.Escape(fieldInfo.Name),
dialect.Escape(nestedStructName)+"."+dialect.Escape(fieldInfo.ColumnName),
)
}
}

View File

@ -6,10 +6,13 @@ import (
"encoding/json"
"fmt"
"io"
"reflect"
"strings"
"testing"
"time"
"github.com/vingarcia/ksql/internal/modifiers"
"github.com/vingarcia/ksql/internal/structs"
tt "github.com/vingarcia/ksql/internal/testtools"
"github.com/vingarcia/ksql/ksqlmodifiers"
"github.com/vingarcia/ksql/nullable"
@ -37,6 +40,8 @@ type address struct {
Country string `json:"country"`
}
var postsTable = NewTable("posts")
type post struct {
ID int `ksql:"id"`
UserID uint `ksql:"user_id"`
@ -1017,7 +1022,6 @@ func InsertTest(
tt.AssertNoErr(t, err)
tt.AssertNotEqual(t, permission.ID, 0)
fmt.Println("permID:", permission.ID)
var untaggedPerm struct {
ID uint `ksql:"id"`
UserID int `ksql:"user_id"`
@ -3050,6 +3054,8 @@ func ScanRowsTest(
connStr string,
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
) {
ctx := context.Background()
t.Run("ScanRows", func(t *testing.T) {
t.Run("should scan users correctly", func(t *testing.T) {
err := createTables(driver, connStr)
@ -3058,7 +3064,6 @@ func ScanRowsTest(
}
dialect := supportedDialects[driver]
ctx := context.TODO()
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
@ -3082,12 +3087,9 @@ func ScanRowsTest(
t.Run("should ignore extra columns from query", func(t *testing.T) {
err := createTables(driver, connStr)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
tt.AssertNoErr(t, err)
dialect := supportedDialects[driver]
ctx := context.TODO()
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
@ -3112,6 +3114,101 @@ func ScanRowsTest(
tt.AssertEqual(t, u.Age, 22)
})
t.Run("should report scan errors", func(t *testing.T) {
type brokenUser struct {
ID uint `ksql:"id"`
// The error will happen here, when scanning
// an integer into a attribute of type struct{}:
Age struct{} `ksql:"age"`
}
type brokenNestedStruct struct {
User struct {
ID uint `ksql:"id"`
// The error will happen here, when scanning
// an integer into a attribute of type struct:
Age struct{} `ksql:"age"`
} `tablename:"u"`
Post post `tablename:"p"`
}
tests := []struct {
desc string
query string
scanTarget interface{}
expectErrToContain []string
}{
{
desc: "with anonymous structs",
query: "FROM users WHERE name='User22'",
scanTarget: &struct {
ID uint `ksql:"id"`
// The error will happen here, when scanning
// an integer into a attribute of type struct{}:
Age struct{} `ksql:"age"`
}{},
expectErrToContain: []string{" .Age", "struct {}"},
},
{
desc: "with named structs",
query: "FROM users WHERE name='User22'",
scanTarget: &brokenUser{},
expectErrToContain: []string{"brokenUser.Age", "struct {}"},
},
{
desc: "with anonymous nested structs",
query: "FROM users u JOIN posts p ON u.id = p.user_id WHERE name='User22'",
scanTarget: &struct {
User struct {
ID uint `ksql:"id"`
// The error will happen here, when scanning
// an integer into a attribute of type struct:
Age struct{} `ksql:"age"`
} `tablename:"u"`
Post post `tablename:"p"`
}{},
expectErrToContain: []string{".User.Age", "struct {}"},
},
{
desc: "with named nested structs",
query: "FROM users u JOIN posts p ON u.id = p.user_id WHERE name='User22'",
scanTarget: &brokenNestedStruct{},
expectErrToContain: []string{"brokenNestedStruct.User.Age", "struct {}"},
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
err := createTables(driver, connStr)
tt.AssertNoErr(t, err)
dialect := supportedDialects[driver]
db, closer := newDBAdapter(t)
defer closer.Close()
c := newTestDB(db, driver)
u := user{Name: "User22", Age: 22}
_ = c.Insert(ctx, usersTable, &u)
_ = c.Insert(ctx, postsTable, &post{UserID: u.ID, Title: "FakeTitle"})
query := mustBuildSelectQuery(t, dialect, test.scanTarget, test.query)
rows, err := db.QueryContext(ctx, query)
tt.AssertNoErr(t, err)
defer rows.Close()
tt.AssertEqual(t, rows.Next(), true)
err = scanRows(ctx, dialect, rows, test.scanTarget)
tt.AssertErrContains(t, err, test.expectErrToContain...)
})
}
})
t.Run("should report error for closed rows", func(t *testing.T) {
err := createTables(driver, connStr)
if err != nil {
@ -3119,7 +3216,6 @@ func ScanRowsTest(
}
dialect := supportedDialects[driver]
ctx := context.TODO()
db, closer := newDBAdapter(t)
defer closer.Close()
@ -3140,7 +3236,6 @@ func ScanRowsTest(
}
dialect := supportedDialects[driver]
ctx := context.TODO()
db, closer := newDBAdapter(t)
defer closer.Close()
@ -3160,7 +3255,6 @@ func ScanRowsTest(
}
dialect := supportedDialects[driver]
ctx := context.TODO()
db, closer := newDBAdapter(t)
defer closer.Close()
@ -3446,3 +3540,22 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results
return results, nil
}
func mustBuildSelectQuery(t *testing.T,
dialect Dialect,
record interface{},
query string,
) string {
if strings.HasPrefix(query, "SELECT") {
return query
}
structType := reflect.TypeOf(record).Elem()
structInfo, err := structs.GetTagInfo(structType)
tt.AssertNoErr(t, err)
selectPrefix, err := buildSelectQuery(dialect, structType, structInfo, selectQueryCache[dialect.DriverName()])
tt.AssertNoErr(t, err)
return selectPrefix + query
}