migrations: add tests and remove XORM (#7050)

pull/7052/head
Joe Chen 2022-06-12 14:15:01 +08:00 committed by GitHub
parent 2e19f5a3c8
commit b772603d78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 451 additions and 297 deletions

View File

@ -116,7 +116,7 @@ jobs:
- name: Checkout code
uses: actions/checkout@v2
- name: Run tests with coverage
run: go test -v -race -coverprofile=coverage -covermode=atomic ./internal/db
run: go test -v -race -coverprofile=coverage -covermode=atomic ./internal/db/...
env:
GOGS_DATABASE_TYPE: postgres
PGPORT: 5432
@ -142,7 +142,7 @@ jobs:
- name: Checkout code
uses: actions/checkout@v2
- name: Run tests with coverage
run: go test -v -race -coverprofile=coverage -covermode=atomic ./internal/db
run: go test -v -race -coverprofile=coverage -covermode=atomic ./internal/db/...
env:
GOGS_DATABASE_TYPE: mysql
MYSQL_USER: root
@ -165,6 +165,6 @@ jobs:
- name: Checkout code
uses: actions/checkout@v2
- name: Run tests with coverage
run: go test -v -race -parallel=1 -coverprofile=coverage -covermode=atomic ./internal/db
run: go test -v -race -parallel=1 -coverprofile=coverage -covermode=atomic ./internal/db/...
env:
GOGS_DATABASE_TYPE: sqlite

10
go.mod
View File

@ -62,7 +62,7 @@ require (
gopkg.in/macaron.v1 v1.4.0
gorm.io/driver/mysql v1.3.4
gorm.io/driver/postgres v1.3.7
gorm.io/driver/sqlite v1.3.2
gorm.io/driver/sqlite v1.3.4
gorm.io/driver/sqlserver v1.3.1
gorm.io/gorm v1.23.5
modernc.org/sqlite v1.17.3
@ -72,13 +72,5 @@ require (
xorm.io/xorm v0.8.0
)
// Temporary replace directives
// ============================
// These entries indicate temporary replace directives due to a pending pull request upstream
// or issues with specific versions.
// https://github.com/gogs/gogs/issues/7019
replace gorm.io/driver/sqlite => github.com/gogs/gorm-sqlite v1.3.3-0.20220608111034-1298ceb93369
// +heroku goVersion go1.16
// +heroku install ./

4
go.sum
View File

@ -173,8 +173,6 @@ github.com/gogs/go-gogs-client v0.0.0-20200128182646-c69cb7680fd4 h1:C7NryI/RQhs
github.com/gogs/go-gogs-client v0.0.0-20200128182646-c69cb7680fd4/go.mod h1:fR6z1Ie6rtF7kl/vBYMfgD5/G5B1blui7z426/sj2DU=
github.com/gogs/go-libravatar v0.0.0-20191106065024-33a75213d0a0 h1:K02vod+sn3M1OOkdqi2tPxN2+xESK4qyITVQ3JkGEv4=
github.com/gogs/go-libravatar v0.0.0-20191106065024-33a75213d0a0/go.mod h1:Zas3BtO88pk1cwUfEYlvnl/CRwh0ybDxRWSwRjG8I3w=
github.com/gogs/gorm-sqlite v1.3.3-0.20220608111034-1298ceb93369 h1:nlpH50ShzqBGnZwul13EYJc8l5quxdTkBmzewuLZHgs=
github.com/gogs/gorm-sqlite v1.3.3-0.20220608111034-1298ceb93369/go.mod h1:B+8GyC9K7VgzJAcrcXMRPdnMcck+8FgJynEehEPM16U=
github.com/gogs/minwinsvc v0.0.0-20170301035411-95be6356811a h1:8DZwxETOVWIinYxDK+i6L+rMb7eGATGaakD6ZucfHVk=
github.com/gogs/minwinsvc v0.0.0-20170301035411-95be6356811a/go.mod h1:TUIZ+29jodWQ8Gk6Pvtg4E09aMsc3C/VLZiVYfUhWQU=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
@ -907,6 +905,8 @@ gorm.io/driver/mysql v1.3.4 h1:/KoBMgsUHC3bExsekDcmNYaBnfH2WNeFuXqqrqMc98Q=
gorm.io/driver/mysql v1.3.4/go.mod h1:s4Tq0KmD0yhPGHbZEwg1VPlH0vT/GBHJZorPzhcxBUE=
gorm.io/driver/postgres v1.3.7 h1:FKF6sIMDHDEvvMF/XJvbnCl0nu6KSKUaPXevJ4r+VYQ=
gorm.io/driver/postgres v1.3.7/go.mod h1:f02ympjIcgtHEGFMZvdgTxODZ9snAHDb4hXfigBVuNI=
gorm.io/driver/sqlite v1.3.4 h1:NnFOPVfzi4CPsJPH4wXr6rMkPb4ElHEqKMvrsx9c9Fk=
gorm.io/driver/sqlite v1.3.4/go.mod h1:B+8GyC9K7VgzJAcrcXMRPdnMcck+8FgJynEehEPM16U=
gorm.io/driver/sqlserver v1.3.1 h1:F5t6ScMzOgy1zukRTIZgLZwKahgt3q1woAILVolKpOI=
gorm.io/driver/sqlserver v1.3.1/go.mod h1:w25Vrx2BG+CJNUu/xKbFhaKlGxT/nzRkhWCCoptX8tQ=
gorm.io/gorm v1.23.1/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=

View File

@ -13,6 +13,7 @@ import (
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"gogs.io/gogs/internal/dbtest"
"gogs.io/gogs/internal/errutil"
)
@ -50,7 +51,7 @@ func TestAccessTokens(t *testing.T) {
tables := []interface{}{new(AccessToken)}
db := &accessTokens{
DB: initTestDB(t, "accessTokens", tables...),
DB: dbtest.NewDB(t, "accessTokens", tables...),
}
for _, tc := range []struct {

View File

@ -20,6 +20,7 @@ import (
"gogs.io/gogs/internal/auth/github"
"gogs.io/gogs/internal/auth/pam"
"gogs.io/gogs/internal/cryptoutil"
"gogs.io/gogs/internal/dbtest"
"gogs.io/gogs/internal/lfsutil"
"gogs.io/gogs/internal/testutil"
)
@ -35,7 +36,7 @@ func TestDumpAndImport(t *testing.T) {
t.Fatalf("New table has added (want 4 got %d), please add new tests for the table and update this check", len(Tables))
}
db := initTestDB(t, "dumpAndImport", Tables...)
db := dbtest.NewDB(t, "dumpAndImport", Tables...)
setupDBToDump(t, db)
dumpTables(t, db)
importTables(t, db)

View File

@ -11,10 +11,6 @@ import (
"time"
"github.com/pkg/errors"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
@ -24,73 +20,6 @@ import (
"gogs.io/gogs/internal/dbutil"
)
// parsePostgreSQLHostPort parses given input in various forms defined in
// https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
// and returns proper host and port number.
func parsePostgreSQLHostPort(info string) (host, port string) {
host, port = "127.0.0.1", "5432"
if strings.Contains(info, ":") && !strings.HasSuffix(info, "]") {
idx := strings.LastIndex(info, ":")
host = info[:idx]
port = info[idx+1:]
} else if len(info) > 0 {
host = info
}
return host, port
}
func parseMSSQLHostPort(info string) (host, port string) {
host, port = "127.0.0.1", "1433"
if strings.Contains(info, ":") {
host = strings.Split(info, ":")[0]
port = strings.Split(info, ":")[1]
} else if strings.Contains(info, ",") {
host = strings.Split(info, ",")[0]
port = strings.TrimSpace(strings.Split(info, ",")[1])
} else if len(info) > 0 {
host = info
}
return host, port
}
// newDSN takes given database options and returns parsed DSN.
func newDSN(opts conf.DatabaseOpts) (dsn string, err error) {
// In case the database name contains "?" with some parameters
concate := "?"
if strings.Contains(opts.Name, concate) {
concate = "&"
}
switch opts.Type {
case "mysql":
if opts.Host[0] == '/' { // Looks like a unix socket
dsn = fmt.Sprintf("%s:%s@unix(%s)/%s%scharset=utf8mb4&parseTime=true",
opts.User, opts.Password, opts.Host, opts.Name, concate)
} else {
dsn = fmt.Sprintf("%s:%s@tcp(%s)/%s%scharset=utf8mb4&parseTime=true",
opts.User, opts.Password, opts.Host, opts.Name, concate)
}
case "postgres":
host, port := parsePostgreSQLHostPort(opts.Host)
dsn = fmt.Sprintf("user='%s' password='%s' host='%s' port='%s' dbname='%s' sslmode='%s' search_path='%s' application_name='gogs'",
opts.User, opts.Password, host, port, opts.Name, opts.SSLMode, opts.Schema)
case "mssql":
host, port := parseMSSQLHostPort(opts.Host)
dsn = fmt.Sprintf("server=%s; port=%s; database=%s; user id=%s; password=%s;",
host, port, opts.Name, opts.User, opts.Password)
case "sqlite3", "sqlite":
dsn = "file:" + opts.Path + "?cache=shared&mode=rwc"
default:
return "", errors.Errorf("unrecognized dialect: %s", opts.Type)
}
return dsn, nil
}
func newLogWriter() (logger.Writer, error) {
sec := conf.File.Section("log.gorm")
w, err := log.NewFileWriter(
@ -108,32 +37,6 @@ func newLogWriter() (logger.Writer, error) {
return &dbutil.Logger{Writer: w}, nil
}
func openDB(opts conf.DatabaseOpts, cfg *gorm.Config) (*gorm.DB, error) {
dsn, err := newDSN(opts)
if err != nil {
return nil, errors.Wrap(err, "parse DSN")
}
var dialector gorm.Dialector
switch opts.Type {
case "mysql":
dialector = mysql.Open(dsn)
case "postgres":
dialector = postgres.Open(dsn)
case "mssql":
dialector = sqlserver.Open(dsn)
case "sqlite3":
dialector = sqlite.Open(dsn)
case "sqlite":
dialector = sqlite.Open(dsn)
dialector.(*sqlite.Dialector).DriverName = "sqlite"
default:
panic("unreachable")
}
return gorm.Open(dialector, cfg)
}
// Tables is the list of struct-to-table mappings.
//
// NOTE: Lines are sorted in alphabetical order, each letter in its own line.
@ -155,7 +58,7 @@ func Init(w logger.Writer) (*gorm.DB, error) {
LogLevel: level,
})
db, err := openDB(
db, err := dbutil.OpenDB(
conf.Database,
&gorm.Config{
SkipDefaultTransaction: true,

View File

@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gogs.io/gogs/internal/dbtest"
"gogs.io/gogs/internal/errutil"
"gogs.io/gogs/internal/lfsutil"
)
@ -25,7 +26,7 @@ func TestLFS(t *testing.T) {
tables := []interface{}{new(LFSObject)}
db := &lfs{
DB: initTestDB(t, "lfs", tables...),
DB: dbtest.NewDB(t, "lfs", tables...),
}
for _, tc := range []struct {

View File

@ -17,6 +17,7 @@ import (
"gogs.io/gogs/internal/auth"
"gogs.io/gogs/internal/auth/github"
"gogs.io/gogs/internal/auth/pam"
"gogs.io/gogs/internal/dbtest"
"gogs.io/gogs/internal/errutil"
)
@ -85,7 +86,7 @@ func Test_loginSources(t *testing.T) {
tables := []interface{}{new(LoginSource), new(User)}
db := &loginSources{
DB: initTestDB(t, "loginSources", tables...),
DB: dbtest.NewDB(t, "loginSources", tables...),
}
for _, tc := range []struct {

View File

@ -5,22 +5,16 @@
package db
import (
"database/sql"
"flag"
"fmt"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
_ "modernc.org/sqlite"
log "unknwon.dev/clog/v2"
"gogs.io/gogs/internal/conf"
"gogs.io/gogs/internal/testutil"
)
@ -60,128 +54,3 @@ func clearTables(t *testing.T, db *gorm.DB, tables ...interface{}) error {
}
return nil
}
func initTestDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB {
dbType := os.Getenv("GOGS_DATABASE_TYPE")
var dbName string
var dbOpts conf.DatabaseOpts
var cleanup func(db *gorm.DB)
switch dbType {
case "mysql":
dbOpts = conf.DatabaseOpts{
Type: "mysql",
Host: os.ExpandEnv("$MYSQL_HOST:$MYSQL_PORT"),
Name: dbName,
User: os.Getenv("MYSQL_USER"),
Password: os.Getenv("MYSQL_PASSWORD"),
}
dsn, err := newDSN(dbOpts)
require.NoError(t, err)
sqlDB, err := sql.Open("mysql", dsn)
require.NoError(t, err)
// Set up test database
dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
_, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName))
require.NoError(t, err)
_, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dbName))
require.NoError(t, err)
dbOpts.Name = dbName
cleanup = func(db *gorm.DB) {
db.Exec(fmt.Sprintf("DROP DATABASE `%s`", dbName))
_ = sqlDB.Close()
}
case "postgres":
dbOpts = conf.DatabaseOpts{
Type: "postgres",
Host: os.ExpandEnv("$PGHOST:$PGPORT"),
Name: dbName,
Schema: "public",
User: os.Getenv("PGUSER"),
Password: os.Getenv("PGPASSWORD"),
SSLMode: os.Getenv("PGSSLMODE"),
}
dsn, err := newDSN(dbOpts)
require.NoError(t, err)
sqlDB, err := sql.Open("pgx", dsn)
require.NoError(t, err)
// Set up test database
dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
_, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName))
require.NoError(t, err)
_, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE %q", dbName))
require.NoError(t, err)
dbOpts.Name = dbName
cleanup = func(db *gorm.DB) {
db.Exec(fmt.Sprintf(`DROP DATABASE %q`, dbName))
_ = sqlDB.Close()
}
case "sqlite":
dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
dbOpts = conf.DatabaseOpts{
Type: "sqlite",
Path: dbName,
}
cleanup = func(db *gorm.DB) {
sqlDB, err := db.DB()
if err == nil {
_ = sqlDB.Close()
}
_ = os.Remove(dbName)
}
default:
dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
dbOpts = conf.DatabaseOpts{
Type: "sqlite3",
Path: dbName,
}
cleanup = func(db *gorm.DB) {
sqlDB, err := db.DB()
if err == nil {
_ = sqlDB.Close()
}
_ = os.Remove(dbName)
}
}
now := time.Now().UTC().Truncate(time.Second)
db, err := openDB(
dbOpts,
&gorm.Config{
SkipDefaultTransaction: true,
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
},
NowFunc: func() time.Time {
return now
},
},
)
require.NoError(t, err)
t.Cleanup(func() {
if t.Failed() {
t.Logf("Database %q left intact for inspection", dbName)
return
}
cleanup(db)
})
err = db.Migrator().AutoMigrate(tables...)
require.NoError(t, err)
return db
}

View File

@ -0,0 +1,40 @@
// Copyright 2022 The Gogs Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package migrations
import (
"flag"
"fmt"
"os"
"testing"
"gorm.io/gorm/logger"
_ "modernc.org/sqlite"
log "unknwon.dev/clog/v2"
"gogs.io/gogs/internal/testutil"
)
func TestMain(m *testing.M) {
flag.Parse()
level := logger.Silent
if !testing.Verbose() {
// Remove the primary logger and register a noop logger.
log.Remove(log.DefaultConsoleName)
err := log.New("noop", testutil.InitNoopLogger)
if err != nil {
fmt.Println(err)
os.Exit(1)
}
} else {
level = logger.Info
}
// NOTE: AutoMigrate does not respect logger passed in gorm.Config.
logger.Default = logger.Default.LogMode(level)
os.Exit(m.Run())
}

View File

@ -5,11 +5,9 @@
package migrations
import (
"fmt"
"github.com/pkg/errors"
"gorm.io/gorm"
log "unknwon.dev/clog/v2"
"xorm.io/xorm"
)
const minDBVersion = 19
@ -58,29 +56,32 @@ var migrations = []Migration{
NewMigration("migrate access tokens to store SHA56", migrateAccessTokenToSHA256),
}
// Migrate database to current version
func Migrate(x *xorm.Engine, db *gorm.DB) error {
if err := x.Sync(new(Version)); err != nil {
return fmt.Errorf("sync: %v", err)
}
currentVersion := &Version{ID: 1}
has, err := x.Get(currentVersion)
// Migrate migrates the database schema and/or data to the current version.
func Migrate(db *gorm.DB) error {
err := db.AutoMigrate(new(Version))
if err != nil {
return fmt.Errorf("get: %v", err)
} else if !has {
// If the version record does not exist we think
// it is a fresh installation and we can skip all migrations.
currentVersion.ID = 0
currentVersion.Version = int64(minDBVersion + len(migrations))
if _, err = x.InsertOne(currentVersion); err != nil {
return fmt.Errorf("insert: %v", err)
}
return errors.Wrap(err, `auto migrate "version" table`)
}
v := currentVersion.Version
if minDBVersion > v {
var current Version
err = db.Where("id = ?", 1).First(&current).Error
if err == gorm.ErrRecordNotFound {
err = db.Create(
&Version{
ID: 1,
Version: int64(minDBVersion + len(migrations)),
},
).Error
if err != nil {
return errors.Wrap(err, "create the version record")
}
return nil
} else if err != nil {
return errors.Wrap(err, "get the version record")
}
if minDBVersion > current.Version {
log.Fatal(`
Hi there, thank you for using Gogs for so long!
However, Gogs has stopped supporting auto-migration from your previously installed version.
@ -108,20 +109,22 @@ In case you're stilling getting this notice, go through instructions again until
return nil
}
if int(v-minDBVersion) > len(migrations) {
if int(current.Version-minDBVersion) > len(migrations) {
// User downgraded Gogs.
currentVersion.Version = int64(len(migrations) + minDBVersion)
_, err = x.Id(1).Update(currentVersion)
return err
current.Version = int64(len(migrations) + minDBVersion)
return db.Where("id = ?", current.ID).Updates(current).Error
}
for i, m := range migrations[v-minDBVersion:] {
for i, m := range migrations[current.Version-minDBVersion:] {
log.Info("Migration: %s", m.Description())
if err = m.Migrate(db); err != nil {
return fmt.Errorf("do migrate: %v", err)
return errors.Wrap(err, "do migrate")
}
currentVersion.Version = v + int64(i) + 1
if _, err = x.Id(1).Update(currentVersion); err != nil {
return err
current.Version += int64(i) + 1
err = db.Where("id = ?", current.ID).Updates(current).Error
if err != nil {
return errors.Wrap(err, "update the version record")
}
}
return nil

View File

@ -0,0 +1,70 @@
// Copyright 2022 The Gogs Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package migrations
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gogs.io/gogs/internal/dbtest"
)
type accessTokenPreV20 struct {
ID int64
UserID int64 `gorm:"COLUMN:uid;INDEX"`
Name string
Sha1 string `gorm:"TYPE:VARCHAR(40);UNIQUE"`
CreatedUnix int64
UpdatedUnix int64
}
func (*accessTokenPreV20) TableName() string {
return "access_token"
}
type accessTokenV20 struct {
ID int64
UserID int64 `gorm:"column:uid;index"`
Name string
Sha1 string `gorm:"type:VARCHAR(40);unique"`
SHA256 string `gorm:"type:VARCHAR(64);unique;not null"`
CreatedUnix int64
UpdatedUnix int64
}
func (*accessTokenV20) TableName() string {
return "access_token"
}
func TestMigrateAccessTokenToSHA256(t *testing.T) {
if testing.Short() {
t.Skip()
}
t.Parallel()
db := dbtest.NewDB(t, "migrateAccessTokenToSHA256", new(accessTokenPreV20))
err := db.Create(
&accessTokenPreV20{
ID: 1,
UserID: 1,
Name: "test",
Sha1: "73da7bb9d2a475bbc2ab79da7d4e94940cb9f9d5",
CreatedUnix: db.NowFunc().Unix(),
UpdatedUnix: db.NowFunc().Unix(),
},
).Error
require.NoError(t, err)
err = migrateAccessTokenToSHA256(db)
require.NoError(t, err)
var got accessTokenV20
err = db.Where("id = ?", 1).First(&got).Error
require.NoError(t, err)
assert.Equal(t, "73da7bb9d2a475bbc2ab79da7d4e94940cb9f9d5", got.Sha1)
assert.Equal(t, "ab144c7bd170691bb9bb995f1541c608e33a78b40174f30fc8a1616c0bc3a477", got.SHA256)
}

View File

@ -89,14 +89,14 @@ func getEngine() (*xorm.Engine, error) {
case "postgres":
conf.UsePostgreSQL = true
host, port := parsePostgreSQLHostPort(conf.Database.Host)
host, port := dbutil.ParsePostgreSQLHostPort(conf.Database.Host)
connStr = fmt.Sprintf("user='%s' password='%s' host='%s' port='%s' dbname='%s' sslmode='%s' search_path='%s'",
conf.Database.User, conf.Database.Password, host, port, conf.Database.Name, conf.Database.SSLMode, conf.Database.Schema)
driver = "pgx"
case "mssql":
conf.UseMSSQL = true
host, port := parseMSSQLHostPort(conf.Database.Host)
host, port := dbutil.ParseMSSQLHostPort(conf.Database.Host)
connStr = fmt.Sprintf("server=%s; port=%s; database=%s; user id=%s; password=%s;", host, port, conf.Database.Name, conf.Database.User, conf.Database.Password)
case "sqlite3":
@ -187,7 +187,7 @@ func NewEngine() (err error) {
return err
}
if err = migrations.Migrate(x, db); err != nil {
if err = migrations.Migrate(db); err != nil {
return fmt.Errorf("migrate: %v", err)
}

View File

@ -10,6 +10,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gogs.io/gogs/internal/dbtest"
)
func TestPerms(t *testing.T) {
@ -21,7 +23,7 @@ func TestPerms(t *testing.T) {
tables := []interface{}{new(Access)}
db := &perms{
DB: initTestDB(t, "perms", tables...),
DB: dbtest.NewDB(t, "perms", tables...),
}
for _, tc := range []struct {

View File

@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gogs.io/gogs/internal/dbtest"
"gogs.io/gogs/internal/errutil"
)
@ -24,7 +25,7 @@ func TestRepos(t *testing.T) {
tables := []interface{}{new(Repository)}
db := &repos{
DB: initTestDB(t, "repos", tables...),
DB: dbtest.NewDB(t, "repos", tables...),
}
for _, tc := range []struct {

View File

@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gogs.io/gogs/internal/dbtest"
"gogs.io/gogs/internal/errutil"
)
@ -24,7 +25,7 @@ func TestTwoFactors(t *testing.T) {
tables := []interface{}{new(TwoFactor), new(TwoFactorRecoveryCode)}
db := &twoFactors{
DB: initTestDB(t, "twoFactors", tables...),
DB: dbtest.NewDB(t, "twoFactors", tables...),
}
for _, tc := range []struct {

View File

@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/require"
"gogs.io/gogs/internal/auth"
"gogs.io/gogs/internal/dbtest"
"gogs.io/gogs/internal/errutil"
)
@ -26,7 +27,7 @@ func TestUsers(t *testing.T) {
tables := []interface{}{new(User), new(EmailAddress)}
db := &users{
DB: initTestDB(t, "users", tables...),
DB: dbtest.NewDB(t, "users", tables...),
}
for _, tc := range []struct {

149
internal/dbtest/dbtest.go Normal file
View File

@ -0,0 +1,149 @@
// Copyright 2020 The Gogs Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbtest
import (
"database/sql"
"fmt"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"gogs.io/gogs/internal/conf"
"gogs.io/gogs/internal/dbutil"
)
// NewDB creates a new test database and initializes the given list of tables
// for the suite. The test database is dropped after testing is completed unless
// failed.
func NewDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB {
dbType := os.Getenv("GOGS_DATABASE_TYPE")
var dbName string
var dbOpts conf.DatabaseOpts
var cleanup func(db *gorm.DB)
switch dbType {
case "mysql":
dbOpts = conf.DatabaseOpts{
Type: "mysql",
Host: os.ExpandEnv("$MYSQL_HOST:$MYSQL_PORT"),
Name: dbName,
User: os.Getenv("MYSQL_USER"),
Password: os.Getenv("MYSQL_PASSWORD"),
}
dsn, err := dbutil.NewDSN(dbOpts)
require.NoError(t, err)
sqlDB, err := sql.Open("mysql", dsn)
require.NoError(t, err)
// Set up test database
dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
_, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName))
require.NoError(t, err)
_, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dbName))
require.NoError(t, err)
dbOpts.Name = dbName
cleanup = func(db *gorm.DB) {
db.Exec(fmt.Sprintf("DROP DATABASE `%s`", dbName))
_ = sqlDB.Close()
}
case "postgres":
dbOpts = conf.DatabaseOpts{
Type: "postgres",
Host: os.ExpandEnv("$PGHOST:$PGPORT"),
Name: dbName,
Schema: "public",
User: os.Getenv("PGUSER"),
Password: os.Getenv("PGPASSWORD"),
SSLMode: os.Getenv("PGSSLMODE"),
}
dsn, err := dbutil.NewDSN(dbOpts)
require.NoError(t, err)
sqlDB, err := sql.Open("pgx", dsn)
require.NoError(t, err)
// Set up test database
dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix())
_, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName))
require.NoError(t, err)
_, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE %q", dbName))
require.NoError(t, err)
dbOpts.Name = dbName
cleanup = func(db *gorm.DB) {
db.Exec(fmt.Sprintf(`DROP DATABASE %q`, dbName))
_ = sqlDB.Close()
}
case "sqlite":
dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
dbOpts = conf.DatabaseOpts{
Type: "sqlite",
Path: dbName,
}
cleanup = func(db *gorm.DB) {
sqlDB, err := db.DB()
if err == nil {
_ = sqlDB.Close()
}
_ = os.Remove(dbName)
}
default:
dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix()))
dbOpts = conf.DatabaseOpts{
Type: "sqlite3",
Path: dbName,
}
cleanup = func(db *gorm.DB) {
sqlDB, err := db.DB()
if err == nil {
_ = sqlDB.Close()
}
_ = os.Remove(dbName)
}
}
now := time.Now().UTC().Truncate(time.Second)
db, err := dbutil.OpenDB(
dbOpts,
&gorm.Config{
SkipDefaultTransaction: true,
NamingStrategy: schema.NamingStrategy{
SingularTable: true,
},
NowFunc: func() time.Time {
return now
},
},
)
require.NoError(t, err)
t.Cleanup(func() {
if t.Failed() {
t.Logf("Database %q left intact for inspection", dbName)
return
}
cleanup(db)
})
err = db.Migrator().AutoMigrate(tables...)
require.NoError(t, err)
return db
}

116
internal/dbutil/dsn.go Normal file
View File

@ -0,0 +1,116 @@
// Copyright 2020 The Gogs Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbutil
import (
"fmt"
"strings"
"github.com/pkg/errors"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
"gorm.io/driver/sqlserver"
"gorm.io/gorm"
"gogs.io/gogs/internal/conf"
)
// ParsePostgreSQLHostPort parses given input in various forms defined in
// https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
// and returns proper host and port number.
func ParsePostgreSQLHostPort(info string) (host, port string) {
host, port = "127.0.0.1", "5432"
if strings.Contains(info, ":") && !strings.HasSuffix(info, "]") {
idx := strings.LastIndex(info, ":")
host = info[:idx]
port = info[idx+1:]
} else if len(info) > 0 {
host = info
}
return host, port
}
// ParseMSSQLHostPort parses given input in various forms for MSSQL and returns
// proper host and port number.
func ParseMSSQLHostPort(info string) (host, port string) {
host, port = "127.0.0.1", "1433"
if strings.Contains(info, ":") {
host = strings.Split(info, ":")[0]
port = strings.Split(info, ":")[1]
} else if strings.Contains(info, ",") {
host = strings.Split(info, ",")[0]
port = strings.TrimSpace(strings.Split(info, ",")[1])
} else if len(info) > 0 {
host = info
}
return host, port
}
// NewDSN takes given database options and returns parsed DSN.
func NewDSN(opts conf.DatabaseOpts) (dsn string, err error) {
// In case the database name contains "?" with some parameters
concate := "?"
if strings.Contains(opts.Name, concate) {
concate = "&"
}
switch opts.Type {
case "mysql":
if opts.Host[0] == '/' { // Looks like a unix socket
dsn = fmt.Sprintf("%s:%s@unix(%s)/%s%scharset=utf8mb4&parseTime=true",
opts.User, opts.Password, opts.Host, opts.Name, concate)
} else {
dsn = fmt.Sprintf("%s:%s@tcp(%s)/%s%scharset=utf8mb4&parseTime=true",
opts.User, opts.Password, opts.Host, opts.Name, concate)
}
case "postgres":
host, port := ParsePostgreSQLHostPort(opts.Host)
dsn = fmt.Sprintf("user='%s' password='%s' host='%s' port='%s' dbname='%s' sslmode='%s' search_path='%s' application_name='gogs'",
opts.User, opts.Password, host, port, opts.Name, opts.SSLMode, opts.Schema)
case "mssql":
host, port := ParseMSSQLHostPort(opts.Host)
dsn = fmt.Sprintf("server=%s; port=%s; database=%s; user id=%s; password=%s;",
host, port, opts.Name, opts.User, opts.Password)
case "sqlite3", "sqlite":
dsn = "file:" + opts.Path + "?cache=shared&mode=rwc"
default:
return "", errors.Errorf("unrecognized dialect: %s", opts.Type)
}
return dsn, nil
}
// OpenDB opens a new database connection encapsulated as gorm.DB using given
// database options and GORM config.
func OpenDB(opts conf.DatabaseOpts, cfg *gorm.Config) (*gorm.DB, error) {
dsn, err := NewDSN(opts)
if err != nil {
return nil, errors.Wrap(err, "parse DSN")
}
var dialector gorm.Dialector
switch opts.Type {
case "mysql":
dialector = mysql.Open(dsn)
case "postgres":
dialector = postgres.Open(dsn)
case "mssql":
dialector = sqlserver.Open(dsn)
case "sqlite3":
dialector = sqlite.Open(dsn)
case "sqlite":
dialector = sqlite.Open(dsn)
dialector.(*sqlite.Dialector).DriverName = "sqlite"
default:
panic("unreachable")
}
return gorm.Open(dialector, cfg)
}

View File

@ -2,18 +2,19 @@
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
package dbutil
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gogs.io/gogs/internal/conf"
)
func Test_parsePostgreSQLHostPort(t *testing.T) {
func TestParsePostgreSQLHostPort(t *testing.T) {
tests := []struct {
info string
expHost string
@ -28,14 +29,14 @@ func Test_parsePostgreSQLHostPort(t *testing.T) {
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
host, port := parsePostgreSQLHostPort(test.info)
host, port := ParsePostgreSQLHostPort(test.info)
assert.Equal(t, test.expHost, host)
assert.Equal(t, test.expPort, port)
})
}
}
func Test_parseMSSQLHostPort(t *testing.T) {
func TestParseMSSQLHostPort(t *testing.T) {
tests := []struct {
info string
expHost string
@ -47,16 +48,16 @@ func Test_parseMSSQLHostPort(t *testing.T) {
}
for _, test := range tests {
t.Run("", func(t *testing.T) {
host, port := parseMSSQLHostPort(test.info)
host, port := ParseMSSQLHostPort(test.info)
assert.Equal(t, test.expHost, host)
assert.Equal(t, test.expPort, port)
})
}
}
func Test_parseDSN(t *testing.T) {
func TestNewDSN(t *testing.T) {
t.Run("bad dialect", func(t *testing.T) {
_, err := newDSN(conf.DatabaseOpts{
_, err := NewDSN(conf.DatabaseOpts{
Type: "bad_dialect",
})
assert.Equal(t, "unrecognized dialect: bad_dialect", fmt.Sprintf("%v", err))
@ -140,10 +141,8 @@ func Test_parseDSN(t *testing.T) {
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
dsn, err := newDSN(test.opts)
if err != nil {
t.Fatal(err)
}
dsn, err := NewDSN(test.opts)
require.NoError(t, err)
assert.Equal(t, test.wantDSN, dsn)
})
}

View File

@ -1,3 +1,7 @@
// Copyright 2020 The Gogs Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package dbutil
import (