mirror of https://github.com/VinGarcia/ksql.git
1014 lines
25 KiB
Go
1014 lines
25 KiB
Go
package ksql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
"unicode"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/vingarcia/ksql/structs"
|
|
)
|
|
|
|
var selectQueryCache = map[string]map[reflect.Type]string{}
|
|
|
|
func init() {
|
|
for dname := range supportedDialects {
|
|
selectQueryCache[dname] = map[reflect.Type]string{}
|
|
}
|
|
}
|
|
|
|
// DB represents the ksql client responsible for
|
|
// interfacing with the "database/sql" package implementing
|
|
// the KissSQL interface `SQLProvider`.
|
|
type DB struct {
|
|
driver string
|
|
dialect dialect
|
|
db sqlProvider
|
|
}
|
|
|
|
type sqlProvider interface {
|
|
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
|
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
|
}
|
|
|
|
// Config describes the optional arguments accepted
|
|
// by the ksql.New() function.
|
|
type Config struct {
|
|
// MaxOpenCons defaults to 1 if not set
|
|
MaxOpenConns int
|
|
}
|
|
|
|
// New instantiates a new KissSQL client
|
|
func New(
|
|
dbDriver string,
|
|
connectionString string,
|
|
config Config,
|
|
) (DB, error) {
|
|
db, err := sql.Open(dbDriver, connectionString)
|
|
if err != nil {
|
|
return DB{}, err
|
|
}
|
|
if err = db.Ping(); err != nil {
|
|
return DB{}, err
|
|
}
|
|
|
|
if config.MaxOpenConns == 0 {
|
|
config.MaxOpenConns = 1
|
|
}
|
|
|
|
db.SetMaxOpenConns(config.MaxOpenConns)
|
|
|
|
dialect := supportedDialects[dbDriver]
|
|
if dialect == nil {
|
|
return DB{}, fmt.Errorf("unsupported driver `%s`", dbDriver)
|
|
}
|
|
|
|
return DB{
|
|
dialect: dialect,
|
|
driver: dbDriver,
|
|
db: db,
|
|
}, nil
|
|
}
|
|
|
|
// Query queries several rows from the database,
|
|
// the input should be a slice of structs (or *struct) passed
|
|
// by reference and it will be filled with all the results.
|
|
//
|
|
// Note: it is very important to make sure the query will
|
|
// return a small known number of results, otherwise you risk
|
|
// of overloading the available memory.
|
|
func (c DB) Query(
|
|
ctx context.Context,
|
|
records interface{},
|
|
query string,
|
|
params ...interface{},
|
|
) error {
|
|
slicePtr := reflect.ValueOf(records)
|
|
slicePtrType := slicePtr.Type()
|
|
if slicePtrType.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("ksql: expected to receive a pointer to slice of structs, but got: %T", records)
|
|
}
|
|
sliceType := slicePtrType.Elem()
|
|
slice := slicePtr.Elem()
|
|
structType, isSliceOfPtrs, err := structs.DecodeAsSliceOfStructs(sliceType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if isSliceOfPtrs {
|
|
// Truncate the slice so there is no risk
|
|
// of overwritting records that were already saved
|
|
// on the slice:
|
|
slice = slice.Slice(0, 0)
|
|
}
|
|
|
|
info := structs.GetTagInfo(structType)
|
|
|
|
firstToken := strings.ToUpper(getFirstToken(query))
|
|
if info.IsNestedStruct && firstToken == "SELECT" {
|
|
// This error check is necessary, since if we can't build the select part of the query this feature won't work.
|
|
return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query")
|
|
}
|
|
|
|
if firstToken == "FROM" {
|
|
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
query = selectPrefix + query
|
|
}
|
|
|
|
rows, err := c.db.QueryContext(ctx, query, params...)
|
|
if err != nil {
|
|
return fmt.Errorf("error running query: %s", err.Error())
|
|
}
|
|
defer rows.Close()
|
|
|
|
for idx := 0; rows.Next(); idx++ {
|
|
// Allocate new slice elements
|
|
// only if they are not already allocated:
|
|
if slice.Len() <= idx {
|
|
var elemValue reflect.Value
|
|
elemValue = reflect.New(structType)
|
|
if !isSliceOfPtrs {
|
|
elemValue = elemValue.Elem()
|
|
}
|
|
slice = reflect.Append(slice, elemValue)
|
|
}
|
|
|
|
elemPtr := slice.Index(idx).Addr()
|
|
if isSliceOfPtrs {
|
|
// This is necessary since scanRows expects a *record not a **record
|
|
elemPtr = elemPtr.Elem()
|
|
}
|
|
|
|
err = scanRows(c.dialect, rows, elemPtr.Interface())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := rows.Close(); err != nil {
|
|
return err
|
|
}
|
|
|
|
if rows.Err() != nil {
|
|
return rows.Err()
|
|
}
|
|
|
|
// Update the original slice passed by reference:
|
|
slicePtr.Elem().Set(slice)
|
|
|
|
return nil
|
|
}
|
|
|
|
// QueryOne queries one instance from the database,
|
|
// the input struct must be passed by reference
|
|
// and the query should return only one result.
|
|
//
|
|
// QueryOne returns a ErrRecordNotFound if
|
|
// the query returns no results.
|
|
func (c DB) QueryOne(
|
|
ctx context.Context,
|
|
record interface{},
|
|
query string,
|
|
params ...interface{},
|
|
) error {
|
|
t := reflect.TypeOf(record)
|
|
if t.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
|
|
}
|
|
t = t.Elem()
|
|
if t.Kind() != reflect.Struct {
|
|
return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
|
|
}
|
|
|
|
info := structs.GetTagInfo(t)
|
|
|
|
firstToken := strings.ToUpper(getFirstToken(query))
|
|
if info.IsNestedStruct && firstToken == "SELECT" {
|
|
// This error check is necessary, since if we can't build the select part of the query this feature won't work.
|
|
return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query")
|
|
}
|
|
|
|
if firstToken == "FROM" {
|
|
selectPrefix, err := buildSelectQuery(c.dialect, t, info, selectQueryCache[c.dialect.DriverName()])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
query = selectPrefix + query
|
|
}
|
|
|
|
rows, err := c.db.QueryContext(ctx, query, params...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
if !rows.Next() {
|
|
if rows.Err() != nil {
|
|
return rows.Err()
|
|
}
|
|
return ErrRecordNotFound
|
|
}
|
|
|
|
err = scanRows(c.dialect, rows, record)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return rows.Close()
|
|
}
|
|
|
|
// QueryChunks is meant to perform queries that returns
|
|
// more results than would normally fit on memory,
|
|
// for others cases the Query and QueryOne functions are indicated.
|
|
//
|
|
// The ChunkParser argument has 4 attributes:
|
|
// (1) The Query;
|
|
// (2) The query args;
|
|
// (3) The chunk size;
|
|
// (4) A callback function called ForEachChunk, that will be called
|
|
// to process each chunk loaded from the database.
|
|
//
|
|
// Note that the signature of the ForEachChunk callback can be
|
|
// any function that receives a slice of structs or a slice of
|
|
// pointers to struct as its only argument and that reflection
|
|
// will be used to instantiate this argument and to fill it
|
|
// with the database rows.
|
|
func (c DB) QueryChunks(
|
|
ctx context.Context,
|
|
parser ChunkParser,
|
|
) error {
|
|
fnValue := reflect.ValueOf(parser.ForEachChunk)
|
|
chunkType, err := parseInputFunc(parser.ForEachChunk)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
chunk := reflect.MakeSlice(chunkType, 0, parser.ChunkSize)
|
|
|
|
structType, isSliceOfPtrs, err := structs.DecodeAsSliceOfStructs(chunkType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
info := structs.GetTagInfo(structType)
|
|
|
|
firstToken := strings.ToUpper(getFirstToken(parser.Query))
|
|
if info.IsNestedStruct && firstToken == "SELECT" {
|
|
// This error check is necessary, since if we can't build the select part of the query this feature won't work.
|
|
return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query")
|
|
}
|
|
|
|
if firstToken == "FROM" {
|
|
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
parser.Query = selectPrefix + parser.Query
|
|
}
|
|
|
|
rows, err := c.db.QueryContext(ctx, parser.Query, parser.Params...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var idx = 0
|
|
for rows.Next() {
|
|
// Allocate new slice elements
|
|
// only if they are not already allocated:
|
|
if chunk.Len() <= idx {
|
|
var elemValue reflect.Value
|
|
elemValue = reflect.New(structType)
|
|
if !isSliceOfPtrs {
|
|
elemValue = elemValue.Elem()
|
|
}
|
|
chunk = reflect.Append(chunk, elemValue)
|
|
}
|
|
|
|
err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if idx < parser.ChunkSize-1 {
|
|
idx++
|
|
continue
|
|
}
|
|
|
|
idx = 0
|
|
err, _ = fnValue.Call([]reflect.Value{chunk})[0].Interface().(error)
|
|
if err != nil {
|
|
if err == ErrAbortIteration {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := rows.Close(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// If Next() returned false because of an error:
|
|
if rows.Err() != nil {
|
|
return rows.Err()
|
|
}
|
|
|
|
// If no rows were found or idx was reset to 0
|
|
// on the last iteration skip this last call to ForEachChunk:
|
|
if idx > 0 {
|
|
chunk = chunk.Slice(0, idx)
|
|
|
|
err, _ = fnValue.Call([]reflect.Value{chunk})[0].Interface().(error)
|
|
if err != nil {
|
|
if err == ErrAbortIteration {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Insert one or more instances on the database
|
|
//
|
|
// If the original instances have been passed by reference
|
|
// the ID is automatically updated after insertion is completed.
|
|
func (c DB) Insert(
|
|
ctx context.Context,
|
|
table Table,
|
|
record interface{},
|
|
) error {
|
|
query, params, scanValues, err := buildInsertQuery(c.dialect, table.name, record, table.idColumns...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch table.insertMethodFor(c.dialect) {
|
|
case insertWithReturning, insertWithOutput:
|
|
err = c.insertReturningIDs(ctx, record, query, params, scanValues, table.idColumns)
|
|
case insertWithLastInsertID:
|
|
err = c.insertWithLastInsertID(ctx, record, query, params, table.idColumns[0])
|
|
case insertWithNoIDRetrieval:
|
|
err = c.insertWithNoIDRetrieval(ctx, record, query, params)
|
|
default:
|
|
// Unsupported drivers should be detected on the New() function,
|
|
// So we don't expect the code to ever get into this default case.
|
|
err = fmt.Errorf("code error: unsupported driver `%s`", c.driver)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (c DB) insertReturningIDs(
|
|
ctx context.Context,
|
|
record interface{},
|
|
query string,
|
|
params []interface{},
|
|
scanValues []interface{},
|
|
idNames []string,
|
|
) error {
|
|
rows, err := c.db.QueryContext(ctx, query, params...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
if !rows.Next() {
|
|
err := fmt.Errorf("unexpected error retrieving the id columns from the database")
|
|
if rows.Err() != nil {
|
|
err = rows.Err()
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
err = rows.Scan(scanValues...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return rows.Close()
|
|
}
|
|
|
|
func (c DB) insertWithLastInsertID(
|
|
ctx context.Context,
|
|
record interface{},
|
|
query string,
|
|
params []interface{},
|
|
idName string,
|
|
) error {
|
|
result, err := c.db.ExecContext(ctx, query, params...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
v := reflect.ValueOf(record)
|
|
t := v.Type()
|
|
if err = assertStructPtr(t); err != nil {
|
|
return errors.Wrap(err, "can't write to `"+idName+"` field")
|
|
}
|
|
|
|
info := structs.GetTagInfo(t.Elem())
|
|
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
vID := reflect.ValueOf(id)
|
|
tID := vID.Type()
|
|
|
|
fieldAddr := v.Elem().Field(info.ByName(idName).Index).Addr()
|
|
fieldType := fieldAddr.Type().Elem()
|
|
|
|
if !tID.ConvertibleTo(fieldType) {
|
|
return fmt.Errorf(
|
|
"Can't convert last insert id of type int64 into field `%s` of type %s",
|
|
idName,
|
|
fieldType,
|
|
)
|
|
}
|
|
|
|
fieldAddr.Elem().Set(vID.Convert(fieldType))
|
|
return nil
|
|
}
|
|
|
|
func (c DB) insertWithNoIDRetrieval(
|
|
ctx context.Context,
|
|
record interface{},
|
|
query string,
|
|
params []interface{},
|
|
) error {
|
|
_, err := c.db.ExecContext(ctx, query, params...)
|
|
return err
|
|
}
|
|
|
|
func assertStructPtr(t reflect.Type) error {
|
|
if t.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("expected a Kind of Ptr but got: %s", t)
|
|
}
|
|
if t.Elem().Kind() != reflect.Struct {
|
|
return fmt.Errorf("expected a Kind of Ptr to Struct but got: %s", t)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Delete deletes one or more instances from the database by id
|
|
func (c DB) Delete(
|
|
ctx context.Context,
|
|
table Table,
|
|
ids ...interface{},
|
|
) error {
|
|
if len(ids) == 0 {
|
|
return nil
|
|
}
|
|
|
|
idMaps, err := normalizeIDsAsMaps(table.idColumns, ids)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var query string
|
|
var params []interface{}
|
|
if len(table.idColumns) == 1 {
|
|
query, params = buildSingleKeyDeleteQuery(c.dialect, table.name, table.idColumns[0], idMaps)
|
|
} else {
|
|
query, params = buildCompositeKeyDeleteQuery(c.dialect, table.name, table.idColumns, idMaps)
|
|
}
|
|
|
|
_, err = c.db.ExecContext(ctx, query, params...)
|
|
|
|
return err
|
|
}
|
|
|
|
func normalizeIDsAsMaps(idNames []string, ids []interface{}) ([]map[string]interface{}, error) {
|
|
if len(idNames) == 0 {
|
|
return nil, fmt.Errorf("internal ksql error: missing idNames")
|
|
}
|
|
|
|
idMaps := []map[string]interface{}{}
|
|
for i := range ids {
|
|
t := reflect.TypeOf(ids[i])
|
|
switch t.Kind() {
|
|
case reflect.Struct:
|
|
m, err := structs.StructToMap(ids[i])
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, "could not get ID(s) from record on idx %d", i)
|
|
}
|
|
idMaps = append(idMaps, m)
|
|
case reflect.Map:
|
|
m, ok := ids[i].(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("expected map[string]interface{} but got %T", ids[i])
|
|
}
|
|
idMaps = append(idMaps, m)
|
|
default:
|
|
idMaps = append(idMaps, map[string]interface{}{
|
|
idNames[0]: ids[i],
|
|
})
|
|
}
|
|
}
|
|
|
|
for i, m := range idMaps {
|
|
for _, id := range idNames {
|
|
if _, found := m[id]; !found {
|
|
return nil, fmt.Errorf("missing required id field `%s` on record with idx %d", id, i)
|
|
}
|
|
}
|
|
}
|
|
|
|
return idMaps, nil
|
|
}
|
|
|
|
// Update updates the given instances on the database by id.
|
|
//
|
|
// Partial updates are supported, i.e. it will ignore nil pointer attributes
|
|
func (c DB) Update(
|
|
ctx context.Context,
|
|
table Table,
|
|
record interface{},
|
|
) error {
|
|
query, params, err := buildUpdateQuery(c.dialect, table.name, record, table.idColumns...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = c.db.ExecContext(ctx, query, params...)
|
|
|
|
return err
|
|
}
|
|
|
|
func buildInsertQuery(
|
|
dialect dialect,
|
|
tableName string,
|
|
record interface{},
|
|
idNames ...string,
|
|
) (query string, params []interface{}, scanValues []interface{}, err error) {
|
|
v := reflect.ValueOf(record)
|
|
t := v.Type()
|
|
if err = assertStructPtr(t); err != nil {
|
|
return "", nil, nil, fmt.Errorf(
|
|
"ksql: expected record to be a pointer to struct, but got: %T",
|
|
record,
|
|
)
|
|
}
|
|
|
|
info := structs.GetTagInfo(t.Elem())
|
|
|
|
recordMap, err := structs.StructToMap(record)
|
|
if err != nil {
|
|
return "", nil, nil, err
|
|
}
|
|
|
|
for _, fieldName := range idNames {
|
|
// Remove any ID field that was not set:
|
|
if reflect.ValueOf(recordMap[fieldName]).IsZero() {
|
|
delete(recordMap, fieldName)
|
|
}
|
|
}
|
|
|
|
columnNames := []string{}
|
|
for col := range recordMap {
|
|
columnNames = append(columnNames, col)
|
|
}
|
|
|
|
params = make([]interface{}, len(recordMap))
|
|
valuesQuery := make([]string, len(recordMap))
|
|
for i, col := range columnNames {
|
|
recordValue := recordMap[col]
|
|
params[i] = recordValue
|
|
if info.ByName(col).SerializeAsJSON {
|
|
params[i] = jsonSerializable{
|
|
DriverName: dialect.DriverName(),
|
|
Attr: recordValue,
|
|
}
|
|
}
|
|
|
|
valuesQuery[i] = dialect.Placeholder(i)
|
|
}
|
|
|
|
// Escape all cols to be sure they will be interpreted as column names:
|
|
escapedColumnNames := []string{}
|
|
for _, col := range columnNames {
|
|
escapedColumnNames = append(escapedColumnNames, dialect.Escape(col))
|
|
}
|
|
|
|
var returningQuery, outputQuery string
|
|
switch dialect.InsertMethod() {
|
|
case insertWithReturning:
|
|
escapedIDNames := []string{}
|
|
for _, id := range idNames {
|
|
escapedIDNames = append(escapedIDNames, dialect.Escape(id))
|
|
}
|
|
returningQuery = " RETURNING " + strings.Join(escapedIDNames, ", ")
|
|
|
|
for _, id := range idNames {
|
|
scanValues = append(
|
|
scanValues,
|
|
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
|
|
)
|
|
}
|
|
case insertWithOutput:
|
|
escapedIDNames := []string{}
|
|
for _, id := range idNames {
|
|
escapedIDNames = append(escapedIDNames, "INSERTED."+dialect.Escape(id))
|
|
}
|
|
outputQuery = " OUTPUT " + strings.Join(escapedIDNames, ", ")
|
|
|
|
for _, id := range idNames {
|
|
scanValues = append(
|
|
scanValues,
|
|
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
|
|
)
|
|
}
|
|
}
|
|
|
|
// Note that the outputQuery and the returningQuery depend
|
|
// on the selected driver, thus, they might be empty strings.
|
|
query = fmt.Sprintf(
|
|
"INSERT INTO %s (%s)%s VALUES (%s)%s",
|
|
dialect.Escape(tableName),
|
|
strings.Join(escapedColumnNames, ", "),
|
|
outputQuery,
|
|
strings.Join(valuesQuery, ", "),
|
|
returningQuery,
|
|
)
|
|
|
|
return query, params, scanValues, nil
|
|
}
|
|
|
|
func buildUpdateQuery(
|
|
dialect dialect,
|
|
tableName string,
|
|
record interface{},
|
|
idFieldNames ...string,
|
|
) (query string, args []interface{}, err error) {
|
|
recordMap, err := structs.StructToMap(record)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
numAttrs := len(recordMap)
|
|
args = make([]interface{}, numAttrs)
|
|
numNonIDArgs := numAttrs - len(idFieldNames)
|
|
whereArgs := args[numNonIDArgs:]
|
|
|
|
whereQuery := make([]string, len(idFieldNames))
|
|
for i, fieldName := range idFieldNames {
|
|
whereArgs[i] = recordMap[fieldName]
|
|
whereQuery[i] = fmt.Sprintf(
|
|
"%s = %s",
|
|
dialect.Escape(fieldName),
|
|
dialect.Placeholder(i+numNonIDArgs),
|
|
)
|
|
|
|
delete(recordMap, fieldName)
|
|
}
|
|
|
|
keys := []string{}
|
|
for key := range recordMap {
|
|
keys = append(keys, key)
|
|
}
|
|
|
|
t := reflect.TypeOf(record)
|
|
if t.Kind() == reflect.Ptr {
|
|
t = t.Elem()
|
|
}
|
|
info := structs.GetTagInfo(t)
|
|
|
|
var setQuery []string
|
|
for i, k := range keys {
|
|
recordValue := recordMap[k]
|
|
if info.ByName(k).SerializeAsJSON {
|
|
recordValue = jsonSerializable{
|
|
DriverName: dialect.DriverName(),
|
|
Attr: recordValue,
|
|
}
|
|
}
|
|
args[i] = recordValue
|
|
setQuery = append(setQuery, fmt.Sprintf(
|
|
"%s = %s",
|
|
dialect.Escape(k),
|
|
dialect.Placeholder(i),
|
|
))
|
|
}
|
|
|
|
query = fmt.Sprintf(
|
|
"UPDATE %s SET %s WHERE %s",
|
|
dialect.Escape(tableName),
|
|
strings.Join(setQuery, ", "),
|
|
strings.Join(whereQuery, ", "),
|
|
)
|
|
|
|
return query, args, nil
|
|
}
|
|
|
|
// Exec just runs an SQL command on the database returning no rows.
|
|
func (c DB) Exec(ctx context.Context, query string, params ...interface{}) error {
|
|
_, err := c.db.ExecContext(ctx, query, params...)
|
|
return err
|
|
}
|
|
|
|
// Transaction just runs an SQL command on the database returning no rows.
|
|
func (c DB) Transaction(ctx context.Context, fn func(SQLProvider) error) error {
|
|
switch db := c.db.(type) {
|
|
case *sql.Tx:
|
|
return fn(c)
|
|
case *sql.DB:
|
|
tx, err := db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
rollbackErr := tx.Rollback()
|
|
if rollbackErr != nil {
|
|
r = errors.Wrap(rollbackErr,
|
|
fmt.Sprintf("unable to rollback after panic with value: %v", r),
|
|
)
|
|
}
|
|
panic(r)
|
|
}
|
|
}()
|
|
|
|
ormCopy := c
|
|
ormCopy.db = tx
|
|
|
|
err = fn(ormCopy)
|
|
if err != nil {
|
|
rollbackErr := tx.Rollback()
|
|
if rollbackErr != nil {
|
|
err = errors.Wrap(rollbackErr,
|
|
fmt.Sprintf("unable to rollback after error: %s", err.Error()),
|
|
)
|
|
}
|
|
return err
|
|
}
|
|
|
|
return tx.Commit()
|
|
|
|
default:
|
|
return fmt.Errorf("unexpected error on ksql: db attribute has an invalid type")
|
|
}
|
|
}
|
|
|
|
var errType = reflect.TypeOf(new(error)).Elem()
|
|
|
|
func parseInputFunc(fn interface{}) (reflect.Type, error) {
|
|
if fn == nil {
|
|
return nil, fmt.Errorf("the ForEachChunk attribute is required and cannot be nil")
|
|
}
|
|
|
|
t := reflect.TypeOf(fn)
|
|
|
|
if t.Kind() != reflect.Func {
|
|
return nil, fmt.Errorf("the ForEachChunk callback must be a function")
|
|
}
|
|
if t.NumIn() != 1 {
|
|
return nil, fmt.Errorf("the ForEachChunk callback must have 1 argument")
|
|
}
|
|
|
|
if t.NumOut() != 1 {
|
|
return nil, fmt.Errorf("the ForEachChunk callback must have a single return value")
|
|
}
|
|
|
|
if t.Out(0) != errType {
|
|
return nil, fmt.Errorf("the return value of the ForEachChunk callback must be of type error")
|
|
}
|
|
|
|
argsType := t.In(0)
|
|
if argsType.Kind() != reflect.Slice {
|
|
return nil, fmt.Errorf("the argument of the ForEachChunk callback must a slice of structs")
|
|
}
|
|
|
|
return argsType, nil
|
|
}
|
|
|
|
type nopScanner struct{}
|
|
|
|
var nopScannerValue = reflect.ValueOf(&nopScanner{}).Interface()
|
|
|
|
func (nopScanner) Scan(value interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error {
|
|
v := reflect.ValueOf(record)
|
|
t := v.Type()
|
|
if t.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record)
|
|
}
|
|
|
|
v = v.Elem()
|
|
t = t.Elem()
|
|
|
|
if t.Kind() != reflect.Struct {
|
|
return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record)
|
|
}
|
|
|
|
info := structs.GetTagInfo(t)
|
|
|
|
var scanArgs []interface{}
|
|
if info.IsNestedStruct {
|
|
// This version is positional meaning that it expect the arguments
|
|
// to follow an specific order. It's ok because we don't allow the
|
|
// user to type the "SELECT" part of the query for nested structs.
|
|
scanArgs = getScanArgsForNestedStructs(dialect, rows, t, v, info)
|
|
} else {
|
|
names, err := rows.Columns()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Since this version uses the names of the columns it works
|
|
// with any order of attributes/columns.
|
|
scanArgs = getScanArgsFromNames(dialect, names, v, info)
|
|
}
|
|
|
|
return rows.Scan(scanArgs...)
|
|
}
|
|
|
|
func getScanArgsForNestedStructs(dialect dialect, rows *sql.Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) []interface{} {
|
|
scanArgs := []interface{}{}
|
|
for i := 0; i < v.NumField(); i++ {
|
|
// TODO(vingarcia00): Handle case where type is pointer
|
|
nestedStructInfo := structs.GetTagInfo(t.Field(i).Type)
|
|
nestedStructValue := v.Field(i)
|
|
for j := 0; j < nestedStructValue.NumField(); j++ {
|
|
fieldInfo := nestedStructInfo.ByIndex(j)
|
|
|
|
valueScanner := nopScannerValue
|
|
if fieldInfo.Valid {
|
|
valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface()
|
|
if fieldInfo.SerializeAsJSON {
|
|
valueScanner = &jsonSerializable{
|
|
DriverName: dialect.DriverName(),
|
|
Attr: valueScanner,
|
|
}
|
|
}
|
|
}
|
|
|
|
scanArgs = append(scanArgs, valueScanner)
|
|
}
|
|
}
|
|
|
|
return scanArgs
|
|
}
|
|
|
|
func getScanArgsFromNames(dialect dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} {
|
|
scanArgs := []interface{}{}
|
|
for _, name := range names {
|
|
fieldInfo := info.ByName(name)
|
|
|
|
valueScanner := nopScannerValue
|
|
if fieldInfo.Valid {
|
|
valueScanner = v.Field(fieldInfo.Index).Addr().Interface()
|
|
if fieldInfo.SerializeAsJSON {
|
|
valueScanner = &jsonSerializable{
|
|
DriverName: dialect.DriverName(),
|
|
Attr: valueScanner,
|
|
}
|
|
}
|
|
}
|
|
|
|
scanArgs = append(scanArgs, valueScanner)
|
|
}
|
|
|
|
return scanArgs
|
|
}
|
|
|
|
func buildSingleKeyDeleteQuery(
|
|
dialect dialect,
|
|
table string,
|
|
idName string,
|
|
idMaps []map[string]interface{},
|
|
) (query string, params []interface{}) {
|
|
values := []string{}
|
|
for i, m := range idMaps {
|
|
values = append(values, dialect.Placeholder(i))
|
|
params = append(params, m[idName])
|
|
}
|
|
|
|
return fmt.Sprintf(
|
|
"DELETE FROM %s WHERE %s IN (%s)",
|
|
dialect.Escape(table),
|
|
dialect.Escape(idName),
|
|
strings.Join(values, ","),
|
|
), params
|
|
}
|
|
|
|
func buildCompositeKeyDeleteQuery(
|
|
dialect dialect,
|
|
table string,
|
|
idNames []string,
|
|
idMaps []map[string]interface{},
|
|
) (query string, params []interface{}) {
|
|
escapedNames := []string{}
|
|
for _, name := range idNames {
|
|
escapedNames = append(escapedNames, dialect.Escape(name))
|
|
}
|
|
|
|
values := []string{}
|
|
for _, m := range idMaps {
|
|
tuple := []string{}
|
|
for _, name := range idNames {
|
|
params = append(params, m[name])
|
|
tuple = append(tuple, dialect.Placeholder(len(values)))
|
|
}
|
|
values = append(values, "("+strings.Join(tuple, ",")+")")
|
|
}
|
|
|
|
return fmt.Sprintf(
|
|
"DELETE FROM %s WHERE (%s) IN (VALUES %s)",
|
|
dialect.Escape(table),
|
|
strings.Join(escapedNames, ","),
|
|
strings.Join(values, ","),
|
|
), params
|
|
}
|
|
|
|
// We implemented this function instead of using
|
|
// a regex or strings.Fields because we wanted
|
|
// to preserve the performance of the package.
|
|
func getFirstToken(s string) string {
|
|
s = strings.TrimLeftFunc(s, unicode.IsSpace)
|
|
|
|
var token strings.Builder
|
|
for _, c := range s {
|
|
if unicode.IsSpace(c) {
|
|
break
|
|
}
|
|
token.WriteRune(c)
|
|
}
|
|
return token.String()
|
|
}
|
|
|
|
func buildSelectQuery(
|
|
dialect dialect,
|
|
structType reflect.Type,
|
|
info structs.StructInfo,
|
|
selectQueryCache map[reflect.Type]string,
|
|
) (query string, err error) {
|
|
if selectQuery, found := selectQueryCache[structType]; found {
|
|
return selectQuery, nil
|
|
}
|
|
|
|
if info.IsNestedStruct {
|
|
query, err = buildSelectQueryForNestedStructs(dialect, structType, info)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
} else {
|
|
query = buildSelectQueryForPlainStructs(dialect, structType, info)
|
|
}
|
|
|
|
selectQueryCache[structType] = query
|
|
return query, nil
|
|
}
|
|
|
|
func buildSelectQueryForPlainStructs(
|
|
dialect dialect,
|
|
structType reflect.Type,
|
|
info structs.StructInfo,
|
|
) string {
|
|
var fields []string
|
|
for i := 0; i < structType.NumField(); i++ {
|
|
fields = append(fields, dialect.Escape(info.ByIndex(i).Name))
|
|
}
|
|
|
|
return "SELECT " + strings.Join(fields, ", ") + " "
|
|
}
|
|
|
|
func buildSelectQueryForNestedStructs(
|
|
dialect dialect,
|
|
structType reflect.Type,
|
|
info structs.StructInfo,
|
|
) (string, error) {
|
|
var fields []string
|
|
for i := 0; i < structType.NumField(); i++ {
|
|
nestedStructName := info.ByIndex(i).Name
|
|
nestedStructType := structType.Field(i).Type
|
|
if nestedStructType.Kind() != reflect.Struct {
|
|
return "", fmt.Errorf(
|
|
"expected nested struct with `tablename:\"%s\"` to be a kind of Struct, but got %v",
|
|
nestedStructName, nestedStructType,
|
|
)
|
|
}
|
|
|
|
nestedStructInfo := structs.GetTagInfo(nestedStructType)
|
|
for j := 0; j < structType.Field(i).Type.NumField(); j++ {
|
|
fields = append(
|
|
fields,
|
|
dialect.Escape(nestedStructName)+"."+dialect.Escape(nestedStructInfo.ByIndex(j).Name),
|
|
)
|
|
}
|
|
}
|
|
|
|
return "SELECT " + strings.Join(fields, ", ") + " ", nil
|
|
}
|