Add the Insert struct to the kbuilder package

pull/2/head
Vinícius Garcia 2021-09-07 12:01:34 -03:00
parent 22fa8fdfa4
commit 6935bddf29
5 changed files with 417 additions and 216 deletions

105
kbuilder/insert.go Normal file
View File

@ -0,0 +1,105 @@
package kbuilder
import (
"fmt"
"reflect"
"strings"
"github.com/vingarcia/ksql"
"github.com/vingarcia/ksql/kstructs"
)
// Insert is the struct template for building INSERT queries
type Insert struct {
// Into expects a table name, e.g. "users"
Into string
// Data expected either a single record annotated with `ksql` tags
// or a list of records annotated likewise.
Data interface{}
}
// Build is a utility function for finding the dialect based on the driver and
// then calling BuildQuery(dialect)
func (i Insert) Build(driver string) (sqlQuery string, params []interface{}, _ error) {
dialect, err := ksql.GetDriverDialect(driver)
if err != nil {
return "", nil, err
}
return i.BuildQuery(dialect)
}
// BuildQuery implements the queryBuilder interface
func (i Insert) BuildQuery(dialect ksql.Dialect) (sqlQuery string, params []interface{}, _ error) {
var b strings.Builder
b.WriteString("INSERT INTO " + dialect.Escape(i.Into))
if i.Into == "" {
return "", nil, fmt.Errorf(
"expected the Into attr to contain the tablename, but got an empty string instead",
)
}
if i.Data == nil {
return "", nil, fmt.Errorf(
"expected the Data attr to contain a struct or a list of structs, but got `%v`",
i.Data,
)
}
v := reflect.ValueOf(i.Data)
t := v.Type()
if t.Kind() != reflect.Slice {
// Convert it to a slice of a single element:
v = reflect.Append(reflect.MakeSlice(reflect.SliceOf(t), 0, 1), v)
} else {
t = t.Elem()
}
if v.Len() == 0 {
return "", nil, fmt.Errorf(
"can't create an insertion query from an empty list of values",
)
}
isPtr := false
if t.Kind() == reflect.Ptr {
isPtr = true
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return "", nil, fmt.Errorf("expected Data attr to be a struct or slice of structs but got: %v", t)
}
info := kstructs.GetTagInfo(t)
b.WriteString(" (")
var escapedNames []string
for i := 0; i < info.NumFields(); i++ {
name := info.ByIndex(i).Name
escapedNames = append(escapedNames, dialect.Escape(name))
}
b.WriteString(strings.Join(escapedNames, ", "))
b.WriteString(") VALUES ")
params = []interface{}{}
values := []string{}
for i := 0; i < v.Len(); i++ {
record := v.Index(i)
if isPtr {
record = record.Elem()
}
placeholders := []string{}
for j := 0; j < info.NumFields(); j++ {
placeholders = append(placeholders, dialect.Placeholder(len(params)))
params = append(params, record.Field(j).Interface())
}
values = append(values, "("+strings.Join(placeholders, ", ")+")")
}
b.WriteString(strings.Join(values, ", "))
return b.String(), params, nil
}

92
kbuilder/insert_test.go Normal file
View File

@ -0,0 +1,92 @@
package kbuilder_test
import (
"testing"
"github.com/tj/assert"
"github.com/vingarcia/ksql/kbuilder"
)
func TestInsertQuery(t *testing.T) {
tests := []struct {
desc string
query kbuilder.Insert
expectedQuery string
expectedParams []interface{}
expectedErr bool
}{
{
desc: "should build queries witha single record correctly",
query: kbuilder.Insert{
Into: "users",
Data: &User{
Name: "foo",
Age: 42,
},
},
expectedQuery: `INSERT INTO "users" ("name", "age") VALUES ($1, $2)`,
expectedParams: []interface{}{"foo", 42},
},
{
desc: "should build queries with multiple records correctly",
query: kbuilder.Insert{
Into: "users",
Data: []User{
{
Name: "foo",
Age: 42,
},
{
Name: "bar",
Age: 43,
},
},
},
expectedQuery: `INSERT INTO "users" ("name", "age") VALUES ($1, $2), ($3, $4)`,
expectedParams: []interface{}{"foo", 42, "bar", 43},
},
/* * * * * Testing error cases: * * * * */
{
desc: "should report error if the `Data` attribute is missing",
query: kbuilder.Insert{
Into: "users",
},
expectedErr: true,
},
{
desc: "should report error if the `Into` attribute is missing",
query: kbuilder.Insert{
Data: &User{
Name: "foo",
Age: 42,
},
},
expectedErr: true,
},
{
desc: "should report error Data contains an empty list",
query: kbuilder.Insert{
Into: "users",
Data: []User{},
},
expectedErr: true,
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
b, err := kbuilder.New("postgres")
assert.Equal(t, nil, err)
query, params, err := b.Build(test.query)
expectError(t, test.expectedErr, err)
assert.Equal(t, test.expectedQuery, query)
assert.Equal(t, test.expectedParams, params)
})
}
}

