gogs/internal/db/follows.go

128 lines
3.3 KiB
Go

// Copyright 2022 The Gogs Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"context"
"github.com/pkg/errors"
"gorm.io/gorm"
)
// FollowsStore is the persistent interface for user follows.
//
// NOTE: All methods are sorted in alphabetical order.
type FollowsStore interface {
// Follow marks the user to follow the other user.
Follow(ctx context.Context, userID, followID int64) error
// IsFollowing returns true if the user is following the other user.
IsFollowing(ctx context.Context, userID, followID int64) bool
// Unfollow removes the mark the user to follow the other user.
Unfollow(ctx context.Context, userID, followID int64) error
}
var Follows FollowsStore
var _ FollowsStore = (*follows)(nil)
type follows struct {
*gorm.DB
}
// NewFollowsStore returns a persistent interface for user follows with given
// database connection.
func NewFollowsStore(db *gorm.DB) FollowsStore {
return &follows{DB: db}
}
func (*follows) updateFollowingCount(tx *gorm.DB, userID, followID int64) error {
/*
Equivalent SQL for PostgreSQL:
UPDATE "user"
SET num_followers = (
SELECT COUNT(*) FROM follow WHERE follow_id = @followID
)
WHERE id = @followID
*/
err := tx.Model(&User{}).
Where("id = ?", followID).
Update(
"num_followers",
tx.Model(&Follow{}).Select("COUNT(*)").Where("follow_id = ?", followID),
).
Error
if err != nil {
return errors.Wrap(err, `update "num_followers"`)
}
/*
Equivalent SQL for PostgreSQL:
UPDATE "user"
SET num_following = (
SELECT COUNT(*) FROM follow WHERE user_id = @userID
)
WHERE id = @userID
*/
err = tx.Model(&User{}).
Where("id = ?", userID).
Update(
"num_following",
tx.Model(&Follow{}).Select("COUNT(*)").Where("user_id = ?", userID),
).
Error
if err != nil {
return errors.Wrap(err, `update "num_following"`)
}
return nil
}
func (db *follows) Follow(ctx context.Context, userID, followID int64) error {
if userID == followID {
return nil
}
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
f := &Follow{
UserID: userID,
FollowID: followID,
}
result := tx.FirstOrCreate(f, f)
if result.Error != nil {
return errors.Wrap(result.Error, "upsert")
} else if result.RowsAffected <= 0 {
return nil // Relation already exists
}
return db.updateFollowingCount(tx, userID, followID)
})
}
func (db *follows) IsFollowing(ctx context.Context, userID, followID int64) bool {
return db.WithContext(ctx).Where("user_id = ? AND follow_id = ?", userID, followID).First(&Follow{}).Error == nil
}
func (db *follows) Unfollow(ctx context.Context, userID, followID int64) error {
if userID == followID {
return nil
}
return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
err := tx.Where("user_id = ? AND follow_id = ?", userID, followID).Delete(&Follow{}).Error
if err != nil {
return errors.Wrap(err, "delete")
}
return db.updateFollowingCount(tx, userID, followID)
})
}
// Follow represents relations of users and their followers.
type Follow struct {
ID int64 `gorm:"primaryKey"`
UserID int64 `xorm:"UNIQUE(follow)" gorm:"uniqueIndex:follow_user_follow_unique;not null"`
FollowID int64 `xorm:"UNIQUE(follow)" gorm:"uniqueIndex:follow_user_follow_unique;not null"`
}