Add error check for nil pointers used as arguments to QueryOne()

pull/7/head
Vinícius Garcia 2021-09-18 14:42:53 -03:00
parent 33dd982d7c
commit 889662c4e0
2 changed files with 33 additions and 6 deletions

28
ksql.go
View File

@ -248,16 +248,22 @@ func (c DB) QueryOne(
query string,
params ...interface{},
) error {
t := reflect.TypeOf(record)
v := reflect.ValueOf(record)
t := v.Type()
if t.Kind() != reflect.Ptr {
return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
}
t = t.Elem()
if t.Kind() != reflect.Struct {
if v.IsNil() {
return fmt.Errorf("ksql: expected a valid pointer to struct as argument but received a nil pointer: %v", record)
}
tStruct := t.Elem()
if tStruct.Kind() != reflect.Struct {
return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
}
info := kstructs.GetTagInfo(t)
info := kstructs.GetTagInfo(tStruct)
firstToken := strings.ToUpper(getFirstToken(query))
if info.IsNestedStruct && firstToken == "SELECT" {
@ -266,7 +272,7 @@ func (c DB) QueryOne(
}
if firstToken == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, t, info, selectQueryCache[c.dialect.DriverName()])
selectPrefix, err := buildSelectQuery(c.dialect, tStruct, info, selectQueryCache[c.dialect.DriverName()])
if err != nil {
return err
}
@ -286,7 +292,7 @@ func (c DB) QueryOne(
return ErrRecordNotFound
}
err = scanRows(c.dialect, rows, record)
err = scanRowsFromType(c.dialect, rows, record, t, v)
if err != nil {
return err
}
@ -860,6 +866,16 @@ func (nopScanner) Scan(value interface{}) error {
func scanRows(dialect Dialect, rows Rows, record interface{}) error {
v := reflect.ValueOf(record)
t := v.Type()
return scanRowsFromType(dialect, rows, record, t, v)
}
func scanRowsFromType(
dialect Dialect,
rows Rows,
record interface{},
t reflect.Type,
v reflect.Value,
) error {
if t.Kind() != reflect.Ptr {
return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record)
}

View File

@ -594,6 +594,17 @@ func TestQueryOne(t *testing.T) {
assert.NotEqual(t, nil, err)
})
t.Run("should report error if it receives a nil pointer to a struct", func(t *testing.T) {
db, closer := connectDB(t, config)
defer closer.Close()
ctx := context.Background()
c := newTestDB(db, config.driver)
var user *User
err := c.QueryOne(ctx, user, `SELECT * FROM users`)
assert.NotEqual(t, nil, err)
})
t.Run("should report error if the query is not valid", func(t *testing.T) {
db, closer := connectDB(t, config)
defer closer.Close()