diff --git a/ksql.go b/ksql.go index a0e127a..b0bc26f 100644 --- a/ksql.go +++ b/ksql.go @@ -248,16 +248,22 @@ func (c DB) QueryOne( query string, params ...interface{}, ) error { - t := reflect.TypeOf(record) + v := reflect.ValueOf(record) + t := v.Type() if t.Kind() != reflect.Ptr { return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record) } - t = t.Elem() - if t.Kind() != reflect.Struct { + + if v.IsNil() { + return fmt.Errorf("ksql: expected a valid pointer to struct as argument but received a nil pointer: %v", record) + } + + tStruct := t.Elem() + if tStruct.Kind() != reflect.Struct { return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record) } - info := kstructs.GetTagInfo(t) + info := kstructs.GetTagInfo(tStruct) firstToken := strings.ToUpper(getFirstToken(query)) if info.IsNestedStruct && firstToken == "SELECT" { @@ -266,7 +272,7 @@ func (c DB) QueryOne( } if firstToken == "FROM" { - selectPrefix, err := buildSelectQuery(c.dialect, t, info, selectQueryCache[c.dialect.DriverName()]) + selectPrefix, err := buildSelectQuery(c.dialect, tStruct, info, selectQueryCache[c.dialect.DriverName()]) if err != nil { return err } @@ -286,7 +292,7 @@ func (c DB) QueryOne( return ErrRecordNotFound } - err = scanRows(c.dialect, rows, record) + err = scanRowsFromType(c.dialect, rows, record, t, v) if err != nil { return err } @@ -860,6 +866,16 @@ func (nopScanner) Scan(value interface{}) error { func scanRows(dialect Dialect, rows Rows, record interface{}) error { v := reflect.ValueOf(record) t := v.Type() + return scanRowsFromType(dialect, rows, record, t, v) +} + +func scanRowsFromType( + dialect Dialect, + rows Rows, + record interface{}, + t reflect.Type, + v reflect.Value, +) error { if t.Kind() != reflect.Ptr { return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record) } diff --git a/ksql_test.go b/ksql_test.go index 0e9b739..257b8c6 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -594,6 +594,17 @@ func TestQueryOne(t *testing.T) { assert.NotEqual(t, nil, err) }) + t.Run("should report error if it receives a nil pointer to a struct", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + var user *User + err := c.QueryOne(ctx, user, `SELECT * FROM users`) + assert.NotEqual(t, nil, err) + }) + t.Run("should report error if the query is not valid", func(t *testing.T) { db, closer := connectDB(t, config) defer closer.Close()