View File

@ -1,14 +1,7 @@
package kbuilder
import (
"fmt"
"reflect"
"strconv"
"strings"
"github.com/pkg/errors"
"github.com/vingarcia/ksql"
"github.com/vingarcia/ksql/kstructs"
)
// Builder is the basic container for injecting
@ -38,209 +31,3 @@ func New(driver string) (Builder, error) {
func (builder *Builder) Build(query queryBuilder) (sqlQuery string, params []interface{}, _ error) {
return query.BuildQuery(builder.dialect)
}
// Query is is the struct template for building SELECT queries.
type Query struct {
// Select expects either a struct using the `ksql` tags
// or a string listing the column names using SQL syntax,
// e.g.: `id, username, address`
Select interface{}
// From expects the FROM clause from an SQL query, e.g. `users JOIN posts USING(post_id)`
From string
// Where expects a list of WhereQuery instances built
// by the public Where() function.
Where WhereQueries
Limit int
Offset int
OrderBy OrderByQuery
}
// Build is a utility function for finding the dialect based on the driver and
// then calling BuildQuery(dialect)
func (q Query) Build(driver string) (sqlQuery string, params []interface{}, _ error) {
dialect, err := ksql.GetDriverDialect(driver)
if err != nil {
return "", nil, err
}
return q.BuildQuery(dialect)
}
// BuildQuery implements the QueryBuilder interface
func (q Query) BuildQuery(dialect ksql.Dialect) (sqlQuery string, params []interface{}, _ error) {
var b strings.Builder
switch v := q.Select.(type) {
case string:
b.WriteString("SELECT " + v)
default:
selectQuery, err := buildSelectQuery(v, dialect)
if err != nil {
return "", nil, errors.Wrap(err, "error reading the Select field")
}
b.WriteString("SELECT " + selectQuery)
}
b.WriteString(" FROM " + q.From)
if len(q.Where) > 0 {
var whereQuery string
whereQuery, params = q.Where.build(dialect)
b.WriteString(" WHERE " + whereQuery)
}
if strings.TrimSpace(q.From) == "" {
return "", nil, fmt.Errorf("the From field is mandatory for every query")
}
if q.OrderBy.fields != "" {
b.WriteString(" ORDER BY " + q.OrderBy.fields)
if q.OrderBy.desc {
b.WriteString(" DESC")
}
}
if q.Limit > 0 {
b.WriteString(" LIMIT " + strconv.Itoa(q.Limit))
}
if q.Offset > 0 {
b.WriteString(" OFFSET " + strconv.Itoa(q.Offset))
}
return b.String(), params, nil
}
// WhereQuery represents a single condition in a WHERE expression.
type WhereQuery struct {
// Accepts any SQL boolean expression
// This expression may optionally contain
// string formatting directives %s and only %s.
//
// For each of these directives we expect a new param
// on the params list below.
//
// In the resulting query each %s will be properly replaced
// by placeholders according to the database driver, e.g. `$1`
// for postgres or `?` for sqlite3.
cond string
params []interface{}
}
// WhereQueries is the helper for creating complex WHERE queries
// in a dynamic way.
type WhereQueries []WhereQuery
func (w WhereQueries) build(dialect ksql.Dialect) (query string, params []interface{}) {
var conds []string
for _, whereQuery := range w {
var placeholders []interface{}
for i := range whereQuery.params {
placeholders = append(placeholders, dialect.Placeholder(len(params)+i))
}
conds = append(conds, fmt.Sprintf(whereQuery.cond, placeholders...))
params = append(params, whereQuery.params...)
}
return strings.Join(conds, " AND "), params
}
// Where adds a new bollean condition to an existing
// WhereQueries helper.
func (w WhereQueries) Where(cond string, params ...interface{}) WhereQueries {
return append(w, WhereQuery{
cond: cond,
params: params,
})
}
// WhereIf condionally adds a new boolean expression to the WhereQueries helper.
func (w WhereQueries) WhereIf(cond string, param interface{}) WhereQueries {
if param == nil || reflect.ValueOf(param).IsNil() {
return w
}
return append(w, WhereQuery{
cond: cond,
params: []interface{}{param},
})
}
// Where adds a new bollean condition to an existing
// WhereQueries helper.
func Where(cond string, params ...interface{}) WhereQueries {
return WhereQueries{{
cond: cond,
params: params,
}}
}
// WhereIf condionally adds a new boolean expression to the WhereQueries helper
func WhereIf(cond string, param interface{}) WhereQueries {
if param == nil || reflect.ValueOf(param).IsNil() {
return WhereQueries{}
}
return WhereQueries{{
cond: cond,
params: []interface{}{param},
}}
}
// OrderByQuery represents the ORDER BY part of the query
type OrderByQuery struct {
fields string
desc bool
}
// Desc is a setter function for configuring the
// ORDER BY part of the query as DESC
func (o OrderByQuery) Desc() OrderByQuery {
return OrderByQuery{
fields: o.fields,
desc: true,
}
}
// OrderBy is a helper for building the ORDER BY
// part of the query.
func OrderBy(fields string) OrderByQuery {
return OrderByQuery{
fields: fields,
desc: false,
}
}
var cachedSelectQueries = map[reflect.Type]string{}
// Builds the select query using cached info so that its efficient
func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) {
t := reflect.TypeOf(obj)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return "", fmt.Errorf("expected to receive a pointer to struct, but got: %T", obj)
}
if query, found := cachedSelectQueries[t]; found {
return query, nil
}
info := kstructs.GetTagInfo(t)
var escapedNames []string
for i := 0; i < info.NumFields(); i++ {
name := info.ByIndex(i).Name
escapedNames = append(escapedNames, dialect.Escape(name))
}
query := strings.Join(escapedNames, ", ")
cachedSelectQueries[t] = query
return query, nil
}

218
kbuilder/query.go Normal file
View File

@ -0,0 +1,218 @@
package kbuilder
import (
"fmt"
"reflect"
"strconv"
"strings"
"github.com/pkg/errors"
"github.com/vingarcia/ksql"
"github.com/vingarcia/ksql/kstructs"
)
// Query is is the struct template for building SELECT queries.
type Query struct {
// Select expects either a struct using the `ksql` tags
// or a string listing the column names using SQL syntax,
// e.g.: `id, username, address`
Select interface{}
// From expects the FROM clause from an SQL query, e.g. `users JOIN posts USING(post_id)`
From string
// Where expects a list of WhereQuery instances built
// by the public Where() function.
Where WhereQueries
Limit int
Offset int
OrderBy OrderByQuery
}
// Build is a utility function for finding the dialect based on the driver and
// then calling BuildQuery(dialect)
func (q Query) Build(driver string) (sqlQuery string, params []interface{}, _ error) {
dialect, err := ksql.GetDriverDialect(driver)
if err != nil {
return "", nil, err
}
return q.BuildQuery(dialect)
}
// BuildQuery implements the queryBuilder interface
func (q Query) BuildQuery(dialect ksql.Dialect) (sqlQuery string, params []interface{}, _ error) {
var b strings.Builder
switch v := q.Select.(type) {
case string:
b.WriteString("SELECT " + v)
default:
selectQuery, err := buildSelectQuery(v, dialect)
if err != nil {
return "", nil, errors.Wrap(err, "error reading the Select field")
}
b.WriteString("SELECT " + selectQuery)
}
b.WriteString(" FROM " + q.From)
if len(q.Where) > 0 {
var whereQuery string
whereQuery, params = q.Where.build(dialect)
b.WriteString(" WHERE " + whereQuery)
}
if strings.TrimSpace(q.From) == "" {
return "", nil, fmt.Errorf("the From field is mandatory for every query")
}
if q.OrderBy.fields != "" {
b.WriteString(" ORDER BY " + q.OrderBy.fields)
if q.OrderBy.desc {
b.WriteString(" DESC")
}
}
if q.Limit > 0 {
b.WriteString(" LIMIT " + strconv.Itoa(q.Limit))
}
if q.Offset > 0 {
b.WriteString(" OFFSET " + strconv.Itoa(q.Offset))
}
return b.String(), params, nil
}
// WhereQuery represents a single condition in a WHERE expression.
type WhereQuery struct {
// Accepts any SQL boolean expression
// This expression may optionally contain
// string formatting directives %s and only %s.
//
// For each of these directives we expect a new param
// on the params list below.
//
// In the resulting query each %s will be properly replaced
// by placeholders according to the database driver, e.g. `$1`
// for postgres or `?` for sqlite3.
cond string
params []interface{}
}
// WhereQueries is the helper for creating complex WHERE queries
// in a dynamic way.
type WhereQueries []WhereQuery
func (w WhereQueries) build(dialect ksql.Dialect) (query string, params []interface{}) {
var conds []string
for _, whereQuery := range w {
var placeholders []interface{}
for i := range whereQuery.params {
placeholders = append(placeholders, dialect.Placeholder(len(params)+i))
}
conds = append(conds, fmt.Sprintf(whereQuery.cond, placeholders...))
params = append(params, whereQuery.params...)
}
return strings.Join(conds, " AND "), params
}
// Where adds a new bollean condition to an existing
// WhereQueries helper.
func (w WhereQueries) Where(cond string, params ...interface{}) WhereQueries {
return append(w, WhereQuery{
cond: cond,
params: params,
})
}
// WhereIf condionally adds a new boolean expression to the WhereQueries helper.
func (w WhereQueries) WhereIf(cond string, param interface{}) WhereQueries {
if param == nil || reflect.ValueOf(param).IsNil() {
return w
}
return append(w, WhereQuery{
cond: cond,
params: []interface{}{param},
})
}
// Where adds a new bollean condition to an existing
// WhereQueries helper.
func Where(cond string, params ...interface{}) WhereQueries {
return WhereQueries{{
cond: cond,
params: params,
}}
}
// WhereIf condionally adds a new boolean expression to the WhereQueries helper
func WhereIf(cond string, param interface{}) WhereQueries {
if param == nil || reflect.ValueOf(param).IsNil() {
return WhereQueries{}
}
return WhereQueries{{
cond: cond,
params: []interface{}{param},
}}
}
// OrderByQuery represents the ORDER BY part of the query
type OrderByQuery struct {
fields string
desc bool
}
// Desc is a setter function for configuring the
// ORDER BY part of the query as DESC
func (o OrderByQuery) Desc() OrderByQuery {
return OrderByQuery{
fields: o.fields,
desc: true,
}
}
// OrderBy is a helper for building the ORDER BY
// part of the query.
func OrderBy(fields string) OrderByQuery {
return OrderByQuery{
fields: fields,
desc: false,
}
}
var cachedSelectQueries = map[reflect.Type]string{}
// Builds the select query using cached info so that its efficient
func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) {
t := reflect.TypeOf(obj)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return "", fmt.Errorf("expected to receive a pointer to struct, but got: %T", obj)
}
if query, found := cachedSelectQueries[t]; found {
return query, nil
}
info := kstructs.GetTagInfo(t)
var escapedNames []string
for i := 0; i < info.NumFields(); i++ {
name := info.ByIndex(i).Name
escapedNames = append(escapedNames, dialect.Escape(name))
}
query := strings.Join(escapedNames, ", ")
cachedSelectQueries[t] = query
return query, nil
}

View File

@ -11,13 +11,12 @@ import (
type User struct {
Name string `ksql:"name"`
Age string `ksql:"age"`
Age int `ksql:"age"`
}
var nullField *int
func TestBuilder(t *testing.T) {
func TestSelectQuery(t *testing.T) {
tests := []struct {
desc string
query kbuilder.Query