diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 03cf8dfdf..296732696 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -116,7 +116,7 @@ jobs: - name: Checkout code uses: actions/checkout@v2 - name: Run tests with coverage - run: go test -v -race -coverprofile=coverage -covermode=atomic ./internal/db + run: go test -v -race -coverprofile=coverage -covermode=atomic ./internal/db/... env: GOGS_DATABASE_TYPE: postgres PGPORT: 5432 @@ -142,7 +142,7 @@ jobs: - name: Checkout code uses: actions/checkout@v2 - name: Run tests with coverage - run: go test -v -race -coverprofile=coverage -covermode=atomic ./internal/db + run: go test -v -race -coverprofile=coverage -covermode=atomic ./internal/db/... env: GOGS_DATABASE_TYPE: mysql MYSQL_USER: root @@ -165,6 +165,6 @@ jobs: - name: Checkout code uses: actions/checkout@v2 - name: Run tests with coverage - run: go test -v -race -parallel=1 -coverprofile=coverage -covermode=atomic ./internal/db + run: go test -v -race -parallel=1 -coverprofile=coverage -covermode=atomic ./internal/db/... env: GOGS_DATABASE_TYPE: sqlite diff --git a/go.mod b/go.mod index 186cb5468..2a81440d3 100644 --- a/go.mod +++ b/go.mod @@ -62,7 +62,7 @@ require ( gopkg.in/macaron.v1 v1.4.0 gorm.io/driver/mysql v1.3.4 gorm.io/driver/postgres v1.3.7 - gorm.io/driver/sqlite v1.3.2 + gorm.io/driver/sqlite v1.3.4 gorm.io/driver/sqlserver v1.3.1 gorm.io/gorm v1.23.5 modernc.org/sqlite v1.17.3 @@ -72,13 +72,5 @@ require ( xorm.io/xorm v0.8.0 ) -// Temporary replace directives -// ============================ -// These entries indicate temporary replace directives due to a pending pull request upstream -// or issues with specific versions. - -// https://github.com/gogs/gogs/issues/7019 -replace gorm.io/driver/sqlite => github.com/gogs/gorm-sqlite v1.3.3-0.20220608111034-1298ceb93369 - // +heroku goVersion go1.16 // +heroku install ./ diff --git a/go.sum b/go.sum index 95b3fdffc..9b1c77290 100644 --- a/go.sum +++ b/go.sum @@ -173,8 +173,6 @@ github.com/gogs/go-gogs-client v0.0.0-20200128182646-c69cb7680fd4 h1:C7NryI/RQhs github.com/gogs/go-gogs-client v0.0.0-20200128182646-c69cb7680fd4/go.mod h1:fR6z1Ie6rtF7kl/vBYMfgD5/G5B1blui7z426/sj2DU= github.com/gogs/go-libravatar v0.0.0-20191106065024-33a75213d0a0 h1:K02vod+sn3M1OOkdqi2tPxN2+xESK4qyITVQ3JkGEv4= github.com/gogs/go-libravatar v0.0.0-20191106065024-33a75213d0a0/go.mod h1:Zas3BtO88pk1cwUfEYlvnl/CRwh0ybDxRWSwRjG8I3w= -github.com/gogs/gorm-sqlite v1.3.3-0.20220608111034-1298ceb93369 h1:nlpH50ShzqBGnZwul13EYJc8l5quxdTkBmzewuLZHgs= -github.com/gogs/gorm-sqlite v1.3.3-0.20220608111034-1298ceb93369/go.mod h1:B+8GyC9K7VgzJAcrcXMRPdnMcck+8FgJynEehEPM16U= github.com/gogs/minwinsvc v0.0.0-20170301035411-95be6356811a h1:8DZwxETOVWIinYxDK+i6L+rMb7eGATGaakD6ZucfHVk= github.com/gogs/minwinsvc v0.0.0-20170301035411-95be6356811a/go.mod h1:TUIZ+29jodWQ8Gk6Pvtg4E09aMsc3C/VLZiVYfUhWQU= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= @@ -907,6 +905,8 @@ gorm.io/driver/mysql v1.3.4 h1:/KoBMgsUHC3bExsekDcmNYaBnfH2WNeFuXqqrqMc98Q= gorm.io/driver/mysql v1.3.4/go.mod h1:s4Tq0KmD0yhPGHbZEwg1VPlH0vT/GBHJZorPzhcxBUE= gorm.io/driver/postgres v1.3.7 h1:FKF6sIMDHDEvvMF/XJvbnCl0nu6KSKUaPXevJ4r+VYQ= gorm.io/driver/postgres v1.3.7/go.mod h1:f02ympjIcgtHEGFMZvdgTxODZ9snAHDb4hXfigBVuNI= +gorm.io/driver/sqlite v1.3.4 h1:NnFOPVfzi4CPsJPH4wXr6rMkPb4ElHEqKMvrsx9c9Fk= +gorm.io/driver/sqlite v1.3.4/go.mod h1:B+8GyC9K7VgzJAcrcXMRPdnMcck+8FgJynEehEPM16U= gorm.io/driver/sqlserver v1.3.1 h1:F5t6ScMzOgy1zukRTIZgLZwKahgt3q1woAILVolKpOI= gorm.io/driver/sqlserver v1.3.1/go.mod h1:w25Vrx2BG+CJNUu/xKbFhaKlGxT/nzRkhWCCoptX8tQ= gorm.io/gorm v1.23.1/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= diff --git a/internal/db/access_tokens_test.go b/internal/db/access_tokens_test.go index b135a7b75..733f79137 100644 --- a/internal/db/access_tokens_test.go +++ b/internal/db/access_tokens_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/require" "gorm.io/gorm" + "gogs.io/gogs/internal/dbtest" "gogs.io/gogs/internal/errutil" ) @@ -50,7 +51,7 @@ func TestAccessTokens(t *testing.T) { tables := []interface{}{new(AccessToken)} db := &accessTokens{ - DB: initTestDB(t, "accessTokens", tables...), + DB: dbtest.NewDB(t, "accessTokens", tables...), } for _, tc := range []struct { diff --git a/internal/db/backup_test.go b/internal/db/backup_test.go index 047a2dcae..1221dac36 100644 --- a/internal/db/backup_test.go +++ b/internal/db/backup_test.go @@ -20,6 +20,7 @@ import ( "gogs.io/gogs/internal/auth/github" "gogs.io/gogs/internal/auth/pam" "gogs.io/gogs/internal/cryptoutil" + "gogs.io/gogs/internal/dbtest" "gogs.io/gogs/internal/lfsutil" "gogs.io/gogs/internal/testutil" ) @@ -35,7 +36,7 @@ func TestDumpAndImport(t *testing.T) { t.Fatalf("New table has added (want 4 got %d), please add new tests for the table and update this check", len(Tables)) } - db := initTestDB(t, "dumpAndImport", Tables...) + db := dbtest.NewDB(t, "dumpAndImport", Tables...) setupDBToDump(t, db) dumpTables(t, db) importTables(t, db) diff --git a/internal/db/db.go b/internal/db/db.go index e67ffde9e..9845bda7f 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -11,10 +11,6 @@ import ( "time" "github.com/pkg/errors" - "gorm.io/driver/mysql" - "gorm.io/driver/postgres" - "gorm.io/driver/sqlite" - "gorm.io/driver/sqlserver" "gorm.io/gorm" "gorm.io/gorm/logger" "gorm.io/gorm/schema" @@ -24,73 +20,6 @@ import ( "gogs.io/gogs/internal/dbutil" ) -// parsePostgreSQLHostPort parses given input in various forms defined in -// https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING -// and returns proper host and port number. -func parsePostgreSQLHostPort(info string) (host, port string) { - host, port = "127.0.0.1", "5432" - if strings.Contains(info, ":") && !strings.HasSuffix(info, "]") { - idx := strings.LastIndex(info, ":") - host = info[:idx] - port = info[idx+1:] - } else if len(info) > 0 { - host = info - } - return host, port -} - -func parseMSSQLHostPort(info string) (host, port string) { - host, port = "127.0.0.1", "1433" - if strings.Contains(info, ":") { - host = strings.Split(info, ":")[0] - port = strings.Split(info, ":")[1] - } else if strings.Contains(info, ",") { - host = strings.Split(info, ",")[0] - port = strings.TrimSpace(strings.Split(info, ",")[1]) - } else if len(info) > 0 { - host = info - } - return host, port -} - -// newDSN takes given database options and returns parsed DSN. -func newDSN(opts conf.DatabaseOpts) (dsn string, err error) { - // In case the database name contains "?" with some parameters - concate := "?" - if strings.Contains(opts.Name, concate) { - concate = "&" - } - - switch opts.Type { - case "mysql": - if opts.Host[0] == '/' { // Looks like a unix socket - dsn = fmt.Sprintf("%s:%s@unix(%s)/%s%scharset=utf8mb4&parseTime=true", - opts.User, opts.Password, opts.Host, opts.Name, concate) - } else { - dsn = fmt.Sprintf("%s:%s@tcp(%s)/%s%scharset=utf8mb4&parseTime=true", - opts.User, opts.Password, opts.Host, opts.Name, concate) - } - - case "postgres": - host, port := parsePostgreSQLHostPort(opts.Host) - dsn = fmt.Sprintf("user='%s' password='%s' host='%s' port='%s' dbname='%s' sslmode='%s' search_path='%s' application_name='gogs'", - opts.User, opts.Password, host, port, opts.Name, opts.SSLMode, opts.Schema) - - case "mssql": - host, port := parseMSSQLHostPort(opts.Host) - dsn = fmt.Sprintf("server=%s; port=%s; database=%s; user id=%s; password=%s;", - host, port, opts.Name, opts.User, opts.Password) - - case "sqlite3", "sqlite": - dsn = "file:" + opts.Path + "?cache=shared&mode=rwc" - - default: - return "", errors.Errorf("unrecognized dialect: %s", opts.Type) - } - - return dsn, nil -} - func newLogWriter() (logger.Writer, error) { sec := conf.File.Section("log.gorm") w, err := log.NewFileWriter( @@ -108,32 +37,6 @@ func newLogWriter() (logger.Writer, error) { return &dbutil.Logger{Writer: w}, nil } -func openDB(opts conf.DatabaseOpts, cfg *gorm.Config) (*gorm.DB, error) { - dsn, err := newDSN(opts) - if err != nil { - return nil, errors.Wrap(err, "parse DSN") - } - - var dialector gorm.Dialector - switch opts.Type { - case "mysql": - dialector = mysql.Open(dsn) - case "postgres": - dialector = postgres.Open(dsn) - case "mssql": - dialector = sqlserver.Open(dsn) - case "sqlite3": - dialector = sqlite.Open(dsn) - case "sqlite": - dialector = sqlite.Open(dsn) - dialector.(*sqlite.Dialector).DriverName = "sqlite" - default: - panic("unreachable") - } - - return gorm.Open(dialector, cfg) -} - // Tables is the list of struct-to-table mappings. // // NOTE: Lines are sorted in alphabetical order, each letter in its own line. @@ -155,7 +58,7 @@ func Init(w logger.Writer) (*gorm.DB, error) { LogLevel: level, }) - db, err := openDB( + db, err := dbutil.OpenDB( conf.Database, &gorm.Config{ SkipDefaultTransaction: true, diff --git a/internal/db/lfs_test.go b/internal/db/lfs_test.go index e650369c9..07518361b 100644 --- a/internal/db/lfs_test.go +++ b/internal/db/lfs_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gogs.io/gogs/internal/dbtest" "gogs.io/gogs/internal/errutil" "gogs.io/gogs/internal/lfsutil" ) @@ -25,7 +26,7 @@ func TestLFS(t *testing.T) { tables := []interface{}{new(LFSObject)} db := &lfs{ - DB: initTestDB(t, "lfs", tables...), + DB: dbtest.NewDB(t, "lfs", tables...), } for _, tc := range []struct { diff --git a/internal/db/login_sources_test.go b/internal/db/login_sources_test.go index ad09a8db2..c33bbf052 100644 --- a/internal/db/login_sources_test.go +++ b/internal/db/login_sources_test.go @@ -17,6 +17,7 @@ import ( "gogs.io/gogs/internal/auth" "gogs.io/gogs/internal/auth/github" "gogs.io/gogs/internal/auth/pam" + "gogs.io/gogs/internal/dbtest" "gogs.io/gogs/internal/errutil" ) @@ -85,7 +86,7 @@ func Test_loginSources(t *testing.T) { tables := []interface{}{new(LoginSource), new(User)} db := &loginSources{ - DB: initTestDB(t, "loginSources", tables...), + DB: dbtest.NewDB(t, "loginSources", tables...), } for _, tc := range []struct { diff --git a/internal/db/main_test.go b/internal/db/main_test.go index d5598c589..bc55a0d0e 100644 --- a/internal/db/main_test.go +++ b/internal/db/main_test.go @@ -5,22 +5,16 @@ package db import ( - "database/sql" "flag" "fmt" "os" - "path/filepath" "testing" - "time" - "github.com/stretchr/testify/require" "gorm.io/gorm" "gorm.io/gorm/logger" - "gorm.io/gorm/schema" _ "modernc.org/sqlite" log "unknwon.dev/clog/v2" - "gogs.io/gogs/internal/conf" "gogs.io/gogs/internal/testutil" ) @@ -60,128 +54,3 @@ func clearTables(t *testing.T, db *gorm.DB, tables ...interface{}) error { } return nil } - -func initTestDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB { - dbType := os.Getenv("GOGS_DATABASE_TYPE") - - var dbName string - var dbOpts conf.DatabaseOpts - var cleanup func(db *gorm.DB) - switch dbType { - case "mysql": - dbOpts = conf.DatabaseOpts{ - Type: "mysql", - Host: os.ExpandEnv("$MYSQL_HOST:$MYSQL_PORT"), - Name: dbName, - User: os.Getenv("MYSQL_USER"), - Password: os.Getenv("MYSQL_PASSWORD"), - } - - dsn, err := newDSN(dbOpts) - require.NoError(t, err) - - sqlDB, err := sql.Open("mysql", dsn) - require.NoError(t, err) - - // Set up test database - dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix()) - _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName)) - require.NoError(t, err) - - _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dbName)) - require.NoError(t, err) - - dbOpts.Name = dbName - - cleanup = func(db *gorm.DB) { - db.Exec(fmt.Sprintf("DROP DATABASE `%s`", dbName)) - _ = sqlDB.Close() - } - case "postgres": - dbOpts = conf.DatabaseOpts{ - Type: "postgres", - Host: os.ExpandEnv("$PGHOST:$PGPORT"), - Name: dbName, - Schema: "public", - User: os.Getenv("PGUSER"), - Password: os.Getenv("PGPASSWORD"), - SSLMode: os.Getenv("PGSSLMODE"), - } - - dsn, err := newDSN(dbOpts) - require.NoError(t, err) - - sqlDB, err := sql.Open("pgx", dsn) - require.NoError(t, err) - - // Set up test database - dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix()) - _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName)) - require.NoError(t, err) - - _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE %q", dbName)) - require.NoError(t, err) - - dbOpts.Name = dbName - - cleanup = func(db *gorm.DB) { - db.Exec(fmt.Sprintf(`DROP DATABASE %q`, dbName)) - _ = sqlDB.Close() - } - case "sqlite": - dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix())) - dbOpts = conf.DatabaseOpts{ - Type: "sqlite", - Path: dbName, - } - cleanup = func(db *gorm.DB) { - sqlDB, err := db.DB() - if err == nil { - _ = sqlDB.Close() - } - _ = os.Remove(dbName) - } - default: - dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix())) - dbOpts = conf.DatabaseOpts{ - Type: "sqlite3", - Path: dbName, - } - cleanup = func(db *gorm.DB) { - sqlDB, err := db.DB() - if err == nil { - _ = sqlDB.Close() - } - _ = os.Remove(dbName) - } - } - - now := time.Now().UTC().Truncate(time.Second) - db, err := openDB( - dbOpts, - &gorm.Config{ - SkipDefaultTransaction: true, - NamingStrategy: schema.NamingStrategy{ - SingularTable: true, - }, - NowFunc: func() time.Time { - return now - }, - }, - ) - require.NoError(t, err) - - t.Cleanup(func() { - if t.Failed() { - t.Logf("Database %q left intact for inspection", dbName) - return - } - - cleanup(db) - }) - - err = db.Migrator().AutoMigrate(tables...) - require.NoError(t, err) - - return db -} diff --git a/internal/db/migrations/main_test.go b/internal/db/migrations/main_test.go new file mode 100644 index 000000000..ee13c3e3e --- /dev/null +++ b/internal/db/migrations/main_test.go @@ -0,0 +1,40 @@ +// Copyright 2022 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package migrations + +import ( + "flag" + "fmt" + "os" + "testing" + + "gorm.io/gorm/logger" + _ "modernc.org/sqlite" + log "unknwon.dev/clog/v2" + + "gogs.io/gogs/internal/testutil" +) + +func TestMain(m *testing.M) { + flag.Parse() + + level := logger.Silent + if !testing.Verbose() { + // Remove the primary logger and register a noop logger. + log.Remove(log.DefaultConsoleName) + err := log.New("noop", testutil.InitNoopLogger) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + } else { + level = logger.Info + } + + // NOTE: AutoMigrate does not respect logger passed in gorm.Config. + logger.Default = logger.Default.LogMode(level) + + os.Exit(m.Run()) +} diff --git a/internal/db/migrations/migrations.go b/internal/db/migrations/migrations.go index 1e89883f8..4f93b9c5b 100644 --- a/internal/db/migrations/migrations.go +++ b/internal/db/migrations/migrations.go @@ -5,11 +5,9 @@ package migrations import ( - "fmt" - + "github.com/pkg/errors" "gorm.io/gorm" log "unknwon.dev/clog/v2" - "xorm.io/xorm" ) const minDBVersion = 19 @@ -58,29 +56,32 @@ var migrations = []Migration{ NewMigration("migrate access tokens to store SHA56", migrateAccessTokenToSHA256), } -// Migrate database to current version -func Migrate(x *xorm.Engine, db *gorm.DB) error { - if err := x.Sync(new(Version)); err != nil { - return fmt.Errorf("sync: %v", err) - } - - currentVersion := &Version{ID: 1} - has, err := x.Get(currentVersion) +// Migrate migrates the database schema and/or data to the current version. +func Migrate(db *gorm.DB) error { + err := db.AutoMigrate(new(Version)) if err != nil { - return fmt.Errorf("get: %v", err) - } else if !has { - // If the version record does not exist we think - // it is a fresh installation and we can skip all migrations. - currentVersion.ID = 0 - currentVersion.Version = int64(minDBVersion + len(migrations)) - - if _, err = x.InsertOne(currentVersion); err != nil { - return fmt.Errorf("insert: %v", err) - } + return errors.Wrap(err, `auto migrate "version" table`) } - v := currentVersion.Version - if minDBVersion > v { + var current Version + err = db.Where("id = ?", 1).First(¤t).Error + if err == gorm.ErrRecordNotFound { + err = db.Create( + &Version{ + ID: 1, + Version: int64(minDBVersion + len(migrations)), + }, + ).Error + if err != nil { + return errors.Wrap(err, "create the version record") + } + return nil + + } else if err != nil { + return errors.Wrap(err, "get the version record") + } + + if minDBVersion > current.Version { log.Fatal(` Hi there, thank you for using Gogs for so long! However, Gogs has stopped supporting auto-migration from your previously installed version. @@ -108,20 +109,22 @@ In case you're stilling getting this notice, go through instructions again until return nil } - if int(v-minDBVersion) > len(migrations) { + if int(current.Version-minDBVersion) > len(migrations) { // User downgraded Gogs. - currentVersion.Version = int64(len(migrations) + minDBVersion) - _, err = x.Id(1).Update(currentVersion) - return err + current.Version = int64(len(migrations) + minDBVersion) + return db.Where("id = ?", current.ID).Updates(current).Error } - for i, m := range migrations[v-minDBVersion:] { + + for i, m := range migrations[current.Version-minDBVersion:] { log.Info("Migration: %s", m.Description()) if err = m.Migrate(db); err != nil { - return fmt.Errorf("do migrate: %v", err) + return errors.Wrap(err, "do migrate") } - currentVersion.Version = v + int64(i) + 1 - if _, err = x.Id(1).Update(currentVersion); err != nil { - return err + + current.Version += int64(i) + 1 + err = db.Where("id = ?", current.ID).Updates(current).Error + if err != nil { + return errors.Wrap(err, "update the version record") } } return nil diff --git a/internal/db/migrations/v20_test.go b/internal/db/migrations/v20_test.go new file mode 100644 index 000000000..b95360de7 --- /dev/null +++ b/internal/db/migrations/v20_test.go @@ -0,0 +1,70 @@ +// Copyright 2022 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package migrations + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "gogs.io/gogs/internal/dbtest" +) + +type accessTokenPreV20 struct { + ID int64 + UserID int64 `gorm:"COLUMN:uid;INDEX"` + Name string + Sha1 string `gorm:"TYPE:VARCHAR(40);UNIQUE"` + CreatedUnix int64 + UpdatedUnix int64 +} + +func (*accessTokenPreV20) TableName() string { + return "access_token" +} + +type accessTokenV20 struct { + ID int64 + UserID int64 `gorm:"column:uid;index"` + Name string + Sha1 string `gorm:"type:VARCHAR(40);unique"` + SHA256 string `gorm:"type:VARCHAR(64);unique;not null"` + CreatedUnix int64 + UpdatedUnix int64 +} + +func (*accessTokenV20) TableName() string { + return "access_token" +} + +func TestMigrateAccessTokenToSHA256(t *testing.T) { + if testing.Short() { + t.Skip() + } + t.Parallel() + + db := dbtest.NewDB(t, "migrateAccessTokenToSHA256", new(accessTokenPreV20)) + err := db.Create( + &accessTokenPreV20{ + ID: 1, + UserID: 1, + Name: "test", + Sha1: "73da7bb9d2a475bbc2ab79da7d4e94940cb9f9d5", + CreatedUnix: db.NowFunc().Unix(), + UpdatedUnix: db.NowFunc().Unix(), + }, + ).Error + require.NoError(t, err) + + err = migrateAccessTokenToSHA256(db) + require.NoError(t, err) + + var got accessTokenV20 + err = db.Where("id = ?", 1).First(&got).Error + require.NoError(t, err) + assert.Equal(t, "73da7bb9d2a475bbc2ab79da7d4e94940cb9f9d5", got.Sha1) + assert.Equal(t, "ab144c7bd170691bb9bb995f1541c608e33a78b40174f30fc8a1616c0bc3a477", got.SHA256) +} diff --git a/internal/db/models.go b/internal/db/models.go index f0e4d5642..d892aeb34 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -89,14 +89,14 @@ func getEngine() (*xorm.Engine, error) { case "postgres": conf.UsePostgreSQL = true - host, port := parsePostgreSQLHostPort(conf.Database.Host) + host, port := dbutil.ParsePostgreSQLHostPort(conf.Database.Host) connStr = fmt.Sprintf("user='%s' password='%s' host='%s' port='%s' dbname='%s' sslmode='%s' search_path='%s'", conf.Database.User, conf.Database.Password, host, port, conf.Database.Name, conf.Database.SSLMode, conf.Database.Schema) driver = "pgx" case "mssql": conf.UseMSSQL = true - host, port := parseMSSQLHostPort(conf.Database.Host) + host, port := dbutil.ParseMSSQLHostPort(conf.Database.Host) connStr = fmt.Sprintf("server=%s; port=%s; database=%s; user id=%s; password=%s;", host, port, conf.Database.Name, conf.Database.User, conf.Database.Password) case "sqlite3": @@ -187,7 +187,7 @@ func NewEngine() (err error) { return err } - if err = migrations.Migrate(x, db); err != nil { + if err = migrations.Migrate(db); err != nil { return fmt.Errorf("migrate: %v", err) } diff --git a/internal/db/perms_test.go b/internal/db/perms_test.go index d7ddb6d55..7182783d4 100644 --- a/internal/db/perms_test.go +++ b/internal/db/perms_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "gogs.io/gogs/internal/dbtest" ) func TestPerms(t *testing.T) { @@ -21,7 +23,7 @@ func TestPerms(t *testing.T) { tables := []interface{}{new(Access)} db := &perms{ - DB: initTestDB(t, "perms", tables...), + DB: dbtest.NewDB(t, "perms", tables...), } for _, tc := range []struct { diff --git a/internal/db/repos_test.go b/internal/db/repos_test.go index 324825063..0f7617dec 100644 --- a/internal/db/repos_test.go +++ b/internal/db/repos_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gogs.io/gogs/internal/dbtest" "gogs.io/gogs/internal/errutil" ) @@ -24,7 +25,7 @@ func TestRepos(t *testing.T) { tables := []interface{}{new(Repository)} db := &repos{ - DB: initTestDB(t, "repos", tables...), + DB: dbtest.NewDB(t, "repos", tables...), } for _, tc := range []struct { diff --git a/internal/db/two_factors_test.go b/internal/db/two_factors_test.go index acd9a576f..935844d47 100644 --- a/internal/db/two_factors_test.go +++ b/internal/db/two_factors_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gogs.io/gogs/internal/dbtest" "gogs.io/gogs/internal/errutil" ) @@ -24,7 +25,7 @@ func TestTwoFactors(t *testing.T) { tables := []interface{}{new(TwoFactor), new(TwoFactorRecoveryCode)} db := &twoFactors{ - DB: initTestDB(t, "twoFactors", tables...), + DB: dbtest.NewDB(t, "twoFactors", tables...), } for _, tc := range []struct { diff --git a/internal/db/users_test.go b/internal/db/users_test.go index 299a1be60..d4945aca6 100644 --- a/internal/db/users_test.go +++ b/internal/db/users_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/require" "gogs.io/gogs/internal/auth" + "gogs.io/gogs/internal/dbtest" "gogs.io/gogs/internal/errutil" ) @@ -26,7 +27,7 @@ func TestUsers(t *testing.T) { tables := []interface{}{new(User), new(EmailAddress)} db := &users{ - DB: initTestDB(t, "users", tables...), + DB: dbtest.NewDB(t, "users", tables...), } for _, tc := range []struct { diff --git a/internal/dbtest/dbtest.go b/internal/dbtest/dbtest.go new file mode 100644 index 000000000..1183d732e --- /dev/null +++ b/internal/dbtest/dbtest.go @@ -0,0 +1,149 @@ +// Copyright 2020 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbtest + +import ( + "database/sql" + "fmt" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "gorm.io/gorm/schema" + + "gogs.io/gogs/internal/conf" + "gogs.io/gogs/internal/dbutil" +) + +// NewDB creates a new test database and initializes the given list of tables +// for the suite. The test database is dropped after testing is completed unless +// failed. +func NewDB(t *testing.T, suite string, tables ...interface{}) *gorm.DB { + dbType := os.Getenv("GOGS_DATABASE_TYPE") + + var dbName string + var dbOpts conf.DatabaseOpts + var cleanup func(db *gorm.DB) + switch dbType { + case "mysql": + dbOpts = conf.DatabaseOpts{ + Type: "mysql", + Host: os.ExpandEnv("$MYSQL_HOST:$MYSQL_PORT"), + Name: dbName, + User: os.Getenv("MYSQL_USER"), + Password: os.Getenv("MYSQL_PASSWORD"), + } + + dsn, err := dbutil.NewDSN(dbOpts) + require.NoError(t, err) + + sqlDB, err := sql.Open("mysql", dsn) + require.NoError(t, err) + + // Set up test database + dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix()) + _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`", dbName)) + require.NoError(t, err) + + _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE `%s`", dbName)) + require.NoError(t, err) + + dbOpts.Name = dbName + + cleanup = func(db *gorm.DB) { + db.Exec(fmt.Sprintf("DROP DATABASE `%s`", dbName)) + _ = sqlDB.Close() + } + case "postgres": + dbOpts = conf.DatabaseOpts{ + Type: "postgres", + Host: os.ExpandEnv("$PGHOST:$PGPORT"), + Name: dbName, + Schema: "public", + User: os.Getenv("PGUSER"), + Password: os.Getenv("PGPASSWORD"), + SSLMode: os.Getenv("PGSSLMODE"), + } + + dsn, err := dbutil.NewDSN(dbOpts) + require.NoError(t, err) + + sqlDB, err := sql.Open("pgx", dsn) + require.NoError(t, err) + + // Set up test database + dbName = fmt.Sprintf("gogs-%s-%d", suite, time.Now().Unix()) + _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName)) + require.NoError(t, err) + + _, err = sqlDB.Exec(fmt.Sprintf("CREATE DATABASE %q", dbName)) + require.NoError(t, err) + + dbOpts.Name = dbName + + cleanup = func(db *gorm.DB) { + db.Exec(fmt.Sprintf(`DROP DATABASE %q`, dbName)) + _ = sqlDB.Close() + } + case "sqlite": + dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix())) + dbOpts = conf.DatabaseOpts{ + Type: "sqlite", + Path: dbName, + } + cleanup = func(db *gorm.DB) { + sqlDB, err := db.DB() + if err == nil { + _ = sqlDB.Close() + } + _ = os.Remove(dbName) + } + default: + dbName = filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%s-%d.db", suite, time.Now().Unix())) + dbOpts = conf.DatabaseOpts{ + Type: "sqlite3", + Path: dbName, + } + cleanup = func(db *gorm.DB) { + sqlDB, err := db.DB() + if err == nil { + _ = sqlDB.Close() + } + _ = os.Remove(dbName) + } + } + + now := time.Now().UTC().Truncate(time.Second) + db, err := dbutil.OpenDB( + dbOpts, + &gorm.Config{ + SkipDefaultTransaction: true, + NamingStrategy: schema.NamingStrategy{ + SingularTable: true, + }, + NowFunc: func() time.Time { + return now + }, + }, + ) + require.NoError(t, err) + + t.Cleanup(func() { + if t.Failed() { + t.Logf("Database %q left intact for inspection", dbName) + return + } + + cleanup(db) + }) + + err = db.Migrator().AutoMigrate(tables...) + require.NoError(t, err) + + return db +} diff --git a/internal/dbutil/dsn.go b/internal/dbutil/dsn.go new file mode 100644 index 000000000..302799884 --- /dev/null +++ b/internal/dbutil/dsn.go @@ -0,0 +1,116 @@ +// Copyright 2020 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package dbutil + +import ( + "fmt" + "strings" + + "github.com/pkg/errors" + "gorm.io/driver/mysql" + "gorm.io/driver/postgres" + "gorm.io/driver/sqlite" + "gorm.io/driver/sqlserver" + "gorm.io/gorm" + + "gogs.io/gogs/internal/conf" +) + +// ParsePostgreSQLHostPort parses given input in various forms defined in +// https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING +// and returns proper host and port number. +func ParsePostgreSQLHostPort(info string) (host, port string) { + host, port = "127.0.0.1", "5432" + if strings.Contains(info, ":") && !strings.HasSuffix(info, "]") { + idx := strings.LastIndex(info, ":") + host = info[:idx] + port = info[idx+1:] + } else if len(info) > 0 { + host = info + } + return host, port +} + +// ParseMSSQLHostPort parses given input in various forms for MSSQL and returns +// proper host and port number. +func ParseMSSQLHostPort(info string) (host, port string) { + host, port = "127.0.0.1", "1433" + if strings.Contains(info, ":") { + host = strings.Split(info, ":")[0] + port = strings.Split(info, ":")[1] + } else if strings.Contains(info, ",") { + host = strings.Split(info, ",")[0] + port = strings.TrimSpace(strings.Split(info, ",")[1]) + } else if len(info) > 0 { + host = info + } + return host, port +} + +// NewDSN takes given database options and returns parsed DSN. +func NewDSN(opts conf.DatabaseOpts) (dsn string, err error) { + // In case the database name contains "?" with some parameters + concate := "?" + if strings.Contains(opts.Name, concate) { + concate = "&" + } + + switch opts.Type { + case "mysql": + if opts.Host[0] == '/' { // Looks like a unix socket + dsn = fmt.Sprintf("%s:%s@unix(%s)/%s%scharset=utf8mb4&parseTime=true", + opts.User, opts.Password, opts.Host, opts.Name, concate) + } else { + dsn = fmt.Sprintf("%s:%s@tcp(%s)/%s%scharset=utf8mb4&parseTime=true", + opts.User, opts.Password, opts.Host, opts.Name, concate) + } + + case "postgres": + host, port := ParsePostgreSQLHostPort(opts.Host) + dsn = fmt.Sprintf("user='%s' password='%s' host='%s' port='%s' dbname='%s' sslmode='%s' search_path='%s' application_name='gogs'", + opts.User, opts.Password, host, port, opts.Name, opts.SSLMode, opts.Schema) + + case "mssql": + host, port := ParseMSSQLHostPort(opts.Host) + dsn = fmt.Sprintf("server=%s; port=%s; database=%s; user id=%s; password=%s;", + host, port, opts.Name, opts.User, opts.Password) + + case "sqlite3", "sqlite": + dsn = "file:" + opts.Path + "?cache=shared&mode=rwc" + + default: + return "", errors.Errorf("unrecognized dialect: %s", opts.Type) + } + + return dsn, nil +} + +// OpenDB opens a new database connection encapsulated as gorm.DB using given +// database options and GORM config. +func OpenDB(opts conf.DatabaseOpts, cfg *gorm.Config) (*gorm.DB, error) { + dsn, err := NewDSN(opts) + if err != nil { + return nil, errors.Wrap(err, "parse DSN") + } + + var dialector gorm.Dialector + switch opts.Type { + case "mysql": + dialector = mysql.Open(dsn) + case "postgres": + dialector = postgres.Open(dsn) + case "mssql": + dialector = sqlserver.Open(dsn) + case "sqlite3": + dialector = sqlite.Open(dsn) + case "sqlite": + dialector = sqlite.Open(dsn) + dialector.(*sqlite.Dialector).DriverName = "sqlite" + default: + panic("unreachable") + } + + return gorm.Open(dialector, cfg) +} diff --git a/internal/db/db_test.go b/internal/dbutil/dsn_test.go similarity index 90% rename from internal/db/db_test.go rename to internal/dbutil/dsn_test.go index 65f0c067f..0dd929015 100644 --- a/internal/db/db_test.go +++ b/internal/dbutil/dsn_test.go @@ -2,18 +2,19 @@ // Use of this source code is governed by a MIT-style // license that can be found in the LICENSE file. -package db +package dbutil import ( "fmt" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gogs.io/gogs/internal/conf" ) -func Test_parsePostgreSQLHostPort(t *testing.T) { +func TestParsePostgreSQLHostPort(t *testing.T) { tests := []struct { info string expHost string @@ -28,14 +29,14 @@ func Test_parsePostgreSQLHostPort(t *testing.T) { } for _, test := range tests { t.Run("", func(t *testing.T) { - host, port := parsePostgreSQLHostPort(test.info) + host, port := ParsePostgreSQLHostPort(test.info) assert.Equal(t, test.expHost, host) assert.Equal(t, test.expPort, port) }) } } -func Test_parseMSSQLHostPort(t *testing.T) { +func TestParseMSSQLHostPort(t *testing.T) { tests := []struct { info string expHost string @@ -47,16 +48,16 @@ func Test_parseMSSQLHostPort(t *testing.T) { } for _, test := range tests { t.Run("", func(t *testing.T) { - host, port := parseMSSQLHostPort(test.info) + host, port := ParseMSSQLHostPort(test.info) assert.Equal(t, test.expHost, host) assert.Equal(t, test.expPort, port) }) } } -func Test_parseDSN(t *testing.T) { +func TestNewDSN(t *testing.T) { t.Run("bad dialect", func(t *testing.T) { - _, err := newDSN(conf.DatabaseOpts{ + _, err := NewDSN(conf.DatabaseOpts{ Type: "bad_dialect", }) assert.Equal(t, "unrecognized dialect: bad_dialect", fmt.Sprintf("%v", err)) @@ -140,10 +141,8 @@ func Test_parseDSN(t *testing.T) { } for _, test := range tests { t.Run(test.name, func(t *testing.T) { - dsn, err := newDSN(test.opts) - if err != nil { - t.Fatal(err) - } + dsn, err := NewDSN(test.opts) + require.NoError(t, err) assert.Equal(t, test.wantDSN, dsn) }) } diff --git a/internal/dbutil/logger.go b/internal/dbutil/logger.go index 66426ae7c..c189949ea 100644 --- a/internal/dbutil/logger.go +++ b/internal/dbutil/logger.go @@ -1,3 +1,7 @@ +// Copyright 2020 The Gogs Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + package dbutil import (