diff --git a/README.md b/README.md index 35b66e0..24d428d 100644 --- a/README.md +++ b/README.md @@ -14,10 +14,12 @@ The goals were: **Supported Drivers:** -Currently we only support 2 Drivers: +Currently we support 4 Drivers: - `"postgres"` - `"sqlite3"` +- `"mysql"` +- `"sqlserver"` ### Why KissSQL? diff --git a/dialect.go b/dialect.go index c46869d..d245da7 100644 --- a/dialect.go +++ b/dialect.go @@ -6,24 +6,31 @@ type insertMethod int const ( insertWithReturning insertMethod = iota + insertWithOutput insertWithLastInsertID insertWithNoIDRetrieval ) var supportedDialects = map[string]dialect{ - "postgres": &postgresDialect{}, - "sqlite3": &sqlite3Dialect{}, - "mysql": &mysqlDialect{}, + "postgres": &postgresDialect{}, + "sqlite3": &sqlite3Dialect{}, + "mysql": &mysqlDialect{}, + "sqlserver": &sqlserverDialect{}, } type dialect interface { InsertMethod() insertMethod Escape(str string) string Placeholder(idx int) string + DriverName() string } type postgresDialect struct{} +func (postgresDialect) DriverName() string { + return "postgres" +} + func (postgresDialect) InsertMethod() insertMethod { return insertWithReturning } @@ -38,6 +45,10 @@ func (postgresDialect) Placeholder(idx int) string { type sqlite3Dialect struct{} +func (sqlite3Dialect) DriverName() string { + return "sqlite3" +} + func (sqlite3Dialect) InsertMethod() insertMethod { return insertWithLastInsertID } @@ -52,6 +63,10 @@ func (sqlite3Dialect) Placeholder(idx int) string { type mysqlDialect struct{} +func (mysqlDialect) DriverName() string { + return "mysql" +} + func (mysqlDialect) InsertMethod() insertMethod { return insertWithLastInsertID } @@ -63,3 +78,21 @@ func (mysqlDialect) Escape(str string) string { func (mysqlDialect) Placeholder(idx int) string { return "?" } + +type sqlserverDialect struct{} + +func (sqlserverDialect) DriverName() string { + return "sqlserver" +} + +func (sqlserverDialect) InsertMethod() insertMethod { + return insertWithOutput +} + +func (sqlserverDialect) Escape(str string) string { + return `[` + str + `]` +} + +func (sqlserverDialect) Placeholder(idx int) string { + return "@p" + strconv.Itoa(idx+1) +} diff --git a/docker-compose.yml b/docker-compose.yml index 6d0932c..6e70b12 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -23,3 +23,13 @@ services: - "127.0.0.1:3306:3306" environment: MYSQL_ROOT_PASSWORD: mysql + + sqlserver: + image: microsoft/mssql-server-linux:2017-latest + restart: always + ports: + - "127.0.0.1:1433:1433" + - "127.0.0.1:1434:1434" + environment: + SA_PASSWORD: "Sqls3rv3r" + ACCEPT_EULA: "Y" diff --git a/go.mod b/go.mod index 2b82c83..f09a85a 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/vingarcia/ksql go 1.14 require ( + github.com/denisenkom/go-mssqldb v0.10.0 // indirect github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018 github.com/go-sql-driver/mysql v1.4.0 // indirect github.com/golang/mock v1.5.0 diff --git a/go.sum b/go.sum index 2e20171..b80fabb 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,14 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/denisenkom/go-mssqldb v0.10.0 h1:QykgLZBorFE95+gO3u9esLd0BmbvpWp0/waNNZfHBM8= +github.com/denisenkom/go-mssqldb v0.10.0/go.mod h1:xbL0rPBG9cCiLr28tMa8zpbdarY27NDyej4t/EjAShU= github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018 h1:QsFkVafcKOaZoAB4WcyUHdkPbwh+VYwZgYJb/rU6EIM= github.com/ditointernet/go-assert v0.0.0-20200120164340-9e13125a7018/go.mod h1:5C3SWkut69TSdkerzRDxXMRM5x73PGWNcRLe/xKjXhs= github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk= github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= +github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/mock v1.5.0 h1:jlYHihg//f7RRwuPfptm04yp4s7O6Kw8EZiVYIGcH0g= github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= @@ -28,6 +32,8 @@ github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/tj/assert v0.0.3 h1:Df/BlaZ20mq6kuai7f5z2TvPFiwC3xaWJSDQNiIS3Rk= github.com/tj/assert v0.0.3/go.mod h1:Ne6X72Q+TB1AteidzQncjw9PabbMp4PBMZ1k+vd1Pvk= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 h1:ObdrDkeb4kJdCP557AjRjq69pTHfNouLtWZG7j9rPN8= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= diff --git a/json.go b/json.go index aaf2ef9..03fdda1 100644 --- a/json.go +++ b/json.go @@ -11,7 +11,8 @@ import ( // input attributes to be convertible to and from JSON // before sending or receiving it from the database. type jsonSerializable struct { - Attr interface{} + DriverName string + Attr interface{} } // Scan Implements the Scanner interface in order to load @@ -40,5 +41,9 @@ func (j *jsonSerializable) Scan(value interface{}) error { // Value Implements the Valuer interface in order to save // this field as JSON on the database. func (j jsonSerializable) Value() (driver.Value, error) { - return json.Marshal(j.Attr) + b, err := json.Marshal(j.Attr) + if j.DriverName == "sqlserver" { + return string(b), err + } + return b, err } diff --git a/ksql.go b/ksql.go index 0386b68..86ba670 100644 --- a/ksql.go +++ b/ksql.go @@ -148,7 +148,7 @@ func (c DB) Query( elemPtr = elemPtr.Elem() } - err = scanRows(rows, elemPtr.Interface()) + err = scanRows(c.dialect, rows, elemPtr.Interface()) if err != nil { return err } @@ -202,7 +202,7 @@ func (c DB) QueryOne( return ErrRecordNotFound } - err = scanRows(rows, record) + err = scanRows(c.dialect, rows, record) if err != nil { return err } @@ -262,7 +262,7 @@ func (c DB) QueryChunks( chunk = reflect.Append(chunk, elemValue) } - err = scanRows(rows, chunk.Index(idx).Addr().Interface()) + err = scanRows(c.dialect, rows, chunk.Index(idx).Addr().Interface()) if err != nil { return err } @@ -320,14 +320,14 @@ func (c DB) Insert( return fmt.Errorf("the optional TableName argument was not provided to New(), can't use the Insert method") } - query, params, err := buildInsertQuery(c.dialect, c.tableName, record, c.idCols...) + query, params, scanValues, err := buildInsertQuery(c.dialect, c.tableName, record, c.idCols...) if err != nil { return err } switch c.insertMethod { - case insertWithReturning: - err = c.insertWithReturningID(ctx, record, query, params, c.idCols) + case insertWithReturning, insertWithOutput: + err = c.insertReturningIDs(ctx, record, query, params, scanValues, c.idCols) case insertWithLastInsertID: err = c.insertWithLastInsertID(ctx, record, query, params, c.idCols[0]) case insertWithNoIDRetrieval: @@ -341,19 +341,14 @@ func (c DB) Insert( return err } -func (c DB) insertWithReturningID( +func (c DB) insertReturningIDs( ctx context.Context, record interface{}, query string, params []interface{}, + scanValues []interface{}, idNames []string, ) error { - escapedIDNames := []string{} - for _, id := range idNames { - escapedIDNames = append(escapedIDNames, c.dialect.Escape(id)) - } - query += " RETURNING " + strings.Join(idNames, ", ") - rows, err := c.db.QueryContext(ctx, query, params...) if err != nil { return err @@ -369,21 +364,7 @@ func (c DB) insertWithReturningID( return err } - v := reflect.ValueOf(record) - t := v.Type() - if err = assertStructPtr(t); err != nil { - return errors.Wrap(err, "can't write id field") - } - info := structs.GetTagInfo(t.Elem()) - - var scanFields []interface{} - for _, id := range idNames { - scanFields = append( - scanFields, - v.Elem().Field(info.ByName(id).Index).Addr().Interface(), - ) - } - err = rows.Scan(scanFields...) + err = rows.Scan(scanValues...) if err != nil { return err } @@ -549,20 +530,25 @@ func buildInsertQuery( dialect dialect, tableName string, record interface{}, - idFieldNames ...string, -) (query string, params []interface{}, err error) { + idNames ...string, +) (query string, params []interface{}, scanValues []interface{}, err error) { + v := reflect.ValueOf(record) + t := v.Type() + if err = assertStructPtr(t); err != nil { + return "", nil, nil, fmt.Errorf( + "ksql: expected record to be a pointer to struct, but got: %T", + record, + ) + } + + info := structs.GetTagInfo(t.Elem()) + recordMap, err := structs.StructToMap(record) if err != nil { - return "", nil, err + return "", nil, nil, err } - t := reflect.TypeOf(record) - if t.Kind() == reflect.Ptr { - t = t.Elem() - } - info := structs.GetTagInfo(t) - - for _, fieldName := range idFieldNames { + for _, fieldName := range idNames { // Remove any ID field that was not set: if reflect.ValueOf(recordMap[fieldName]).IsZero() { delete(recordMap, fieldName) @@ -580,7 +566,10 @@ func buildInsertQuery( recordValue := recordMap[col] params[i] = recordValue if info.ByName(col).SerializeAsJSON { - params[i] = jsonSerializable{Attr: recordValue} + params[i] = jsonSerializable{ + DriverName: dialect.DriverName(), + Attr: recordValue, + } } valuesQuery[i] = dialect.Placeholder(i) @@ -592,14 +581,48 @@ func buildInsertQuery( escapedColumnNames = append(escapedColumnNames, dialect.Escape(col)) } + var returningQuery, outputQuery string + switch dialect.InsertMethod() { + case insertWithReturning: + escapedIDNames := []string{} + for _, id := range idNames { + escapedIDNames = append(escapedIDNames, dialect.Escape(id)) + } + returningQuery = " RETURNING " + strings.Join(escapedIDNames, ", ") + + for _, id := range idNames { + scanValues = append( + scanValues, + v.Elem().Field(info.ByName(id).Index).Addr().Interface(), + ) + } + case insertWithOutput: + escapedIDNames := []string{} + for _, id := range idNames { + escapedIDNames = append(escapedIDNames, "INSERTED."+dialect.Escape(id)) + } + outputQuery = " OUTPUT " + strings.Join(escapedIDNames, ", ") + + for _, id := range idNames { + scanValues = append( + scanValues, + v.Elem().Field(info.ByName(id).Index).Addr().Interface(), + ) + } + } + + // Note that the outputQuery and the returningQuery depend + // on the selected driver, thus, they might be empty strings. query = fmt.Sprintf( - "INSERT INTO %s (%s) VALUES (%s)", + "INSERT INTO %s (%s)%s VALUES (%s)%s", dialect.Escape(tableName), strings.Join(escapedColumnNames, ", "), + outputQuery, strings.Join(valuesQuery, ", "), + returningQuery, ) - return query, params, nil + return query, params, scanValues, nil } func buildUpdateQuery( @@ -644,7 +667,10 @@ func buildUpdateQuery( for i, k := range keys { recordValue := recordMap[k] if info.ByName(k).SerializeAsJSON { - recordValue = jsonSerializable{Attr: recordValue} + recordValue = jsonSerializable{ + DriverName: dialect.DriverName(), + Attr: recordValue, + } } args[i] = recordValue setQuery = append(setQuery, fmt.Sprintf( @@ -753,7 +779,7 @@ func (nopScanner) Scan(value interface{}) error { return nil } -func scanRows(rows *sql.Rows, record interface{}) error { +func scanRows(dialect dialect, rows *sql.Rows, record interface{}) error { names, err := rows.Columns() if err != nil { return err @@ -782,7 +808,10 @@ func scanRows(rows *sql.Rows, record interface{}) error { if fieldInfo.Valid { valueScanner = v.Field(fieldInfo.Index).Addr().Interface() if fieldInfo.SerializeAsJSON { - valueScanner = &jsonSerializable{Attr: valueScanner} + valueScanner = &jsonSerializable{ + DriverName: dialect.DriverName(), + Attr: valueScanner, + } } } diff --git a/ksql_test.go b/ksql_test.go index 85a4691..b8132d5 100644 --- a/ksql_test.go +++ b/ksql_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + _ "github.com/denisenkom/go-mssqldb" "github.com/ditointernet/go-assert" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" @@ -1233,6 +1234,7 @@ func TestScanRows(t *testing.T) { t.Fatal("could not create test table!, reason:", err.Error()) } + dialect := supportedDialects["sqlite3"] ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() @@ -1248,7 +1250,7 @@ func TestScanRows(t *testing.T) { assert.Equal(t, true, rows.Next()) var u User - err = scanRows(rows, &u) + err = scanRows(dialect, rows, &u) assert.Equal(t, nil, err) assert.Equal(t, "User2", u.Name) @@ -1261,6 +1263,7 @@ func TestScanRows(t *testing.T) { t.Fatal("could not create test table!, reason:", err.Error()) } + dialect := supportedDialects["sqlite3"] ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() @@ -1280,7 +1283,7 @@ func TestScanRows(t *testing.T) { // Omitted for testing purposes: // Name string `ksql:"name"` } - err = scanRows(rows, &user) + err = scanRows(dialect, rows, &user) assert.Equal(t, nil, err) assert.Equal(t, 22, user.Age) @@ -1292,6 +1295,7 @@ func TestScanRows(t *testing.T) { t.Fatal("could not create test table!, reason:", err.Error()) } + dialect := supportedDialects["sqlite3"] ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() @@ -1302,7 +1306,7 @@ func TestScanRows(t *testing.T) { var u User err = rows.Close() assert.Equal(t, nil, err) - err = scanRows(rows, &u) + err = scanRows(dialect, rows, &u) assert.NotEqual(t, nil, err) }) @@ -1312,6 +1316,7 @@ func TestScanRows(t *testing.T) { t.Fatal("could not create test table!, reason:", err.Error()) } + dialect := supportedDialects["sqlite3"] ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() @@ -1320,7 +1325,7 @@ func TestScanRows(t *testing.T) { assert.Equal(t, nil, err) var u User - err = scanRows(rows, u) + err = scanRows(dialect, rows, u) assert.NotEqual(t, nil, err) }) @@ -1330,6 +1335,7 @@ func TestScanRows(t *testing.T) { t.Fatal("could not create test table!, reason:", err.Error()) } + dialect := supportedDialects["sqlite3"] ctx := context.TODO() db := connectDB(t, "sqlite3") defer db.Close() @@ -1338,15 +1344,16 @@ func TestScanRows(t *testing.T) { assert.Equal(t, nil, err) var u map[string]interface{} - err = scanRows(rows, &u) + err = scanRows(dialect, rows, &u) assert.NotEqual(t, nil, err) }) } var connectionString = map[string]string{ - "postgres": "host=localhost port=5432 user=postgres password=postgres dbname=ksql sslmode=disable", - "sqlite3": "/tmp/ksql.db", - "mysql": "root:mysql@(127.0.0.1:3306)/ksql?timeout=30s", + "postgres": "host=localhost port=5432 user=postgres password=postgres dbname=ksql sslmode=disable", + "sqlite3": "/tmp/ksql.db", + "mysql": "root:mysql@(127.0.0.1:3306)/ksql?timeout=30s", + "sqlserver": "sqlserver://sa:Sqls3rv3r@127.0.0.1:1433?databaseName=ksql", } func createTable(driver string) error { @@ -1385,6 +1392,13 @@ func createTable(driver string) error { name VARCHAR(50), address JSON )`) + case "sqlserver": + _, err = db.Exec(`CREATE TABLE users ( + id INT IDENTITY(1,1) PRIMARY KEY, + age INT, + name VARCHAR(50), + address NVARCHAR(4000) + )`) } if err != nil { return fmt.Errorf("failed to create new users table: %s", err.Error()) @@ -1404,12 +1418,8 @@ func newTestDB(db *sql.DB, driver string, tableName string, ids ...string) DB { db: db, tableName: tableName, - idCols: ids, - insertMethod: map[string]insertMethod{ - "sqlite3": insertWithLastInsertID, - "postgres": insertWithReturning, - "mysql": insertWithLastInsertID, - }[driver], + idCols: ids, + insertMethod: supportedDialects[driver].InsertMethod(), } } @@ -1482,17 +1492,17 @@ func getUserByID(dbi sqlProvider, dialect dialect, result *User, id uint) error return row.Err() } - var rawAddr []byte - err := row.Scan(&result.ID, &result.Name, &result.Age, &rawAddr) + value := jsonSerializable{ + DriverName: dialect.DriverName(), + Attr: &result.Address, + } + + err := row.Scan(&result.ID, &result.Name, &result.Age, &value) if err != nil { return err } - if rawAddr == nil { - return nil - } - - return json.Unmarshal(rawAddr, &result.Address) + return nil } func getUserByName(dbi sqlProvider, dialect dialect, result *User, name string) error {