Add support to the `sqlserver` driver =]

pull/2/head
Vinícius Garcia 2021-05-09 22:33:03 -03:00
parent 2dd55131d5
commit 56aa77135c
8 changed files with 167 additions and 71 deletions

View File

@ -14,10 +14,12 @@ The goals were:
**Supported Drivers:**
Currently we only support 2 Drivers:
Currently we support 4 Drivers:
- `"postgres"`
- `"sqlite3"`
- `"mysql"`
- `"sqlserver"`
### Why KissSQL?

View File

@ -6,24 +6,31 @@ type insertMethod int
const (
insertWithReturning insertMethod = iota
insertWithOutput
insertWithLastInsertID
insertWithNoIDRetrieval
)
var supportedDialects = map[string]dialect{
"postgres": &postgresDialect{},
"sqlite3": &sqlite3Dialect{},
"mysql": &mysqlDialect{},
"postgres": &postgresDialect{},
"sqlite3": &sqlite3Dialect{},
"mysql": &mysqlDialect{},
"sqlserver": &sqlserverDialect{},
}
type dialect interface {
InsertMethod() insertMethod
Escape(str string) string
Placeholder(idx int) string
DriverName() string
}
type postgresDialect struct{}
func (postgresDialect) DriverName() string {
return "postgres"
}
func (postgresDialect) InsertMethod() insertMethod {
return insertWithReturning
}
@ -38,6 +45,10 @@ func (postgresDialect) Placeholder(idx int) string {
type sqlite3Dialect struct{}
func (sqlite3Dialect) DriverName() string {
return "sqlite3"
}
func (sqlite3Dialect) InsertMethod() insertMethod {
return insertWithLastInsertID
}
@ -52,6 +63,10 @@ func (sqlite3Dialect) Placeholder(idx int) string {
type mysqlDialect struct{}
func (mysqlDialect) DriverName() string {
return "mysql"
}
func (mysqlDialect) InsertMethod() insertMethod {
return insertWithLastInsertID
}
@ -63,3 +78,21 @@ func (mysqlDialect) Escape(str string) string {
func (mysqlDialect) Placeholder(idx int) string {
return "?"
}
type sqlserverDialect struct{}
func (sqlserverDialect) DriverName() string {
return "sqlserver"
}
func (sqlserverDialect) InsertMethod() insertMethod {
return insertWithOutput
}
func (sqlserverDialect) Escape(str string) string {
return `[` + str + `]`
}
func (sqlserverDialect) Placeholder(idx int) string {
return "@p" + strconv.Itoa(idx+1)
}

View File

@ -23,3 +23,13 @@ services:
- "127.0.0.1:3306:3306"
environment:
MYSQL_ROOT_PASSWORD: mysql
sqlserver:
image: microsoft/mssql-server-linux:2017-latest
restart: always
ports:
- "127.0.0.1:1433:1433"
- "127.0.0.1:1434:1434"
environment:
SA_PASSWORD: "Sqls3rv3r"
ACCEPT_EULA: "Y"

1
go.mod
View File

@ -3,6 +3,7 @@ module github.com/vingarcia/ksql
go 1.14
require (
github.com/denisenkom/go-mssqldb v0.10.0 // indirect
github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018
github.com/go-sql-driver/mysql v1.4.0 // indirect
github.com/golang/mock v1.5.0

6
go.sum
View File

@ -1,10 +1,14 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waNNZfHBM8=
github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU=
github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018 h1:QsFkVafcKOaZoAB4WcyUHdkPbwh+VYwZgYJb/rU6EIM=
github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018/go.mod h1:5C3SWkut69TSdkerzRDxXMRM5x73PGWNcRLe/xKjXhs=
github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk=
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang/mock v1.5.0 h1:jlYHihg//f7RRwuPfptm04yp4s7O6Kw8EZiVYIGcH0g=
github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
@ -28,6 +32,8 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/tj/assert v0.0.3 h1:Df/BlaZ20mq6kuai7f5z2TvPFiwC3xaWJSDQNiIS3Rk=
github.com/tj/assert v0.0.3/go.mod h1:Ne6X72Q+TB1AteidzQncjw9PabbMp4PBMZ1k+vd1Pvk=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=

View File

@ -11,7 +11,8 @@ import (
// input attributes to be convertible to and from JSON
// before sending or receiving it from the database.
type jsonSerializable struct {
Attr interface{}
DriverName string
Attr interface{}
}
// Scan Implements the Scanner interface in order to load
@ -40,5 +41,9 @@ func (j *jsonSerializable) Scan(value interface{}) error {
// Value Implements the Valuer interface in order to save
// this field as JSON on the database.
func (j jsonSerializable) Value() (driver.Value, error) {
return json.Marshal(j.Attr)
b, err := json.Marshal(j.Attr)
if j.DriverName == "sqlserver" {
return string(b), err
}
return b, err
}

117
ksql.go
View File

@ -148,7 +148,7 @@ func (c DB) Query(
elemPtr = elemPtr.Elem()
}
err = scanRows(rows, elemPtr.Interface())
err = scanRows(c.dialect, rows, elemPtr.Interface())
if err != nil {
return err
}
@ -202,7 +202,7 @@ func (c DB) QueryOne(
return ErrRecordNotFound
}
err = scanRows(rows, record)
err = scanRows(c.dialect, rows, record)
if err != nil {
return err
}
@ -262,7 +262,7 @@ func (c DB) QueryChunks(
chunk = reflect.Append(chunk, elemValue)
}
err = scanRows(rows, chunk.Index(idx).Addr().Interface())
err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface())
if err != nil {
return err
}
@ -320,14 +320,14 @@ func (c DB) Insert(
return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Insert method")
}
query, params, err := buildInsertQuery(c.dialect, c.tableName, record, c.idCols...)
query, params, scanValues, err := buildInsertQuery(c.dialect, c.tableName, record, c.idCols...)
if err != nil {
return err
}
switch c.insertMethod {
case insertWithReturning:
err = c.insertWithReturningID(ctx, record, query, params, c.idCols)
case insertWithReturning, insertWithOutput:
err = c.insertReturningIDs(ctx, record, query, params, scanValues, c.idCols)
case insertWithLastInsertID:
err = c.insertWithLastInsertID(ctx, record, query, params, c.idCols[0])
case insertWithNoIDRetrieval:
@ -341,19 +341,14 @@ func (c DB) Insert(
return err
}
func (c DB) insertWithReturningID(
func (c DB) insertReturningIDs(
ctx context.Context,
record interface{},
query string,
params []interface{},
scanValues []interface{},
idNames []string,
) error {
escapedIDNames := []string{}
for _, id := range idNames {
escapedIDNames = append(escapedIDNames, c.dialect.Escape(id))
}
query += " RETURNING " + strings.Join(idNames, ", ")
rows, err := c.db.QueryContext(ctx, query, params...)
if err != nil {
return err
@ -369,21 +364,7 @@ func (c DB) insertWithReturningID(
return err
}
v := reflect.ValueOf(record)
t := v.Type()
if err = assertStructPtr(t); err != nil {
return errors.Wrap(err, "can't write id field")
}
info := structs.GetTagInfo(t.Elem())
var scanFields []interface{}
for _, id := range idNames {
scanFields = append(
scanFields,
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
)
}
err = rows.Scan(scanFields...)
err = rows.Scan(scanValues...)
if err != nil {
return err
}
@ -549,20 +530,25 @@ func buildInsertQuery(
dialect dialect,
tableName string,
record interface{},
idFieldNames ...string,
) (query string, params []interface{}, err error) {
idNames ...string,
) (query string, params []interface{}, scanValues []interface{}, err error) {
v := reflect.ValueOf(record)
t := v.Type()
if err = assertStructPtr(t); err != nil {
return "", nil, nil, fmt.Errorf(
"ksql: expected record to be a pointer to struct, but got: %T",
record,
)
}
info := structs.GetTagInfo(t.Elem())
recordMap, err := structs.StructToMap(record)
if err != nil {
return "", nil, err
return "", nil, nil, err
}
t := reflect.TypeOf(record)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
info := structs.GetTagInfo(t)
for _, fieldName := range idFieldNames {
for _, fieldName := range idNames {
// Remove any ID field that was not set:
if reflect.ValueOf(recordMap[fieldName]).IsZero() {
delete(recordMap, fieldName)
@ -580,7 +566,10 @@ func buildInsertQuery(
recordValue := recordMap[col]
params[i] = recordValue
if info.ByName(col).SerializeAsJSON {
params[i] = jsonSerializable{Attr: recordValue}
params[i] = jsonSerializable{
DriverName: dialect.DriverName(),
Attr: recordValue,
}
}
valuesQuery[i] = dialect.Placeholder(i)
@ -592,14 +581,48 @@ func buildInsertQuery(
escapedColumnNames = append(escapedColumnNames, dialect.Escape(col))
}
var returningQuery, outputQuery string
switch dialect.InsertMethod() {
case insertWithReturning:
escapedIDNames := []string{}
for _, id := range idNames {
escapedIDNames = append(escapedIDNames, dialect.Escape(id))
}
returningQuery = " RETURNING " + strings.Join(escapedIDNames, ", ")
for _, id := range idNames {
scanValues = append(
scanValues,
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
)
}
case insertWithOutput:
escapedIDNames := []string{}
for _, id := range idNames {
escapedIDNames = append(escapedIDNames, "INSERTED."+dialect.Escape(id))
}
outputQuery = " OUTPUT " + strings.Join(escapedIDNames, ", ")
for _, id := range idNames {
scanValues = append(
scanValues,
v.Elem().Field(info.ByName(id).Index).Addr().Interface(),
)
}
}
// Note that the outputQuery and the returningQuery depend
// on the selected driver, thus, they might be empty strings.
query = fmt.Sprintf(
"INSERT INTO %s (%s) VALUES (%s)",
"INSERT INTO %s (%s)%s VALUES (%s)%s",
dialect.Escape(tableName),
strings.Join(escapedColumnNames, ", "),
outputQuery,
strings.Join(valuesQuery, ", "),
returningQuery,
)
return query, params, nil
return query, params, scanValues, nil
}
func buildUpdateQuery(
@ -644,7 +667,10 @@ func buildUpdateQuery(
for i, k := range keys {
recordValue := recordMap[k]
if info.ByName(k).SerializeAsJSON {
recordValue = jsonSerializable{Attr: recordValue}
recordValue = jsonSerializable{
DriverName: dialect.DriverName(),
Attr: recordValue,
}
}
args[i] = recordValue
setQuery = append(setQuery, fmt.Sprintf(
@ -753,7 +779,7 @@ func (nopScanner) Scan(value interface{}) error {
return nil
}
func scanRows(rows *sql.Rows, record interface{}) error {
func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error {
names, err := rows.Columns()
if err != nil {
return err
@ -782,7 +808,10 @@ func scanRows(rows *sql.Rows, record interface{}) error {
if fieldInfo.Valid {
valueScanner = v.Field(fieldInfo.Index).Addr().Interface()
if fieldInfo.SerializeAsJSON {
valueScanner = &jsonSerializable{Attr: valueScanner}
valueScanner = &jsonSerializable{
DriverName: dialect.DriverName(),
Attr: valueScanner,
}
}
}

View File

@ -9,6 +9,7 @@ import (
"strings"
"testing"
_ "github.com/denisenkom/go-mssqldb"
"github.com/ditointernet/go-assert"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
@ -1233,6 +1234,7 @@ func TestScanRows(t *testing.T) {
t.Fatal("could not create test table!, reason:", err.Error())
}
dialect := supportedDialects["sqlite3"]
ctx := context.TODO()
db := connectDB(t, "sqlite3")
defer db.Close()
@ -1248,7 +1250,7 @@ func TestScanRows(t *testing.T) {
assert.Equal(t, true, rows.Next())
var u User
err = scanRows(rows, &u)
err = scanRows(dialect, rows, &u)
assert.Equal(t, nil, err)
assert.Equal(t, "User2", u.Name)
@ -1261,6 +1263,7 @@ func TestScanRows(t *testing.T) {
t.Fatal("could not create test table!, reason:", err.Error())
}
dialect := supportedDialects["sqlite3"]
ctx := context.TODO()
db := connectDB(t, "sqlite3")
defer db.Close()
@ -1280,7 +1283,7 @@ func TestScanRows(t *testing.T) {
// Omitted for testing purposes:
// Name string `ksql:"name"`
}
err = scanRows(rows, &user)
err = scanRows(dialect, rows, &user)
assert.Equal(t, nil, err)
assert.Equal(t, 22, user.Age)
@ -1292,6 +1295,7 @@ func TestScanRows(t *testing.T) {
t.Fatal("could not create test table!, reason:", err.Error())
}
dialect := supportedDialects["sqlite3"]
ctx := context.TODO()
db := connectDB(t, "sqlite3")
defer db.Close()
@ -1302,7 +1306,7 @@ func TestScanRows(t *testing.T) {
var u User
err = rows.Close()
assert.Equal(t, nil, err)
err = scanRows(rows, &u)
err = scanRows(dialect, rows, &u)
assert.NotEqual(t, nil, err)
})
@ -1312,6 +1316,7 @@ func TestScanRows(t *testing.T) {
t.Fatal("could not create test table!, reason:", err.Error())
}
dialect := supportedDialects["sqlite3"]
ctx := context.TODO()
db := connectDB(t, "sqlite3")
defer db.Close()
@ -1320,7 +1325,7 @@ func TestScanRows(t *testing.T) {
assert.Equal(t, nil, err)
var u User
err = scanRows(rows, u)
err = scanRows(dialect, rows, u)
assert.NotEqual(t, nil, err)
})
@ -1330,6 +1335,7 @@ func TestScanRows(t *testing.T) {
t.Fatal("could not create test table!, reason:", err.Error())
}
dialect := supportedDialects["sqlite3"]
ctx := context.TODO()
db := connectDB(t, "sqlite3")
defer db.Close()
@ -1338,15 +1344,16 @@ func TestScanRows(t *testing.T) {
assert.Equal(t, nil, err)
var u map[string]interface{}
err = scanRows(rows, &u)
err = scanRows(dialect, rows, &u)
assert.NotEqual(t, nil, err)
})
}
var connectionString = map[string]string{
"postgres": "host=localhost port=5432 user=postgres password=postgres dbname=ksql sslmode=disable",
"sqlite3": "/tmp/ksql.db",
"mysql": "root:mysql@(127.0.0.1:3306)/ksql?timeout=30s",
"postgres": "host=localhost port=5432 user=postgres password=postgres dbname=ksql sslmode=disable",
"sqlite3": "/tmp/ksql.db",
"mysql": "root:mysql@(127.0.0.1:3306)/ksql?timeout=30s",
"sqlserver": "sqlserver://sa:Sqls3rv3r@127.0.0.1:1433?databaseName=ksql",
}
func createTable(driver string) error {
@ -1385,6 +1392,13 @@ func createTable(driver string) error {
name VARCHAR(50),
address JSON
)`)
case "sqlserver":
_, err = db.Exec(`CREATE TABLE users (
id INT IDENTITY(1,1) PRIMARY KEY,
age INT,
name VARCHAR(50),
address NVARCHAR(4000)
)`)
}
if err != nil {
return fmt.Errorf("failed to create new users table: %s", err.Error())
@ -1404,12 +1418,8 @@ func newTestDB(db *sql.DB, driver string, tableName string, ids ...string) DB {
db: db,
tableName: tableName,
idCols: ids,
insertMethod: map[string]insertMethod{
"sqlite3": insertWithLastInsertID,
"postgres": insertWithReturning,
"mysql": insertWithLastInsertID,
}[driver],
idCols: ids,
insertMethod: supportedDialects[driver].InsertMethod(),
}
}
@ -1482,17 +1492,17 @@ func getUserByID(dbi sqlProvider, dialect dialect, result *User, id uint) error
return row.Err()
}
var rawAddr []byte
err := row.Scan(&result.ID, &result.Name, &result.Age, &rawAddr)
value := jsonSerializable{
DriverName: dialect.DriverName(),
Attr: &result.Address,
}
err := row.Scan(&result.ID, &result.Name, &result.Age, &value)
if err != nil {
return err
}
if rawAddr == nil {
return nil
}
return json.Unmarshal(rawAddr, &result.Address)
return nil
}
func getUserByName(dbi sqlProvider, dialect dialect, result *User, name string) error {