access_token: migrate to GORM and add tests (#6086)

* access_token: migrate to GORM

* Add tests

* Fix tests

* Fix test clock
pull/6087/head
ᴜɴᴋɴᴡᴏɴ 2020-04-11 01:25:19 +08:00 committed by GitHub
parent 5753d4cb87
commit 62dda96159
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 356 additions and 141 deletions

View File

@ -6,7 +6,6 @@ package auth
import (
"strings"
"time"
"github.com/go-macaron/session"
gouuid "github.com/satori/go.uuid"
@ -55,7 +54,6 @@ func SignedInID(c *macaron.Context, sess session.Store) (_ int64, isTokenAuth bo
}
return 0, false
}
t.Updated = time.Now()
if err = db.AccessTokens.Save(t); err != nil {
log.Error("UpdateAccessToken: %v", err)
}

View File

@ -6,27 +6,111 @@ package db
import (
"fmt"
"time"
"github.com/jinzhu/gorm"
gouuid "github.com/satori/go.uuid"
"gogs.io/gogs/internal/errutil"
"gogs.io/gogs/internal/tool"
)
// AccessTokensStore is the persistent interface for access tokens.
//
// NOTE: All methods are sorted in alphabetical order.
type AccessTokensStore interface {
// Create creates a new access token and persist to database.
// It returns ErrAccessTokenAlreadyExist when an access token
// with same name already exists for the user.
Create(userID int64, name string) (*AccessToken, error)
// DeleteByID deletes the access token by given ID.
// 🚨 SECURITY: The "userID" is required to prevent attacker
// deletes arbitrary access token that belongs to another user.
DeleteByID(userID, id int64) error
// GetBySHA returns the access token with given SHA1.
// It returns ErrAccessTokenNotExist when not found.
GetBySHA(sha string) (*AccessToken, error)
// List returns all access tokens belongs to given user.
List(userID int64) ([]*AccessToken, error)
// Save persists all values of given access token.
// The Updated field is set to current time automatically.
Save(t *AccessToken) error
}
var AccessTokens AccessTokensStore
// AccessToken is a personal access token.
type AccessToken struct {
ID int64
UserID int64 `xorm:"uid INDEX" gorm:"COLUMN:uid;INDEX"`
Name string
Sha1 string `xorm:"UNIQUE VARCHAR(40)" gorm:"TYPE:VARCHAR(40);UNIQUE"`
Created time.Time `xorm:"-" gorm:"-" json:"-"`
CreatedUnix int64
Updated time.Time `xorm:"-" gorm:"-" json:"-"`
UpdatedUnix int64
HasRecentActivity bool `xorm:"-" gorm:"-" json:"-"`
HasUsed bool `xorm:"-" gorm:"-" json:"-"`
}
// NOTE: This is a GORM create hook.
func (t *AccessToken) BeforeCreate() {
t.CreatedUnix = t.Created.Unix()
}
// NOTE: This is a GORM update hook.
func (t *AccessToken) BeforeUpdate() {
t.UpdatedUnix = t.Updated.Unix()
}
// NOTE: This is a GORM query hook.
func (t *AccessToken) AfterFind() {
t.Created = time.Unix(t.CreatedUnix, 0).Local()
t.Updated = time.Unix(t.UpdatedUnix, 0).Local()
t.HasUsed = t.Updated.After(t.Created)
t.HasRecentActivity = t.Updated.Add(7 * 24 * time.Hour).After(time.Now())
}
var _ AccessTokensStore = (*accessTokens)(nil)
type accessTokens struct {
*gorm.DB
clock func() time.Time
}
type ErrAccessTokenAlreadyExist struct {
args errutil.Args
}
func IsErrAccessTokenAlreadyExist(err error) bool {
_, ok := err.(ErrAccessTokenAlreadyExist)
return ok
}
func (err ErrAccessTokenAlreadyExist) Error() string {
return fmt.Sprintf("access token already exists: %v", err.args)
}
func (db *accessTokens) Create(userID int64, name string) (*AccessToken, error) {
err := db.Where("uid = ? AND name = ?", userID, name).First(new(AccessToken)).Error
if err == nil {
return nil, ErrAccessTokenAlreadyExist{args: errutil.Args{"userID": userID, "name": name}}
} else if !gorm.IsRecordNotFoundError(err) {
return nil, err
}
token := &AccessToken{
UserID: userID,
Name: name,
Sha1: tool.SHA1(gouuid.NewV4().String()),
Created: db.clock(),
}
return token, db.DB.Create(token).Error
}
func (db *accessTokens) DeleteByID(userID, id int64) error {
return db.Where("id = ? AND uid = ?", id, userID).Delete(new(AccessToken)).Error
}
var _ errutil.NotFound = (*ErrAccessTokenNotExist)(nil)
@ -60,6 +144,12 @@ func (db *accessTokens) GetBySHA(sha string) (*AccessToken, error) {
return token, nil
}
func (db *accessTokens) List(userID int64) ([]*AccessToken, error) {
var tokens []*AccessToken
return tokens, db.Where("uid = ?", userID).Find(&tokens).Error
}
func (db *accessTokens) Save(t *AccessToken) error {
t.Updated = db.clock()
return db.DB.Save(t).Error
}

View File

@ -0,0 +1,201 @@
// Copyright 2020 The Gogs Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"fmt"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/assert"
"gogs.io/gogs/internal/conf"
"gogs.io/gogs/internal/errutil"
)
func Test_accessTokens(t *testing.T) {
if testing.Short() {
t.Skip()
}
t.Parallel()
dbpath := filepath.Join(os.TempDir(), fmt.Sprintf("gogs-%d.db", time.Now().Unix()))
gdb, err := openDB(conf.DatabaseOpts{
Type: "sqlite3",
Path: dbpath,
})
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
_ = gdb.Close()
if t.Failed() {
t.Logf("Database %q left intact for inspection", dbpath)
return
}
_ = os.Remove(dbpath)
})
err = gdb.AutoMigrate(new(AccessToken)).Error
if err != nil {
t.Fatal(err)
}
now := time.Now().Truncate(time.Second)
clock := func() time.Time { return now }
db := &accessTokens{DB: gdb, clock: clock}
for _, tc := range []struct {
name string
test func(*testing.T, *accessTokens)
}{
{"Create", test_accessTokens_Create},
{"DeleteByID", test_accessTokens_DeleteByID},
{"GetBySHA", test_accessTokens_GetBySHA},
{"List", test_accessTokens_List},
{"Save", test_accessTokens_Save},
} {
t.Run(tc.name, func(t *testing.T) {
t.Cleanup(func() {
err := deleteTables(gdb, new(AccessToken))
if err != nil {
t.Fatal(err)
}
})
tc.test(t, db)
})
}
}
func test_accessTokens_Create(t *testing.T, db *accessTokens) {
// Create first access token with name "Test"
token, err := db.Create(1, "Test")
if err != nil {
t.Fatal(err)
}
assert.Equal(t, int64(1), token.UserID)
assert.Equal(t, "Test", token.Name)
assert.Equal(t, 40, len(token.Sha1), "sha1 length")
assert.Equal(t, db.clock(), token.Created)
// Try create second access token with same name should fail
_, err = db.Create(token.UserID, token.Name)
expErr := ErrAccessTokenAlreadyExist{args: errutil.Args{"userID": token.UserID, "name": token.Name}}
assert.Equal(t, expErr, err)
}
func test_accessTokens_DeleteByID(t *testing.T, db *accessTokens) {
// Create an access token with name "Test"
token, err := db.Create(1, "Test")
if err != nil {
t.Fatal(err)
}
// We should be able to get it back
_, err = db.GetBySHA(token.Sha1)
if err != nil {
t.Fatal(err)
}
// Delete a token with mismatched user ID is noop
err = db.DeleteByID(2, token.ID)
if err != nil {
t.Fatal(err)
}
_, err = db.GetBySHA(token.Sha1)
if err != nil {
t.Fatal(err)
}
// Now delete this token with correct user ID
err = db.DeleteByID(token.UserID, token.ID)
if err != nil {
t.Fatal(err)
}
// We should get token not found error
_, err = db.GetBySHA(token.Sha1)
expErr := ErrAccessTokenNotExist{args: errutil.Args{"sha": token.Sha1}}
assert.Equal(t, expErr, err)
}
func test_accessTokens_GetBySHA(t *testing.T, db *accessTokens) {
// Create an access token with name "Test"
token, err := db.Create(1, "Test")
if err != nil {
t.Fatal(err)
}
// We should be able to get it back
_, err = db.GetBySHA(token.Sha1)
if err != nil {
t.Fatal(err)
}
// Try to get a non-existent token
_, err = db.GetBySHA("bad_sha")
expErr := ErrAccessTokenNotExist{args: errutil.Args{"sha": "bad_sha"}}
assert.Equal(t, expErr, err)
}
func test_accessTokens_List(t *testing.T, db *accessTokens) {
// Create two access tokens for user 1
_, err := db.Create(1, "user1_1")
if err != nil {
t.Fatal(err)
}
_, err = db.Create(1, "user1_2")
if err != nil {
t.Fatal(err)
}
// Create one access token for user 2
_, err = db.Create(2, "user2_1")
if err != nil {
t.Fatal(err)
}
// List all access tokens for user 1
tokens, err := db.List(1)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, 2, len(tokens), "number of tokens")
assert.Equal(t, int64(1), tokens[0].UserID)
assert.Equal(t, "user1_1", tokens[0].Name)
assert.Equal(t, int64(1), tokens[1].UserID)
assert.Equal(t, "user1_2", tokens[1].Name)
}
func test_accessTokens_Save(t *testing.T, db *accessTokens) {
// Create an access token with name "Test"
token, err := db.Create(1, "Test")
if err != nil {
t.Fatal(err)
}
// Updated field is zero now
assert.True(t, token.Updated.IsZero())
err = db.Save(token)
if err != nil {
t.Fatal(err)
}
// Get back from DB should have Updated set
token, err = db.GetBySHA(token.Sha1)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, db.clock(), token.Updated)
}

