Fix Insert function to work with postgres

This commit adds the concept of dialects so we can support
different ways of escaping names, creating placeholders, etc.

Currently we are only using it on the Insert route and we only
support postgres and sqlite3, in the future we should add
more tests so we can start supporting more drivers.
pull/2/head
Vinícius Garcia 2020-12-29 23:36:07 -03:00
parent a7b3c12b95
commit de8f4e56d7
4 changed files with 218 additions and 209 deletions

35
dialect.go Normal file
View File

@ -0,0 +1,35 @@
package kissorm
import "strconv"
type dialect interface {
Escape(str string) string
Placeholder(idx int) string
}
type postgresDialect struct{}
func (postgresDialect) Escape(str string) string {
return `"` + str + `"`
}
func (postgresDialect) Placeholder(idx int) string {
return "$" + strconv.Itoa(idx+1)
}
type sqlite3Dialect struct{}
func (sqlite3Dialect) Escape(str string) string {
return "`" + str + "`"
}
func (sqlite3Dialect) Placeholder(idx int) string {
return "?"
}
func getDriverDialect(driver string) dialect {
return map[string]dialect{
"postgres": &postgresDialect{},
"sqlite3": &sqlite3Dialect{},
}[driver]
}

1
go.mod
View File

@ -5,4 +5,5 @@ go 1.14
require (
github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018
github.com/jinzhu/gorm v1.9.16
github.com/lib/pq v1.1.1
)

View File

