From d1e97489ef5f8cc22d6a245af5eca23397a1558a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Fri, 11 Feb 2022 17:30:42 -0300 Subject: [PATCH] Add some tests for invalid tables passed to Delete() --- contracts.go | 14 ++++++++++++++ ksql.go | 16 ++++++++-------- ksql_test.go | 22 ++++++++++++++++++++++ 3 files changed, 44 insertions(+), 8 deletions(-) diff --git a/contracts.go b/contracts.go index 85ce350..599b4c6 100644 --- a/contracts.go +++ b/contracts.go @@ -72,6 +72,20 @@ func NewTable(tableName string, ids ...string) Table { } } +func (t Table) validate() error { + if t.name == "" { + return fmt.Errorf("table name cannot be an empty string") + } + + for _, fieldName := range t.idColumns { + if fieldName == "" { + return fmt.Errorf("ID columns cannot be empty strings") + } + } + + return nil +} + func (t Table) insertMethodFor(dialect Dialect) insertMethod { if len(t.idColumns) == 1 { return dialect.InsertMethod() diff --git a/ksql.go b/ksql.go index f74f699..9a1a7ba 100644 --- a/ksql.go +++ b/ksql.go @@ -402,6 +402,10 @@ func (c DB) Insert( return fmt.Errorf("ksql: expected a valid pointer to struct as argument but received a nil pointer: %v", record) } + if err := table.validate(); err != nil { + return fmt.Errorf("can't insert in ksql.Table: %s", err) + } + info, err := structs.GetTagInfo(t.Elem()) if err != nil { return err @@ -543,6 +547,10 @@ func (c DB) Delete( table Table, idOrRecord interface{}, ) error { + if err := table.validate(); err != nil { + return fmt.Errorf("can't delete from ksql.Table: %s", err) + } + idMaps, err := normalizeIDsAsMaps(table.idColumns, []interface{}{idOrRecord}) if err != nil { return err @@ -678,15 +686,7 @@ func buildInsertQuery( return "", nil, nil, err } - if table.name == "" { - return "", nil, nil, fmt.Errorf("can't insert in ksql.Table: table name cannot be an empty string") - } - for _, fieldName := range table.idColumns { - if fieldName == "" { - return "", nil, nil, fmt.Errorf("can't insert in ksql.Table: ID columns cannot be empty strings") - } - field, found := recordMap[fieldName] if !found { continue diff --git a/ksql_test.go b/ksql_test.go index ed22392..bc9ec82 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -1103,6 +1103,28 @@ func TestDelete(t *testing.T) { err := c.Delete(ctx, UsersTable, user) assert.NotEqual(t, nil, err) }) + + t.Run("should report error if table contains an empty ID name", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err := c.Delete(ctx, NewTable("users", ""), &User{Name: "fake-name"}) + tt.AssertErrContains(t, err, "ksql.Table", "ID", "empty string") + }) + + t.Run("should report error if ksql.Table.name is empty", func(t *testing.T) { + db, closer := connectDB(t, config) + defer closer.Close() + + ctx := context.Background() + c := newTestDB(db, config.driver) + + err := c.Delete(ctx, NewTable("", "id"), &User{Name: "fake-name"}) + tt.AssertErrContains(t, err, "ksql.Table", "table name", "empty string") + }) }) } }