Normalize mysql dsn to always have parseTime=true

pull/156/head
Nathan James Tindall 2018-03-29 20:56:49 -07:00 committed by Vojtech Vitek
parent f2dc36702f
commit 4621f19b3c
3 changed files with 91 additions and 2 deletions

View File

@ -1,7 +1,6 @@
package main
import (
"database/sql"
"flag"
"fmt"
"log"
@ -73,7 +72,7 @@ func main() {
default:
}
db, err := sql.Open(driver, dbstring)
db, err := createDBWithDriver(driver, dbstring)
if err != nil {
log.Fatalf("-dbstring=%q: %v\n", dbstring, err)
}

36
cmd/goose/sql.go Normal file
View File

@ -0,0 +1,36 @@
package main
import (
"database/sql"
"fmt"
"github.com/go-sql-driver/mysql"
)
// normalizeMySQLDSN parses the dsn used with the mysql driver to always have
// the parameter `parseTime` set to true. This allows internal goose logic
// to assume that DATETIME/DATE/TIMESTAMP can be scanned into the time.Time
// type.
func normalizeMySQLDSN(dsn string) (string, error) {
config, err := mysql.ParseDSN(dsn)
if err != nil {
return "", err
}
config.ParseTime = true
return config.FormatDSN(), nil
}
func createDBWithDriver(driver string, dbstring string) (*sql.DB, error) {
switch driver {
case "postgres", "sqlite3":
return sql.Open(driver, dbstring)
case "mysql":
dsn, err := normalizeMySQLDSN(dbstring)
if err != nil {
return nil, err
}
return sql.Open(driver, dsn)
default:
}
return nil, fmt.Errorf("unsupported driver %s", driver)
}

54
cmd/goose/sql_test.go Normal file
View File

@ -0,0 +1,54 @@
package main
import (
"testing"
)
func TestNormalizeMySQLDSN(t *testing.T) {
t.Parallel()
testCases := []struct {
desc string
in string
out string
expectedErr string
}{
{
desc: "errors if dsn is invalid",
in: "root:password@tcp(mysql:3306)", // forgot the database name
expectedErr: "invalid DSN: missing the slash separating the database name",
},
{
desc: "works when there are no query parameters supplied with the dsn",
in: "root:password@tcp(mysql:3306)/db",
out: "root:password@tcp(mysql:3306)/db?parseTime=true",
},
{
desc: "works when parseTime is already set to true supplied with the dsn",
in: "root:password@tcp(mysql:3306)/db?parseTime=true",
out: "root:password@tcp(mysql:3306)/db?parseTime=true",
},
{
desc: "persists other parameters if they are present",
in: "root:password@tcp(mysql:3306)/db?allowCleartextPasswords=true&interpolateParams=true",
out: "root:password@tcp(mysql:3306)/db?allowCleartextPasswords=true&interpolateParams=true&parseTime=true",
},
}
for _, tc := range testCases {
t.Run(tc.desc, func(t *testing.T) {
out, err := normalizeMySQLDSN(tc.in)
if tc.expectedErr != "" {
if err == nil {
t.Errorf("expected an error, but did not have one, had (%#v, %#v)", out, err)
} else if err.Error() != tc.expectedErr {
t.Errorf("expected error %s but had %s", tc.expectedErr, err.Error())
}
} else if err != nil {
t.Errorf("had unexpected error %s", err.Error())
} else if out != tc.out {
t.Errorf("had output mismatch, wanted %s but had %s", tc.out, out)
}
})
}
}