From d933794459715a14c0503338e5d0ad303b362b06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vin=C3=ADcius=20Garcia?= Date: Tue, 16 Jan 2024 23:18:39 -0300 Subject: [PATCH] Update Insert() method to work with qualified table names --- ksql.go | 4 ++-- test_adapters.go | 35 +++++++++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/ksql.go b/ksql.go index 050e9c3..2628c64 100644 --- a/ksql.go +++ b/ksql.go @@ -807,7 +807,7 @@ func buildInsertQuery( if len(columnNames) == 0 && dialect.DriverName() != "mysql" { query = fmt.Sprintf( "INSERT INTO %s%s DEFAULT VALUES%s", - dialect.Escape(table.name), + table.name, outputQuery, returningQuery, ) @@ -818,7 +818,7 @@ func buildInsertQuery( // on the selected driver, thus, they might be empty strings. query = fmt.Sprintf( "INSERT INTO %s (%s)%s VALUES (%s)%s", - dialect.Escape(table.name), + table.name, strings.Join(escapedColumnNames, ", "), outputQuery, strings.Join(valuesQuery, ", "), diff --git a/test_adapters.go b/test_adapters.go index c26f5fa..2cdc613 100644 --- a/test_adapters.go +++ b/test_adapters.go @@ -866,6 +866,41 @@ func InsertTest( tt.AssertNoErr(t, err) tt.AssertEqual(t, untaggedUser.Name, (*string)(nil)) }) + + t.Run("should work even when ksql.NewTable receives a qualified table name", func(t *testing.T) { + c := newTestDB(db, dialect) + + u := user{ + Name: "Amanda", + Address: address{ + Country: "Brasil", + }, + } + + var err error + switch dialect.DriverName() { + case "postgres": + // public is the default schema name for postgres: + err = c.Insert(ctx, NewTable("public.users"), &u) + case "sqlserver": + // dbo is the default schema name for sqlserver: + err = c.Insert(ctx, NewTable("dbo.users"), &u) + case "sqlite3": + // main is the default schema name for sqlite: + err = c.Insert(ctx, NewTable("main.users"), &u) + case "mysql": + err = c.Insert(ctx, NewTable("ksql.users"), &u) + } + tt.AssertNoErr(t, err) + tt.AssertNotEqual(t, u.ID, 0) + + result := user{} + err = getUserByID(c.db, c.dialect, &result, u.ID) + tt.AssertNoErr(t, err) + + tt.AssertEqual(t, result.Name, u.Name) + tt.AssertEqual(t, result.Address, u.Address) + }) }) t.Run("composite key tables", func(t *testing.T) {