mirror of https://github.com/VinGarcia/ksql.git
Update Insert() method to work with qualified table names
parent
2b1dd6db3d
commit
d933794459
4
ksql.go
4
ksql.go
|
@ -807,7 +807,7 @@ func buildInsertQuery(
|
||||||
if len(columnNames) == 0 && dialect.DriverName() != "mysql" {
|
if len(columnNames) == 0 && dialect.DriverName() != "mysql" {
|
||||||
query = fmt.Sprintf(
|
query = fmt.Sprintf(
|
||||||
"INSERT INTO %s%s DEFAULT VALUES%s",
|
"INSERT INTO %s%s DEFAULT VALUES%s",
|
||||||
dialect.Escape(table.name),
|
table.name,
|
||||||
outputQuery,
|
outputQuery,
|
||||||
returningQuery,
|
returningQuery,
|
||||||
)
|
)
|
||||||
|
@ -818,7 +818,7 @@ func buildInsertQuery(
|
||||||
// on the selected driver, thus, they might be empty strings.
|
// on the selected driver, thus, they might be empty strings.
|
||||||
query = fmt.Sprintf(
|
query = fmt.Sprintf(
|
||||||
"INSERT INTO %s (%s)%s VALUES (%s)%s",
|
"INSERT INTO %s (%s)%s VALUES (%s)%s",
|
||||||
dialect.Escape(table.name),
|
table.name,
|
||||||
strings.Join(escapedColumnNames, ", "),
|
strings.Join(escapedColumnNames, ", "),
|
||||||
outputQuery,
|
outputQuery,
|
||||||
strings.Join(valuesQuery, ", "),
|
strings.Join(valuesQuery, ", "),
|
||||||
|
|
|
@ -866,6 +866,41 @@ func InsertTest(
|
||||||
tt.AssertNoErr(t, err)
|
tt.AssertNoErr(t, err)
|
||||||
tt.AssertEqual(t, untaggedUser.Name, (*string)(nil))
|
tt.AssertEqual(t, untaggedUser.Name, (*string)(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("should work even when ksql.NewTable receives a qualified table name", func(t *testing.T) {
|
||||||
|
c := newTestDB(db, dialect)
|
||||||
|
|
||||||
|
u := user{
|
||||||
|
Name: "Amanda",
|
||||||
|
Address: address{
|
||||||
|
Country: "Brasil",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
switch dialect.DriverName() {
|
||||||
|
case "postgres":
|
||||||
|
// public is the default schema name for postgres:
|
||||||
|
err = c.Insert(ctx, NewTable("public.users"), &u)
|
||||||
|
case "sqlserver":
|
||||||
|
// dbo is the default schema name for sqlserver:
|
||||||
|
err = c.Insert(ctx, NewTable("dbo.users"), &u)
|
||||||
|
case "sqlite3":
|
||||||
|
// main is the default schema name for sqlite:
|
||||||
|
err = c.Insert(ctx, NewTable("main.users"), &u)
|
||||||
|
case "mysql":
|
||||||
|
err = c.Insert(ctx, NewTable("ksql.users"), &u)
|
||||||
|
}
|
||||||
|
tt.AssertNoErr(t, err)
|
||||||
|
tt.AssertNotEqual(t, u.ID, 0)
|
||||||
|
|
||||||
|
result := user{}
|
||||||
|
err = getUserByID(c.db, c.dialect, &result, u.ID)
|
||||||
|
tt.AssertNoErr(t, err)
|
||||||
|
|
||||||
|
tt.AssertEqual(t, result.Name, u.Name)
|
||||||
|
tt.AssertEqual(t, result.Address, u.Address)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("composite key tables", func(t *testing.T) {
|
t.Run("composite key tables", func(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue