diff --git a/kbuilder/insert.go b/kbuilder/insert.go new file mode 100644 index 0000000..d65c483 --- /dev/null +++ b/kbuilder/insert.go @@ -0,0 +1,105 @@ +package kbuilder + +import ( + "fmt" + "reflect" + "strings" + + "github.com/vingarcia/ksql" + "github.com/vingarcia/ksql/kstructs" +) + +// Insert is the struct template for building INSERT queries +type Insert struct { + // Into expects a table name, e.g. "users" + Into string + + // Data expected either a single record annotated with `ksql` tags + // or a list of records annotated likewise. + Data interface{} +} + +// Build is a utility function for finding the dialect based on the driver and +// then calling BuildQuery(dialect) +func (i Insert) Build(driver string) (sqlQuery string, params []interface{}, _ error) { + dialect, err := ksql.GetDriverDialect(driver) + if err != nil { + return "", nil, err + } + + return i.BuildQuery(dialect) +} + +// BuildQuery implements the queryBuilder interface +func (i Insert) BuildQuery(dialect ksql.Dialect) (sqlQuery string, params []interface{}, _ error) { + var b strings.Builder + b.WriteString("INSERT INTO " + dialect.Escape(i.Into)) + + if i.Into == "" { + return "", nil, fmt.Errorf( + "expected the Into attr to contain the tablename, but got an empty string instead", + ) + } + + if i.Data == nil { + return "", nil, fmt.Errorf( + "expected the Data attr to contain a struct or a list of structs, but got `%v`", + i.Data, + ) + } + + v := reflect.ValueOf(i.Data) + t := v.Type() + if t.Kind() != reflect.Slice { + // Convert it to a slice of a single element: + v = reflect.Append(reflect.MakeSlice(reflect.SliceOf(t), 0, 1), v) + } else { + t = t.Elem() + } + + if v.Len() == 0 { + return "", nil, fmt.Errorf( + "can't create an insertion query from an empty list of values", + ) + } + + isPtr := false + if t.Kind() == reflect.Ptr { + isPtr = true + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + return "", nil, fmt.Errorf("expected Data attr to be a struct or slice of structs but got: %v", t) + } + + info := kstructs.GetTagInfo(t) + + b.WriteString(" (") + var escapedNames []string + for i := 0; i < info.NumFields(); i++ { + name := info.ByIndex(i).Name + escapedNames = append(escapedNames, dialect.Escape(name)) + } + b.WriteString(strings.Join(escapedNames, ", ")) + b.WriteString(") VALUES ") + + params = []interface{}{} + values := []string{} + for i := 0; i < v.Len(); i++ { + record := v.Index(i) + if isPtr { + record = record.Elem() + } + + placeholders := []string{} + for j := 0; j < info.NumFields(); j++ { + placeholders = append(placeholders, dialect.Placeholder(len(params))) + params = append(params, record.Field(j).Interface()) + } + values = append(values, "("+strings.Join(placeholders, ", ")+")") + } + b.WriteString(strings.Join(values, ", ")) + + return b.String(), params, nil +} diff --git a/kbuilder/insert_test.go b/kbuilder/insert_test.go new file mode 100644 index 0000000..745e8bf --- /dev/null +++ b/kbuilder/insert_test.go @@ -0,0 +1,92 @@ +package kbuilder_test + +import ( + "testing" + + "github.com/tj/assert" + "github.com/vingarcia/ksql/kbuilder" +) + +func TestInsertQuery(t *testing.T) { + tests := []struct { + desc string + query kbuilder.Insert + expectedQuery string + expectedParams []interface{} + expectedErr bool + }{ + { + desc: "should build queries witha single record correctly", + query: kbuilder.Insert{ + Into: "users", + Data: &User{ + Name: "foo", + Age: 42, + }, + }, + expectedQuery: `INSERT INTO "users" ("name", "age") VALUES ($1, $2)`, + expectedParams: []interface{}{"foo", 42}, + }, + { + desc: "should build queries with multiple records correctly", + query: kbuilder.Insert{ + Into: "users", + Data: []User{ + { + Name: "foo", + Age: 42, + }, + { + Name: "bar", + Age: 43, + }, + }, + }, + expectedQuery: `INSERT INTO "users" ("name", "age") VALUES ($1, $2), ($3, $4)`, + expectedParams: []interface{}{"foo", 42, "bar", 43}, + }, + + /* * * * * Testing error cases: * * * * */ + { + desc: "should report error if the `Data` attribute is missing", + query: kbuilder.Insert{ + Into: "users", + }, + + expectedErr: true, + }, + { + desc: "should report error if the `Into` attribute is missing", + query: kbuilder.Insert{ + Data: &User{ + Name: "foo", + Age: 42, + }, + }, + + expectedErr: true, + }, + { + desc: "should report error Data contains an empty list", + query: kbuilder.Insert{ + Into: "users", + Data: []User{}, + }, + + expectedErr: true, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + b, err := kbuilder.New("postgres") + assert.Equal(t, nil, err) + + query, params, err := b.Build(test.query) + + expectError(t, test.expectedErr, err) + assert.Equal(t, test.expectedQuery, query) + assert.Equal(t, test.expectedParams, params) + }) + } +} diff --git a/kbuilder/kbuilder.go b/kbuilder/kbuilder.go index ed188e4..0cf84de 100644 --- a/kbuilder/kbuilder.go +++ b/kbuilder/kbuilder.go @@ -1,14 +1,7 @@ package kbuilder import ( - "fmt" - "reflect" - "strconv" - "strings" - - "github.com/pkg/errors" "github.com/vingarcia/ksql" - "github.com/vingarcia/ksql/kstructs" ) // Builder is the basic container for injecting @@ -38,209 +31,3 @@ func New(driver string) (Builder, error) { func (builder *Builder) Build(query queryBuilder) (sqlQuery string, params []interface{}, _ error) { return query.BuildQuery(builder.dialect) } - -// Query is is the struct template for building SELECT queries. -type Query struct { - // Select expects either a struct using the `ksql` tags - // or a string listing the column names using SQL syntax, - // e.g.: `id, username, address` - Select interface{} - - // From expects the FROM clause from an SQL query, e.g. `users JOIN posts USING(post_id)` - From string - - // Where expects a list of WhereQuery instances built - // by the public Where() function. - Where WhereQueries - - Limit int - Offset int - OrderBy OrderByQuery -} - -// Build is a utility function for finding the dialect based on the driver and -// then calling BuildQuery(dialect) -func (q Query) Build(driver string) (sqlQuery string, params []interface{}, _ error) { - dialect, err := ksql.GetDriverDialect(driver) - if err != nil { - return "", nil, err - } - - return q.BuildQuery(dialect) -} - -// BuildQuery implements the QueryBuilder interface -func (q Query) BuildQuery(dialect ksql.Dialect) (sqlQuery string, params []interface{}, _ error) { - var b strings.Builder - - switch v := q.Select.(type) { - case string: - b.WriteString("SELECT " + v) - default: - selectQuery, err := buildSelectQuery(v, dialect) - if err != nil { - return "", nil, errors.Wrap(err, "error reading the Select field") - } - b.WriteString("SELECT " + selectQuery) - } - - b.WriteString(" FROM " + q.From) - - if len(q.Where) > 0 { - var whereQuery string - whereQuery, params = q.Where.build(dialect) - b.WriteString(" WHERE " + whereQuery) - } - - if strings.TrimSpace(q.From) == "" { - return "", nil, fmt.Errorf("the From field is mandatory for every query") - } - - if q.OrderBy.fields != "" { - b.WriteString(" ORDER BY " + q.OrderBy.fields) - if q.OrderBy.desc { - b.WriteString(" DESC") - } - } - - if q.Limit > 0 { - b.WriteString(" LIMIT " + strconv.Itoa(q.Limit)) - } - - if q.Offset > 0 { - b.WriteString(" OFFSET " + strconv.Itoa(q.Offset)) - } - - return b.String(), params, nil -} - -// WhereQuery represents a single condition in a WHERE expression. -type WhereQuery struct { - // Accepts any SQL boolean expression - // This expression may optionally contain - // string formatting directives %s and only %s. - // - // For each of these directives we expect a new param - // on the params list below. - // - // In the resulting query each %s will be properly replaced - // by placeholders according to the database driver, e.g. `$1` - // for postgres or `?` for sqlite3. - cond string - params []interface{} -} - -// WhereQueries is the helper for creating complex WHERE queries -// in a dynamic way. -type WhereQueries []WhereQuery - -func (w WhereQueries) build(dialect ksql.Dialect) (query string, params []interface{}) { - var conds []string - for _, whereQuery := range w { - var placeholders []interface{} - for i := range whereQuery.params { - placeholders = append(placeholders, dialect.Placeholder(len(params)+i)) - } - - conds = append(conds, fmt.Sprintf(whereQuery.cond, placeholders...)) - params = append(params, whereQuery.params...) - } - - return strings.Join(conds, " AND "), params -} - -// Where adds a new bollean condition to an existing -// WhereQueries helper. -func (w WhereQueries) Where(cond string, params ...interface{}) WhereQueries { - return append(w, WhereQuery{ - cond: cond, - params: params, - }) -} - -// WhereIf condionally adds a new boolean expression to the WhereQueries helper. -func (w WhereQueries) WhereIf(cond string, param interface{}) WhereQueries { - if param == nil || reflect.ValueOf(param).IsNil() { - return w - } - - return append(w, WhereQuery{ - cond: cond, - params: []interface{}{param}, - }) -} - -// Where adds a new bollean condition to an existing -// WhereQueries helper. -func Where(cond string, params ...interface{}) WhereQueries { - return WhereQueries{{ - cond: cond, - params: params, - }} -} - -// WhereIf condionally adds a new boolean expression to the WhereQueries helper -func WhereIf(cond string, param interface{}) WhereQueries { - if param == nil || reflect.ValueOf(param).IsNil() { - return WhereQueries{} - } - - return WhereQueries{{ - cond: cond, - params: []interface{}{param}, - }} -} - -// OrderByQuery represents the ORDER BY part of the query -type OrderByQuery struct { - fields string - desc bool -} - -// Desc is a setter function for configuring the -// ORDER BY part of the query as DESC -func (o OrderByQuery) Desc() OrderByQuery { - return OrderByQuery{ - fields: o.fields, - desc: true, - } -} - -// OrderBy is a helper for building the ORDER BY -// part of the query. -func OrderBy(fields string) OrderByQuery { - return OrderByQuery{ - fields: fields, - desc: false, - } -} - -var cachedSelectQueries = map[reflect.Type]string{} - -// Builds the select query using cached info so that its efficient -func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) { - t := reflect.TypeOf(obj) - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - - if t.Kind() != reflect.Struct { - return "", fmt.Errorf("expected to receive a pointer to struct, but got: %T", obj) - } - - if query, found := cachedSelectQueries[t]; found { - return query, nil - } - - info := kstructs.GetTagInfo(t) - - var escapedNames []string - for i := 0; i < info.NumFields(); i++ { - name := info.ByIndex(i).Name - escapedNames = append(escapedNames, dialect.Escape(name)) - } - - query := strings.Join(escapedNames, ", ") - cachedSelectQueries[t] = query - return query, nil -} diff --git a/kbuilder/query.go b/kbuilder/query.go new file mode 100644 index 0000000..2fc4d28 --- /dev/null +++ b/kbuilder/query.go @@ -0,0 +1,218 @@ +package kbuilder + +import ( + "fmt" + "reflect" + "strconv" + "strings" + + "github.com/pkg/errors" + "github.com/vingarcia/ksql" + "github.com/vingarcia/ksql/kstructs" +) + +// Query is is the struct template for building SELECT queries. +type Query struct { + // Select expects either a struct using the `ksql` tags + // or a string listing the column names using SQL syntax, + // e.g.: `id, username, address` + Select interface{} + + // From expects the FROM clause from an SQL query, e.g. `users JOIN posts USING(post_id)` + From string + + // Where expects a list of WhereQuery instances built + // by the public Where() function. + Where WhereQueries + + Limit int + Offset int + OrderBy OrderByQuery +} + +// Build is a utility function for finding the dialect based on the driver and +// then calling BuildQuery(dialect) +func (q Query) Build(driver string) (sqlQuery string, params []interface{}, _ error) { + dialect, err := ksql.GetDriverDialect(driver) + if err != nil { + return "", nil, err + } + + return q.BuildQuery(dialect) +} + +// BuildQuery implements the queryBuilder interface +func (q Query) BuildQuery(dialect ksql.Dialect) (sqlQuery string, params []interface{}, _ error) { + var b strings.Builder + + switch v := q.Select.(type) { + case string: + b.WriteString("SELECT " + v) + default: + selectQuery, err := buildSelectQuery(v, dialect) + if err != nil { + return "", nil, errors.Wrap(err, "error reading the Select field") + } + b.WriteString("SELECT " + selectQuery) + } + + b.WriteString(" FROM " + q.From) + + if len(q.Where) > 0 { + var whereQuery string + whereQuery, params = q.Where.build(dialect) + b.WriteString(" WHERE " + whereQuery) + } + + if strings.TrimSpace(q.From) == "" { + return "", nil, fmt.Errorf("the From field is mandatory for every query") + } + + if q.OrderBy.fields != "" { + b.WriteString(" ORDER BY " + q.OrderBy.fields) + if q.OrderBy.desc { + b.WriteString(" DESC") + } + } + + if q.Limit > 0 { + b.WriteString(" LIMIT " + strconv.Itoa(q.Limit)) + } + + if q.Offset > 0 { + b.WriteString(" OFFSET " + strconv.Itoa(q.Offset)) + } + + return b.String(), params, nil +} + +// WhereQuery represents a single condition in a WHERE expression. +type WhereQuery struct { + // Accepts any SQL boolean expression + // This expression may optionally contain + // string formatting directives %s and only %s. + // + // For each of these directives we expect a new param + // on the params list below. + // + // In the resulting query each %s will be properly replaced + // by placeholders according to the database driver, e.g. `$1` + // for postgres or `?` for sqlite3. + cond string + params []interface{} +} + +// WhereQueries is the helper for creating complex WHERE queries +// in a dynamic way. +type WhereQueries []WhereQuery + +func (w WhereQueries) build(dialect ksql.Dialect) (query string, params []interface{}) { + var conds []string + for _, whereQuery := range w { + var placeholders []interface{} + for i := range whereQuery.params { + placeholders = append(placeholders, dialect.Placeholder(len(params)+i)) + } + + conds = append(conds, fmt.Sprintf(whereQuery.cond, placeholders...)) + params = append(params, whereQuery.params...) + } + + return strings.Join(conds, " AND "), params +} + +// Where adds a new bollean condition to an existing +// WhereQueries helper. +func (w WhereQueries) Where(cond string, params ...interface{}) WhereQueries { + return append(w, WhereQuery{ + cond: cond, + params: params, + }) +} + +// WhereIf condionally adds a new boolean expression to the WhereQueries helper. +func (w WhereQueries) WhereIf(cond string, param interface{}) WhereQueries { + if param == nil || reflect.ValueOf(param).IsNil() { + return w + } + + return append(w, WhereQuery{ + cond: cond, + params: []interface{}{param}, + }) +} + +// Where adds a new bollean condition to an existing +// WhereQueries helper. +func Where(cond string, params ...interface{}) WhereQueries { + return WhereQueries{{ + cond: cond, + params: params, + }} +} + +// WhereIf condionally adds a new boolean expression to the WhereQueries helper +func WhereIf(cond string, param interface{}) WhereQueries { + if param == nil || reflect.ValueOf(param).IsNil() { + return WhereQueries{} + } + + return WhereQueries{{ + cond: cond, + params: []interface{}{param}, + }} +} + +// OrderByQuery represents the ORDER BY part of the query +type OrderByQuery struct { + fields string + desc bool +} + +// Desc is a setter function for configuring the +// ORDER BY part of the query as DESC +func (o OrderByQuery) Desc() OrderByQuery { + return OrderByQuery{ + fields: o.fields, + desc: true, + } +} + +// OrderBy is a helper for building the ORDER BY +// part of the query. +func OrderBy(fields string) OrderByQuery { + return OrderByQuery{ + fields: fields, + desc: false, + } +} + +var cachedSelectQueries = map[reflect.Type]string{} + +// Builds the select query using cached info so that its efficient +func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) { + t := reflect.TypeOf(obj) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + return "", fmt.Errorf("expected to receive a pointer to struct, but got: %T", obj) + } + + if query, found := cachedSelectQueries[t]; found { + return query, nil + } + + info := kstructs.GetTagInfo(t) + + var escapedNames []string + for i := 0; i < info.NumFields(); i++ { + name := info.ByIndex(i).Name + escapedNames = append(escapedNames, dialect.Escape(name)) + } + + query := strings.Join(escapedNames, ", ") + cachedSelectQueries[t] = query + return query, nil +} diff --git a/kbuilder/kbuilder_test.go b/kbuilder/query_test.go similarity index 98% rename from kbuilder/kbuilder_test.go rename to kbuilder/query_test.go index bbf4503..d737d3c 100644 --- a/kbuilder/kbuilder_test.go +++ b/kbuilder/query_test.go @@ -11,13 +11,12 @@ import ( type User struct { Name string `ksql:"name"` - Age string `ksql:"age"` + Age int `ksql:"age"` } var nullField *int -func TestBuilder(t *testing.T) { - +func TestSelectQuery(t *testing.T) { tests := []struct { desc string query kbuilder.Query