mirror of
https://github.com/VinGarcia/ksql.git
synced 2025-05-03 22:22:15 +00:00
As known, a map can't be accessed with read/write concurrently on multiple goroutines. This just replaces all uses of global maps for caches with sync.Map, which is safe to be used concurrently.
1148 lines
28 KiB
Go
1148 lines
28 KiB
Go
package ksql
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io"
|
|
"reflect"
|
|
"strings"
|
|
"sync"
|
|
"unicode"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/vingarcia/ksql/internal/structs"
|
|
"github.com/vingarcia/ksql/ksqltest"
|
|
)
|
|
|
|
var selectQueryCache = initializeQueryCache()
|
|
|
|
func initializeQueryCache() map[string]*sync.Map {
|
|
cache := map[string]*sync.Map{}
|
|
for dname := range supportedDialects {
|
|
cache[dname] = &sync.Map{}
|
|
}
|
|
|
|
return cache
|
|
}
|
|
|
|
// DB represents the KSQL client responsible for
|
|
// interfacing with the "database/sql" package implementing
|
|
// the KSQL interface `ksql.Provider`.
|
|
type DB struct {
|
|
driver string
|
|
dialect Dialect
|
|
db DBAdapter
|
|
}
|
|
|
|
// DBAdapter is minimalistic interface to decouple our implementation
|
|
// from database/sql, i.e. if any struct implements the functions below
|
|
// with the exact same semantic as the sql package it will work with KSQL.
|
|
//
|
|
// To create a new client using this adapter use `ksql.NewWithAdapter()`
|
|
type DBAdapter interface {
|
|
ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error)
|
|
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
|
|
}
|
|
|
|
// TxBeginner needs to be implemented by the DBAdapter in order to make it possible
|
|
// to use the `ksql.Transaction()` function.
|
|
type TxBeginner interface {
|
|
BeginTx(ctx context.Context) (Tx, error)
|
|
}
|
|
|
|
// Result stores information about the result of an Exec query
|
|
type Result interface {
|
|
LastInsertId() (int64, error)
|
|
RowsAffected() (int64, error)
|
|
}
|
|
|
|
// Rows represents the results from a call to Query()
|
|
type Rows interface {
|
|
Scan(...interface{}) error
|
|
Close() error
|
|
Next() bool
|
|
Err() error
|
|
Columns() ([]string, error)
|
|
}
|
|
|
|
// Tx represents a transaction and is expected to be returned by the DBAdapter.BeginTx function
|
|
type Tx interface {
|
|
DBAdapter
|
|
|
|
Rollback(ctx context.Context) error
|
|
Commit(ctx context.Context) error
|
|
}
|
|
|
|
// Config describes the optional arguments accepted
|
|
// by the `ksql.New()` function.
|
|
type Config struct {
|
|
// MaxOpenCons defaults to 1 if not set
|
|
MaxOpenConns int
|
|
|
|
// Used by some adapters (such as kpgx) where nil disables TLS
|
|
TLSConfig *tls.Config
|
|
}
|
|
|
|
// SetDefaultValues should be called by all adapters
|
|
// to set the default config values if unset.
|
|
func (c *Config) SetDefaultValues() {
|
|
if c.MaxOpenConns == 0 {
|
|
c.MaxOpenConns = 1
|
|
}
|
|
}
|
|
|
|
// NewWithAdapter allows the user to insert a custom implementation
|
|
// of the DBAdapter interface
|
|
func NewWithAdapter(
|
|
db DBAdapter,
|
|
dialectName string,
|
|
) (DB, error) {
|
|
dialect := supportedDialects[dialectName]
|
|
if dialect == nil {
|
|
return DB{}, fmt.Errorf("unsupported driver `%s`", dialectName)
|
|
}
|
|
|
|
return DB{
|
|
dialect: dialect,
|
|
driver: dialectName,
|
|
db: db,
|
|
}, nil
|
|
}
|
|
|
|
// Query queries several rows from the database,
|
|
// the input should be a slice of structs (or *struct) passed
|
|
// by reference and it will be filled with all the results.
|
|
//
|
|
// Note: it is very important to make sure the query will
|
|
// return a small known number of results, otherwise you risk
|
|
// of overloading the available memory.
|
|
func (c DB) Query(
|
|
ctx context.Context,
|
|
records interface{},
|
|
query string,
|
|
params ...interface{},
|
|
) error {
|
|
slicePtr := reflect.ValueOf(records)
|
|
slicePtrType := slicePtr.Type()
|
|
if slicePtrType.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("ksql: expected to receive a pointer to slice of structs, but got: %T", records)
|
|
}
|
|
sliceType := slicePtrType.Elem()
|
|
slice := slicePtr.Elem()
|
|
structType, isSliceOfPtrs, err := structs.DecodeAsSliceOfStructs(sliceType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if isSliceOfPtrs {
|
|
// Truncate the slice so there is no risk
|
|
// of overwritting records that were already saved
|
|
// on the slice:
|
|
slice = slice.Slice(0, 0)
|
|
}
|
|
|
|
info, err := structs.GetTagInfo(structType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
firstToken := strings.ToUpper(getFirstToken(query))
|
|
if info.IsNestedStruct && firstToken == "SELECT" {
|
|
// This error check is necessary, since if we can't build the select part of the query this feature won't work.
|
|
return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query")
|
|
}
|
|
|
|
if firstToken == "FROM" {
|
|
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
query = selectPrefix + query
|
|
}
|
|
|
|
rows, err := c.db.QueryContext(ctx, query, params...)
|
|
if err != nil {
|
|
return fmt.Errorf("error running query: %s", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
for idx := 0; rows.Next(); idx++ {
|
|
// Allocate new slice elements
|
|
// only if they are not already allocated:
|
|
if slice.Len() <= idx {
|
|
var elemValue reflect.Value
|
|
elemValue = reflect.New(structType)
|
|
if !isSliceOfPtrs {
|
|
elemValue = elemValue.Elem()
|
|
}
|
|
slice = reflect.Append(slice, elemValue)
|
|
}
|
|
|
|
elemPtr := slice.Index(idx).Addr()
|
|
if isSliceOfPtrs {
|
|
// This is necessary since scanRows expects a *record not a **record
|
|
elemPtr = elemPtr.Elem()
|
|
}
|
|
|
|
err = scanRows(c.dialect, rows, elemPtr.Interface())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
if rows.Err() != nil {
|
|
return rows.Err()
|
|
}
|
|
|
|
if err := rows.Close(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Update the original slice passed by reference:
|
|
slicePtr.Elem().Set(slice)
|
|
|
|
return nil
|
|
}
|
|
|
|
// QueryOne queries one instance from the database,
|
|
// the input struct must be passed by reference
|
|
// and the query should return only one result.
|
|
//
|
|
// QueryOne returns a ErrRecordNotFound if
|
|
// the query returns no results.
|
|
func (c DB) QueryOne(
|
|
ctx context.Context,
|
|
record interface{},
|
|
query string,
|
|
params ...interface{},
|
|
) error {
|
|
v := reflect.ValueOf(record)
|
|
t := v.Type()
|
|
if t.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
|
|
}
|
|
|
|
if v.IsNil() {
|
|
return fmt.Errorf("ksql: expected a valid pointer to struct as argument but received a nil pointer: %v", record)
|
|
}
|
|
|
|
tStruct := t.Elem()
|
|
if tStruct.Kind() != reflect.Struct {
|
|
return fmt.Errorf("ksql: expected to receive a pointer to struct, but got: %T", record)
|
|
}
|
|
|
|
info, err := structs.GetTagInfo(tStruct)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
firstToken := strings.ToUpper(getFirstToken(query))
|
|
if info.IsNestedStruct && firstToken == "SELECT" {
|
|
// This error check is necessary, since if we can't build the select part of the query this feature won't work.
|
|
return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query")
|
|
}
|
|
|
|
if firstToken == "FROM" {
|
|
selectPrefix, err := buildSelectQuery(c.dialect, tStruct, info, selectQueryCache[c.dialect.DriverName()])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
query = selectPrefix + query
|
|
}
|
|
|
|
rows, err := c.db.QueryContext(ctx, query, params...)
|
|
if err != nil {
|
|
return fmt.Errorf("error running query: %s", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
if !rows.Next() {
|
|
if rows.Err() != nil {
|
|
return rows.Err()
|
|
}
|
|
return ErrRecordNotFound
|
|
}
|
|
|
|
err = scanRowsFromType(c.dialect, rows, record, t, v)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return rows.Close()
|
|
}
|
|
|
|
// QueryChunks is meant to perform queries that returns
|
|
// more results than would normally fit on memory,
|
|
// for others cases the Query and QueryOne functions are indicated.
|
|
//
|
|
// The ChunkParser argument has 4 attributes:
|
|
// (1) The Query;
|
|
// (2) The query args;
|
|
// (3) The chunk size;
|
|
// (4) A callback function called ForEachChunk, that will be called
|
|
// to process each chunk loaded from the database.
|
|
//
|
|
// Note that the signature of the ForEachChunk callback can be
|
|
// any function that receives a slice of structs or a slice of
|
|
// pointers to struct as its only argument and that reflection
|
|
// will be used to instantiate this argument and to fill it
|
|
// with the database rows.
|
|
func (c DB) QueryChunks(
|
|
ctx context.Context,
|
|
parser ChunkParser,
|
|
) error {
|
|
fnValue := reflect.ValueOf(parser.ForEachChunk)
|
|
chunkType, err := structs.ParseInputFunc(parser.ForEachChunk)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
chunk := reflect.MakeSlice(chunkType, 0, parser.ChunkSize)
|
|
|
|
structType, isSliceOfPtrs, err := structs.DecodeAsSliceOfStructs(chunkType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
info, err := structs.GetTagInfo(structType)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
firstToken := strings.ToUpper(getFirstToken(parser.Query))
|
|
if info.IsNestedStruct && firstToken == "SELECT" {
|
|
// This error check is necessary, since if we can't build the select part of the query this feature won't work.
|
|
return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query")
|
|
}
|
|
|
|
if firstToken == "FROM" {
|
|
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
|
|
if err != nil {
|
|
return err
|
|
}
|
|
parser.Query = selectPrefix + parser.Query
|
|
}
|
|
|
|
rows, err := c.db.QueryContext(ctx, parser.Query, parser.Params...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var idx = 0
|
|
for rows.Next() {
|
|
// Allocate new slice elements
|
|
// only if they are not already allocated:
|
|
if chunk.Len() <= idx {
|
|
var elemValue reflect.Value
|
|
elemValue = reflect.New(structType)
|
|
if !isSliceOfPtrs {
|
|
elemValue = elemValue.Elem()
|
|
}
|
|
chunk = reflect.Append(chunk, elemValue)
|
|
}
|
|
|
|
err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if idx < parser.ChunkSize-1 {
|
|
idx++
|
|
continue
|
|
}
|
|
|
|
idx = 0
|
|
err, _ = fnValue.Call([]reflect.Value{chunk})[0].Interface().(error)
|
|
if err != nil {
|
|
if err == ErrAbortIteration {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
if err := rows.Close(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// If Next() returned false because of an error:
|
|
if rows.Err() != nil {
|
|
return rows.Err()
|
|
}
|
|
|
|
// If no rows were found or idx was reset to 0
|
|
// on the last iteration skip this last call to ForEachChunk:
|
|
if idx > 0 {
|
|
chunk = chunk.Slice(0, idx)
|
|
|
|
err, _ = fnValue.Call([]reflect.Value{chunk})[0].Interface().(error)
|
|
if err != nil {
|
|
if err == ErrAbortIteration {
|
|
return nil
|
|
}
|
|
return err
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Insert one or more instances on the database
|
|
//
|
|
// If the original instances have been passed by reference
|
|
// the ID is automatically updated after insertion is completed.
|
|
func (c DB) Insert(
|
|
ctx context.Context,
|
|
table Table,
|
|
record interface{},
|
|
) error {
|
|
v := reflect.ValueOf(record)
|
|
t := v.Type()
|
|
if err := assertStructPtr(t); err != nil {
|
|
return fmt.Errorf(
|
|
"ksql: expected record to be a pointer to struct, but got: %T",
|
|
record,
|
|
)
|
|
}
|
|
|
|
if v.IsNil() {
|
|
return fmt.Errorf("ksql: expected a valid pointer to struct as argument but received a nil pointer: %v", record)
|
|
}
|
|
|
|
if err := table.validate(); err != nil {
|
|
return fmt.Errorf("can't insert in ksql.Table: %s", err)
|
|
}
|
|
|
|
info, err := structs.GetTagInfo(t.Elem())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
query, params, scanValues, err := buildInsertQuery(c.dialect, table, t, v, info, record)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
switch table.insertMethodFor(c.dialect) {
|
|
case insertWithReturning, insertWithOutput:
|
|
err = c.insertReturningIDs(ctx, query, params, scanValues, table.idColumns)
|
|
case insertWithLastInsertID:
|
|
err = c.insertWithLastInsertID(ctx, t, v, info, record, query, params, table.idColumns[0])
|
|
case insertWithNoIDRetrieval:
|
|
err = c.insertWithNoIDRetrieval(ctx, query, params)
|
|
default:
|
|
// Unsupported drivers should be detected on the New() function,
|
|
// So we don't expect the code to ever get into this default case.
|
|
err = fmt.Errorf("code error: unsupported driver `%s`", c.driver)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (c DB) insertReturningIDs(
|
|
ctx context.Context,
|
|
query string,
|
|
params []interface{},
|
|
scanValues []interface{},
|
|
idNames []string,
|
|
) error {
|
|
rows, err := c.db.QueryContext(ctx, query, params...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer rows.Close()
|
|
|
|
if !rows.Next() {
|
|
err := fmt.Errorf("unexpected error when retrieving the id columns from the database")
|
|
if rows.Err() != nil {
|
|
err = rows.Err()
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
err = rows.Scan(scanValues...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return rows.Close()
|
|
}
|
|
|
|
func (c DB) insertWithLastInsertID(
|
|
ctx context.Context,
|
|
t reflect.Type,
|
|
v reflect.Value,
|
|
info structs.StructInfo,
|
|
record interface{},
|
|
query string,
|
|
params []interface{},
|
|
idName string,
|
|
) error {
|
|
result, err := c.db.ExecContext(ctx, query, params...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
vID := reflect.ValueOf(id)
|
|
tID := vID.Type()
|
|
|
|
fieldAddr := v.Elem().Field(info.ByName(idName).Index).Addr()
|
|
fieldType := fieldAddr.Type().Elem()
|
|
|
|
if !tID.ConvertibleTo(fieldType) {
|
|
return fmt.Errorf(
|
|
"can't convert last insert id of type int64 into field `%s` of type %v",
|
|
idName,
|
|
fieldType,
|
|
)
|
|
}
|
|
|
|
fieldAddr.Elem().Set(vID.Convert(fieldType))
|
|
return nil
|
|
}
|
|
|
|
func (c DB) insertWithNoIDRetrieval(
|
|
ctx context.Context,
|
|
query string,
|
|
params []interface{},
|
|
) error {
|
|
_, err := c.db.ExecContext(ctx, query, params...)
|
|
return err
|
|
}
|
|
|
|
func assertStructPtr(t reflect.Type) error {
|
|
if t.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("expected a Kind of Ptr but got: %s", t)
|
|
}
|
|
if t.Elem().Kind() != reflect.Struct {
|
|
return fmt.Errorf("expected a Kind of Ptr to Struct but got: %s", t)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Delete deletes one record from the database using the ID or IDs
|
|
// defined on the `ksql.Table` passed as second argument.
|
|
//
|
|
// For tables with a single ID column you can pass the record
|
|
// to be deleted as a struct, as a map or just pass the ID itself.
|
|
//
|
|
// For tables with composite keys you must pass the record
|
|
// as a struct or a map so that KSQL can read all the composite keys
|
|
// from it.
|
|
//
|
|
// The examples below should work for both types of tables:
|
|
//
|
|
// err := c.Delete(ctx, UsersTable, user)
|
|
//
|
|
// err := c.Delete(ctx, UserPostsTable, map[string]interface{}{
|
|
// "user_id": user.ID,
|
|
// "post_id": post.ID,
|
|
// })
|
|
//
|
|
// The example below is shorter but will only work for tables with a single primary key:
|
|
//
|
|
// err := c.Delete(ctx, UsersTable, user.ID)
|
|
//
|
|
func (c DB) Delete(
|
|
ctx context.Context,
|
|
table Table,
|
|
idOrRecord interface{},
|
|
) error {
|
|
if err := table.validate(); err != nil {
|
|
return fmt.Errorf("can't delete from ksql.Table: %s", err)
|
|
}
|
|
|
|
idMap, err := normalizeIDsAsMap(table.idColumns, idOrRecord)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var query string
|
|
var params []interface{}
|
|
query, params = buildDeleteQuery(c.dialect, table, idMap)
|
|
|
|
result, err := c.db.ExecContext(ctx, query, params...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
n, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("unable to check if the record was succesfully deleted: %s", err)
|
|
}
|
|
|
|
if n == 0 {
|
|
return ErrRecordNotFound
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func normalizeIDsAsMap(idNames []string, idOrMap interface{}) (idMap map[string]interface{}, err error) {
|
|
if len(idNames) == 0 {
|
|
return nil, fmt.Errorf("internal ksql error: missing idNames")
|
|
}
|
|
|
|
t := reflect.TypeOf(idOrMap)
|
|
if t.Kind() == reflect.Ptr {
|
|
v := reflect.ValueOf(idOrMap)
|
|
if v.IsNil() {
|
|
return nil, fmt.Errorf("ksql: expected a valid pointer to struct as argument but received a nil pointer: %v", idOrMap)
|
|
}
|
|
t = t.Elem()
|
|
}
|
|
|
|
switch t.Kind() {
|
|
case reflect.Struct:
|
|
idMap, err = ksqltest.StructToMap(idOrMap)
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, "could not get ID(s) from input record")
|
|
}
|
|
case reflect.Map:
|
|
var ok bool
|
|
idMap, ok = idOrMap.(map[string]interface{})
|
|
if !ok {
|
|
return nil, fmt.Errorf("expected map[string]interface{} but got %T", idOrMap)
|
|
}
|
|
default:
|
|
idMap = map[string]interface{}{
|
|
idNames[0]: idOrMap,
|
|
}
|
|
}
|
|
|
|
for _, idName := range idNames {
|
|
id, found := idMap[idName]
|
|
if !found {
|
|
return nil, fmt.Errorf("missing required id field `%s` on input record", idName)
|
|
}
|
|
|
|
if id == nil || reflect.ValueOf(id).IsZero() {
|
|
return nil, fmt.Errorf("invalid value '%v' received for id column: '%s'", id, idName)
|
|
}
|
|
}
|
|
|
|
return idMap, nil
|
|
}
|
|
|
|
// Update updates the given instances on the database by id.
|
|
//
|
|
// Partial updates are supported, i.e. it will ignore nil pointer attributes
|
|
//
|
|
// Deprecated: Use the Patch method instead
|
|
func (c DB) Update(
|
|
ctx context.Context,
|
|
table Table,
|
|
record interface{},
|
|
) error {
|
|
return c.Patch(ctx, table, record)
|
|
}
|
|
|
|
// Patch applies a partial update (explained below) to the given instance on the database by id.
|
|
//
|
|
// Partial updates will ignore any nil pointer attributes from the struct, updating only
|
|
// the non nil pointers and non pointer attributes.
|
|
func (c DB) Patch(
|
|
ctx context.Context,
|
|
table Table,
|
|
record interface{},
|
|
) error {
|
|
v := reflect.ValueOf(record)
|
|
t := v.Type()
|
|
tStruct := t
|
|
if t.Kind() == reflect.Ptr {
|
|
if v.IsNil() {
|
|
return fmt.Errorf("ksql: expected a valid pointer to struct as argument but received a nil pointer: %v", record)
|
|
}
|
|
tStruct = t.Elem()
|
|
}
|
|
info, err := structs.GetTagInfo(tStruct)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
query, params, err := buildUpdateQuery(c.dialect, table.name, info, record, table.idColumns...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
result, err := c.db.ExecContext(ctx, query, params...)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
n, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf(
|
|
"unexpected error: unable to fetch how many rows were affected by the update: %s",
|
|
err,
|
|
)
|
|
}
|
|
if n < 1 {
|
|
return ErrRecordNotFound
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func buildInsertQuery(
|
|
dialect Dialect,
|
|
table Table,
|
|
t reflect.Type,
|
|
v reflect.Value,
|
|
info structs.StructInfo,
|
|
record interface{},
|
|
) (query string, params []interface{}, scanValues []interface{}, err error) {
|
|
recordMap, err := ksqltest.StructToMap(record)
|
|
if err != nil {
|
|
return "", nil, nil, err
|
|
}
|
|
|
|
for _, fieldName := range table.idColumns {
|
|
field, found := recordMap[fieldName]
|
|
if !found {
|
|
continue
|
|
}
|
|
|
|
// Remove any ID field that was not set:
|
|
if reflect.ValueOf(field).IsZero() {
|
|
delete(recordMap, fieldName)
|
|
}
|
|
}
|
|
|
|
columnNames := []string{}
|
|
for col := range recordMap {
|
|
columnNames = append(columnNames, col)
|
|
}
|
|
|
|
params = make([]interface{}, len(recordMap))
|
|
valuesQuery := make([]string, len(recordMap))
|
|
for i, col := range columnNames {
|
|
recordValue := recordMap[col]
|
|
params[i] = recordValue
|
|
if info.ByName(col).SerializeAsJSON {
|
|
params[i] = jsonSerializable{
|
|
DriverName: dialect.DriverName(),
|
|
Attr: recordValue,
|
|
}
|
|
}
|
|
|
|
valuesQuery[i] = dialect.Placeholder(i)
|
|
}
|
|
|
|
// Escape all cols to be sure they will be interpreted as column names:
|
|
escapedColumnNames := []string{}
|
|
for _, col := range columnNames {
|
|
escapedColumnNames = append(escapedColumnNames, dialect.Escape(col))
|
|
}
|
|
|
|
var returningQuery, outputQuery string
|
|
switch dialect.InsertMethod() {
|
|
case insertWithReturning:
|
|
escapedIDNames := []string{}
|
|
for _, id := range table.idColumns {
|
|
escapedIDNames = append(escapedIDNames, dialect.Escape(id))
|
|
}
|
|
returningQuery = " RETURNING " + strings.Join(escapedIDNames, ", ")
|
|
|
|
for _, id := range table.idColumns {
|
|
scanValues = append(
|
|
scanValues,
|
|
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
|
|
)
|
|
}
|
|
case insertWithOutput:
|
|
escapedIDNames := []string{}
|
|
for _, id := range table.idColumns {
|
|
escapedIDNames = append(escapedIDNames, "INSERTED."+dialect.Escape(id))
|
|
}
|
|
outputQuery = " OUTPUT " + strings.Join(escapedIDNames, ", ")
|
|
|
|
for _, id := range table.idColumns {
|
|
scanValues = append(
|
|
scanValues,
|
|
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
|
|
)
|
|
}
|
|
}
|
|
|
|
// Note that the outputQuery and the returningQuery depend
|
|
// on the selected driver, thus, they might be empty strings.
|
|
query = fmt.Sprintf(
|
|
"INSERT INTO %s (%s)%s VALUES (%s)%s",
|
|
dialect.Escape(table.name),
|
|
strings.Join(escapedColumnNames, ", "),
|
|
outputQuery,
|
|
strings.Join(valuesQuery, ", "),
|
|
returningQuery,
|
|
)
|
|
|
|
return query, params, scanValues, nil
|
|
}
|
|
|
|
func buildUpdateQuery(
|
|
dialect Dialect,
|
|
tableName string,
|
|
info structs.StructInfo,
|
|
record interface{},
|
|
idFieldNames ...string,
|
|
) (query string, args []interface{}, err error) {
|
|
recordMap, err := ksqltest.StructToMap(record)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
numAttrs := len(recordMap)
|
|
args = make([]interface{}, numAttrs)
|
|
numNonIDArgs := numAttrs - len(idFieldNames)
|
|
whereArgs := args[numNonIDArgs:]
|
|
|
|
whereQuery := make([]string, len(idFieldNames))
|
|
for i, fieldName := range idFieldNames {
|
|
whereArgs[i] = recordMap[fieldName]
|
|
whereQuery[i] = fmt.Sprintf(
|
|
"%s = %s",
|
|
dialect.Escape(fieldName),
|
|
dialect.Placeholder(i+numNonIDArgs),
|
|
)
|
|
|
|
delete(recordMap, fieldName)
|
|
}
|
|
|
|
keys := []string{}
|
|
for key := range recordMap {
|
|
keys = append(keys, key)
|
|
}
|
|
|
|
var setQuery []string
|
|
for i, k := range keys {
|
|
recordValue := recordMap[k]
|
|
if info.ByName(k).SerializeAsJSON {
|
|
recordValue = jsonSerializable{
|
|
DriverName: dialect.DriverName(),
|
|
Attr: recordValue,
|
|
}
|
|
}
|
|
args[i] = recordValue
|
|
setQuery = append(setQuery, fmt.Sprintf(
|
|
"%s = %s",
|
|
dialect.Escape(k),
|
|
dialect.Placeholder(i),
|
|
))
|
|
}
|
|
|
|
query = fmt.Sprintf(
|
|
"UPDATE %s SET %s WHERE %s",
|
|
dialect.Escape(tableName),
|
|
strings.Join(setQuery, ", "),
|
|
strings.Join(whereQuery, ", "),
|
|
)
|
|
|
|
return query, args, nil
|
|
}
|
|
|
|
// Exec just runs an SQL command on the database returning no rows.
|
|
func (c DB) Exec(ctx context.Context, query string, params ...interface{}) (Result, error) {
|
|
return c.db.ExecContext(ctx, query, params...)
|
|
}
|
|
|
|
// Transaction just runs an SQL command on the database returning no rows.
|
|
func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error {
|
|
switch txBeginner := c.db.(type) {
|
|
case Tx:
|
|
return fn(c)
|
|
case TxBeginner:
|
|
tx, err := txBeginner.BeginTx(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
rollbackErr := tx.Rollback(ctx)
|
|
if rollbackErr != nil {
|
|
r = errors.Wrap(rollbackErr,
|
|
fmt.Sprintf("unable to rollback after panic with value: %v", r),
|
|
)
|
|
}
|
|
panic(r)
|
|
}
|
|
}()
|
|
|
|
dbCopy := c
|
|
dbCopy.db = tx
|
|
|
|
err = fn(dbCopy)
|
|
if err != nil {
|
|
rollbackErr := tx.Rollback(ctx)
|
|
if rollbackErr != nil {
|
|
err = errors.Wrap(rollbackErr,
|
|
fmt.Sprintf("unable to rollback after error: %s", err.Error()),
|
|
)
|
|
}
|
|
return err
|
|
}
|
|
|
|
return tx.Commit(ctx)
|
|
|
|
default:
|
|
return fmt.Errorf("can't start transaction: The DBAdapter doesn't implement the TxBeginner interface")
|
|
}
|
|
}
|
|
|
|
// Close implements the io.Closer interface
|
|
func (c DB) Close() error {
|
|
closer, ok := c.db.(io.Closer)
|
|
if ok {
|
|
return closer.Close()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type nopScanner struct{}
|
|
|
|
var nopScannerValue = reflect.ValueOf(&nopScanner{}).Interface()
|
|
|
|
func (nopScanner) Scan(value interface{}) error {
|
|
return nil
|
|
}
|
|
|
|
func scanRows(dialect Dialect, rows Rows, record interface{}) error {
|
|
v := reflect.ValueOf(record)
|
|
t := v.Type()
|
|
return scanRowsFromType(dialect, rows, record, t, v)
|
|
}
|
|
|
|
func scanRowsFromType(
|
|
dialect Dialect,
|
|
rows Rows,
|
|
record interface{},
|
|
t reflect.Type,
|
|
v reflect.Value,
|
|
) error {
|
|
if t.Kind() != reflect.Ptr {
|
|
return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record)
|
|
}
|
|
|
|
v = v.Elem()
|
|
t = t.Elem()
|
|
|
|
if t.Kind() != reflect.Struct {
|
|
return fmt.Errorf("ksql: expected record to be a pointer to struct, but got: %T", record)
|
|
}
|
|
|
|
info, err := structs.GetTagInfo(t)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var scanArgs []interface{}
|
|
if info.IsNestedStruct {
|
|
// This version is positional meaning that it expect the arguments
|
|
// to follow an specific order. It's ok because we don't allow the
|
|
// user to type the "SELECT" part of the query for nested structs.
|
|
scanArgs, err = getScanArgsForNestedStructs(dialect, rows, t, v, info)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
names, err := rows.Columns()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// Since this version uses the names of the columns it works
|
|
// with any order of attributes/columns.
|
|
scanArgs = getScanArgsFromNames(dialect, names, v, info)
|
|
}
|
|
|
|
return rows.Scan(scanArgs...)
|
|
}
|
|
|
|
func getScanArgsForNestedStructs(dialect Dialect, rows Rows, t reflect.Type, v reflect.Value, info structs.StructInfo) ([]interface{}, error) {
|
|
scanArgs := []interface{}{}
|
|
for i := 0; i < v.NumField(); i++ {
|
|
if !info.ByIndex(i).Valid {
|
|
continue
|
|
}
|
|
|
|
// TODO(vingarcia00): Handle case where type is pointer
|
|
nestedStructInfo, err := structs.GetTagInfo(t.Field(i).Type)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
nestedStructValue := v.Field(i)
|
|
for j := 0; j < nestedStructValue.NumField(); j++ {
|
|
fieldInfo := nestedStructInfo.ByIndex(j)
|
|
if !fieldInfo.Valid {
|
|
continue
|
|
}
|
|
|
|
valueScanner := nopScannerValue
|
|
if fieldInfo.Valid {
|
|
valueScanner = nestedStructValue.Field(fieldInfo.Index).Addr().Interface()
|
|
if fieldInfo.SerializeAsJSON {
|
|
valueScanner = &jsonSerializable{
|
|
DriverName: dialect.DriverName(),
|
|
Attr: valueScanner,
|
|
}
|
|
}
|
|
}
|
|
|
|
scanArgs = append(scanArgs, valueScanner)
|
|
}
|
|
}
|
|
|
|
return scanArgs, nil
|
|
}
|
|
|
|
func getScanArgsFromNames(dialect Dialect, names []string, v reflect.Value, info structs.StructInfo) []interface{} {
|
|
scanArgs := []interface{}{}
|
|
for _, name := range names {
|
|
fieldInfo := info.ByName(name)
|
|
|
|
valueScanner := nopScannerValue
|
|
if fieldInfo.Valid {
|
|
valueScanner = v.Field(fieldInfo.Index).Addr().Interface()
|
|
if fieldInfo.SerializeAsJSON {
|
|
valueScanner = &jsonSerializable{
|
|
DriverName: dialect.DriverName(),
|
|
Attr: valueScanner,
|
|
}
|
|
}
|
|
}
|
|
|
|
scanArgs = append(scanArgs, valueScanner)
|
|
}
|
|
|
|
return scanArgs
|
|
}
|
|
|
|
func buildDeleteQuery(
|
|
dialect Dialect,
|
|
table Table,
|
|
idMap map[string]interface{},
|
|
) (query string, params []interface{}) {
|
|
whereQuery := []string{}
|
|
for i, idName := range table.idColumns {
|
|
whereQuery = append(whereQuery, fmt.Sprintf(
|
|
"%s = %s", dialect.Escape(idName), dialect.Placeholder(i),
|
|
))
|
|
params = append(params, idMap[idName])
|
|
}
|
|
|
|
return fmt.Sprintf(
|
|
"DELETE FROM %s WHERE %s",
|
|
dialect.Escape(table.name),
|
|
strings.Join(whereQuery, " AND "),
|
|
), params
|
|
}
|
|
|
|
// We implemented this function instead of using
|
|
// a regex or strings.Fields because we wanted
|
|
// to preserve the performance of the package.
|
|
func getFirstToken(s string) string {
|
|
s = strings.TrimLeftFunc(s, unicode.IsSpace)
|
|
|
|
var token strings.Builder
|
|
for _, c := range s {
|
|
if unicode.IsSpace(c) {
|
|
break
|
|
}
|
|
token.WriteRune(c)
|
|
}
|
|
return token.String()
|
|
}
|
|
|
|
func buildSelectQuery(
|
|
dialect Dialect,
|
|
structType reflect.Type,
|
|
info structs.StructInfo,
|
|
selectQueryCache *sync.Map,
|
|
) (query string, err error) {
|
|
if data, found := selectQueryCache.Load(structType); found {
|
|
if selectQuery, ok := data.(string); !ok {
|
|
return "", fmt.Errorf("invalid cache entry, expected type string, found %T", data)
|
|
} else {
|
|
return selectQuery, nil
|
|
}
|
|
}
|
|
|
|
if info.IsNestedStruct {
|
|
query, err = buildSelectQueryForNestedStructs(dialect, structType, info)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
} else {
|
|
query = buildSelectQueryForPlainStructs(dialect, structType, info)
|
|
}
|
|
|
|
selectQueryCache.Store(structType, query)
|
|
return query, nil
|
|
}
|
|
|
|
func buildSelectQueryForPlainStructs(
|
|
dialect Dialect,
|
|
structType reflect.Type,
|
|
info structs.StructInfo,
|
|
) string {
|
|
var fields []string
|
|
for i := 0; i < structType.NumField(); i++ {
|
|
fieldInfo := info.ByIndex(i)
|
|
if !fieldInfo.Valid {
|
|
continue
|
|
}
|
|
|
|
fields = append(fields, dialect.Escape(fieldInfo.Name))
|
|
}
|
|
|
|
return "SELECT " + strings.Join(fields, ", ") + " "
|
|
}
|
|
|
|
func buildSelectQueryForNestedStructs(
|
|
dialect Dialect,
|
|
structType reflect.Type,
|
|
info structs.StructInfo,
|
|
) (string, error) {
|
|
var fields []string
|
|
for i := 0; i < structType.NumField(); i++ {
|
|
nestedStructInfo := info.ByIndex(i)
|
|
if !nestedStructInfo.Valid {
|
|
continue
|
|
}
|
|
|
|
nestedStructName := nestedStructInfo.Name
|
|
nestedStructType := structType.Field(i).Type
|
|
if nestedStructType.Kind() != reflect.Struct {
|
|
return "", fmt.Errorf(
|
|
"expected nested struct with `tablename:\"%s\"` to be a kind of Struct, but got %v",
|
|
nestedStructName, nestedStructType,
|
|
)
|
|
}
|
|
|
|
nestedStructTagInfo, err := structs.GetTagInfo(nestedStructType)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
for j := 0; j < structType.Field(i).Type.NumField(); j++ {
|
|
fieldInfo := nestedStructTagInfo.ByIndex(j)
|
|
if !fieldInfo.Valid {
|
|
continue
|
|
}
|
|
|
|
fields = append(
|
|
fields,
|
|
dialect.Escape(nestedStructName)+"."+dialect.Escape(fieldInfo.Name),
|
|
)
|
|
}
|
|
}
|
|
|
|
return "SELECT " + strings.Join(fields, ", ") + " ", nil
|
|
}
|