View File

@ -122,6 +122,11 @@ func getLogWriter() (io.Writer, error) {
return w, nil
}
var tables = []interface{}{
new(AccessToken),
new(LFSObject),
}
func Init() error {
db, err := openDB(conf.Database)
if err != nil {
@ -150,16 +155,17 @@ func Init() error {
case "mssql":
conf.UseMSSQL = true
case "sqlite3":
conf.UseMySQL = true
conf.UseSQLite3 = true
}
err = db.AutoMigrate(new(LFSObject)).Error
err = db.AutoMigrate(tables...).Error
if err != nil {
return errors.Wrap(err, "migrate schemes")
}
clock := func() time.Time {return time.Now().UTC().Truncate(time.Microsecond)}
// Initialize stores, sorted in alphabetical order.
AccessTokens = &accessTokens{DB: db}
AccessTokens = &accessTokens{DB: db, clock: clock}
LoginSources = &loginSources{DB: db}
LFS = &lfs{DB: db}
Perms = &perms{DB: db}

View File

@ -1,16 +0,0 @@
package errors
import "fmt"
type AccessTokenNameAlreadyExist struct {
Name string
}
func IsAccessTokenNameAlreadyExist(err error) bool {
_, ok := err.(AccessTokenNameAlreadyExist)
return ok
}
func (err AccessTokenNameAlreadyExist) Error() string {
return fmt.Sprintf("access token already exist [name: %s]", err.Name)
}

View File

@ -37,7 +37,7 @@ type lfs struct {
// LFSObject is the relation between an LFS object and a repository.
type LFSObject struct {
RepoID int64 `gorm:"PRIMARY_KEY;AUTO_INCREMENT:false"`
OID lfsutil.OID `gorm:"PRIMARY_KEY;column:oid"`
OID lfsutil.OID `gorm:"PRIMARY_KEY;COLUMN:oid"`
Size int64 `gorm:"NOT NULL"`
Storage lfsutil.Storage `gorm:"NOT NULL"`
CreatedAt time.Time `gorm:"NOT NULL"`

View File

@ -10,6 +10,7 @@ import (
"os"
"testing"
"github.com/jinzhu/gorm"
log "unknwon.dev/clog/v2"
"gogs.io/gogs/internal/testutil"
@ -28,3 +29,13 @@ func TestMain(m *testing.M) {
}
os.Exit(m.Run())
}
func deleteTables(db *gorm.DB, tables ...interface{}) error {
for _, t := range tables {
err := db.Delete(t).Error
if err != nil {
return err
}
}
return nil
}

View File

@ -15,14 +15,29 @@ import (
var _ AccessTokensStore = (*MockAccessTokensStore)(nil)
type MockAccessTokensStore struct {
MockCreate func(userID int64, name string) (*AccessToken, error)
MockDeleteByID func(userID, id int64) error
MockGetBySHA func(sha string) (*AccessToken, error)
MockList func(userID int64) ([]*AccessToken, error)
MockSave func(t *AccessToken) error
}
func (m *MockAccessTokensStore) Create(userID int64, name string) (*AccessToken, error) {
return m.MockCreate(userID, name)
}
func (m *MockAccessTokensStore) DeleteByID(userID, id int64) error {
return m.MockDeleteByID(userID, id)
}
func (m *MockAccessTokensStore) GetBySHA(sha string) (*AccessToken, error) {
return m.MockGetBySHA(sha)
}
func (m *MockAccessTokensStore) List(userID int64) ([]*AccessToken, error) {
return m.MockList(userID)
}
func (m *MockAccessTokensStore) Save(t *AccessToken) error {
return m.MockSave(t)
}

View File

@ -42,13 +42,13 @@ type Engine interface {
var (
x *xorm.Engine
tables []interface{}
legacyTables []interface{}
HasEngine bool
)
func init() {
tables = append(tables,
new(User), new(PublicKey), new(AccessToken), new(TwoFactor), new(TwoFactorRecoveryCode),
legacyTables = append(legacyTables,
new(User), new(PublicKey), new(TwoFactor), new(TwoFactorRecoveryCode),
new(Repository), new(DeployKey), new(Collaboration), new(Access), new(Upload),
new(Watch), new(Star), new(Follow), new(Action),
new(Issue), new(PullRequest), new(Comment), new(Attachment), new(IssueUser),
@ -120,7 +120,7 @@ func NewTestEngine() error {
}
x.SetMapper(core.GonicMapper{})
return x.StoreEngine("InnoDB").Sync2(tables...)
return x.StoreEngine("InnoDB").Sync2(legacyTables...)
}
func SetEngine() (err error) {
@ -167,7 +167,7 @@ func NewEngine() (err error) {
return fmt.Errorf("migrate: %v", err)
}
if err = x.StoreEngine("InnoDB").Sync2(tables...); err != nil {
if err = x.StoreEngine("InnoDB").Sync2(legacyTables...); err != nil {
return fmt.Errorf("sync structs to database tables: %v\n", err)
}
@ -227,8 +227,9 @@ func DumpDatabase(dirPath string) error {
}
// Purposely create a local variable to not modify global variable
tables := append(tables, new(Version))
for _, table := range tables {
allTables := append(legacyTables, new(Version))
allTables = append(allTables, tables...)
for _, table := range allTables {
tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.")
tableFile := path.Join(dirPath, tableName+".json")
f, err := os.Create(tableFile)
@ -257,8 +258,9 @@ func ImportDatabase(dirPath string, verbose bool) (err error) {
}
// Purposely create a local variable to not modify global variable
tables := append(tables, new(Version))
for _, table := range tables {
allTables := append(legacyTables, new(Version))
allTables = append(allTables, tables...)
for _, table := range allTables {
tableName := strings.TrimPrefix(fmt.Sprintf("%T", table), "*db.")
tableFile := path.Join(dirPath, tableName+".json")
if !com.IsExist(tableFile) {

View File

@ -1,81 +0,0 @@
// Copyright 2014 The Gogs Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"time"
gouuid "github.com/satori/go.uuid"
"xorm.io/xorm"
"gogs.io/gogs/internal/db/errors"
"gogs.io/gogs/internal/tool"
)
// AccessToken represents a personal access token.
type AccessToken struct {
ID int64
UserID int64 `xorm:"uid INDEX" gorm:"COLUMN:uid"`
Name string
Sha1 string `xorm:"UNIQUE VARCHAR(40)"`
Created time.Time `xorm:"-" gorm:"-" json:"-"`
CreatedUnix int64
Updated time.Time `xorm:"-" gorm:"-" json:"-"` // Note: Updated must below Created for AfterSet.
UpdatedUnix int64
HasRecentActivity bool `xorm:"-" gorm:"-" json:"-"`
HasUsed bool `xorm:"-" gorm:"-" json:"-"`
}
func (t *AccessToken) BeforeInsert() {
t.CreatedUnix = time.Now().Unix()
}
func (t *AccessToken) BeforeUpdate() {
t.UpdatedUnix = time.Now().Unix()
}
func (t *AccessToken) AfterSet(colName string, _ xorm.Cell) {
switch colName {
case "created_unix":
t.Created = time.Unix(t.CreatedUnix, 0).Local()
case "updated_unix":
t.Updated = time.Unix(t.UpdatedUnix, 0).Local()
t.HasUsed = t.Updated.After(t.Created)
t.HasRecentActivity = t.Updated.Add(7 * 24 * time.Hour).After(time.Now())
}
}
// NewAccessToken creates new access token.
func NewAccessToken(t *AccessToken) error {
t.Sha1 = tool.SHA1(gouuid.NewV4().String())
has, err := x.Get(&AccessToken{
UserID: t.UserID,
Name: t.Name,
})
if err != nil {
return err
} else if has {
return errors.AccessTokenNameAlreadyExist{Name: t.Name}
}
_, err = x.Insert(t)
return err
}
// ListAccessTokens returns a list of access tokens belongs to given user.
func ListAccessTokens(uid int64) ([]*AccessToken, error) {
tokens := make([]*AccessToken, 0, 5)
return tokens, x.Where("uid=?", uid).Desc("id").Find(&tokens)
}
// DeleteAccessTokenOfUserByID deletes access token by given ID.
func DeleteAccessTokenOfUserByID(userID, id int64) error {
_, err := x.Delete(&AccessToken{
ID: id,
UserID: userID,
})
return err
}

View File

@ -11,11 +11,10 @@ import (
"gogs.io/gogs/internal/context"
"gogs.io/gogs/internal/db"
"gogs.io/gogs/internal/db/errors"
)
func ListAccessTokens(c *context.APIContext) {
tokens, err := db.ListAccessTokens(c.User.ID)
tokens, err := db.AccessTokens.List(c.User.ID)
if err != nil {
c.Error(err, "list access tokens")
return
@ -29,12 +28,9 @@ func ListAccessTokens(c *context.APIContext) {
}
func CreateAccessToken(c *context.APIContext, form api.CreateAccessTokenOption) {
t := &db.AccessToken{
UserID: c.User.ID,
Name: form.Name,
}
if err := db.NewAccessToken(t); err != nil {
if errors.IsAccessTokenNameAlreadyExist(err) {
t, err := db.AccessTokens.Create(c.User.ID, form.Name)
if err != nil {
if db.IsErrAccessTokenAlreadyExist(err) {
c.ErrorStatus(http.StatusUnprocessableEntity, err)
} else {
c.Error(err, "new access token")

View File

@ -7,7 +7,6 @@ package lfs
import (
"net/http"
"strings"
"time"
"gopkg.in/macaron.v1"
log "unknwon.dev/clog/v2"
@ -83,7 +82,6 @@ func authenticate() macaron.Handler {
}
return
}
token.Updated = time.Now()
if err = db.AccessTokens.Save(token); err != nil {
log.Error("Failed to update access token: %v", err)
}

View File

@ -11,7 +11,6 @@ import (
"net/http/httptest"
"testing"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"gopkg.in/macaron.v1"
@ -124,9 +123,6 @@ func Test_authenticate(t *testing.T) {
return &db.AccessToken{}, nil
},
MockSave: func(t *db.AccessToken) error {
if t.Updated.IsZero() {
return errors.New("Updated is zero")
}
return nil
},
},

View File

@ -140,7 +140,6 @@ func HTTPContexter() macaron.Handler {
}
return
}
token.Updated = time.Now()
if err = db.AccessTokens.Save(token); err != nil {
log.Error("Failed to update access token: %v", err)
}

View File

@ -581,7 +581,7 @@ func SettingsApplications(c *context.Context) {
c.Title("settings.applications")
c.PageIs("SettingsApplications")
tokens, err := db.ListAccessTokens(c.User.ID)
tokens, err := db.AccessTokens.List(c.User.ID)
if err != nil {
c.Errorf(err, "list access tokens")
return
@ -596,7 +596,7 @@ func SettingsApplicationsPost(c *context.Context, f form.NewAccessToken) {
c.PageIs("SettingsApplications")
if c.HasError() {
tokens, err := db.ListAccessTokens(c.User.ID)
tokens, err := db.AccessTokens.List(c.User.ID)
if err != nil {
c.Errorf(err, "list access tokens")
return
@ -607,12 +607,9 @@ func SettingsApplicationsPost(c *context.Context, f form.NewAccessToken) {
return
}
t := &db.AccessToken{
UserID: c.User.ID,
Name: f.Name,
}
if err := db.NewAccessToken(t); err != nil {
if errors.IsAccessTokenNameAlreadyExist(err) {
t, err := db.AccessTokens.Create(c.User.ID, f.Name)
if err != nil {
if db.IsErrAccessTokenAlreadyExist(err) {
c.Flash.Error(c.Tr("settings.token_name_exists"))
c.RedirectSubpath("/user/settings/applications")
} else {
@ -627,7 +624,7 @@ func SettingsApplicationsPost(c *context.Context, f form.NewAccessToken) {
}
func SettingsDeleteApplication(c *context.Context) {
if err := db.DeleteAccessTokenOfUserByID(c.User.ID, c.QueryInt64("id")); err != nil {
if err := db.AccessTokens.DeleteByID(c.User.ID, c.QueryInt64("id")); err != nil {
c.Flash.Error("DeleteAccessTokenByID: " + err.Error())
} else {
c.Flash.Success(c.Tr("settings.delete_token_success"))

View File

@ -28,6 +28,7 @@ import (
)
// MD5Bytes encodes string to MD5 bytes.
// TODO: Move to hashutil.MD5Bytes.
func MD5Bytes(str string) []byte {
m := md5.New()
_, _ = m.Write([]byte(str))
@ -35,11 +36,13 @@ func MD5Bytes(str string) []byte {
}
// MD5 encodes string to MD5 hex value.
// TODO: Move to hashutil.MD5.
func MD5(str string) string {
return hex.EncodeToString(MD5Bytes(str))
}
// SHA1 encodes string to SHA1 hex value.
// TODO: Move to hashutil.SHA1.
func SHA1(str string) string {
h := sha1.New()
_, _ = h.Write([]byte(str))