mirror of https://github.com/VinGarcia/ksql.git
Improve error messages for scan errors on all adapters
parent
d2c90f4e42
commit
5bfb5cd92a
|
@ -3,6 +3,9 @@ package kmysql
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/vingarcia/ksql"
|
||||
)
|
||||
|
@ -29,7 +32,8 @@ func (s SQLAdapter) ExecContext(ctx context.Context, query string, args ...inter
|
|||
|
||||
// QueryContext implements the DBAdapter interface
|
||||
func (s SQLAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) {
|
||||
return s.DB.QueryContext(ctx, query, args...)
|
||||
rows, err := s.DB.QueryContext(ctx, query, args...)
|
||||
return SQLRows{rows}, err
|
||||
}
|
||||
|
||||
// BeginTx implements the Tx interface
|
||||
|
@ -56,7 +60,8 @@ func (s SQLTx) ExecContext(ctx context.Context, query string, args ...interface{
|
|||
|
||||
// QueryContext implements the Tx interface
|
||||
func (s SQLTx) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) {
|
||||
return s.Tx.QueryContext(ctx, query, args...)
|
||||
rows, err := s.Tx.QueryContext(ctx, query, args...)
|
||||
return SQLRows{rows}, err
|
||||
}
|
||||
|
||||
// Rollback implements the Tx interface
|
||||
|
@ -70,3 +75,40 @@ func (s SQLTx) Commit(ctx context.Context) error {
|
|||
}
|
||||
|
||||
var _ ksql.Tx = SQLTx{}
|
||||
|
||||
// SQLRows implements the ksql.Rows interface and is used to help
|
||||
// the SQLAdapter to implement the ksql.DBAdapter interface.
|
||||
type SQLRows struct {
|
||||
*sql.Rows
|
||||
}
|
||||
|
||||
var _ ksql.Rows = SQLRows{}
|
||||
|
||||
// Scan implements the ksql.Rows interface
|
||||
func (p SQLRows) Scan(args ...interface{}) error {
|
||||
err := p.Rows.Scan(args...)
|
||||
if err != nil {
|
||||
// Since this is the error flow we decided it would be ok
|
||||
// to spend a little bit more time parsing this error in order
|
||||
// to produce better error messages.
|
||||
//
|
||||
// If the parsing fails we just return the error unchanged.
|
||||
const scanErrPrefix = "sql: Scan error on column index "
|
||||
var errMsg = err.Error()
|
||||
if strings.HasPrefix(errMsg, scanErrPrefix) {
|
||||
i := len(scanErrPrefix)
|
||||
for unicode.IsDigit(rune(errMsg[i])) {
|
||||
i++
|
||||
}
|
||||
colIndex, convErr := strconv.Atoi(errMsg[len(scanErrPrefix):i])
|
||||
if convErr == nil {
|
||||
return ksql.ScanArgError{
|
||||
ColumnIndex: colIndex,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -96,14 +96,27 @@ func (p PGXTx) Commit(ctx context.Context) error {
|
|||
|
||||
var _ ksql.Tx = PGXTx{}
|
||||
|
||||
// PGXRows implements the Rows interface and is used to help
|
||||
// the PGXAdapter to implement the DBAdapter interface.
|
||||
// PGXRows implements the ksql.Rows interface and is used to help
|
||||
// the PGXAdapter to implement the ksql.DBAdapter interface.
|
||||
type PGXRows struct {
|
||||
pgx.Rows
|
||||
}
|
||||
|
||||
var _ ksql.Rows = PGXRows{}
|
||||
|
||||
// Scan implements the ksql.Rows interface
|
||||
func (p PGXRows) Scan(args ...interface{}) error {
|
||||
err := p.Rows.Scan(args...)
|
||||
if scanErr, ok := err.(pgx.ScanArgError); ok {
|
||||
return ksql.ScanArgError{
|
||||
Err: scanErr.Err,
|
||||
ColumnIndex: scanErr.ColumnIndex,
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// Columns implements the Rows interface
|
||||
func (p PGXRows) Columns() ([]string, error) {
|
||||
var names []string
|
||||
|
|
|
@ -3,6 +3,9 @@ package ksqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/vingarcia/ksql"
|
||||
)
|
||||
|
@ -29,7 +32,8 @@ func (s SQLAdapter) ExecContext(ctx context.Context, query string, args ...inter
|
|||
|
||||
// QueryContext implements the DBAdapter interface
|
||||
func (s SQLAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) {
|
||||
return s.DB.QueryContext(ctx, query, args...)
|
||||
rows, err := s.DB.QueryContext(ctx, query, args...)
|
||||
return SQLRows{rows}, err
|
||||
}
|
||||
|
||||
// BeginTx implements the Tx interface
|
||||
|
@ -56,7 +60,8 @@ func (s SQLTx) ExecContext(ctx context.Context, query string, args ...interface{
|
|||
|
||||
// QueryContext implements the Tx interface
|
||||
func (s SQLTx) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) {
|
||||
return s.Tx.QueryContext(ctx, query, args...)
|
||||
rows, err := s.Tx.QueryContext(ctx, query, args...)
|
||||
return SQLRows{rows}, err
|
||||
}
|
||||
|
||||
// Rollback implements the Tx interface
|
||||
|
@ -70,3 +75,40 @@ func (s SQLTx) Commit(ctx context.Context) error {
|
|||
}
|
||||
|
||||
var _ ksql.Tx = SQLTx{}
|
||||
|
||||
// SQLRows implements the ksql.Rows interface and is used to help
|
||||
// the SQLAdapter to implement the ksql.DBAdapter interface.
|
||||
type SQLRows struct {
|
||||
*sql.Rows
|
||||
}
|
||||
|
||||
var _ ksql.Rows = SQLRows{}
|
||||
|
||||
// Scan implements the ksql.Rows interface
|
||||
func (p SQLRows) Scan(args ...interface{}) error {
|
||||
err := p.Rows.Scan(args...)
|
||||
if err != nil {
|
||||
// Since this is the error flow we decided it would be ok
|
||||
// to spend a little bit more time parsing this error in order
|
||||
// to produce better error messages.
|
||||
//
|
||||
// If the parsing fails we just return the error unchanged.
|
||||
const scanErrPrefix = "sql: Scan error on column index "
|
||||
var errMsg = err.Error()
|
||||
if strings.HasPrefix(errMsg, scanErrPrefix) {
|
||||
i := len(scanErrPrefix)
|
||||
for unicode.IsDigit(rune(errMsg[i])) {
|
||||
i++
|
||||
}
|
||||
colIndex, convErr := strconv.Atoi(errMsg[len(scanErrPrefix):i])
|
||||
if convErr == nil {
|
||||
return ksql.ScanArgError{
|
||||
ColumnIndex: colIndex,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -3,6 +3,9 @@ package ksqlserver
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/vingarcia/ksql"
|
||||
)
|
||||
|
@ -29,7 +32,8 @@ func (s SQLAdapter) ExecContext(ctx context.Context, query string, args ...inter
|
|||
|
||||
// QueryContext implements the DBAdapter interface
|
||||
func (s SQLAdapter) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) {
|
||||
return s.DB.QueryContext(ctx, query, args...)
|
||||
rows, err := s.DB.QueryContext(ctx, query, args...)
|
||||
return SQLRows{rows}, err
|
||||
}
|
||||
|
||||
// BeginTx implements the Tx interface
|
||||
|
@ -56,7 +60,8 @@ func (s SQLTx) ExecContext(ctx context.Context, query string, args ...interface{
|
|||
|
||||
// QueryContext implements the Tx interface
|
||||
func (s SQLTx) QueryContext(ctx context.Context, query string, args ...interface{}) (ksql.Rows, error) {
|
||||
return s.Tx.QueryContext(ctx, query, args...)
|
||||
rows, err := s.Tx.QueryContext(ctx, query, args...)
|
||||
return SQLRows{rows}, err
|
||||
}
|
||||
|
||||
// Rollback implements the Tx interface
|
||||
|
@ -70,3 +75,40 @@ func (s SQLTx) Commit(ctx context.Context) error {
|
|||
}
|
||||
|
||||
var _ ksql.Tx = SQLTx{}
|
||||
|
||||
// SQLRows implements the ksql.Rows interface and is used to help
|
||||
// the SQLAdapter to implement the ksql.DBAdapter interface.
|
||||
type SQLRows struct {
|
||||
*sql.Rows
|
||||
}
|
||||
|
||||
var _ ksql.Rows = SQLRows{}
|
||||
|
||||
// Scan implements the ksql.Rows interface
|
||||
func (p SQLRows) Scan(args ...interface{}) error {
|
||||
err := p.Rows.Scan(args...)
|
||||
if err != nil {
|
||||
// Since this is the error flow we decided it would be ok
|
||||
// to spend a little bit more time parsing this error in order
|
||||
// to produce better error messages.
|
||||
//
|
||||
// If the parsing fails we just return the error unchanged.
|
||||
const scanErrPrefix = "sql: Scan error on column index "
|
||||
var errMsg = err.Error()
|
||||
if strings.HasPrefix(errMsg, scanErrPrefix) {
|
||||
i := len(scanErrPrefix)
|
||||
for unicode.IsDigit(rune(errMsg[i])) {
|
||||
i++
|
||||
}
|
||||
colIndex, convErr := strconv.Atoi(errMsg[len(scanErrPrefix):i])
|
||||
if convErr == nil {
|
||||
return ksql.ScanArgError{
|
||||
ColumnIndex: colIndex,
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -12,7 +12,7 @@ replace (
|
|||
|
||||
require (
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/stretchr/testify v1.7.0
|
||||
github.com/stretchr/testify v1.8.0
|
||||
github.com/vingarcia/ksql v1.4.7
|
||||
github.com/vingarcia/ksql/adapters/kmysql v0.0.0-00010101000000-000000000000
|
||||
github.com/vingarcia/ksql/adapters/kpgx v0.0.0-00010101000000-000000000000
|
||||
|
|
|
@ -230,12 +230,15 @@ github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/y
|
|||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/syndtr/gocapability v0.0.0-20200815063812-42c35b437635/go.mod h1:hkRG7XYTFWNJGYcbNJQlaLq0fg1yr4J4t/NcTQtrfww=
|
||||
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=
|
||||
github.com/tv42/httpunix v0.0.0-20191220191345-2ba4b9c3382c/go.mod h1:hzIxponao9Kjc7aWznkXaL4U4TWaDSs8zcsY4Ka08nM=
|
||||
|
@ -371,8 +374,9 @@ gopkg.in/resty.v1 v1.12.0/go.mod h1:mDo4pnntr5jdWRML875a/NmxYqAlA73dVijT2AXvQQo=
|
|||
gopkg.in/yaml.v2 v2.0.0-20170812160011-eb3733d160e7/go.mod h1:JAlM8MvJe8wmxCU4Bli9HhUf9+ttbYbLASfIpnQbh74=
|
||||
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo=
|
||||
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
|
|
|
@ -23,10 +23,11 @@ type StructInfo struct {
|
|||
// information regarding a specific field
|
||||
// of a struct.
|
||||
type FieldInfo struct {
|
||||
Name string
|
||||
Index int
|
||||
Valid bool
|
||||
Modifier ksqlmodifiers.AttrModifier
|
||||
AttrName string
|
||||
ColumnName string
|
||||
Index int
|
||||
Valid bool
|
||||
Modifier ksqlmodifiers.AttrModifier
|
||||
}
|
||||
|
||||
// ByIndex returns either the *FieldInfo of a valid
|
||||
|
@ -52,12 +53,12 @@ func (s StructInfo) ByName(name string) *FieldInfo {
|
|||
func (s StructInfo) add(field FieldInfo) {
|
||||
field.Valid = true
|
||||
s.byIndex[field.Index] = &field
|
||||
s.byName[field.Name] = &field
|
||||
s.byName[field.ColumnName] = &field
|
||||
|
||||
// Make sure to save a lowercased version because
|
||||
// some databases will set these keys to lowercase.
|
||||
if _, found := s.byName[strings.ToLower(field.Name)]; !found {
|
||||
s.byName[strings.ToLower(field.Name)] = &field
|
||||
if _, found := s.byName[strings.ToLower(field.ColumnName)]; !found {
|
||||
s.byName[strings.ToLower(field.ColumnName)] = &field
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -143,7 +144,7 @@ func StructToMap(obj interface{}) (map[string]interface{}, error) {
|
|||
field = field.Elem()
|
||||
}
|
||||
|
||||
m[fieldInfo.Name] = field.Interface()
|
||||
m[fieldInfo.ColumnName] = field.Interface()
|
||||
}
|
||||
|
||||
return m, nil
|
||||
|
@ -246,6 +247,7 @@ func getTagNames(t reflect.Type) (_ StructInfo, err error) {
|
|||
return StructInfo{}, fmt.Errorf("all fields using the ksql tags must be exported, but %v is unexported", t)
|
||||
}
|
||||
|
||||
attrName := t.Field(i).Name
|
||||
name := t.Field(i).Tag.Get("ksql")
|
||||
if name == "" {
|
||||
continue
|
||||
|
@ -269,9 +271,10 @@ func getTagNames(t reflect.Type) (_ StructInfo, err error) {
|
|||
}
|
||||
|
||||
info.add(FieldInfo{
|
||||
Name: name,
|
||||
Index: i,
|
||||
Modifier: modifier,
|
||||
AttrName: attrName,
|
||||
ColumnName: name,
|
||||
Index: i,
|
||||
Modifier: modifier,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -289,8 +292,9 @@ func getTagNames(t reflect.Type) (_ StructInfo, err error) {
|
|||
}
|
||||
|
||||
info.add(FieldInfo{
|
||||
Name: name,
|
||||
Index: i,
|
||||
AttrName: t.Field(i).Name,
|
||||
ColumnName: name,
|
||||
Index: i,
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -81,7 +81,7 @@ func (i Insert) BuildQuery(dialect ksql.Dialect) (sqlQuery string, params []inte
|
|||
b.WriteString(" (")
|
||||
var escapedNames []string
|
||||
for i := 0; i < info.NumFields(); i++ {
|
||||
name := info.ByIndex(i).Name
|
||||
name := info.ByIndex(i).ColumnName
|
||||
escapedNames = append(escapedNames, dialect.Escape(name))
|
||||
}
|
||||
b.WriteString(strings.Join(escapedNames, ", "))
|
||||
|
|
|
@ -215,7 +215,7 @@ func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) {
|
|||
|
||||
var escapedNames []string
|
||||
for i := 0; i < info.NumFields(); i++ {
|
||||
name := info.ByIndex(i).Name
|
||||
name := info.ByIndex(i).ColumnName
|
||||
escapedNames = append(escapedNames, dialect.Escape(name))
|
||||
}
|
||||
|
||||
|
|
95
ksql.go
95
ksql.go
|
@ -67,6 +67,34 @@ type Rows interface {
|
|||
Columns() ([]string, error)
|
||||
}
|
||||
|
||||
// ScanArgError is a type of error that is expected to be returned
|
||||
// from the Scan() method of the Rows interface.
|
||||
//
|
||||
// It should be returned when there is an error scanning one of the input
|
||||
// values.
|
||||
//
|
||||
// This is necessary in order to allow KSQL to produce a better and more
|
||||
// readable error message when this type of error occur.
|
||||
type ScanArgError struct {
|
||||
ColumnIndex int
|
||||
Err error
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (s ScanArgError) Error() string {
|
||||
return fmt.Sprintf(
|
||||
"error scanning input attribute with index %d: %s",
|
||||
s.ColumnIndex, s.Err.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
func (s ScanArgError) ErrorWithStructNames(structName string, colName string) error {
|
||||
return fmt.Errorf(
|
||||
"error scanning %s.%s: %s",
|
||||
structName, colName, s.Err.Error(),
|
||||
)
|
||||
}
|
||||
|
||||
// Tx represents a transaction and is expected to be returned by the DBAdapter.BeginTx function
|
||||
type Tx interface {
|
||||
DBAdapter
|
||||
|
@ -1002,27 +1030,34 @@ func scanRowsFromType(
|
|||
return err
|
||||
}
|
||||
|
||||
var attrNames []string
|
||||
var scanArgs []interface{}
|
||||
if info.IsNestedStruct {
|
||||
// This version is positional meaning that it expect the arguments
|
||||
// to follow an specific order. It's ok because we don't allow the
|
||||
// user to type the "SELECT" part of the query for nested structs.
|
||||
scanArgs, err = getScanArgsForNestedStructs(ctx, dialect, rows, t, v, info)
|
||||
attrNames, scanArgs, err = getScanArgsForNestedStructs(ctx, dialect, rows, t, v, info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
names, err := rows.Columns()
|
||||
colNames, err := rows.Columns()
|
||||
if err != nil {
|
||||
return fmt.Errorf("KSQL: unable to read columns from returned rows: %w", err)
|
||||
}
|
||||
// Since this version uses the names of the columns it works
|
||||
// with any order of attributes/columns.
|
||||
scanArgs = getScanArgsFromNames(ctx, dialect, names, v, info)
|
||||
attrNames, scanArgs = getScanArgsFromNames(ctx, dialect, colNames, v, info)
|
||||
}
|
||||
|
||||
err = rows.Scan(scanArgs...)
|
||||
if err != nil {
|
||||
if scanErr, ok := err.(ScanArgError); ok {
|
||||
return fmt.Errorf(
|
||||
"KSQL: scan error: %w",
|
||||
scanErr.ErrorWithStructNames(t.Name(), attrNames[scanErr.ColumnIndex]),
|
||||
)
|
||||
}
|
||||
return fmt.Errorf("KSQL: scan error: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
@ -1035,8 +1070,7 @@ func getScanArgsForNestedStructs(
|
|||
t reflect.Type,
|
||||
v reflect.Value,
|
||||
info structs.StructInfo,
|
||||
) ([]interface{}, error) {
|
||||
scanArgs := []interface{}{}
|
||||
) (attrNames []string, scanArgs []interface{}, _ error) {
|
||||
for i := 0; i < v.NumField(); i++ {
|
||||
if !info.ByIndex(i).Valid {
|
||||
continue
|
||||
|
@ -1045,7 +1079,7 @@ func getScanArgsForNestedStructs(
|
|||
// TODO(vingarcia00): Handle case where type is pointer
|
||||
nestedStructInfo, err := structs.GetTagInfo(t.Field(i).Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
nestedStructValue := v.Field(i)
|
||||
|
@ -1055,34 +1089,36 @@ func getScanArgsForNestedStructs(
|
|||
continue
|
||||
}
|
||||
|
||||
valueScanner := nopScannerValue
|
||||
if fieldInfo.Valid {
|
||||
valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface()
|
||||
|
||||
if fieldInfo.Modifier.Scan != nil {
|
||||
valueScanner = &modifiers.AttrScanWrapper{
|
||||
Ctx: ctx,
|
||||
AttrPtr: valueScanner,
|
||||
ScanFn: fieldInfo.Modifier.Scan,
|
||||
OpInfo: ksqlmodifiers.OpInfo{
|
||||
DriverName: dialect.DriverName(),
|
||||
// We will not differentiate between Query, QueryOne and QueryChunks
|
||||
// if we did this could lead users to make very strange modifiers
|
||||
Method: "Query",
|
||||
},
|
||||
}
|
||||
valueScanner := nestedStructValue.Field(fieldInfo.Index).Addr().Interface()
|
||||
if fieldInfo.Modifier.Scan != nil {
|
||||
valueScanner = &modifiers.AttrScanWrapper{
|
||||
Ctx: ctx,
|
||||
AttrPtr: valueScanner,
|
||||
ScanFn: fieldInfo.Modifier.Scan,
|
||||
OpInfo: ksqlmodifiers.OpInfo{
|
||||
DriverName: dialect.DriverName(),
|
||||
// We will not differentiate between Query, QueryOne and QueryChunks
|
||||
// if we did this could lead users to make very strange modifiers
|
||||
Method: "Query",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
scanArgs = append(scanArgs, valueScanner)
|
||||
attrNames = append(attrNames, info.ByIndex(i).AttrName+"."+fieldInfo.AttrName)
|
||||
}
|
||||
}
|
||||
|
||||
return scanArgs, nil
|
||||
return attrNames, scanArgs, nil
|
||||
}
|
||||
|
||||
func getScanArgsFromNames(ctx context.Context, dialect Dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} {
|
||||
scanArgs := []interface{}{}
|
||||
func getScanArgsFromNames(
|
||||
ctx context.Context,
|
||||
dialect Dialect,
|
||||
names []string,
|
||||
v reflect.Value,
|
||||
info structs.StructInfo,
|
||||
) (attrNames []string, scanArgs []interface{}) {
|
||||
for _, name := range names {
|
||||
fieldInfo := info.ByName(name)
|
||||
|
||||
|
@ -1105,9 +1141,10 @@ func getScanArgsFromNames(ctx context.Context, dialect Dialect, names []string,
|
|||
}
|
||||
|
||||
scanArgs = append(scanArgs, valueScanner)
|
||||
attrNames = append(attrNames, fieldInfo.AttrName)
|
||||
}
|
||||
|
||||
return scanArgs
|
||||
return attrNames, scanArgs
|
||||
}
|
||||
|
||||
func buildDeleteQuery(
|
||||
|
@ -1185,7 +1222,7 @@ func buildSelectQueryForPlainStructs(
|
|||
continue
|
||||
}
|
||||
|
||||
fields = append(fields, dialect.Escape(fieldInfo.Name))
|
||||
fields = append(fields, dialect.Escape(fieldInfo.ColumnName))
|
||||
}
|
||||
|
||||
return "SELECT " + strings.Join(fields, ", ") + " "
|
||||
|
@ -1203,7 +1240,7 @@ func buildSelectQueryForNestedStructs(
|
|||
continue
|
||||
}
|
||||
|
||||
nestedStructName := nestedStructInfo.Name
|
||||
nestedStructName := nestedStructInfo.ColumnName
|
||||
nestedStructType := structType.Field(i).Type
|
||||
if nestedStructType.Kind() != reflect.Struct {
|
||||
return "", fmt.Errorf(
|
||||
|
@ -1225,7 +1262,7 @@ func buildSelectQueryForNestedStructs(
|
|||
|
||||
fields = append(
|
||||
fields,
|
||||
dialect.Escape(nestedStructName)+"."+dialect.Escape(fieldInfo.Name),
|
||||
dialect.Escape(nestedStructName)+"."+dialect.Escape(fieldInfo.ColumnName),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
131
test_adapters.go
131
test_adapters.go
|
@ -6,10 +6,13 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/vingarcia/ksql/internal/modifiers"
|
||||
"github.com/vingarcia/ksql/internal/structs"
|
||||
tt "github.com/vingarcia/ksql/internal/testtools"
|
||||
"github.com/vingarcia/ksql/ksqlmodifiers"
|
||||
"github.com/vingarcia/ksql/nullable"
|
||||
|
@ -37,6 +40,8 @@ type address struct {
|
|||
Country string `json:"country"`
|
||||
}
|
||||
|
||||
var postsTable = NewTable("posts")
|
||||
|
||||
type post struct {
|
||||
ID int `ksql:"id"`
|
||||
UserID uint `ksql:"user_id"`
|
||||
|
@ -1017,7 +1022,6 @@ func InsertTest(
|
|||
tt.AssertNoErr(t, err)
|
||||
tt.AssertNotEqual(t, permission.ID, 0)
|
||||
|
||||
fmt.Println("permID:", permission.ID)
|
||||
var untaggedPerm struct {
|
||||
ID uint `ksql:"id"`
|
||||
UserID int `ksql:"user_id"`
|
||||
|
@ -3050,6 +3054,8 @@ func ScanRowsTest(
|
|||
connStr string,
|
||||
newDBAdapter func(t *testing.T) (DBAdapter, io.Closer),
|
||||
) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("ScanRows", func(t *testing.T) {
|
||||
t.Run("should scan users correctly", func(t *testing.T) {
|
||||
err := createTables(driver, connStr)
|
||||
|
@ -3058,7 +3064,6 @@ func ScanRowsTest(
|
|||
}
|
||||
|
||||
dialect := supportedDialects[driver]
|
||||
ctx := context.TODO()
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
c := newTestDB(db, driver)
|
||||
|
@ -3082,12 +3087,9 @@ func ScanRowsTest(
|
|||
|
||||
t.Run("should ignore extra columns from query", func(t *testing.T) {
|
||||
err := createTables(driver, connStr)
|
||||
if err != nil {
|
||||
t.Fatal("could not create test table!, reason:", err.Error())
|
||||
}
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
dialect := supportedDialects[driver]
|
||||
ctx := context.TODO()
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
c := newTestDB(db, driver)
|
||||
|
@ -3112,6 +3114,101 @@ func ScanRowsTest(
|
|||
tt.AssertEqual(t, u.Age, 22)
|
||||
})
|
||||
|
||||
t.Run("should report scan errors", func(t *testing.T) {
|
||||
type brokenUser struct {
|
||||
ID uint `ksql:"id"`
|
||||
|
||||
// The error will happen here, when scanning
|
||||
// an integer into a attribute of type struct{}:
|
||||
Age struct{} `ksql:"age"`
|
||||
}
|
||||
|
||||
type brokenNestedStruct struct {
|
||||
User struct {
|
||||
ID uint `ksql:"id"`
|
||||
|
||||
// The error will happen here, when scanning
|
||||
// an integer into a attribute of type struct:
|
||||
Age struct{} `ksql:"age"`
|
||||
} `tablename:"u"`
|
||||
Post post `tablename:"p"`
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
desc string
|
||||
query string
|
||||
scanTarget interface{}
|
||||
expectErrToContain []string
|
||||
}{
|
||||
{
|
||||
desc: "with anonymous structs",
|
||||
query: "FROM users WHERE name='User22'",
|
||||
scanTarget: &struct {
|
||||
ID uint `ksql:"id"`
|
||||
|
||||
// The error will happen here, when scanning
|
||||
// an integer into a attribute of type struct{}:
|
||||
Age struct{} `ksql:"age"`
|
||||
}{},
|
||||
expectErrToContain: []string{" .Age", "struct {}"},
|
||||
},
|
||||
{
|
||||
desc: "with named structs",
|
||||
query: "FROM users WHERE name='User22'",
|
||||
scanTarget: &brokenUser{},
|
||||
expectErrToContain: []string{"brokenUser.Age", "struct {}"},
|
||||
},
|
||||
{
|
||||
desc: "with anonymous nested structs",
|
||||
query: "FROM users u JOIN posts p ON u.id = p.user_id WHERE name='User22'",
|
||||
scanTarget: &struct {
|
||||
User struct {
|
||||
ID uint `ksql:"id"`
|
||||
|
||||
// The error will happen here, when scanning
|
||||
// an integer into a attribute of type struct:
|
||||
Age struct{} `ksql:"age"`
|
||||
} `tablename:"u"`
|
||||
Post post `tablename:"p"`
|
||||
}{},
|
||||
expectErrToContain: []string{".User.Age", "struct {}"},
|
||||
},
|
||||
{
|
||||
desc: "with named nested structs",
|
||||
query: "FROM users u JOIN posts p ON u.id = p.user_id WHERE name='User22'",
|
||||
scanTarget: &brokenNestedStruct{},
|
||||
expectErrToContain: []string{"brokenNestedStruct.User.Age", "struct {}"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.desc, func(t *testing.T) {
|
||||
err := createTables(driver, connStr)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
dialect := supportedDialects[driver]
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
c := newTestDB(db, driver)
|
||||
|
||||
u := user{Name: "User22", Age: 22}
|
||||
_ = c.Insert(ctx, usersTable, &u)
|
||||
_ = c.Insert(ctx, postsTable, &post{UserID: u.ID, Title: "FakeTitle"})
|
||||
|
||||
query := mustBuildSelectQuery(t, dialect, test.scanTarget, test.query)
|
||||
|
||||
rows, err := db.QueryContext(ctx, query)
|
||||
tt.AssertNoErr(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
tt.AssertEqual(t, rows.Next(), true)
|
||||
|
||||
err = scanRows(ctx, dialect, rows, test.scanTarget)
|
||||
tt.AssertErrContains(t, err, test.expectErrToContain...)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should report error for closed rows", func(t *testing.T) {
|
||||
err := createTables(driver, connStr)
|
||||
if err != nil {
|
||||
|
@ -3119,7 +3216,6 @@ func ScanRowsTest(
|
|||
}
|
||||
|
||||
dialect := supportedDialects[driver]
|
||||
ctx := context.TODO()
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
|
@ -3140,7 +3236,6 @@ func ScanRowsTest(
|
|||
}
|
||||
|
||||
dialect := supportedDialects[driver]
|
||||
ctx := context.TODO()
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
|
@ -3160,7 +3255,6 @@ func ScanRowsTest(
|
|||
}
|
||||
|
||||
dialect := supportedDialects[driver]
|
||||
ctx := context.TODO()
|
||||
db, closer := newDBAdapter(t)
|
||||
defer closer.Close()
|
||||
|
||||
|
@ -3446,3 +3540,22 @@ func getUserPermissionsByUser(db DBAdapter, driver string, userID int) (results
|
|||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func mustBuildSelectQuery(t *testing.T,
|
||||
dialect Dialect,
|
||||
record interface{},
|
||||
query string,
|
||||
) string {
|
||||
if strings.HasPrefix(query, "SELECT") {
|
||||
return query
|
||||
}
|
||||
|
||||
structType := reflect.TypeOf(record).Elem()
|
||||
structInfo, err := structs.GetTagInfo(structType)
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
selectPrefix, err := buildSelectQuery(dialect, structType, structInfo, selectQueryCache[dialect.DriverName()])
|
||||
tt.AssertNoErr(t, err)
|
||||
|
||||
return selectPrefix + query
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue