diff --git a/adapters/kmysql/sql_adapter.go b/adapters/kmysql/sql_adapter.go index 532767d..c040626 100644 --- a/adapters/kmysql/sql_adapter.go +++ b/adapters/kmysql/sql_adapter.go @@ -3,6 +3,9 @@ package kmysql import ( "context" "database/sql" + "strconv" + "strings" + "unicode" "github.com/vingarcia/ksql" ) @@ -29,7 +32,8 @@ func (s SQLAdapter) ExecContext(ctx context.Context, query string, args ...inter // QueryContext implements the DBAdapter interface func (s SQLAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) { - return s.DB.QueryContext(ctx, query, args...) + rows, err := s.DB.QueryContext(ctx, query, args...) + return SQLRows{rows}, err } // BeginTx implements the Tx interface @@ -56,7 +60,8 @@ func (s SQLTx) ExecContext(ctx context.Context, query string, args ...interface{ // QueryContext implements the Tx interface func (s SQLTx) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) { - return s.Tx.QueryContext(ctx, query, args...) + rows, err := s.Tx.QueryContext(ctx, query, args...) + return SQLRows{rows}, err } // Rollback implements the Tx interface @@ -70,3 +75,40 @@ func (s SQLTx) Commit(ctx context.Context) error { } var _ ksql.Tx = SQLTx{} + +// SQLRows implements the ksql.Rows interface and is used to help +// the SQLAdapter to implement the ksql.DBAdapter interface. +type SQLRows struct { + *sql.Rows +} + +var _ ksql.Rows = SQLRows{} + +// Scan implements the ksql.Rows interface +func (p SQLRows) Scan(args ...interface{}) error { + err := p.Rows.Scan(args...) + if err != nil { + // Since this is the error flow we decided it would be ok + // to spend a little bit more time parsing this error in order + // to produce better error messages. + // + // If the parsing fails we just return the error unchanged. + const scanErrPrefix = "sql: Scan error on column index " + var errMsg = err.Error() + if strings.HasPrefix(errMsg, scanErrPrefix) { + i := len(scanErrPrefix) + for unicode.IsDigit(rune(errMsg[i])) { + i++ + } + colIndex, convErr := strconv.Atoi(errMsg[len(scanErrPrefix):i]) + if convErr == nil { + return ksql.ScanArgError{ + ColumnIndex: colIndex, + Err: err, + } + } + } + } + + return err +} diff --git a/adapters/kpgx/pgx_adapter.go b/adapters/kpgx/pgx_adapter.go index 93a7c40..69c75a1 100644 --- a/adapters/kpgx/pgx_adapter.go +++ b/adapters/kpgx/pgx_adapter.go @@ -96,14 +96,27 @@ func (p PGXTx) Commit(ctx context.Context) error { var _ ksql.Tx = PGXTx{} -// PGXRows implements the Rows interface and is used to help -// the PGXAdapter to implement the DBAdapter interface. +// PGXRows implements the ksql.Rows interface and is used to help +// the PGXAdapter to implement the ksql.DBAdapter interface. type PGXRows struct { pgx.Rows } var _ ksql.Rows = PGXRows{} +// Scan implements the ksql.Rows interface +func (p PGXRows) Scan(args ...interface{}) error { + err := p.Rows.Scan(args...) + if scanErr, ok := err.(pgx.ScanArgError); ok { + return ksql.ScanArgError{ + Err: scanErr.Err, + ColumnIndex: scanErr.ColumnIndex, + } + } + + return err +} + // Columns implements the Rows interface func (p PGXRows) Columns() ([]string, error) { var names []string diff --git a/adapters/ksqlite3/sql_adapter.go b/adapters/ksqlite3/sql_adapter.go index 7a06798..5939ef5 100644 --- a/adapters/ksqlite3/sql_adapter.go +++ b/adapters/ksqlite3/sql_adapter.go @@ -3,6 +3,9 @@ package ksqlite3 import ( "context" "database/sql" + "strconv" + "strings" + "unicode" "github.com/vingarcia/ksql" ) @@ -29,7 +32,8 @@ func (s SQLAdapter) ExecContext(ctx context.Context, query string, args ...inter // QueryContext implements the DBAdapter interface func (s SQLAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) { - return s.DB.QueryContext(ctx, query, args...) + rows, err := s.DB.QueryContext(ctx, query, args...) + return SQLRows{rows}, err } // BeginTx implements the Tx interface @@ -56,7 +60,8 @@ func (s SQLTx) ExecContext(ctx context.Context, query string, args ...interface{ // QueryContext implements the Tx interface func (s SQLTx) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) { - return s.Tx.QueryContext(ctx, query, args...) + rows, err := s.Tx.QueryContext(ctx, query, args...) + return SQLRows{rows}, err } // Rollback implements the Tx interface @@ -70,3 +75,40 @@ func (s SQLTx) Commit(ctx context.Context) error { } var _ ksql.Tx = SQLTx{} + +// SQLRows implements the ksql.Rows interface and is used to help +// the SQLAdapter to implement the ksql.DBAdapter interface. +type SQLRows struct { + *sql.Rows +} + +var _ ksql.Rows = SQLRows{} + +// Scan implements the ksql.Rows interface +func (p SQLRows) Scan(args ...interface{}) error { + err := p.Rows.Scan(args...) + if err != nil { + // Since this is the error flow we decided it would be ok + // to spend a little bit more time parsing this error in order + // to produce better error messages. + // + // If the parsing fails we just return the error unchanged. + const scanErrPrefix = "sql: Scan error on column index " + var errMsg = err.Error() + if strings.HasPrefix(errMsg, scanErrPrefix) { + i := len(scanErrPrefix) + for unicode.IsDigit(rune(errMsg[i])) { + i++ + } + colIndex, convErr := strconv.Atoi(errMsg[len(scanErrPrefix):i]) + if convErr == nil { + return ksql.ScanArgError{ + ColumnIndex: colIndex, + Err: err, + } + } + } + } + + return err +} diff --git a/adapters/ksqlserver/sql_adapter.go b/adapters/ksqlserver/sql_adapter.go index 3bf66df..6e40b65 100644 --- a/adapters/ksqlserver/sql_adapter.go +++ b/adapters/ksqlserver/sql_adapter.go @@ -3,6 +3,9 @@ package ksqlserver import ( "context" "database/sql" + "strconv" + "strings" + "unicode" "github.com/vingarcia/ksql" ) @@ -29,7 +32,8 @@ func (s SQLAdapter) ExecContext(ctx context.Context, query string, args ...inter // QueryContext implements the DBAdapter interface func (s SQLAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) { - return s.DB.QueryContext(ctx, query, args...) + rows, err := s.DB.QueryContext(ctx, query, args...) + return SQLRows{rows}, err } // BeginTx implements the Tx interface @@ -56,7 +60,8 @@ func (s SQLTx) ExecContext(ctx context.Context, query string, args ...interface{ // QueryContext implements the Tx interface func (s SQLTx) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) { - return s.Tx.QueryContext(ctx, query, args...) + rows, err := s.Tx.QueryContext(ctx, query, args...) + return SQLRows{rows}, err } // Rollback implements the Tx interface @@ -70,3 +75,40 @@ func (s SQLTx) Commit(ctx context.Context) error { } var _ ksql.Tx = SQLTx{} + +// SQLRows implements the ksql.Rows interface and is used to help +// the SQLAdapter to implement the ksql.DBAdapter interface. +type SQLRows struct { + *sql.Rows +} + +var _ ksql.Rows = SQLRows{} + +// Scan implements the ksql.Rows interface +func (p SQLRows) Scan(args ...interface{}) error { + err := p.Rows.Scan(args...) + if err != nil { + // Since this is the error flow we decided it would be ok + // to spend a little bit more time parsing this error in order + // to produce better error messages. + // + // If the parsing fails we just return the error unchanged. + const scanErrPrefix = "sql: Scan error on column index " + var errMsg = err.Error() + if strings.HasPrefix(errMsg, scanErrPrefix) { + i := len(scanErrPrefix) + for unicode.IsDigit(rune(errMsg[i])) { + i++ + } + colIndex, convErr := strconv.Atoi(errMsg[len(scanErrPrefix):i]) + if convErr == nil { + return ksql.ScanArgError{ + ColumnIndex: colIndex, + Err: err, + } + } + } + } + + return err +} diff --git a/examples/go.mod b/examples/go.mod index 5f7729f..d75ec21 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -12,7 +12,7 @@ replace ( require ( github.com/golang/mock v1.6.0 - github.com/stretchr/testify v1.7.0 + github.com/stretchr/testify v1.8.0 github.com/vingarcia/ksql v1.4.7 github.com/vingarcia/ksql/adapters/kmysql v0.0.0-00010101000000-000000000000 github.com/vingarcia/ksql/adapters/kpgx v0.0.0-00010101000000-000000000000 diff --git a/examples/go.sum b/examples/go.sum index 94f0952..98a15fb 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -230,12 +230,15 @@ github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/y github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tv42/httpunix v0.0.0-20191220191345-2ba4b9c3382c/go.mod h1:hzIxponao9Kjc7aWznkXaL4U4TWaDSs8zcsY4Ka08nM= @@ -371,8 +374,9 @@ gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo= gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/internal/structs/structs.go b/internal/structs/structs.go index fe62adb..e21bd56 100644 --- a/internal/structs/structs.go +++ b/internal/structs/structs.go @@ -23,10 +23,11 @@ type StructInfo struct { // information regarding a specific field // of a struct. type FieldInfo struct { - Name string - Index int - Valid bool - Modifier ksqlmodifiers.AttrModifier + AttrName string + ColumnName string + Index int + Valid bool + Modifier ksqlmodifiers.AttrModifier } // ByIndex returns either the *FieldInfo of a valid @@ -52,12 +53,12 @@ func (s StructInfo) ByName(name string) *FieldInfo { func (s StructInfo) add(field FieldInfo) { field.Valid = true s.byIndex[field.Index] = &field - s.byName[field.Name] = &field + s.byName[field.ColumnName] = &field // Make sure to save a lowercased version because // some databases will set these keys to lowercase. - if _, found := s.byName[strings.ToLower(field.Name)]; !found { - s.byName[strings.ToLower(field.Name)] = &field + if _, found := s.byName[strings.ToLower(field.ColumnName)]; !found { + s.byName[strings.ToLower(field.ColumnName)] = &field } } @@ -143,7 +144,7 @@ func StructToMap(obj interface{}) (map[string]interface{}, error) { field = field.Elem() } - m[fieldInfo.Name] = field.Interface() + m[fieldInfo.ColumnName] = field.Interface() } return m, nil @@ -246,6 +247,7 @@ func getTagNames(t reflect.Type) (_ StructInfo, err error) { return StructInfo{}, fmt.Errorf("all fields using the ksql tags must be exported, but %v is unexported", t) } + attrName := t.Field(i).Name name := t.Field(i).Tag.Get("ksql") if name == "" { continue @@ -269,9 +271,10 @@ func getTagNames(t reflect.Type) (_ StructInfo, err error) { } info.add(FieldInfo{ - Name: name, - Index: i, - Modifier: modifier, + AttrName: attrName, + ColumnName: name, + Index: i, + Modifier: modifier, }) } @@ -289,8 +292,9 @@ func getTagNames(t reflect.Type) (_ StructInfo, err error) { } info.add(FieldInfo{ - Name: name, - Index: i, + AttrName: t.Field(i).Name, + ColumnName: name, + Index: i, }) } diff --git a/kbuilder/insert.go b/kbuilder/insert.go index f5a8e21..eb34300 100644 --- a/kbuilder/insert.go +++ b/kbuilder/insert.go @@ -81,7 +81,7 @@ func (i Insert) BuildQuery(dialect ksql.Dialect) (sqlQuery string, params []inte b.WriteString(" (") var escapedNames []string for i := 0; i < info.NumFields(); i++ { - name := info.ByIndex(i).Name + name := info.ByIndex(i).ColumnName escapedNames = append(escapedNames, dialect.Escape(name)) } b.WriteString(strings.Join(escapedNames, ", ")) diff --git a/kbuilder/query.go b/kbuilder/query.go index c88e01c..a36ab26 100644 --- a/kbuilder/query.go +++ b/kbuilder/query.go @@ -215,7 +215,7 @@ func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) { var escapedNames []string for i := 0; i < info.NumFields(); i++ { - name := info.ByIndex(i).Name + name := info.ByIndex(i).ColumnName escapedNames = append(escapedNames, dialect.Escape(name)) } diff --git a/ksql.go b/ksql.go index 114357b..344c2d6 100644 --- a/ksql.go +++ b/ksql.go @@ -67,6 +67,34 @@ type Rows interface { Columns() ([]string, error) } +// ScanArgError is a type of error that is expected to be returned +// from the Scan() method of the Rows interface. +// +// It should be returned when there is an error scanning one of the input +// values. +// +// This is necessary in order to allow KSQL to produce a better and more +// readable error message when this type of error occur. +type ScanArgError struct { + ColumnIndex int + Err error +} + +// Error implements the error interface. +func (s ScanArgError) Error() string { + return fmt.Sprintf( + "error scanning input attribute with index %d: %s", + s.ColumnIndex, s.Err.Error(), + ) +} + +func (s ScanArgError) ErrorWithStructNames(structName string, colName string) error { + return fmt.Errorf( + "error scanning %s.%s: %s", + structName, colName, s.Err.Error(), + ) +} + // Tx represents a transaction and is expected to be returned by the DBAdapter.BeginTx function type Tx interface { DBAdapter @@ -1002,27 +1030,34 @@ func scanRowsFromType( return err } + var attrNames []string var scanArgs []interface{} if info.IsNestedStruct { // This version is positional meaning that it expect the arguments // to follow an specific order. It's ok because we don't allow the // user to type the "SELECT" part of the query for nested structs. - scanArgs, err = getScanArgsForNestedStructs(ctx, dialect, rows, t, v, info) + attrNames, scanArgs, err = getScanArgsForNestedStructs(ctx, dialect, rows, t, v, info) if err != nil { return err } } else { - names, err := rows.Columns() + colNames, err := rows.Columns() if err != nil { return fmt.Errorf("KSQL: unable to read columns from returned rows: %w", err) } // Since this version uses the names of the columns it works // with any order of attributes/columns. - scanArgs = getScanArgsFromNames(ctx, dialect, names, v, info) + attrNames, scanArgs = getScanArgsFromNames(ctx, dialect, colNames, v, info) } err = rows.Scan(scanArgs...) if err != nil { + if scanErr, ok := err.(ScanArgError); ok { + return fmt.Errorf( + "KSQL: scan error: %w", + scanErr.ErrorWithStructNames(t.Name(), attrNames[scanErr.ColumnIndex]), + ) + } return fmt.Errorf("KSQL: scan error: %w", err) } return nil @@ -1035,8 +1070,7 @@ func getScanArgsForNestedStructs( t reflect.Type, v reflect.Value, info structs.StructInfo, -) ([]interface{}, error) { - scanArgs := []interface{}{} +) (attrNames []string, scanArgs []interface{}, _ error) { for i := 0; i < v.NumField(); i++ { if !info.ByIndex(i).Valid { continue @@ -1045,7 +1079,7 @@ func getScanArgsForNestedStructs( // TODO(vingarcia00): Handle case where type is pointer nestedStructInfo, err := structs.GetTagInfo(t.Field(i).Type) if err != nil { - return nil, err + return nil, nil, err } nestedStructValue := v.Field(i) @@ -1055,34 +1089,36 @@ func getScanArgsForNestedStructs( continue } - valueScanner := nopScannerValue - if fieldInfo.Valid { - valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface() - - if fieldInfo.Modifier.Scan != nil { - valueScanner = &modifiers.AttrScanWrapper{ - Ctx: ctx, - AttrPtr: valueScanner, - ScanFn: fieldInfo.Modifier.Scan, - OpInfo: ksqlmodifiers.OpInfo{ - DriverName: dialect.DriverName(), - // We will not differentiate between Query, QueryOne and QueryChunks - // if we did this could lead users to make very strange modifiers - Method: "Query", - }, - } + valueScanner := nestedStructValue.Field(fieldInfo.Index).Addr().Interface() + if fieldInfo.Modifier.Scan != nil { + valueScanner = &modifiers.AttrScanWrapper{ + Ctx: ctx, + AttrPtr: valueScanner, + ScanFn: fieldInfo.Modifier.Scan, + OpInfo: ksqlmodifiers.OpInfo{ + DriverName: dialect.DriverName(), + // We will not differentiate between Query, QueryOne and QueryChunks + // if we did this could lead users to make very strange modifiers + Method: "Query", + }, } } scanArgs = append(scanArgs, valueScanner) + attrNames = append(attrNames, info.ByIndex(i).AttrName+"."+fieldInfo.AttrName) } } - return scanArgs, nil + return attrNames, scanArgs, nil } -func getScanArgsFromNames(ctx context.Context, dialect Dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} { - scanArgs := []interface{}{} +func getScanArgsFromNames( + ctx context.Context, + dialect Dialect, + names []string, + v reflect.Value, + info structs.StructInfo, +) (attrNames []string, scanArgs []interface{}) { for _, name := range names { fieldInfo := info.ByName(name) @@ -1105,9 +1141,10 @@ func getScanArgsFromNames(ctx context.Context, dialect Dialect, names []string, } scanArgs = append(scanArgs, valueScanner) + attrNames = append(attrNames, fieldInfo.AttrName) } - return scanArgs + return attrNames, scanArgs } func buildDeleteQuery( @@ -1185,7 +1222,7 @@ func buildSelectQueryForPlainStructs( continue } - fields = append(fields, dialect.Escape(fieldInfo.Name)) + fields = append(fields, dialect.Escape(fieldInfo.ColumnName)) } return "SELECT " + strings.Join(fields, ", ") + " " @@ -1203,7 +1240,7 @@ func buildSelectQueryForNestedStructs( continue } - nestedStructName := nestedStructInfo.Name + nestedStructName := nestedStructInfo.ColumnName nestedStructType := structType.Field(i).Type if nestedStructType.Kind() != reflect.Struct { return "", fmt.Errorf( @@ -1225,7 +1262,7 @@ func buildSelectQueryForNestedStructs( fields = append( fields, - dialect.Escape(nestedStructName)+"."+dialect.Escape(fieldInfo.Name), + dialect.Escape(nestedStructName)+"."+dialect.Escape(fieldInfo.ColumnName), ) } } diff --git a/test_adapters.go b/test_adapters.go index 1649fa2..f00c6d6 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -6,10 +6,13 @@ import ( "encoding/json" "fmt" "io" + "reflect" + "strings" "testing" "time" "github.com/vingarcia/ksql/internal/modifiers" + "github.com/vingarcia/ksql/internal/structs" tt "github.com/vingarcia/ksql/internal/testtools" "github.com/vingarcia/ksql/ksqlmodifiers" "github.com/vingarcia/ksql/nullable" @@ -37,6 +40,8 @@ type address struct { Country string `json:"country"` } +var postsTable = NewTable("posts") + type post struct { ID int `ksql:"id"` UserID uint `ksql:"user_id"` @@ -1017,7 +1022,6 @@ func InsertTest( tt.AssertNoErr(t, err) tt.AssertNotEqual(t, permission.ID, 0) - fmt.Println("permID:", permission.ID) var untaggedPerm struct { ID uint `ksql:"id"` UserID int `ksql:"user_id"` @@ -3050,6 +3054,8 @@ func ScanRowsTest( connStr string, newDBAdapter func(t *testing.T) (DBAdapter, io.Closer), ) { + ctx := context.Background() + t.Run("ScanRows", func(t *testing.T) { t.Run("should scan users correctly", func(t *testing.T) { err := createTables(driver, connStr) @@ -3058,7 +3064,6 @@ func ScanRowsTest( } dialect := supportedDialects[driver] - ctx := context.TODO() db, closer := newDBAdapter(t) defer closer.Close() c := newTestDB(db, driver) @@ -3082,12 +3087,9 @@ func ScanRowsTest( t.Run("should ignore extra columns from query", func(t *testing.T) { err := createTables(driver, connStr) - if err != nil { - t.Fatal("could not create test table!, reason:", err.Error()) - } + tt.AssertNoErr(t, err) dialect := supportedDialects[driver] - ctx := context.TODO() db, closer := newDBAdapter(t) defer closer.Close() c := newTestDB(db, driver) @@ -3112,6 +3114,101 @@ func ScanRowsTest( tt.AssertEqual(t, u.Age, 22) }) + t.Run("should report scan errors", func(t *testing.T) { + type brokenUser struct { + ID uint `ksql:"id"` + + // The error will happen here, when scanning + // an integer into a attribute of type struct{}: + Age struct{} `ksql:"age"` + } + + type brokenNestedStruct struct { + User struct { + ID uint `ksql:"id"` + + // The error will happen here, when scanning + // an integer into a attribute of type struct: + Age struct{} `ksql:"age"` + } `tablename:"u"` + Post post `tablename:"p"` + } + + tests := []struct { + desc string + query string + scanTarget interface{} + expectErrToContain []string + }{ + { + desc: "with anonymous structs", + query: "FROM users WHERE name='User22'", + scanTarget: &struct { + ID uint `ksql:"id"` + + // The error will happen here, when scanning + // an integer into a attribute of type struct{}: + Age struct{} `ksql:"age"` + }{}, + expectErrToContain: []string{" .Age", "struct {}"}, + }, + { + desc: "with named structs", + query: "FROM users WHERE name='User22'", + scanTarget: &brokenUser{}, + expectErrToContain: []string{"brokenUser.Age", "struct {}"}, + }, + { + desc: "with anonymous nested structs", + query: "FROM users u JOIN posts p ON u.id = p.user_id WHERE name='User22'", + scanTarget: &struct { + User struct { + ID uint `ksql:"id"` + + // The error will happen here, when scanning + // an integer into a attribute of type struct: + Age struct{} `ksql:"age"` + } `tablename:"u"` + Post post `tablename:"p"` + }{}, + expectErrToContain: []string{".User.Age", "struct {}"}, + }, + { + desc: "with named nested structs", + query: "FROM users u JOIN posts p ON u.id = p.user_id WHERE name='User22'", + scanTarget: &brokenNestedStruct{}, + expectErrToContain: []string{"brokenNestedStruct.User.Age", "struct {}"}, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + err := createTables(driver, connStr) + tt.AssertNoErr(t, err) + + dialect := supportedDialects[driver] + db, closer := newDBAdapter(t) + defer closer.Close() + c := newTestDB(db, driver) + + u := user{Name: "User22", Age: 22} + _ = c.Insert(ctx, usersTable, &u) + _ = c.Insert(ctx, postsTable, &post{UserID: u.ID, Title: "FakeTitle"}) + + query := mustBuildSelectQuery(t, dialect, test.scanTarget, test.query) + + rows, err := db.QueryContext(ctx, query) + tt.AssertNoErr(t, err) + defer rows.Close() + + tt.AssertEqual(t, rows.Next(), true) + + err = scanRows(ctx, dialect, rows, test.scanTarget) + tt.AssertErrContains(t, err, test.expectErrToContain...) + }) + } + }) + t.Run("should report error for closed rows", func(t *testing.T) { err := createTables(driver, connStr) if err != nil { @@ -3119,7 +3216,6 @@ func ScanRowsTest( } dialect := supportedDialects[driver] - ctx := context.TODO() db, closer := newDBAdapter(t) defer closer.Close() @@ -3140,7 +3236,6 @@ func ScanRowsTest( } dialect := supportedDialects[driver] - ctx := context.TODO() db, closer := newDBAdapter(t) defer closer.Close() @@ -3160,7 +3255,6 @@ func ScanRowsTest( } dialect := supportedDialects[driver] - ctx := context.TODO() db, closer := newDBAdapter(t) defer closer.Close() @@ -3446,3 +3540,22 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results return results, nil } + +func mustBuildSelectQuery(t *testing.T, + dialect Dialect, + record interface{}, + query string, +) string { + if strings.HasPrefix(query, "SELECT") { + return query + } + + structType := reflect.TypeOf(record).Elem() + structInfo, err := structs.GetTagInfo(structType) + tt.AssertNoErr(t, err) + + selectPrefix, err := buildSelectQuery(dialect, structType, structInfo, selectQueryCache[dialect.DriverName()]) + tt.AssertNoErr(t, err) + + return selectPrefix + query +}