@ -13,6 +13,7 @@ import (
// Client ...
type Client struct {
driver string
dialect dialect
tableName string
db *gorm.DB
}
@ -34,7 +35,13 @@ func NewClient(
db.DB().SetMaxOpenConns(maxOpenConns)
dialect := getDriverDialect(dbDriver)
if dialect == nil {
return Client{}, fmt.Errorf("unsupported driver `%s`", dbDriver)
}
return Client{
dialect: dialect,
driver: dbDriver,
db: db,
tableName: tableName,
@ -255,7 +262,7 @@ func (c Client) Insert(
records ...interface{},
) error {
for _, record := range records {
query, params, err := buildInsertQuery(c.tableName, record, "id")
query, params, err := buildInsertQuery(c.dialect, c.tableName, record, "id")
if err != nil {
return err
}
@ -263,8 +270,14 @@ func (c Client) Insert(
switch c.driver {
case "postgres":
err = c.insertOnPostgres(ctx, record, query, params)
default:
case "sqlite3":
err = c.insertWithLastInsertID(ctx, record, query, params)
default:
err = fmt.Errorf("unsupported driver `%s`", c.driver)
}
if err != nil {
return err
}
}
@ -295,10 +308,10 @@ func (c Client) insertOnPostgres(
v := reflect.ValueOf(record)
t := v.Type()
info := getTagInfoWithCache(tagInfoCache, t.Elem())
info := getCachedTagInfo(tagInfoCache, t.Elem())
fieldAddr := v.Elem().Field(info.Index["id"]).Addr()
return rows.Scan(fieldAddr)
return rows.Scan(fieldAddr.Interface())
}
func (c Client) insertWithLastInsertID(
@ -314,7 +327,7 @@ func (c Client) insertWithLastInsertID(
v := reflect.ValueOf(record)
t := v.Type()
info := getTagInfoWithCache(tagInfoCache, t.Elem())
info := getCachedTagInfo(tagInfoCache, t.Elem())
id, err := result.LastInsertId()
if err != nil {
@ -365,6 +378,7 @@ func (c Client) Update(
}
func buildInsertQuery(
dialect dialect,
tableName string,
record interface{},
idFieldNames ...string,
@ -374,9 +388,6 @@ func buildInsertQuery(
return "", nil, err
}
numAttrs := len(recordMap)
params = make([]interface{}, numAttrs)
for _, fieldName := range idFieldNames {
// Remove any ID field that was not set:
if reflect.ValueOf(recordMap[fieldName]).IsZero() {
@ -389,16 +400,23 @@ func buildInsertQuery(
columnNames = append(columnNames, col)
}
var valuesQuery []string
params = make([]interface{}, len(recordMap))
valuesQuery := make([]string, len(recordMap))
for i, col := range columnNames {
params[i] = recordMap[col]
valuesQuery = append(valuesQuery, "?")
valuesQuery[i] = dialect.Placeholder(i)
}
// Escape all cols to be sure they will be interpreted as column names:
escapedColumnNames := []string{}
for _, col := range columnNames {
escapedColumnNames = append(escapedColumnNames, dialect.Escape(col))
}
query = fmt.Sprintf(
"INSERT INTO `%s` (`%s`) VALUES (%s)",
tableName,
strings.Join(columnNames, "`, `"),
"INSERT INTO %s (%s) VALUES (%s)",
dialect.Escape(tableName),
strings.Join(escapedColumnNames, ", "),
strings.Join(valuesQuery, ", "),
)
@ -479,7 +497,7 @@ func StructToMap(obj interface{}) (map[string]interface{}, error) {
return nil, fmt.Errorf("input must be a struct or struct pointer")
}
info := getTagInfoWithCache(tagInfoCache, t)
info := getCachedTagInfo(tagInfoCache, t)
m := map[string]interface{}{}
for i := 0; i < v.NumField(); i++ {
@ -548,7 +566,7 @@ func FillStructWith(record interface{}, dbRow map[string]interface{}) error {
)
}
info := getTagInfoWithCache(tagInfoCache, t)
info := getCachedTagInfo(tagInfoCache, t)
for colName, attr := range dbRow {
attrValue := reflect.ValueOf(attr)
@ -690,7 +708,7 @@ func scanRows(rows *sql.Rows, record interface{}) error {
return fmt.Errorf("kissorm: expected to receive a pointer to slice of structs, but got: %T", record)
}
info := getTagInfoWithCache(tagInfoCache, t)
info := getCachedTagInfo(tagInfoCache, t)
scanArgs := []interface{}{}
for _, name := range names {
@ -700,7 +718,7 @@ func scanRows(rows *sql.Rows, record interface{}) error {
return rows.Scan(scanArgs...)
}
func getTagInfoWithCache(tagInfoCache map[reflect.Type]structInfo, key reflect.Type) structInfo {
func getCachedTagInfo(tagInfoCache map[reflect.Type]structInfo, key reflect.Type) structInfo {
info, found := tagInfoCache[key]
if !found {
info = getTagNames(key)

View File

@ -2,12 +2,14 @@ package kissorm
import (
"context"
"fmt"
"testing"
"time"
"github.com/ditointernet/go-assert"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/sqlite"
_ "github.com/lib/pq"
"github.com/vingarcia/kissorm/nullable"
)
@ -19,20 +21,17 @@ type User struct {
}
func TestQuery(t *testing.T) {
err := createTable()
err := createTable("sqlite3")
if err != nil {
t.Fatal("could not create test table!")
}
t.Run("should return 0 results correctly", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
var users []User
err := c.Query(ctx, &users, `SELECT * FROM users WHERE id=1;`)
assert.Equal(t, nil, err)
@ -45,7 +44,7 @@ func TestQuery(t *testing.T) {
})
t.Run("should return a user correctly", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
db.Create(&User{
@ -53,10 +52,7 @@ func TestQuery(t *testing.T) {
})
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
var users []User
err = c.Query(ctx, &users, `SELECT * FROM users WHERE name=?;`, "Bia")
@ -67,7 +63,7 @@ func TestQuery(t *testing.T) {
})
t.Run("should return multiple users correctly", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
db.Create(&User{
@ -79,10 +75,7 @@ func TestQuery(t *testing.T) {
})
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
var users []User
err = c.Query(ctx, &users, `SELECT * FROM users WHERE name like ?;`, "% Garcia")
@ -95,7 +88,7 @@ func TestQuery(t *testing.T) {
})
t.Run("should report error if input is not a pointer to a slice of structs", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
db.Create(&User{
@ -107,10 +100,7 @@ func TestQuery(t *testing.T) {
})
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "postgres", "users")
err = c.Query(ctx, &User{}, `SELECT * FROM users WHERE name like ?;`, "% Sá")
assert.NotEqual(t, nil, err)
@ -127,27 +117,24 @@ func TestQuery(t *testing.T) {
}
func TestQueryOne(t *testing.T) {
err := createTable()
err := createTable("sqlite3")
if err != nil {
t.Fatal("could not create test table!")
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should return RecordNotFoundErr when there are no results", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "postgres", "users")
u := User{}
err := c.QueryOne(ctx, &u, `SELECT * FROM users WHERE id=1;`)
assert.Equal(t, ErrRecordNotFound, err)
})
t.Run("should return a user correctly", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
db.Create(&User{
@ -155,10 +142,7 @@ func TestQueryOne(t *testing.T) {
})
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "postgres", "users")
u := User{}
err = c.QueryOne(ctx, &u, `SELECT * FROM users WHERE name=?;`, "Bia")
@ -168,7 +152,7 @@ func TestQueryOne(t *testing.T) {
})
t.Run("should report error if input is not a pointer to struct", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
db.Create(&User{
@ -180,10 +164,7 @@ func TestQueryOne(t *testing.T) {
})
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "postgres", "users")
err = c.QueryOne(ctx, &[]User{}, `SELECT * FROM users WHERE name like ?;`, "% Sá")
assert.NotEqual(t, nil, err)
@ -194,67 +175,62 @@ func TestQueryOne(t *testing.T) {
}
func TestInsert(t *testing.T) {
err := createTable()
if err != nil {
t.Fatal("could not create test table!")
for _, driver := range []string{"sqlite3", "postgres"} {
t.Run(driver, func(t *testing.T) {
err := createTable(driver)
if err != nil {
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should ignore empty lists of users", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestClient(db, driver, "users")
err = c.Insert(ctx)
assert.Equal(t, nil, err)
})
t.Run("should insert one user correctly", func(t *testing.T) {
db := connectDB(t, driver)
defer db.Close()
ctx := context.Background()
c := newTestClient(db, driver, "users")
u := User{
Name: "Fernanda",
}
err := c.Insert(ctx, &u)
assert.Equal(t, nil, err)
assert.NotEqual(t, 0, u.ID)
result := User{}
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
it.Scan(&result)
assert.Equal(t, nil, it.Error)
assert.Equal(t, u.Name, result.Name)
assert.Equal(t, u.CreatedAt.Format(time.RFC3339), result.CreatedAt.Format(time.RFC3339))
})
})
}
t.Run("should ignore empty lists of users", func(t *testing.T) {
db := connectDB(t)
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
err = c.Insert(ctx)
assert.Equal(t, nil, err)
})
t.Run("should insert one user correctly", func(t *testing.T) {
db := connectDB(t)
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
u := User{
Name: "Fernanda",
}
err := c.Insert(ctx, &u)
assert.Equal(t, nil, err)
assert.NotEqual(t, 0, u.ID)
result := User{}
it := c.db.Raw("SELECT * FROM users WHERE id=?", u.ID)
it.Scan(&result)
assert.Equal(t, nil, it.Error)
assert.Equal(t, u.Name, result.Name)
assert.Equal(t, u.CreatedAt.Format(time.RFC3339), result.CreatedAt.Format(time.RFC3339))
})
}
func TestDelete(t *testing.T) {
err := createTable()
err := createTable("sqlite3")
if err != nil {
t.Fatal("could not create test table!")
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should ignore empty lists of ids", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
u := User{
Name: "Won't be deleted",
@ -279,14 +255,11 @@ func TestDelete(t *testing.T) {
})
t.Run("should delete one id correctly", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
u1 := User{
Name: "Fernanda",
@ -335,14 +308,11 @@ func TestDelete(t *testing.T) {
})
t.Run("should delete multiple ids correctly", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
u1 := User{
Name: "Fernanda",
@ -394,20 +364,17 @@ func TestDelete(t *testing.T) {
}
func TestUpdate(t *testing.T) {
err := createTable()
err := createTable("sqlite3")
if err != nil {
t.Fatal("could not create test table!")
t.Fatal("could not create test table!, reason:", err.Error())
}
t.Run("should ignore empty lists of ids", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
u := User{
Name: "Thay",
@ -430,14 +397,11 @@ func TestUpdate(t *testing.T) {
})
t.Run("should update one user correctly", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
u := User{
Name: "Letícia",
@ -460,14 +424,11 @@ func TestUpdate(t *testing.T) {
})
t.Run("should update one user correctly", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
u := User{
Name: "Letícia",
@ -490,14 +451,11 @@ func TestUpdate(t *testing.T) {
})
t.Run("should ignore null pointers on partial updates", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
type partialUser struct {
ID uint `gorm:"id"`
@ -530,14 +488,11 @@ func TestUpdate(t *testing.T) {
})
t.Run("should update valid pointers on partial updates", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
type partialUser struct {
ID uint `gorm:"id"`
@ -569,14 +524,11 @@ func TestUpdate(t *testing.T) {
})
t.Run("should report database errors correctly", func(t *testing.T) {
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "non_existing_table",
}
c := newTestClient(db, "sqlite3", "non_existing_table")
err = c.Update(ctx, User{
ID: 1,
@ -650,19 +602,16 @@ func TestStructToMap(t *testing.T) {
func TestQueryChunks(t *testing.T) {
t.Run("should query a single row correctly", func(t *testing.T) {
err := createTable()
err := createTable("sqlite3")
if err != nil {
t.Fatal("could not create test table!")
t.Fatal("could not create test table!, reason:", err.Error())
}
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
_ = c.Insert(ctx, &User{Name: "User1"})
@ -689,19 +638,16 @@ func TestQueryChunks(t *testing.T) {
})
t.Run("should query one chunk correctly", func(t *testing.T) {
err := createTable()
err := createTable("sqlite3")
if err != nil {
t.Fatal("could not create test table!")
t.Fatal("could not create test table!, reason:", err.Error())
}
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
_ = c.Insert(ctx, &User{Name: "User1"})
_ = c.Insert(ctx, &User{Name: "User2"})
@ -730,19 +676,16 @@ func TestQueryChunks(t *testing.T) {
})
t.Run("should query chunks of 1 correctly", func(t *testing.T) {
err := createTable()
err := createTable("sqlite3")
if err != nil {
t.Fatal("could not create test table!")
t.Fatal("could not create test table!, reason:", err.Error())
}
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
_ = c.Insert(ctx, &User{Name: "User1"})
_ = c.Insert(ctx, &User{Name: "User2"})
@ -771,19 +714,16 @@ func TestQueryChunks(t *testing.T) {
})
t.Run("should load partially filled chunks correctly", func(t *testing.T) {
err := createTable()
err := createTable("sqlite3")
if err != nil {
t.Fatal("could not create test table!")
t.Fatal("could not create test table!, reason:", err.Error())
}
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
_ = c.Insert(ctx, &User{Name: "User1"})
_ = c.Insert(ctx, &User{Name: "User2"})
@ -815,19 +755,16 @@ func TestQueryChunks(t *testing.T) {
})
t.Run("should abort the first iteration when the callback returns an ErrAbortIteration", func(t *testing.T) {
err := createTable()
err := createTable("sqlite3")
if err != nil {
t.Fatal("could not create test table!")
t.Fatal("could not create test table!, reason:", err.Error())
}
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
_ = c.Insert(ctx, &User{Name: "User1"})
_ = c.Insert(ctx, &User{Name: "User2"})
@ -857,19 +794,16 @@ func TestQueryChunks(t *testing.T) {
})
t.Run("should abort the last iteration when the callback returns an ErrAbortIteration", func(t *testing.T) {
err := createTable()
err := createTable("sqlite3")
if err != nil {
t.Fatal("could not create test table!")
t.Fatal("could not create test table!, reason:", err.Error())
}
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
ctx := context.Background()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
_ = c.Insert(ctx, &User{Name: "User1"})
_ = c.Insert(ctx, &User{Name: "User2"})
@ -928,18 +862,15 @@ func TestFillSliceWith(t *testing.T) {
func TestScanRows(t *testing.T) {
t.Run("should scan users correctly", func(t *testing.T) {
err := createTable()
err := createTable("sqlite3")
if err != nil {
t.Fatal("could not create test table!")
t.Fatal("could not create test table!, reason:", err.Error())
}
ctx := context.TODO()
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
c := Client{
db: db,
tableName: "users",
}
c := newTestClient(db, "sqlite3", "users")
_ = c.Insert(ctx, &User{Name: "User1", Age: 22})
_ = c.Insert(ctx, &User{Name: "User2", Age: 14})
_ = c.Insert(ctx, &User{Name: "User3", Age: 43})
@ -958,13 +889,13 @@ func TestScanRows(t *testing.T) {
})
t.Run("should report error for closed rows", func(t *testing.T) {
err := createTable()
err := createTable("sqlite3")
if err != nil {
t.Fatal("could not create test table!")
t.Fatal("could not create test table!, reason:", err.Error())
}
ctx := context.TODO()
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
rows, err := db.DB().QueryContext(ctx, "select * from users where name='User2'")
@ -978,13 +909,13 @@ func TestScanRows(t *testing.T) {
})
t.Run("should report if record is not a pointer", func(t *testing.T) {
err := createTable()
err := createTable("sqlite3")
if err != nil {
t.Fatal("could not create test table!")
t.Fatal("could not create test table!, reason:", err.Error())
}
ctx := context.TODO()
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
rows, err := db.DB().QueryContext(ctx, "select * from users where name='User2'")
@ -996,13 +927,13 @@ func TestScanRows(t *testing.T) {
})
t.Run("should report if record is not a pointer to struct", func(t *testing.T) {
err := createTable()
err := createTable("sqlite3")
if err != nil {
t.Fatal("could not create test table!")
t.Fatal("could not create test table!, reason:", err.Error())
}
ctx := context.TODO()
db := connectDB(t)
db := connectDB(t, "sqlite3")
defer db.Close()
rows, err := db.DB().QueryContext(ctx, "select * from users where name='User2'")
@ -1014,8 +945,18 @@ func TestScanRows(t *testing.T) {
})
}
func createTable() error {
db, err := gorm.Open("sqlite3", "/tmp/test.db")
var connectionString = map[string]string{
"postgres": "host=localhost port=5432 user=postgres dbname=kissorm sslmode=disable",
"sqlite3": "/tmp/kissorm.db",
}
func createTable(driver string) error {
connStr := connectionString[driver]
if connStr == "" {
return fmt.Errorf("unsupported driver: '%s'", driver)
}
db, err := gorm.Open(driver, connStr)
if err != nil {
return err
}
@ -1027,8 +968,22 @@ func createTable() error {
return nil
}
func connectDB(t *testing.T) *gorm.DB {
db, err := gorm.Open("sqlite3", "/tmp/test.db")
func newTestClient(db *gorm.DB, driver string, tableName string) Client {
return Client{
driver: driver,
dialect: getDriverDialect(driver),
db: db,
tableName: tableName,
}
}
func connectDB(t *testing.T, driver string) *gorm.DB {
connStr := connectionString[driver]
if connStr == "" {
panic(fmt.Sprintf("unsupported driver: '%s'", driver))
}
db, err := gorm.Open(driver, connStr)
if err != nil {
t.Fatal(err.Error())
}