diff --git a/README.md b/README.md index 2fe2de1..f791f58 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ The current interface is as follows and we plan on keeping it with as little functions as possible, so don't expect many additions: ```go -// Provider describes the public behavior of this ORM +// Provider describes the ksql public behavior type Provider interface { Insert(ctx context.Context, table Table, record interface{}) error Update(ctx context.Context, table Table, record interface{}) error diff --git a/contracts.go b/contracts.go index 8a9866f..3325bce 100644 --- a/contracts.go +++ b/contracts.go @@ -14,7 +14,7 @@ var ErrRecordNotFound error = errors.Wrap(sql.ErrNoRows, "ksql: the query return // ErrAbortIteration ... var ErrAbortIteration error = fmt.Errorf("ksql: abort iteration, should only be used inside QueryChunks function") -// Provider describes the public behavior of this ORM +// Provider describes the ksql public behavior type Provider interface { Insert(ctx context.Context, table Table, record interface{}) error Update(ctx context.Context, table Table, record interface{}) error @@ -69,7 +69,7 @@ func NewTable(tableName string, ids ...string) Table { } } -func (t Table) insertMethodFor(dialect dialect) insertMethod { +func (t Table) insertMethodFor(dialect Dialect) insertMethod { if len(t.idColumns) == 1 { return dialect.InsertMethod() } diff --git a/dialect.go b/dialect.go index d245da7..655c29a 100644 --- a/dialect.go +++ b/dialect.go @@ -1,6 +1,9 @@ package ksql -import "strconv" +import ( + "fmt" + "strconv" +) type insertMethod int @@ -11,14 +14,16 @@ const ( insertWithNoIDRetrieval ) -var supportedDialects = map[string]dialect{ +var supportedDialects = map[string]Dialect{ "postgres": &postgresDialect{}, "sqlite3": &sqlite3Dialect{}, "mysql": &mysqlDialect{}, "sqlserver": &sqlserverDialect{}, } -type dialect interface { +// Dialect is used to represent the different ways +// of writing SQL queries used by each SQL driver. +type Dialect interface { InsertMethod() insertMethod Escape(str string) string Placeholder(idx int) string @@ -61,6 +66,21 @@ func (sqlite3Dialect) Placeholder(idx int) string { return "?" } +// GetDriverDialect instantiantes the dialect for the +// provided driver string, if the drive is not supported +// it returns an error +func GetDriverDialect(driver string) (Dialect, error) { + dialect, found := map[string]Dialect{ + "postgres": &postgresDialect{}, + "sqlite3": &sqlite3Dialect{}, + }[driver] + if !found { + return nil, fmt.Errorf("unsupported driver `%s`", driver) + } + + return dialect, nil +} + type mysqlDialect struct{} func (mysqlDialect) DriverName() string { diff --git a/go.mod b/go.mod index 4fa2edf..e53b423 100644 --- a/go.mod +++ b/go.mod @@ -13,6 +13,7 @@ require ( github.com/lib/pq v1.10.2 github.com/mattn/go-sqlite3 v1.14.6 github.com/pkg/errors v0.9.1 + github.com/stretchr/testify v1.7.0 // indirect github.com/tj/assert v0.0.3 google.golang.org/appengine v1.6.7 // indirect ) diff --git a/kbuilder/README.md b/kbuilder/README.md new file mode 100644 index 0000000..de6385d --- /dev/null +++ b/kbuilder/README.md @@ -0,0 +1,14 @@ +# Welcome to the KISS Query Builder + +This is the Keep It Stupid Simple query builder created to work +either in conjunction or separated from the ksql package. + +This package was started after ksql and while the ksql is already +in a usable state I still don't recommend using this one since this +being actively implemented and might change without further warning. + +## TODO List + +- Add support to Update and Delete operations +- Improve support to JOINs by adding the `tablename` tag to the structs +- Add error check for when the Select, Insert and Update attrs are all empty diff --git a/kbuilder/insert.go b/kbuilder/insert.go new file mode 100644 index 0000000..d65c483 --- /dev/null +++ b/kbuilder/insert.go @@ -0,0 +1,105 @@ +package kbuilder + +import ( + "fmt" + "reflect" + "strings" + + "github.com/vingarcia/ksql" + "github.com/vingarcia/ksql/kstructs" +) + +// Insert is the struct template for building INSERT queries +type Insert struct { + // Into expects a table name, e.g. "users" + Into string + + // Data expected either a single record annotated with `ksql` tags + // or a list of records annotated likewise. + Data interface{} +} + +// Build is a utility function for finding the dialect based on the driver and +// then calling BuildQuery(dialect) +func (i Insert) Build(driver string) (sqlQuery string, params []interface{}, _ error) { + dialect, err := ksql.GetDriverDialect(driver) + if err != nil { + return "", nil, err + } + + return i.BuildQuery(dialect) +} + +// BuildQuery implements the queryBuilder interface +func (i Insert) BuildQuery(dialect ksql.Dialect) (sqlQuery string, params []interface{}, _ error) { + var b strings.Builder + b.WriteString("INSERT INTO " + dialect.Escape(i.Into)) + + if i.Into == "" { + return "", nil, fmt.Errorf( + "expected the Into attr to contain the tablename, but got an empty string instead", + ) + } + + if i.Data == nil { + return "", nil, fmt.Errorf( + "expected the Data attr to contain a struct or a list of structs, but got `%v`", + i.Data, + ) + } + + v := reflect.ValueOf(i.Data) + t := v.Type() + if t.Kind() != reflect.Slice { + // Convert it to a slice of a single element: + v = reflect.Append(reflect.MakeSlice(reflect.SliceOf(t), 0, 1), v) + } else { + t = t.Elem() + } + + if v.Len() == 0 { + return "", nil, fmt.Errorf( + "can't create an insertion query from an empty list of values", + ) + } + + isPtr := false + if t.Kind() == reflect.Ptr { + isPtr = true + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + return "", nil, fmt.Errorf("expected Data attr to be a struct or slice of structs but got: %v", t) + } + + info := kstructs.GetTagInfo(t) + + b.WriteString(" (") + var escapedNames []string + for i := 0; i < info.NumFields(); i++ { + name := info.ByIndex(i).Name + escapedNames = append(escapedNames, dialect.Escape(name)) + } + b.WriteString(strings.Join(escapedNames, ", ")) + b.WriteString(") VALUES ") + + params = []interface{}{} + values := []string{} + for i := 0; i < v.Len(); i++ { + record := v.Index(i) + if isPtr { + record = record.Elem() + } + + placeholders := []string{} + for j := 0; j < info.NumFields(); j++ { + placeholders = append(placeholders, dialect.Placeholder(len(params))) + params = append(params, record.Field(j).Interface()) + } + values = append(values, "("+strings.Join(placeholders, ", ")+")") + } + b.WriteString(strings.Join(values, ", ")) + + return b.String(), params, nil +} diff --git a/kbuilder/insert_test.go b/kbuilder/insert_test.go new file mode 100644 index 0000000..49f55a5 --- /dev/null +++ b/kbuilder/insert_test.go @@ -0,0 +1,92 @@ +package kbuilder_test + +import ( + "testing" + + "github.com/tj/assert" + "github.com/vingarcia/ksql/kbuilder" +) + +func TestInsertQuery(t *testing.T) { + tests := []struct { + desc string + query kbuilder.Insert + expectedQuery string + expectedParams []interface{} + expectedErr bool + }{ + { + desc: "should build queries witha single record correctly", + query: kbuilder.Insert{ + Into: "users", + Data: &User{ + Name: "foo", + Age: 42, + }, + }, + expectedQuery: `INSERT INTO "users" ("name", "age") VALUES ($1, $2)`, + expectedParams: []interface{}{"foo", 42}, + }, + { + desc: "should build queries with multiple records correctly", + query: kbuilder.Insert{ + Into: "users", + Data: []User{ + { + Name: "foo", + Age: 42, + }, + { + Name: "bar", + Age: 43, + }, + }, + }, + expectedQuery: `INSERT INTO "users" ("name", "age") VALUES ($1, $2), ($3, $4)`, + expectedParams: []interface{}{"foo", 42, "bar", 43}, + }, + + /* * * * * Testing error cases: * * * * */ + { + desc: "should report error if the `Data` attribute is missing", + query: kbuilder.Insert{ + Into: "users", + }, + + expectedErr: true, + }, + { + desc: "should report error if the `Into` attribute is missing", + query: kbuilder.Insert{ + Data: &User{ + Name: "foo", + Age: 42, + }, + }, + + expectedErr: true, + }, + { + desc: "should report error if `Data` contains an empty list", + query: kbuilder.Insert{ + Into: "users", + Data: []User{}, + }, + + expectedErr: true, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + b, err := kbuilder.New("postgres") + assert.Equal(t, nil, err) + + query, params, err := b.Build(test.query) + + expectError(t, test.expectedErr, err) + assert.Equal(t, test.expectedQuery, query) + assert.Equal(t, test.expectedParams, params) + }) + } +} diff --git a/kbuilder/kbuilder.go b/kbuilder/kbuilder.go new file mode 100644 index 0000000..0cf84de --- /dev/null +++ b/kbuilder/kbuilder.go @@ -0,0 +1,33 @@ +package kbuilder + +import ( + "github.com/vingarcia/ksql" +) + +// Builder is the basic container for injecting +// query builder configurations. +// +// All the Query structs can also be called +// directly without this builder, but we kept it +// here for convenience. +type Builder struct { + dialect ksql.Dialect +} + +type queryBuilder interface { + BuildQuery(dialect ksql.Dialect) (sqlQuery string, params []interface{}, _ error) +} + +// New creates a new Builder container. +func New(driver string) (Builder, error) { + dialect, err := ksql.GetDriverDialect(driver) + return Builder{ + dialect: dialect, + }, err +} + +// Build receives a query builder struct, injects it with the configurations +// build the query according to its arguments. +func (builder *Builder) Build(query queryBuilder) (sqlQuery string, params []interface{}, _ error) { + return query.BuildQuery(builder.dialect) +} diff --git a/kbuilder/query.go b/kbuilder/query.go new file mode 100644 index 0000000..2fc4d28 --- /dev/null +++ b/kbuilder/query.go @@ -0,0 +1,218 @@ +package kbuilder + +import ( + "fmt" + "reflect" + "strconv" + "strings" + + "github.com/pkg/errors" + "github.com/vingarcia/ksql" + "github.com/vingarcia/ksql/kstructs" +) + +// Query is is the struct template for building SELECT queries. +type Query struct { + // Select expects either a struct using the `ksql` tags + // or a string listing the column names using SQL syntax, + // e.g.: `id, username, address` + Select interface{} + + // From expects the FROM clause from an SQL query, e.g. `users JOIN posts USING(post_id)` + From string + + // Where expects a list of WhereQuery instances built + // by the public Where() function. + Where WhereQueries + + Limit int + Offset int + OrderBy OrderByQuery +} + +// Build is a utility function for finding the dialect based on the driver and +// then calling BuildQuery(dialect) +func (q Query) Build(driver string) (sqlQuery string, params []interface{}, _ error) { + dialect, err := ksql.GetDriverDialect(driver) + if err != nil { + return "", nil, err + } + + return q.BuildQuery(dialect) +} + +// BuildQuery implements the queryBuilder interface +func (q Query) BuildQuery(dialect ksql.Dialect) (sqlQuery string, params []interface{}, _ error) { + var b strings.Builder + + switch v := q.Select.(type) { + case string: + b.WriteString("SELECT " + v) + default: + selectQuery, err := buildSelectQuery(v, dialect) + if err != nil { + return "", nil, errors.Wrap(err, "error reading the Select field") + } + b.WriteString("SELECT " + selectQuery) + } + + b.WriteString(" FROM " + q.From) + + if len(q.Where) > 0 { + var whereQuery string + whereQuery, params = q.Where.build(dialect) + b.WriteString(" WHERE " + whereQuery) + } + + if strings.TrimSpace(q.From) == "" { + return "", nil, fmt.Errorf("the From field is mandatory for every query") + } + + if q.OrderBy.fields != "" { + b.WriteString(" ORDER BY " + q.OrderBy.fields) + if q.OrderBy.desc { + b.WriteString(" DESC") + } + } + + if q.Limit > 0 { + b.WriteString(" LIMIT " + strconv.Itoa(q.Limit)) + } + + if q.Offset > 0 { + b.WriteString(" OFFSET " + strconv.Itoa(q.Offset)) + } + + return b.String(), params, nil +} + +// WhereQuery represents a single condition in a WHERE expression. +type WhereQuery struct { + // Accepts any SQL boolean expression + // This expression may optionally contain + // string formatting directives %s and only %s. + // + // For each of these directives we expect a new param + // on the params list below. + // + // In the resulting query each %s will be properly replaced + // by placeholders according to the database driver, e.g. `$1` + // for postgres or `?` for sqlite3. + cond string + params []interface{} +} + +// WhereQueries is the helper for creating complex WHERE queries +// in a dynamic way. +type WhereQueries []WhereQuery + +func (w WhereQueries) build(dialect ksql.Dialect) (query string, params []interface{}) { + var conds []string + for _, whereQuery := range w { + var placeholders []interface{} + for i := range whereQuery.params { + placeholders = append(placeholders, dialect.Placeholder(len(params)+i)) + } + + conds = append(conds, fmt.Sprintf(whereQuery.cond, placeholders...)) + params = append(params, whereQuery.params...) + } + + return strings.Join(conds, " AND "), params +} + +// Where adds a new bollean condition to an existing +// WhereQueries helper. +func (w WhereQueries) Where(cond string, params ...interface{}) WhereQueries { + return append(w, WhereQuery{ + cond: cond, + params: params, + }) +} + +// WhereIf condionally adds a new boolean expression to the WhereQueries helper. +func (w WhereQueries) WhereIf(cond string, param interface{}) WhereQueries { + if param == nil || reflect.ValueOf(param).IsNil() { + return w + } + + return append(w, WhereQuery{ + cond: cond, + params: []interface{}{param}, + }) +} + +// Where adds a new bollean condition to an existing +// WhereQueries helper. +func Where(cond string, params ...interface{}) WhereQueries { + return WhereQueries{{ + cond: cond, + params: params, + }} +} + +// WhereIf condionally adds a new boolean expression to the WhereQueries helper +func WhereIf(cond string, param interface{}) WhereQueries { + if param == nil || reflect.ValueOf(param).IsNil() { + return WhereQueries{} + } + + return WhereQueries{{ + cond: cond, + params: []interface{}{param}, + }} +} + +// OrderByQuery represents the ORDER BY part of the query +type OrderByQuery struct { + fields string + desc bool +} + +// Desc is a setter function for configuring the +// ORDER BY part of the query as DESC +func (o OrderByQuery) Desc() OrderByQuery { + return OrderByQuery{ + fields: o.fields, + desc: true, + } +} + +// OrderBy is a helper for building the ORDER BY +// part of the query. +func OrderBy(fields string) OrderByQuery { + return OrderByQuery{ + fields: fields, + desc: false, + } +} + +var cachedSelectQueries = map[reflect.Type]string{} + +// Builds the select query using cached info so that its efficient +func buildSelectQuery(obj interface{}, dialect ksql.Dialect) (string, error) { + t := reflect.TypeOf(obj) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + return "", fmt.Errorf("expected to receive a pointer to struct, but got: %T", obj) + } + + if query, found := cachedSelectQueries[t]; found { + return query, nil + } + + info := kstructs.GetTagInfo(t) + + var escapedNames []string + for i := 0; i < info.NumFields(); i++ { + name := info.ByIndex(i).Name + escapedNames = append(escapedNames, dialect.Escape(name)) + } + + query := strings.Join(escapedNames, ", ") + cachedSelectQueries[t] = query + return query, nil +} diff --git a/kbuilder/query_test.go b/kbuilder/query_test.go new file mode 100644 index 0000000..d737d3c --- /dev/null +++ b/kbuilder/query_test.go @@ -0,0 +1,144 @@ +package kbuilder_test + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tj/assert" + "github.com/vingarcia/ksql/kbuilder" +) + +type User struct { + Name string `ksql:"name"` + Age int `ksql:"age"` +} + +var nullField *int + +func TestSelectQuery(t *testing.T) { + tests := []struct { + desc string + query kbuilder.Query + expectedQuery string + expectedParams []interface{} + expectedErr bool + }{ + { + desc: "should build queries correctly", + query: kbuilder.Query{ + Select: &User{}, + From: "users", + Where: kbuilder. + Where("foo < %s", 42). + Where("bar LIKE %s", "%ending"). + WhereIf("foobar = %s", nullField), + + OrderBy: kbuilder.OrderBy("id").Desc(), + Offset: 100, + Limit: 10, + }, + expectedQuery: `SELECT "name", "age" FROM users WHERE foo < $1 AND bar LIKE $2 ORDER BY id DESC LIMIT 10 OFFSET 100`, + expectedParams: []interface{}{42, "%ending"}, + }, + { + desc: "should build queries omitting the OFFSET", + query: kbuilder.Query{ + Select: &User{}, + From: "users", + Where: kbuilder. + Where("foo < %s", 42). + Where("bar LIKE %s", "%ending"). + WhereIf("foobar = %s", nullField), + + OrderBy: kbuilder.OrderBy("id").Desc(), + Limit: 10, + }, + expectedQuery: `SELECT "name", "age" FROM users WHERE foo < $1 AND bar LIKE $2 ORDER BY id DESC LIMIT 10`, + expectedParams: []interface{}{42, "%ending"}, + }, + { + desc: "should build queries omitting the LIMIT", + query: kbuilder.Query{ + Select: &User{}, + From: "users", + Where: kbuilder. + Where("foo < %s", 42). + Where("bar LIKE %s", "%ending"). + WhereIf("foobar = %s", nullField), + + OrderBy: kbuilder.OrderBy("id").Desc(), + Offset: 100, + }, + expectedQuery: `SELECT "name", "age" FROM users WHERE foo < $1 AND bar LIKE $2 ORDER BY id DESC OFFSET 100`, + expectedParams: []interface{}{42, "%ending"}, + }, + { + desc: "should build queries omitting the ORDER BY clause", + query: kbuilder.Query{ + Select: &User{}, + From: "users", + Where: kbuilder. + Where("foo < %s", 42). + Where("bar LIKE %s", "%ending"). + WhereIf("foobar = %s", nullField), + + Offset: 100, + Limit: 10, + }, + expectedQuery: `SELECT "name", "age" FROM users WHERE foo < $1 AND bar LIKE $2 LIMIT 10 OFFSET 100`, + expectedParams: []interface{}{42, "%ending"}, + }, + { + desc: "should build queries omitting the WHERE clause", + query: kbuilder.Query{ + Select: &User{}, + From: "users", + + OrderBy: kbuilder.OrderBy("id").Desc(), + Offset: 100, + Limit: 10, + }, + expectedQuery: `SELECT "name", "age" FROM users ORDER BY id DESC LIMIT 10 OFFSET 100`, + }, + + /* * * * * Testing error cases: * * * * */ + { + desc: "should report error if the FROM clause is missing", + query: kbuilder.Query{ + Select: &User{}, + Where: kbuilder. + Where("foo < %s", 42). + Where("bar LIKE %s", "%ending"). + WhereIf("foobar = %s", nullField), + + OrderBy: kbuilder.OrderBy("id").Desc(), + Offset: 100, + Limit: 10, + }, + + expectedErr: true, + }, + } + + for _, test := range tests { + t.Run(test.desc, func(t *testing.T) { + b, err := kbuilder.New("postgres") + assert.Equal(t, nil, err) + + query, params, err := b.Build(test.query) + + expectError(t, test.expectedErr, err) + assert.Equal(t, test.expectedQuery, query) + assert.Equal(t, test.expectedParams, params) + }) + } +} + +func expectError(t *testing.T, expect bool, err error) { + if expect { + require.Equal(t, true, err != nil, "expected an error, but got nothing") + } else { + require.Equal(t, false, err != nil, fmt.Sprintf("unexpected error %s", err)) + } +} diff --git a/ksql.go b/ksql.go index eafeefe..a0e127a 100644 --- a/ksql.go +++ b/ksql.go @@ -26,7 +26,7 @@ func init() { // the KissSQL interface `ksql.Provider`. type DB struct { driver string - dialect dialect + dialect Dialect db DBAdapter } @@ -632,7 +632,7 @@ func (c DB) Update( } func buildInsertQuery( - dialect dialect, + dialect Dialect, tableName string, record interface{}, idNames ...string, @@ -736,7 +736,7 @@ func buildInsertQuery( } func buildUpdateQuery( - dialect dialect, + dialect Dialect, tableName string, record interface{}, idFieldNames ...string, @@ -857,7 +857,7 @@ func (nopScanner) Scan(value interface{}) error { return nil } -func scanRows(dialect dialect, rows Rows, record interface{}) error { +func scanRows(dialect Dialect, rows Rows, record interface{}) error { v := reflect.ValueOf(record) t := v.Type() if t.Kind() != reflect.Ptr { @@ -892,7 +892,7 @@ func scanRows(dialect dialect, rows Rows, record interface{}) error { return rows.Scan(scanArgs...) } -func getScanArgsForNestedStructs(dialect dialect, rows Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) []interface{} { +func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info kstructs.StructInfo) []interface{} { scanArgs := []interface{}{} for i := 0; i < v.NumField(); i++ { // TODO(vingarcia00): Handle case where type is pointer @@ -919,7 +919,7 @@ func getScanArgsForNestedStructs(dialect dialect, rows Rows, t reflect.Type, v r return scanArgs } -func getScanArgsFromNames(dialect dialect, names []string, v reflect.Value, info kstructs.StructInfo) []interface{} { +func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info kstructs.StructInfo) []interface{} { scanArgs := []interface{}{} for _, name := range names { fieldInfo := info.ByName(name) @@ -942,7 +942,7 @@ func getScanArgsFromNames(dialect dialect, names []string, v reflect.Value, info } func buildSingleKeyDeleteQuery( - dialect dialect, + dialect Dialect, table string, idName string, idMaps []map[string]interface{}, @@ -962,7 +962,7 @@ func buildSingleKeyDeleteQuery( } func buildCompositeKeyDeleteQuery( - dialect dialect, + dialect Dialect, table string, idNames []string, idMaps []map[string]interface{}, @@ -1007,7 +1007,7 @@ func getFirstToken(s string) string { } func buildSelectQuery( - dialect dialect, + dialect Dialect, structType reflect.Type, info kstructs.StructInfo, selectQueryCache map[reflect.Type]string, @@ -1030,7 +1030,7 @@ func buildSelectQuery( } func buildSelectQueryForPlainStructs( - dialect dialect, + dialect Dialect, structType reflect.Type, info kstructs.StructInfo, ) string { @@ -1043,7 +1043,7 @@ func buildSelectQueryForPlainStructs( } func buildSelectQueryForNestedStructs( - dialect dialect, + dialect Dialect, structType reflect.Type, info kstructs.StructInfo, ) (string, error) { diff --git a/ksql_test.go b/ksql_test.go index 18df87c..0e9b739 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -2022,7 +2022,7 @@ func shiftErrSlice(errs *[]error) error { return err } -func getUsersByID(db DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uint) error { +func getUsersByID(db DBAdapter, dialect Dialect, resultsPtr *[]User, ids ...uint) error { placeholders := make([]string, len(ids)) params := make([]interface{}, len(ids)) for i := range ids { @@ -2063,7 +2063,7 @@ func getUsersByID(db DBAdapter, dialect dialect, resultsPtr *[]User, ids ...uint return nil } -func getUserByID(db DBAdapter, dialect dialect, result *User, id uint) error { +func getUserByID(db DBAdapter, dialect Dialect, result *User, id uint) error { rows, err := db.QueryContext(context.TODO(), `SELECT id, name, age, address FROM users WHERE id=`+dialect.Placeholder(0), id) if err != nil { return err diff --git a/kstructs/structs.go b/kstructs/structs.go index 0232862..80959c0 100644 --- a/kstructs/structs.go +++ b/kstructs/structs.go @@ -51,6 +51,11 @@ func (s StructInfo) add(field FieldInfo) { s.byName[field.Name] = &field } +// NumFields ... +func (s StructInfo) NumFields() int { + return len(s.byIndex) +} + // This cache is kept as a pkg variable // because the total number of types on a program // should be finite. So keeping a single cache here diff --git a/mocks.go b/mocks.go index 97911cd..e8b251a 100644 --- a/mocks.go +++ b/mocks.go @@ -1,6 +1,9 @@ package ksql -import "context" +import ( + "context" + "fmt" +) var _ Provider = Mock{} @@ -20,40 +23,64 @@ type Mock struct { // Insert ... func (m Mock) Insert(ctx context.Context, table Table, record interface{}) error { + if m.InsertFn == nil { + panic(fmt.Errorf("Mock.Insert(ctx, %v, %v) called but the ksql.Mock.InsertFn() is not set", table, record)) + } return m.InsertFn(ctx, table, record) } // Update ... func (m Mock) Update(ctx context.Context, table Table, record interface{}) error { + if m.UpdateFn == nil { + panic(fmt.Errorf("Mock.Update(ctx, %v, %v) called but the ksql.Mock.UpdateFn() is not set", table, record)) + } return m.UpdateFn(ctx, table, record) } // Delete ... func (m Mock) Delete(ctx context.Context, table Table, ids ...interface{}) error { + if m.DeleteFn == nil { + panic(fmt.Errorf("Mock.Delete(ctx, %v, %v) called but the ksql.Mock.DeleteFn() is not set", table, ids)) + } return m.DeleteFn(ctx, table, ids...) } // Query ... func (m Mock) Query(ctx context.Context, records interface{}, query string, params ...interface{}) error { + if m.QueryFn == nil { + panic(fmt.Errorf("Mock.Query(ctx, %v, %s, %v) called but the ksql.Mock.QueryFn() is not set", records, query, params)) + } return m.QueryFn(ctx, records, query, params...) } // QueryOne ... func (m Mock) QueryOne(ctx context.Context, record interface{}, query string, params ...interface{}) error { + if m.QueryOneFn == nil { + panic(fmt.Errorf("Mock.QueryOne(ctx, %v, %s, %v) called but the ksql.Mock.QueryOneFn() is not set", record, query, params)) + } return m.QueryOneFn(ctx, record, query, params...) } // QueryChunks ... func (m Mock) QueryChunks(ctx context.Context, parser ChunkParser) error { + if m.QueryChunksFn == nil { + panic(fmt.Errorf("Mock.QueryChunks(ctx, %v) called but the ksql.Mock.QueryChunksFn() is not set", parser)) + } return m.QueryChunksFn(ctx, parser) } // Exec ... func (m Mock) Exec(ctx context.Context, query string, params ...interface{}) error { + if m.ExecFn == nil { + panic(fmt.Errorf("Mock.Exec(ctx, %s, %v) called but the ksql.Mock.ExecFn() is not set", query, params)) + } return m.ExecFn(ctx, query, params...) } // Transaction ... func (m Mock) Transaction(ctx context.Context, fn func(db Provider) error) error { + if m.TransactionFn == nil { + return fn(m) + } return m.TransactionFn(ctx, fn) }