Files
2026-04-10 18:25:21 +03:30

431 lines
11 KiB
Go

package auth
import (
"context"
"errors"
"github.com/google/uuid"
"go.uber.org/fx"
"gorm.io/gorm"
domainAuth "base/internal/domain/auth"
)
type userRepository struct {
db *gorm.DB
}
func NewUserRepository(lc fx.Lifecycle, db *gorm.DB) domainAuth.UserRepository {
lc.Append(
fx.Hook{
OnStart: func(ctx context.Context) error {
return nil
},
OnStop: func(ctx context.Context) error {
return nil
},
})
return &userRepository{db: db}
}
func (r *userRepository) Create(ctx context.Context, user *domainAuth.User) error {
model := toUserModel(user)
if err := r.db.WithContext(ctx).Create(model).Error; err != nil {
return err
}
copyUserFromModel(user, model)
return nil
}
func (r *userRepository) CreateWithAccount(ctx context.Context, user *domainAuth.User, account *domainAuth.Account) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Create user within transaction
userModel := toUserModel(user)
if err := tx.WithContext(ctx).Create(userModel).Error; err != nil {
return err
}
copyUserFromModel(user, userModel)
// Create account within transaction
accountModel := toAccountModel(account)
if err := tx.WithContext(ctx).Create(accountModel).Error; err != nil {
return err
}
copyAccountFromModel(account, accountModel)
return nil
})
}
func (r *userRepository) UpsertWithAccount(ctx context.Context, email string, user *domainAuth.User, account *domainAuth.Account) (bool, error) {
isNewUser := false
err := r.db.WithContext(ctx).Transaction(
func(tx *gorm.DB) error {
// Check if user exists by email
var existingUserModel UserModel
err := tx.WithContext(ctx).Where("email = ?", email).First(&existingUserModel).Error
if err != nil {
if !errors.Is(err, gorm.ErrRecordNotFound) {
return err
}
isNewUser = true
userModel := toUserModel(user)
if err = tx.WithContext(ctx).Create(userModel).Error; err != nil {
return err
}
copyUserFromModel(user, userModel)
account.UserID = user.ID
// Create account for new user
accountModel := toAccountModel(account)
if err = tx.WithContext(ctx).Create(accountModel).Error; err != nil {
return err
}
copyAccountFromModel(account, accountModel)
}
// TODO: check no error if user exist because in find user accounts we use user.ID
if !isNewUser {
// Load all accounts for this user to check if one with this provider exists
var existingAccountModel AccountModel
findAccountsErr := tx.WithContext(ctx).
Where("user_id = ? AND provider = ?", user.ID, int(account.Provider)).
First(&existingAccountModel).Error
if findAccountsErr != nil {
if !errors.Is(findAccountsErr, gorm.ErrRecordNotFound) {
return findAccountsErr
}
accountModel := toAccountModel(account)
if err = tx.WithContext(ctx).Create(accountModel).Error; err != nil {
return err
}
copyAccountFromModel(account, accountModel)
return nil
}
accountModel := toAccountModel(account)
updateAccountErr := tx.WithContext(ctx).
Model(&AccountModel{}).
Where("id = ?", existingAccountModel.ID).
Updates(accountModel).Error
if updateAccountErr != nil {
return updateAccountErr
}
copyAccountFromModel(account, accountModel)
}
return nil
})
return isNewUser, err
}
func (r *userRepository) FindByID(ctx context.Context, id uuid.UUID, opts ...domainAuth.UserQueryOption) (*domainAuth.User, error) {
// Parse query options
options := &domainAuth.UserQueryOptions{}
for _, opt := range opts {
opt(options)
}
var model UserModel
if err := r.db.WithContext(ctx).Where("id = ?", id).First(&model).Error; err != nil {
return nil, err
}
user := toUserDomain(&model)
// Conditionally load relations based on options
if options.LoadRoles {
roles, err := r.loadUserRoles(ctx, id)
if err != nil {
return nil, err
}
user.Roles = roles
}
if options.LoadAccounts {
accounts, err := r.loadUserAccounts(ctx, id)
if err != nil {
return nil, err
}
user.Accounts = accounts
}
return user, nil
}
func (r *userRepository) FindByEmail(ctx context.Context, email string, opts ...domainAuth.UserQueryOption) (*domainAuth.User, error) {
// Parse query options
options := &domainAuth.UserQueryOptions{}
for _, opt := range opts {
opt(options)
}
var model UserModel
if err := r.db.WithContext(ctx).Where("email = ?", email).First(&model).Error; err != nil {
return nil, err
}
user := toUserDomain(&model)
// Conditionally load relations based on options
if options.LoadRoles {
roles, err := r.loadUserRoles(ctx, user.ID)
if err != nil {
return nil, err
}
user.Roles = roles
} else {
user.Roles = []domainAuth.Role{}
}
if options.LoadAccounts {
accounts, err := r.loadUserAccounts(ctx, user.ID)
if err != nil {
return nil, err
}
user.Accounts = accounts
} else {
user.Accounts = []domainAuth.Account{}
}
return user, nil
}
func (r *userRepository) Update(ctx context.Context, user *domainAuth.User) error {
model := toUserModel(user)
return r.db.WithContext(ctx).Model(&UserModel{}).Where("id = ?", user.ID).Updates(model).Error
}
func (r *userRepository) Delete(ctx context.Context, id uuid.UUID) error {
return r.db.WithContext(ctx).Delete(&UserModel{}, "id = ?", id).Error
}
func (r *userRepository) List(ctx context.Context, limit, offset int, opts ...domainAuth.UserQueryOption) ([]*domainAuth.User, error) {
// Parse query options
options := &domainAuth.UserQueryOptions{}
for _, opt := range opts {
opt(options)
}
var models []UserModel
if err := r.db.WithContext(ctx).Limit(limit).Offset(offset).Find(&models).Error; err != nil {
return nil, err
}
if len(models) == 0 {
return []*domainAuth.User{}, nil
}
users := make([]*domainAuth.User, len(models))
userIDs := make([]uuid.UUID, len(models))
for i, model := range models {
users[i] = toUserDomain(&model)
userIDs[i] = users[i].ID
}
// Batch load relations if requested
if options.LoadRoles {
rolesMap, err := r.loadUsersRoles(ctx, userIDs)
if err != nil {
return nil, err
}
for _, user := range users {
if roles, ok := rolesMap[user.ID]; ok {
user.Roles = roles
} else {
user.Roles = []domainAuth.Role{}
}
}
} else {
for _, user := range users {
user.Roles = []domainAuth.Role{}
}
}
if options.LoadAccounts {
accountsMap, err := r.loadUsersAccounts(ctx, userIDs)
if err != nil {
return nil, err
}
for _, user := range users {
if accounts, ok := accountsMap[user.ID]; ok {
user.Accounts = accounts
} else {
user.Accounts = []domainAuth.Account{}
}
}
} else {
for _, user := range users {
user.Accounts = []domainAuth.Account{}
}
}
return users, nil
}
func (r *userRepository) Count(ctx context.Context) (int64, error) {
var count int64
if err := r.db.WithContext(ctx).Model(&UserModel{}).Count(&count).Error; err != nil {
return 0, err
}
return count, nil
}
// loadUserRoles loads roles for a single user
func (r *userRepository) loadUserRoles(ctx context.Context, userID uuid.UUID) ([]domainAuth.Role, error) {
var roleModels []RoleModel
if err := r.db.WithContext(ctx).
Table("roles").
Joins("INNER JOIN user_roles ON roles.id = user_roles.role_id").
Where("user_roles.user_id = ? AND user_roles.deleted_at IS NULL AND roles.deleted_at IS NULL", userID).
Find(&roleModels).Error; err != nil {
return nil, err
}
roles := make([]domainAuth.Role, len(roleModels))
for i, model := range roleModels {
role := toRoleDomain(&model)
roles[i] = *role
}
return roles, nil
}
func (r *userRepository) UserRoles(ctx context.Context, userID uuid.UUID) ([]domainAuth.Role, error) {
var roleModels []RoleModel
if err := r.db.WithContext(ctx).
Table("roles").
Joins("INNER JOIN user_roles ON roles.id = user_roles.role_id").
Where("user_roles.user_id = ? AND user_roles.deleted_at IS NULL AND roles.deleted_at IS NULL", userID).
Find(&roleModels).Error; err != nil {
return nil, err
}
roles := make([]domainAuth.Role, len(roleModels))
for i, model := range roleModels {
role := toRoleDomain(&model)
roles[i] = *role
}
return roles, nil
}
func (r *userRepository) loadUserAccounts(ctx context.Context, userID uuid.UUID) ([]domainAuth.Account, error) {
var accountModels []AccountModel
if err := r.db.WithContext(ctx).
Where("user_id = ?", userID).
Find(&accountModels).Error; err != nil {
return nil, err
}
accounts := make([]domainAuth.Account, len(accountModels))
for i, model := range accountModels {
account := toAccountDomain(&model)
accounts[i] = *account
}
return accounts, nil
}
func (r *userRepository) UserAccounts(ctx context.Context, userID uuid.UUID) ([]domainAuth.Account, error) {
var accountModels []AccountModel
if err := r.db.WithContext(ctx).
Where("user_id = ?", userID).
Find(&accountModels).Error; err != nil {
return nil, err
}
accounts := make([]domainAuth.Account, len(accountModels))
for i, model := range accountModels {
account := toAccountDomain(&model)
accounts[i] = *account
}
return accounts, nil
}
func (r *userRepository) loadUsersRoles(ctx context.Context, userIDs []uuid.UUID) (map[uuid.UUID][]domainAuth.Role, error) {
if len(userIDs) == 0 {
return make(map[uuid.UUID][]domainAuth.Role), nil
}
var userRoles []struct {
UserID uuid.UUID `gorm:"column:user_id"`
RoleID uuid.UUID `gorm:"column:role_id"`
}
if err := r.db.WithContext(ctx).
Table("user_roles").
Select("user_id, role_id").
Where("user_id IN ? AND deleted_at IS NULL", userIDs).
Find(&userRoles).Error; err != nil {
return nil, err
}
if len(userRoles) == 0 {
return make(map[uuid.UUID][]domainAuth.Role), nil
}
roleIDs := make([]uuid.UUID, 0, len(userRoles))
for _, ur := range userRoles {
roleIDs = append(roleIDs, ur.RoleID)
}
var roleModels []RoleModel
if err := r.db.WithContext(ctx).
Where("id IN ? AND deleted_at IS NULL", roleIDs).
Find(&roleModels).Error; err != nil {
return nil, err
}
// Create a map of role_id -> role
rolesByID := make(map[uuid.UUID]*domainAuth.Role)
for i := range roleModels {
role := toRoleDomain(&roleModels[i])
rolesByID[role.ID] = role
}
// Group roles by user_id
rolesMap := make(map[uuid.UUID][]domainAuth.Role)
for _, ur := range userRoles {
if role, ok := rolesByID[ur.RoleID]; ok {
rolesMap[ur.UserID] = append(rolesMap[ur.UserID], *role)
}
}
return rolesMap, nil
}
func (r *userRepository) loadUsersAccounts(ctx context.Context, userIDs []uuid.UUID) (map[uuid.UUID][]domainAuth.Account, error) {
if len(userIDs) == 0 {
return make(map[uuid.UUID][]domainAuth.Account), nil
}
var accountModels []AccountModel
if err := r.db.WithContext(ctx).
Where("user_id IN ?", userIDs).
Find(&accountModels).Error; err != nil {
return nil, err
}
accountsMap := make(map[uuid.UUID][]domainAuth.Account)
for _, model := range accountModels {
account := toAccountDomain(&model)
accountsMap[model.UserID] = append(accountsMap[model.UserID], *account)
}
return accountsMap, nil
}