ksql/ksql.go

1257 lines
31 KiB
Go

package ksql
import (
"context"
"crypto/tls"
"fmt"
"io"
"reflect"
"strings"
"sync"
"unicode"
"github.com/vingarcia/ksql/internal/modifiers"
"github.com/vingarcia/ksql/internal/structs"
"github.com/vingarcia/ksql/ksqlmodifiers"
"github.com/vingarcia/ksql/sqldialect"
)
var selectQueryCache = initializeQueryCache()
func initializeQueryCache() map[string]*sync.Map {
cache := map[string]*sync.Map{}
for dname := range sqldialect.SupportedDialects {
cache[dname] = &sync.Map{}
}
return cache
}
// DB represents the KSQL client responsible for
// interfacing with the "database/sql" package implementing
// the KSQL interface `ksql.Provider`.
type DB struct {
dialect sqldialect.Provider
db DBAdapter
}
// DBAdapter is minimalistic interface to decouple our implementation
// from database/sql, i.e. if any struct implements the functions below
// with the exact same semantic as the sql package it will work with KSQL.
//
// To create a new client using this adapter use `ksql.NewWithAdapter()`
type DBAdapter interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
}
// TxBeginner needs to be implemented by the DBAdapter in order to make it possible
// to use the `ksql.Transaction()` function.
type TxBeginner interface {
BeginTx(ctx context.Context) (Tx, error)
}
// Result stores information about the result of an Exec query
type Result interface {
LastInsertId() (int64, error)
RowsAffected() (int64, error)
}
// Rows represents the results from a call to Query()
type Rows interface {
Scan(...interface{}) error
Close() error
Next() bool
Err() error
Columns() ([]string, error)
}
// ScanArgError is a type of error that is expected to be returned
// from the Scan() method of the Rows interface.
//
// It should be returned when there is an error scanning one of the input
// values.
//
// This is necessary in order to allow KSQL to produce a better and more
// readable error message when this type of error occur.
type ScanArgError struct {
ColumnIndex int
Err error
}
// Error implements the error interface.
func (s ScanArgError) Error() string {
return fmt.Sprintf(
"error scanning input attribute with index %d: %s",
s.ColumnIndex, s.Err,
)
}
func (s ScanArgError) ErrorWithStructNames(structName string, colName string) error {
return fmt.Errorf(
"error scanning %s.%s: %w",
structName, colName, s.Err,
)
}
// Tx represents a transaction and is expected to be returned by the DBAdapter.BeginTx function
type Tx interface {
DBAdapter
Rollback(ctx context.Context) error
Commit(ctx context.Context) error
}
// Config describes the optional arguments accepted
// by the `ksql.New()` function.
type Config struct {
// MaxOpenCons defaults to 1 if not set
MaxOpenConns int
// Used by some adapters (such as kpgx) where nil disables TLS
TLSConfig *tls.Config
}
// SetDefaultValues should be called by all adapters
// to set the default config values if unset.
func (c *Config) SetDefaultValues() {
if c.MaxOpenConns == 0 {
c.MaxOpenConns = 1
}
}
// NewWithAdapter allows the user to insert a custom implementation
// of the DBAdapter interface
func NewWithAdapter(
adapter DBAdapter,
dialect sqldialect.Provider,
) (DB, error) {
if dialect == nil {
return DB{}, fmt.Errorf("expected a valid sqldialect.Provider as argument but got `nil`")
}
return DB{
dialect: dialect,
db: adapter,
}, nil
}
// Query queries several rows from the database,
// the input should be a slice of structs (or *struct) passed
// by reference and it will be filled with all the results.
//
// Note: it is very important to make sure the query will
// return a small known number of results, otherwise you risk
// of overloading the available memory.
func (c DB) Query(
ctx context.Context,
records interface{},
query string,
params ...interface{},
) error {
slicePtr := reflect.ValueOf(records)
slicePtrType := slicePtr.Type()
if slicePtrType.Kind() != reflect.Ptr {
return fmt.Errorf("KSQL: expected to receive a pointer to slice of structs, but got: %T", records)
}
sliceType := slicePtrType.Elem()
slice := slicePtr.Elem()
structType, isSliceOfPtrs, err := structs.DecodeAsSliceOfStructs(sliceType)
if err != nil {
return err
}
if isSliceOfPtrs {
// Truncate the slice so there is no risk
// of overwritting records that were already saved
// on the slice:
slice = slice.Slice(0, 0)
}
info, err := structs.GetTagInfo(structType)
if err != nil {
return err
}
firstToken := strings.ToUpper(getFirstToken(query))
if info.IsNestedStruct && firstToken == "SELECT" {
// This error check is necessary, since if we can't build the select part of the query this feature won't work.
return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query")
}
if firstToken == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
if err != nil {
return err
}
query = selectPrefix + query
}
rows, err := c.db.QueryContext(ctx, query, params...)
if err != nil {
return fmt.Errorf("error running query: %w", err)
}
defer rows.Close()
for idx := 0; rows.Next(); idx++ {
// Allocate new slice elements
// only if they are not already allocated:
if slice.Len() <= idx {
var elemValue reflect.Value
elemValue = reflect.New(structType)
if !isSliceOfPtrs {
elemValue = elemValue.Elem()
}
slice = reflect.Append(slice, elemValue)
}
elemPtr := slice.Index(idx).Addr()
if isSliceOfPtrs {
// This is necessary since scanRows expects a *record not a **record
elemPtr = elemPtr.Elem()
}
err = scanRows(ctx, c.dialect, rows, elemPtr.Interface())
if err != nil {
return err
}
}
if rows.Err() != nil {
return fmt.Errorf("KSQL: unexpected error when parsing query result: %w", rows.Err())
}
if err := rows.Close(); err != nil {
return fmt.Errorf("KSQL: unexpected error when closing query result rows: %w", err)
}
// Update the original slice passed by reference:
slicePtr.Elem().Set(slice)
return nil
}
// QueryOne queries one instance from the database,
// the input struct must be passed by reference
// and the query should return only one result.
//
// QueryOne returns a ErrRecordNotFound if
// the query returns no results.
func (c DB) QueryOne(
ctx context.Context,
record interface{},
query string,
params ...interface{},
) error {
v := reflect.ValueOf(record)
t := v.Type()
if t.Kind() != reflect.Ptr {
return fmt.Errorf("KSQL: expected to receive a pointer to struct, but got: %T", record)
}
if v.IsNil() {
return fmt.Errorf("KSQL: expected a valid pointer to struct as argument but received a nil pointer: %v", record)
}
tStruct := t.Elem()
if tStruct.Kind() != reflect.Struct {
return fmt.Errorf("KSQL: expected to receive a pointer to struct, but got: %T", record)
}
info, err := structs.GetTagInfo(tStruct)
if err != nil {
return err
}
firstToken := strings.ToUpper(getFirstToken(query))
if info.IsNestedStruct && firstToken == "SELECT" {
// This error check is necessary, since if we can't build the select part of the query this feature won't work.
return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query")
}
if firstToken == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, tStruct, info, selectQueryCache[c.dialect.DriverName()])
if err != nil {
return err
}
query = selectPrefix + query
}
rows, err := c.db.QueryContext(ctx, query, params...)
if err != nil {
return fmt.Errorf("error running query: %w", err)
}
defer rows.Close()
if !rows.Next() {
if err := rows.Err(); err != nil {
return err
}
return ErrRecordNotFound
}
err = scanRowsFromType(ctx, c.dialect, rows, record, t, v)
if err != nil {
return err
}
return rows.Close()
}
// QueryChunks is meant to perform queries that returns
// more results than would normally fit on memory,
// for others cases the Query and QueryOne functions are indicated.
//
// The ChunkParser argument has 4 attributes:
// (1) The Query;
// (2) The query args;
// (3) The chunk size;
// (4) A callback function called ForEachChunk, that will be called
// to process each chunk loaded from the database.
//
// Note that the signature of the ForEachChunk callback can be
// any function that receives a slice of structs or a slice of
// pointers to struct as its only argument and that reflection
// will be used to instantiate this argument and to fill it
// with the database rows.
func (c DB) QueryChunks(
ctx context.Context,
parser ChunkParser,
) error {
fnValue := reflect.ValueOf(parser.ForEachChunk)
chunkType, err := structs.ParseInputFunc(parser.ForEachChunk)
if err != nil {
return err
}
chunk := reflect.MakeSlice(chunkType, 0, parser.ChunkSize)
structType, isSliceOfPtrs, err := structs.DecodeAsSliceOfStructs(chunkType)
if err != nil {
return err
}
info, err := structs.GetTagInfo(structType)
if err != nil {
return err
}
firstToken := strings.ToUpper(getFirstToken(parser.Query))
if info.IsNestedStruct && firstToken == "SELECT" {
// This error check is necessary, since if we can't build the select part of the query this feature won't work.
return fmt.Errorf("can't generate SELECT query for nested struct: when using this feature omit the SELECT part of the query")
}
if firstToken == "FROM" {
selectPrefix, err := buildSelectQuery(c.dialect, structType, info, selectQueryCache[c.dialect.DriverName()])
if err != nil {
return err
}
parser.Query = selectPrefix + parser.Query
}
rows, err := c.db.QueryContext(ctx, parser.Query, parser.Params...)
if err != nil {
return err
}
defer rows.Close()
var idx = 0
for rows.Next() {
// Allocate new slice elements
// only if they are not already allocated:
if chunk.Len() <= idx {
var elemValue reflect.Value
elemValue = reflect.New(structType)
if !isSliceOfPtrs {
elemValue = elemValue.Elem()
}
chunk = reflect.Append(chunk, elemValue)
}
err = scanRows(ctx, c.dialect, rows, chunk.Index(idx).Addr().Interface())
if err != nil {
return err
}
if idx < parser.ChunkSize-1 {
idx++
continue
}
idx = 0
err, _ = fnValue.Call([]reflect.Value{chunk})[0].Interface().(error)
if err != nil {
if err == ErrAbortIteration {
return nil
}
return err
}
}
if err := rows.Close(); err != nil {
return err
}
// If Next() returned false because of an error:
if rows.Err() != nil {
return rows.Err()
}
// If no rows were found or idx was reset to 0
// on the last iteration skip this last call to ForEachChunk:
if idx > 0 {
chunk = chunk.Slice(0, idx)
err, _ = fnValue.Call([]reflect.Value{chunk})[0].Interface().(error)
if err != nil {
if err == ErrAbortIteration {
return nil
}
return err
}
}
return nil
}
// Insert one or more instances on the database
//
// If the original instances have been passed by reference
// the ID is automatically updated after insertion is completed.
func (c DB) Insert(
ctx context.Context,
table Table,
record interface{},
) error {
v := reflect.ValueOf(record)
t := v.Type()
if err := assertStructPtr(t); err != nil {
return fmt.Errorf(
"KSQL: expected record to be a pointer to struct, but got: %T",
record,
)
}
if v.IsNil() {
return fmt.Errorf("KSQL: expected a valid pointer to struct as argument but received a nil pointer: %v", record)
}
if err := table.validate(); err != nil {
return fmt.Errorf("can't insert in ksql.Table: %w", err)
}
info, err := structs.GetTagInfo(t.Elem())
if err != nil {
return err
}
query, params, scanValues, err := buildInsertQuery(ctx, c.dialect, table, t, v, info, record)
if err != nil {
return err
}
switch table.insertMethodFor(c.dialect) {
case sqldialect.InsertWithReturning, sqldialect.InsertWithOutput:
err = c.insertReturningIDs(ctx, query, params, scanValues, table.idColumns)
case sqldialect.InsertWithLastInsertID:
err = c.insertWithLastInsertID(ctx, t, v, info, record, query, params, table.idColumns[0])
case sqldialect.InsertWithNoIDRetrieval:
err = c.insertWithNoIDRetrieval(ctx, query, params)
default:
// Unsupported drivers should be detected on the New() function,
// So we don't expect the code to ever get into this default case.
err = fmt.Errorf("code error: unsupported driver `%s`", c.dialect.DriverName())
}
return err
}
func (c DB) insertReturningIDs(
ctx context.Context,
query string,
params []interface{},
scanValues []interface{},
idNames []string,
) error {
rows, err := c.db.QueryContext(ctx, query, params...)
if err != nil {
return err
}
defer rows.Close()
if !rows.Next() {
err := fmt.Errorf("unexpected error when retrieving the id columns from the database")
if rows.Err() != nil {
err = rows.Err()
}
return err
}
err = rows.Scan(scanValues...)
if err != nil {
return err
}
return rows.Close()
}
func (c DB) insertWithLastInsertID(
ctx context.Context,
t reflect.Type,
v reflect.Value,
info structs.StructInfo,
record interface{},
query string,
params []interface{},
idName string,
) error {
result, err := c.db.ExecContext(ctx, query, params...)
if err != nil {
return err
}
id, err := result.LastInsertId()
if err != nil {
return err
}
vID := reflect.ValueOf(id)
tID := vID.Type()
fieldAddr := v.Elem().Field(info.ByName(idName).Index).Addr()
fieldType := fieldAddr.Type().Elem()
if !tID.ConvertibleTo(fieldType) {
return fmt.Errorf(
"can't convert last insert id of type int64 into field `%s` of type %v",
idName,
fieldType,
)
}
fieldAddr.Elem().Set(vID.Convert(fieldType))
return nil
}
func (c DB) insertWithNoIDRetrieval(
ctx context.Context,
query string,
params []interface{},
) error {
_, err := c.db.ExecContext(ctx, query, params...)
return err
}
func assertStructPtr(t reflect.Type) error {
if t.Kind() != reflect.Ptr {
return fmt.Errorf("expected a Kind of Ptr but got: %s", t)
}
if t.Elem().Kind() != reflect.Struct {
return fmt.Errorf("expected a Kind of Ptr to Struct but got: %s", t)
}
return nil
}
// Delete deletes one record from the database using the ID or IDs
// defined on the `ksql.Table` passed as second argument.
//
// For tables with a single ID column you can pass the record
// to be deleted as a struct, as a map or just pass the ID itself.
//
// For tables with composite keys you must pass the record
// as a struct or a map so that KSQL can read all the composite keys
// from it.
//
// The examples below should work for both types of tables:
//
// err := c.Delete(ctx, UsersTable, user)
//
// err := c.Delete(ctx, UserPostsTable, map[string]interface{}{
// "user_id": user.ID,
// "post_id": post.ID,
// })
//
// The example below is shorter but will only work for tables with a single primary key:
//
// err := c.Delete(ctx, UsersTable, user.ID)
//
func (c DB) Delete(
ctx context.Context,
table Table,
idOrRecord interface{},
) error {
if err := table.validate(); err != nil {
return fmt.Errorf("can't delete from ksql.Table: %w", err)
}
idMap, err := normalizeIDsAsMap(table.idColumns, idOrRecord)
if err != nil {
return err
}
var query string
var params []interface{}
query, params = buildDeleteQuery(c.dialect, table, idMap)
result, err := c.db.ExecContext(ctx, query, params...)
if err != nil {
return err
}
n, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("unable to check if the record was succesfully deleted: %w", err)
}
if n == 0 {
return ErrRecordNotFound
}
return err
}
func normalizeIDsAsMap(idNames []string, idOrMap interface{}) (idMap map[string]interface{}, err error) {
if len(idNames) == 0 {
return nil, fmt.Errorf("internal ksql error: missing idNames")
}
t := reflect.TypeOf(idOrMap)
if t.Kind() == reflect.Ptr {
v := reflect.ValueOf(idOrMap)
if v.IsNil() {
return nil, fmt.Errorf("KSQL: expected a valid pointer to struct as argument but received a nil pointer: %v", idOrMap)
}
t = t.Elem()
}
switch t.Kind() {
case reflect.Struct:
idMap, err = structs.StructToMap(idOrMap)
if err != nil {
return nil, fmt.Errorf("could not get ID(s) from input record: %w", err)
}
case reflect.Map:
var ok bool
idMap, ok = idOrMap.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("expected map[string]interface{} but got %T", idOrMap)
}
default:
idMap = map[string]interface{}{
idNames[0]: idOrMap,
}
}
return idMap, validateIfAllIdsArePresent(idNames, idMap)
}
// Patch applies a partial update (explained below) to the given instance on the database by id.
//
// Partial updates will ignore any nil pointer attributes from the struct, updating only
// the non nil pointers and non pointer attributes.
func (c DB) Patch(
ctx context.Context,
table Table,
record interface{},
) error {
v := reflect.ValueOf(record)
t := v.Type()
tStruct := t
if t.Kind() == reflect.Ptr {
if v.IsNil() {
return fmt.Errorf("KSQL: expected a valid pointer to struct as argument but received a nil pointer: %v", record)
}
tStruct = t.Elem()
}
info, err := structs.GetTagInfo(tStruct)
if err != nil {
return err
}
recordMap, err := structs.StructToMap(record)
if err != nil {
return err
}
query, params, err := buildUpdateQuery(ctx, c.dialect, table.name, info, recordMap, table.idColumns...)
if err != nil {
return err
}
result, err := c.db.ExecContext(ctx, query, params...)
if err != nil {
return err
}
n, err := result.RowsAffected()
if err != nil {
return fmt.Errorf(
"unexpected error: unable to fetch how many rows were affected by the update: %w",
err,
)
}
if n < 1 {
return ErrRecordNotFound
}
return nil
}
func buildInsertQuery(
ctx context.Context,
dialect sqldialect.Provider,
table Table,
t reflect.Type,
v reflect.Value,
info structs.StructInfo,
record interface{},
) (query string, params []interface{}, scanValues []interface{}, err error) {
recordMap, err := structs.StructToMap(record)
if err != nil {
return "", nil, nil, err
}
for _, fieldName := range table.idColumns {
field, found := recordMap[fieldName]
if !found {
continue
}
// Remove any ID field that was not set:
if reflect.ValueOf(field).IsZero() {
delete(recordMap, fieldName)
}
}
columnNames := []string{}
for col := range recordMap {
if info.ByName(col).Modifier.SkipOnInsert {
continue
}
columnNames = append(columnNames, col)
}
params = make([]interface{}, len(columnNames))
valuesQuery := make([]string, len(columnNames))
for i, col := range columnNames {
recordValue := recordMap[col]
params[i] = recordValue
valueFn := info.ByName(col).Modifier.Value
if valueFn != nil {
params[i] = modifiers.AttrValueWrapper{
Ctx: ctx,
Attr: recordValue,
ValueFn: valueFn,
OpInfo: ksqlmodifiers.OpInfo{
DriverName: dialect.DriverName(),
Method: "Insert",
},
}
}
valuesQuery[i] = dialect.Placeholder(i)
}
// Escape all cols to be sure they will be interpreted as column names:
escapedColumnNames := []string{}
for _, col := range columnNames {
escapedColumnNames = append(escapedColumnNames, dialect.Escape(col))
}
var returningQuery, outputQuery string
switch dialect.InsertMethod() {
case sqldialect.InsertWithReturning:
escapedIDNames := []string{}
for _, id := range table.idColumns {
escapedIDNames = append(escapedIDNames, dialect.Escape(id))
}
returningQuery = " RETURNING " + strings.Join(escapedIDNames, ", ")
for _, id := range table.idColumns {
scanValues = append(
scanValues,
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
)
}
case sqldialect.InsertWithOutput:
escapedIDNames := []string{}
for _, id := range table.idColumns {
escapedIDNames = append(escapedIDNames, "INSERTED."+dialect.Escape(id))
}
outputQuery = " OUTPUT " + strings.Join(escapedIDNames, ", ")
for _, id := range table.idColumns {
scanValues = append(
scanValues,
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
)
}
}
if len(columnNames) == 0 && dialect.DriverName() != "mysql" {
query = fmt.Sprintf(
"INSERT INTO %s%s DEFAULT VALUES%s",
dialect.Escape(table.name),
outputQuery,
returningQuery,
)
return query, params, scanValues, nil
}
// Note that the outputQuery and the returningQuery depend
// on the selected driver, thus, they might be empty strings.
query = fmt.Sprintf(
"INSERT INTO %s (%s)%s VALUES (%s)%s",
dialect.Escape(table.name),
strings.Join(escapedColumnNames, ", "),
outputQuery,
strings.Join(valuesQuery, ", "),
returningQuery,
)
return query, params, scanValues, nil
}
func buildUpdateQuery(
ctx context.Context,
dialect sqldialect.Provider,
tableName string,
info structs.StructInfo,
recordMap map[string]interface{},
idFieldNames ...string,
) (query string, args []interface{}, err error) {
for key := range recordMap {
if info.ByName(key).Modifier.SkipOnUpdate {
delete(recordMap, key)
}
}
numAttrs := len(recordMap)
args = make([]interface{}, numAttrs)
err = validateIfAllIdsArePresent(idFieldNames, recordMap)
if err != nil {
return "", nil, err
}
numNonIDArgs := numAttrs - len(idFieldNames)
whereArgs := args[numNonIDArgs:]
if numNonIDArgs == 0 {
return "", nil, ErrNoValuesToUpdate
}
whereQuery := make([]string, len(idFieldNames))
for i, fieldName := range idFieldNames {
whereArgs[i] = recordMap[fieldName]
whereQuery[i] = fmt.Sprintf(
"%s = %s",
dialect.Escape(fieldName),
dialect.Placeholder(i+numNonIDArgs),
)
delete(recordMap, fieldName)
}
keys := []string{}
for key := range recordMap {
keys = append(keys, key)
}
var setQuery []string
for i, k := range keys {
recordValue := recordMap[k]
valueFn := info.ByName(k).Modifier.Value
if valueFn != nil {
recordValue = modifiers.AttrValueWrapper{
Ctx: ctx,
Attr: recordValue,
ValueFn: valueFn,
OpInfo: ksqlmodifiers.OpInfo{
DriverName: dialect.DriverName(),
Method: "Update",
},
}
}
args[i] = recordValue
setQuery = append(setQuery, fmt.Sprintf(
"%s = %s",
dialect.Escape(k),
dialect.Placeholder(i),
))
}
query = fmt.Sprintf(
"UPDATE %s SET %s WHERE %s",
dialect.Escape(tableName),
strings.Join(setQuery, ", "),
strings.Join(whereQuery, " AND "),
)
return query, args, nil
}
func validateIfAllIdsArePresent(idNames []string, idMap map[string]interface{}) error {
for _, idName := range idNames {
id, found := idMap[idName]
if !found {
return fmt.Errorf("missing required id field `%s` on input record: %w", idName, ErrRecordMissingIDs)
}
if id == nil || reflect.ValueOf(id).IsZero() {
return fmt.Errorf("invalid value '%v' received for id column: '%s': %w", id, idName, ErrRecordMissingIDs)
}
}
return nil
}
// Exec just runs an SQL command on the database returning no rows.
func (c DB) Exec(ctx context.Context, query string, params ...interface{}) (Result, error) {
return c.db.ExecContext(ctx, query, params...)
}
// Transaction encapsulates several queries into a single transaction.
// All these queries should be made inside the input callback `fn`
// and they should use the input ksql.Provider.
//
// If the callback returns any errors the transaction will be rolled back,
// otherwise the transaction will me committed.
//
// If it happens that a second transaction is started inside a transaction
// callback the same transaction will be reused with no errors.
func (c DB) Transaction(ctx context.Context, fn func(Provider) error) error {
switch txBeginner := c.db.(type) {
case Tx:
return fn(c)
case TxBeginner:
tx, err := txBeginner.BeginTx(ctx)
if err != nil {
return fmt.Errorf("KSQL: error starting transaction: %w", err)
}
defer func() {
if r := recover(); r != nil {
rollbackErr := tx.Rollback(ctx)
if rollbackErr != nil {
r = fmt.Errorf(
"KSQL: unable to rollback after panic with value: %v, rollback error: %w",
r, rollbackErr,
)
}
panic(r)
}
}()
dbCopy := c
dbCopy.db = tx
err = fn(dbCopy)
if err != nil {
rollbackErr := tx.Rollback(ctx)
if rollbackErr != nil {
err = fmt.Errorf(
"KSQL: unable to rollback after error: %s, rollback error: %w",
err, rollbackErr,
)
}
return err
}
return tx.Commit(ctx)
default:
return fmt.Errorf("KSQL: can't start transaction: The DBAdapter doesn't implement the TxBeginner interface")
}
}
// Close implements the io.Closer interface
func (c DB) Close() error {
closer, ok := c.db.(io.Closer)
if ok {
return closer.Close()
}
return nil
}
type nopScanner struct{}
var nopScannerValue = reflect.ValueOf(&nopScanner{}).Interface()
func (nopScanner) Scan(value interface{}) error {
return nil
}
func scanRows(ctx context.Context, dialect sqldialect.Provider, rows Rows, record interface{}) error {
v := reflect.ValueOf(record)
t := v.Type()
return scanRowsFromType(ctx, dialect, rows, record, t, v)
}
func scanRowsFromType(
ctx context.Context,
dialect sqldialect.Provider,
rows Rows,
record interface{},
t reflect.Type,
v reflect.Value,
) error {
if t.Kind() != reflect.Ptr {
return fmt.Errorf("KSQL: expected record to be a pointer to struct, but got: %T", record)
}
v = v.Elem()
t = t.Elem()
if t.Kind() != reflect.Struct {
return fmt.Errorf("KSQL: expected record to be a pointer to struct, but got: %T", record)
}
info, err := structs.GetTagInfo(t)
if err != nil {
return err
}
var attrNames []string
var scanArgs []interface{}
if info.IsNestedStruct {
// This version is positional meaning that it expect the arguments
// to follow an specific order. It's ok because we don't allow the
// user to type the "SELECT" part of the query for nested structs.
attrNames, scanArgs, err = getScanArgsForNestedStructs(ctx, dialect, rows, t, v, info)
if err != nil {
return err
}
} else {
colNames, err := rows.Columns()
if err != nil {
return fmt.Errorf("KSQL: unable to read columns from returned rows: %w", err)
}
// Since this version uses the names of the columns it works
// with any order of attributes/columns.
attrNames, scanArgs = getScanArgsFromNames(ctx, dialect, colNames, v, info)
}
err = rows.Scan(scanArgs...)
if err != nil {
if scanErr, ok := err.(ScanArgError); ok {
return fmt.Errorf(
"KSQL: scan error: %w",
scanErr.ErrorWithStructNames(t.Name(), attrNames[scanErr.ColumnIndex]),
)
}
return fmt.Errorf("KSQL: scan error: %w", err)
}
return nil
}
func getScanArgsForNestedStructs(
ctx context.Context,
dialect sqldialect.Provider,
rows Rows,
t reflect.Type,
v reflect.Value,
info structs.StructInfo,
) (attrNames []string, scanArgs []interface{}, _ error) {
for i := 0; i < v.NumField(); i++ {
if !info.ByIndex(i).Valid {
continue
}
// TODO(vingarcia00): Handle case where type is pointer
nestedStructInfo, err := structs.GetTagInfo(t.Field(i).Type)
if err != nil {
return nil, nil, err
}
nestedStructValue := v.Field(i)
for j := 0; j < nestedStructValue.NumField(); j++ {
fieldInfo := nestedStructInfo.ByIndex(j)
if !fieldInfo.Valid {
continue
}
valueScanner := nestedStructValue.Field(fieldInfo.Index).Addr().Interface()
if fieldInfo.Modifier.Scan != nil {
valueScanner = &modifiers.AttrScanWrapper{
Ctx: ctx,
AttrPtr: valueScanner,
ScanFn: fieldInfo.Modifier.Scan,
OpInfo: ksqlmodifiers.OpInfo{
DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks
// if we did this could lead users to make very strange modifiers
Method: "Query",
},
}
}
scanArgs = append(scanArgs, valueScanner)
attrNames = append(attrNames, info.ByIndex(i).AttrName+"."+fieldInfo.AttrName)
}
}
return attrNames, scanArgs, nil
}
func getScanArgsFromNames(
ctx context.Context,
dialect sqldialect.Provider,
names []string,
v reflect.Value,
info structs.StructInfo,
) (attrNames []string, scanArgs []interface{}) {
for _, name := range names {
fieldInfo := info.ByName(name)
valueScanner := nopScannerValue
if fieldInfo.Valid {
valueScanner = v.Field(fieldInfo.Index).Addr().Interface()
if fieldInfo.Modifier.Scan != nil {
valueScanner = &modifiers.AttrScanWrapper{
Ctx: ctx,
AttrPtr: valueScanner,
ScanFn: fieldInfo.Modifier.Scan,
OpInfo: ksqlmodifiers.OpInfo{
DriverName: dialect.DriverName(),
// We will not differentiate between Query, QueryOne and QueryChunks
// if we did this could lead users to make very strange modifiers
Method: "Query",
},
}
}
}
scanArgs = append(scanArgs, valueScanner)
attrNames = append(attrNames, fieldInfo.AttrName)
}
return attrNames, scanArgs
}
func buildDeleteQuery(
dialect sqldialect.Provider,
table Table,
idMap map[string]interface{},
) (query string, params []interface{}) {
whereQuery := []string{}
for i, idName := range table.idColumns {
whereQuery = append(whereQuery, fmt.Sprintf(
"%s = %s", dialect.Escape(idName), dialect.Placeholder(i),
))
params = append(params, idMap[idName])
}
return fmt.Sprintf(
"DELETE FROM %s WHERE %s",
dialect.Escape(table.name),
strings.Join(whereQuery, " AND "),
), params
}
// We implemented this function instead of using
// a regex or strings.Fields because we wanted
// to preserve the performance of the package.
func getFirstToken(s string) string {
s = strings.TrimLeftFunc(s, unicode.IsSpace)
var token strings.Builder
for _, c := range s {
if unicode.IsSpace(c) {
break
}
token.WriteRune(c)
}
return token.String()
}
func buildSelectQuery(
dialect sqldialect.Provider,
structType reflect.Type,
info structs.StructInfo,
selectQueryCache *sync.Map,
) (query string, err error) {
if data, found := selectQueryCache.Load(structType); found {
if selectQuery, ok := data.(string); !ok {
return "", fmt.Errorf("invalid cache entry, expected type string, found %T", data)
} else {
return selectQuery, nil
}
}
if info.IsNestedStruct {
query, err = buildSelectQueryForNestedStructs(dialect, structType, info)
if err != nil {
return "", err
}
} else {
query = buildSelectQueryForPlainStructs(dialect, structType, info)
}
selectQueryCache.Store(structType, query)
return query, nil
}
func buildSelectQueryForPlainStructs(
dialect sqldialect.Provider,
structType reflect.Type,
info structs.StructInfo,
) string {
var fields []string
for i := 0; i < structType.NumField(); i++ {
fieldInfo := info.ByIndex(i)
if !fieldInfo.Valid {
continue
}
fields = append(fields, dialect.Escape(fieldInfo.ColumnName))
}
return "SELECT " + strings.Join(fields, ", ") + " "
}
func buildSelectQueryForNestedStructs(
dialect sqldialect.Provider,
structType reflect.Type,
info structs.StructInfo,
) (string, error) {
var fields []string
for i := 0; i < structType.NumField(); i++ {
nestedStructInfo := info.ByIndex(i)
if !nestedStructInfo.Valid {
continue
}
nestedStructName := nestedStructInfo.ColumnName
nestedStructType := structType.Field(i).Type
if nestedStructType.Kind() != reflect.Struct {
return "", fmt.Errorf(
"expected nested struct with `tablename:\"%s\"` to be a kind of Struct, but got %v",
nestedStructName, nestedStructType,
)
}
nestedStructTagInfo, err := structs.GetTagInfo(nestedStructType)
if err != nil {
return "", err
}
for j := 0; j < structType.Field(i).Type.NumField(); j++ {
fieldInfo := nestedStructTagInfo.ByIndex(j)
if !fieldInfo.Valid {
continue
}
fields = append(
fields,
dialect.Escape(nestedStructName)+"."+dialect.Escape(fieldInfo.ColumnName),
)
}
}
return "SELECT " + strings.Join(fields, ", ") + " ", nil
}