mirror of https://github.com/pressly/goose.git
Normalize mysql dsn to always have parseTime=true
parent
f2dc36702f
commit
4621f19b3c
|
@ -1,7 +1,6 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
|
@ -73,7 +72,7 @@ func main() {
|
|||
default:
|
||||
}
|
||||
|
||||
db, err := sql.Open(driver, dbstring)
|
||||
db, err := createDBWithDriver(driver, dbstring)
|
||||
if err != nil {
|
||||
log.Fatalf("-dbstring=%q: %v\n", dbstring, err)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
// normalizeMySQLDSN parses the dsn used with the mysql driver to always have
|
||||
// the parameter `parseTime` set to true. This allows internal goose logic
|
||||
// to assume that DATETIME/DATE/TIMESTAMP can be scanned into the time.Time
|
||||
// type.
|
||||
func normalizeMySQLDSN(dsn string) (string, error) {
|
||||
config, err := mysql.ParseDSN(dsn)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
config.ParseTime = true
|
||||
return config.FormatDSN(), nil
|
||||
}
|
||||
|
||||
func createDBWithDriver(driver string, dbstring string) (*sql.DB, error) {
|
||||
switch driver {
|
||||
case "postgres", "sqlite3":
|
||||
return sql.Open(driver, dbstring)
|
||||
case "mysql":
|
||||
dsn, err := normalizeMySQLDSN(dbstring)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sql.Open(driver, dsn)
|
||||
default:
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported driver %s", driver)
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeMySQLDSN(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
in string
|
||||
out string
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
desc: "errors if dsn is invalid",
|
||||
in: "root:password@tcp(mysql:3306)", // forgot the database name
|
||||
expectedErr: "invalid DSN: missing the slash separating the database name",
|
||||
},
|
||||
{
|
||||
desc: "works when there are no query parameters supplied with the dsn",
|
||||
in: "root:password@tcp(mysql:3306)/db",
|
||||
out: "root:password@tcp(mysql:3306)/db?parseTime=true",
|
||||
},
|
||||
{
|
||||
desc: "works when parseTime is already set to true supplied with the dsn",
|
||||
in: "root:password@tcp(mysql:3306)/db?parseTime=true",
|
||||
out: "root:password@tcp(mysql:3306)/db?parseTime=true",
|
||||
},
|
||||
{
|
||||
desc: "persists other parameters if they are present",
|
||||
in: "root:password@tcp(mysql:3306)/db?allowCleartextPasswords=true&interpolateParams=true",
|
||||
out: "root:password@tcp(mysql:3306)/db?allowCleartextPasswords=true&interpolateParams=true&parseTime=true",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
out, err := normalizeMySQLDSN(tc.in)
|
||||
if tc.expectedErr != "" {
|
||||
if err == nil {
|
||||
t.Errorf("expected an error, but did not have one, had (%#v, %#v)", out, err)
|
||||
} else if err.Error() != tc.expectedErr {
|
||||
t.Errorf("expected error %s but had %s", tc.expectedErr, err.Error())
|
||||
}
|
||||
} else if err != nil {
|
||||
t.Errorf("had unexpected error %s", err.Error())
|
||||
} else if out != tc.out {
|
||||
t.Errorf("had output mismatch, wanted %s but had %s", tc.out, out)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue