mirror of https://github.com/VinGarcia/ksql.git
Add error check for nil pointers used as arguments to QueryOne()
parent
33dd982d7c
commit
889662c4e0
28
ksql.go
28
ksql.go
|
@ -248,16 +248,22 @@ func (c DB) QueryOne(
|
|||
query string,
|
||||
params ...interface{},
|
||||
) error {
|
||||
t := reflect.TypeOf(record)
|
||||
v := reflect.ValueOf(record)
|
||||
t := v.Type()
|
||||
if t.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
|
||||
}
|
||||
t = t.Elem()
|
||||
if t.Kind() != reflect.Struct {
|
||||
|
||||
if v.IsNil() {
|
||||
return fmt.Errorf("ksql: expected a valid pointer to struct as argument but received a nil pointer: %v", record)
|
||||
}
|
||||
|
||||
tStruct := t.Elem()
|
||||
if tStruct.Kind() != reflect.Struct {
|
||||
return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
|
||||
}
|
||||
|
||||
info := kstructs.GetTagInfo(t)
|
||||
info := kstructs.GetTagInfo(tStruct)
|
||||
|
||||
firstToken := strings.ToUpper(getFirstToken(query))
|
||||
if info.IsNestedStruct && firstToken == "SELECT" {
|
||||
|
@ -266,7 +272,7 @@ func (c DB) QueryOne(
|
|||
}
|
||||
|
||||
if firstToken == "FROM" {
|
||||
selectPrefix, err := buildSelectQuery(c.dialect, t, info, selectQueryCache[c.dialect.DriverName()])
|
||||
selectPrefix, err := buildSelectQuery(c.dialect, tStruct, info, selectQueryCache[c.dialect.DriverName()])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -286,7 +292,7 @@ func (c DB) QueryOne(
|
|||
return ErrRecordNotFound
|
||||
}
|
||||
|
||||
err = scanRows(c.dialect, rows, record)
|
||||
err = scanRowsFromType(c.dialect, rows, record, t, v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -860,6 +866,16 @@ func (nopScanner) Scan(value interface{}) error {
|
|||
func scanRows(dialect Dialect, rows Rows, record interface{}) error {
|
||||
v := reflect.ValueOf(record)
|
||||
t := v.Type()
|
||||
return scanRowsFromType(dialect, rows, record, t, v)
|
||||
}
|
||||
|
||||
func scanRowsFromType(
|
||||
dialect Dialect,
|
||||
rows Rows,
|
||||
record interface{},
|
||||
t reflect.Type,
|
||||
v reflect.Value,
|
||||
) error {
|
||||
if t.Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record)
|
||||
}
|
||||
|
|
11
ksql_test.go
11
ksql_test.go
|
@ -594,6 +594,17 @@ func TestQueryOne(t *testing.T) {
|
|||
assert.NotEqual(t, nil, err)
|
||||
})
|
||||
|
||||
t.Run("should report error if it receives a nil pointer to a struct", func(t *testing.T) {
|
||||
db, closer := connectDB(t, config)
|
||||
defer closer.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
c := newTestDB(db, config.driver)
|
||||
var user *User
|
||||
err := c.QueryOne(ctx, user, `SELECT * FROM users`)
|
||||
assert.NotEqual(t, nil, err)
|
||||
})
|
||||
|
||||
t.Run("should report error if the query is not valid", func(t *testing.T) {
|
||||
db, closer := connectDB(t, config)
|
||||
defer closer.Close()
|
||||
|
|
Loading…
Reference in New Issue