From f06706b08184604c6c8d033cf2d0b78f9422b572 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Sun, 17 Jan 2021 20:26:34 -0300 Subject: [PATCH] Fix scanRows to ignore extra columns from query without errors --- README.md | 2 -- kiss_orm.go | 16 +++++++++++++++- kiss_orm_test.go | 31 +++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 1e28734..097fb24 100644 --- a/README.md +++ b/README.md @@ -205,5 +205,3 @@ read the example tests available on our [example service](./examples/example_ser - Allow the ID field to have a different name - Implement a JSON fields on the database (encoding/decoding them automatically into structs) - Implement support for nested objects with prefixed table names -- Double check if all reflection is safe on the Insert() function -- Make sure `SELECT *` works even if not all fields are present diff --git a/kiss_orm.go b/kiss_orm.go index 923f599..d857e94 100644 --- a/kiss_orm.go +++ b/kiss_orm.go @@ -612,6 +612,14 @@ func parseInputFunc(fn interface{}) (reflect.Type, error) { return argsType, nil } +type nopScanner struct{} + +var nopScannerValue = reflect.ValueOf(&nopScanner{}) + +func (nopScanner) Scan(value interface{}) error { + return nil +} + func scanRows(rows *sql.Rows, record interface{}) error { names, err := rows.Columns() if err != nil { @@ -635,7 +643,13 @@ func scanRows(rows *sql.Rows, record interface{}) error { scanArgs := []interface{}{} for _, name := range names { - scanArgs = append(scanArgs, v.Field(info.Index[name]).Addr().Interface()) + idx, found := info.Index[name] + valueScanner := v.Field(idx).Addr() + if !found { + valueScanner = nopScannerValue + } + + scanArgs = append(scanArgs, valueScanner.Interface()) } return rows.Scan(scanArgs...) diff --git a/kiss_orm_test.go b/kiss_orm_test.go index 063490e..9390312 100644 --- a/kiss_orm_test.go +++ b/kiss_orm_test.go @@ -896,6 +896,37 @@ func TestScanRows(t *testing.T) { assert.Equal(t, 14, u.Age) }) + t.Run("should ignore extra columns from query", func(t *testing.T) { + err := createTable("sqlite3") + if err != nil { + t.Fatal("could not create test table!, reason:", err.Error()) + } + + ctx := context.TODO() + db := connectDB(t, "sqlite3") + defer db.Close() + c := newTestDB(db, "sqlite3", "users") + _ = c.Insert(ctx, &User{Name: "User1", Age: 22}) + + rows, err := db.QueryContext(ctx, "SELECT * FROM users WHERE name='User1'") + assert.Equal(t, nil, err) + defer rows.Close() + + assert.Equal(t, true, rows.Next()) + + var user struct { + ID int `kissorm:"id"` + Age int `kissorm:"age"` + + // Omitted for testing purposes: + // Name string `kissorm:"name"` + } + err = scanRows(rows, &user) + assert.Equal(t, nil, err) + + assert.Equal(t, 22, user.Age) + }) + t.Run("should report error for closed rows", func(t *testing.T) { err := createTable("sqlite3") if err != nil {