2025-10-29 23:07:53 +08:00

371 lines
10 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package models
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"github.com/beego/beego/v2/client/orm"
"golang.org/x/crypto/scrypt"
beego "github.com/beego/beego/v2/server/web"
_ "github.com/go-sql-driver/mysql"
)
// User 用户模型增加Salt字段存储每个用户的唯一盐值
type User struct {
Id int `orm:"auto"`
TenantId int `orm:"column(tenant_id);default(0)" json:"tenant_id"` // 租户ID
Username string // 用户名不再全局唯一而是在租户内唯一tenant_id + username 的组合唯一)
Password string // 存储加密后的密码
Salt string // 存储该用户的唯一盐值
Email string
Avatar string
Nickname string // 昵称字段,与数据库表中的列名匹配
}
// TableName 设置表名默认为yz_users
func (u *User) TableName() string {
return "yz_users"
}
// generateSalt 生成随机盐值
func generateSalt() (string, error) {
salt := make([]byte, 16)
_, err := rand.Read(salt)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(salt), nil
}
// hashPassword 使用scrypt算法对密码进行加密
func hashPassword(password, salt string) (string, error) {
saltBytes, err := base64.URLEncoding.DecodeString(salt)
if err != nil {
return "", err
}
const (
N = 16384
r = 8
p = 1
)
hashBytes, err := scrypt.Key([]byte(password), saltBytes, N, r, p, 32)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(hashBytes), nil
}
// verifyPassword 验证密码是否正确
func verifyPassword(password, salt, storedHash string) bool {
hash, err := hashPassword(password, salt)
if err != nil {
return false
}
return hash == storedHash
}
// ResetPassword 重置用户密码(支持租户模式)
func ResetPassword(username, superPassword string, tenantId int) error {
if superPassword != "Lzq920103" {
return fmt.Errorf("超级密码错误")
}
o := orm.NewOrm()
user, err := GetUserByUsername(username, tenantId)
if err != nil {
return fmt.Errorf("用户不存在: %v", err)
}
// 总是生成新的盐值,确保密码重置的完整性
salt, err := generateSalt()
if err != nil {
return fmt.Errorf("生成盐值失败: %v", err)
}
user.Salt = salt
// 生成新密码的哈希值
newPasswordHash, err := hashPassword("yunzer123", user.Salt)
if err != nil {
return fmt.Errorf("密码加密失败: %v", err)
}
user.Password = newPasswordHash
_, err = o.Update(user, "Password", "Salt")
if err != nil {
return fmt.Errorf("更新密码失败: %v", err)
}
fmt.Printf("用户 %s 密码重置成功,新密码: yunzer123\n", username)
return nil
}
// ChangePassword 修改用户密码(支持租户模式)
func ChangePassword(username, oldPassword, newPassword string, tenantId int) error {
user, err := GetUserByUsername(username, tenantId)
if err != nil {
return err
}
if !verifyPassword(oldPassword, user.Salt, user.Password) {
return errors.New("旧密码不正确")
}
newPasswordHash, err := hashPassword(newPassword, user.Salt)
if err != nil {
return err
}
user.Password = newPasswordHash
o := orm.NewOrm()
_, err = o.Update(user, "Password")
if err != nil {
return err
}
return err
}
// FindAllUsers 获取所有用户(支持按租户过滤)
func FindAllUsers(tenantId int) []*User {
o := orm.NewOrm()
var users []*User
if tenantId > 0 {
// 按租户ID查询
_, err := o.Raw("SELECT * FROM yz_users WHERE tenant_id = ?", tenantId).QueryRows(&users)
if err != nil {
return []*User{}
}
} else {
// 查询所有用户
_, err := o.QueryTable("yz_users").All(&users)
if err != nil {
return []*User{}
}
}
return users
}
// GetUserByUsername 根据用户名获取用户(支持租户隔离)
func GetUserByUsername(username string, tenantId int) (*User, error) {
o := orm.NewOrm()
user := &User{}
// 使用原生 SQL 查询考虑租户ID
err := o.Raw("SELECT * FROM yz_users WHERE username = ? AND tenant_id = ?", username, tenantId).QueryRow(user)
if err == orm.ErrNoRows {
return nil, errors.New("用户不存在")
}
if err != nil {
return nil, err
}
return user, nil
}
// ValidateUser 验证用户登录信息(支持租户模式,根据租户名称)
// 先验证租户是否存在且有效,再验证租户下的用户
func ValidateUser(username, password string, tenantName string) (*User, error) {
o := orm.NewOrm()
// 1. 根据租户名称查询租户(只查询未删除的)
var tenant struct {
Id int
Status string
DeleteTime interface{} // 使用 interface{} 来处理 NULL 值
}
err := o.Raw("SELECT id, status, delete_time FROM yz_tenants WHERE name = ? AND delete_time IS NULL", tenantName).QueryRow(&tenant)
if err == orm.ErrNoRows {
// 租户不存在(数据库中根本没有这个名称)
return nil, errors.New("租户不存在")
}
if err != nil {
return nil, fmt.Errorf("查询租户失败: %v", err)
}
// 检查租户是否被删除(软删除)
if tenant.DeleteTime != nil {
// delete_time 不为 NULL说明已被删除
return nil, errors.New("租户已被删除")
}
// 检查租户状态
if tenant.Status == "disabled" {
return nil, errors.New("租户已被禁用")
}
if tenant.Status != "enabled" {
return nil, fmt.Errorf("租户状态异常: %s", tenant.Status)
}
tenantId := tenant.Id
// 2. 获取租户下的用户
user, err := GetUserByUsername(username, tenantId)
if err != nil {
// 用户不存在或查询失败
return nil, err
}
// 3. 验证密码
if verifyPassword(password, user.Salt, user.Password) {
return user, nil
}
return nil, errors.New("密码不正确")
}
// AddUser 向数据库添加新用户(模型层核心方法,支持租户模式)
func AddUser(username, password, email, nickname, avatar string, tenantId int) (*User, error) {
// 1. 验证租户是否存在且有效
o := orm.NewOrm()
var tenantExists bool
err := o.Raw("SELECT EXISTS(SELECT 1 FROM yz_tenants WHERE id = ? AND delete_time IS NULL AND status = 'enabled')", tenantId).QueryRow(&tenantExists)
if err != nil {
return nil, fmt.Errorf("验证租户失败: %v", err)
}
if !tenantExists {
return nil, fmt.Errorf("租户不存在或已被禁用")
}
// 2. 检查该租户下用户是否已存在(避免用户名重复,但不同租户可以有相同的用户名)
existingUser, err := GetUserByUsername(username, tenantId)
if err == nil && existingUser != nil {
return nil, fmt.Errorf("该租户下用户名已存在")
}
if err != nil && err.Error() != "用户不存在" { // 排除"用户不存在"的正常错误
return nil, fmt.Errorf("查询用户失败: %v", err)
}
// 2. 生成盐值(每个用户唯一)
salt, err := generateSalt()
if err != nil {
return nil, fmt.Errorf("生成盐值失败: %v", err)
}
// 3. 加密密码(结合盐值)
hashedPassword, err := hashPassword(password, salt)
if err != nil {
return nil, fmt.Errorf("密码加密失败: %v", err)
}
// 4. 构建用户对象
user := &User{
TenantId: tenantId,
Username: username,
Password: hashedPassword, // 存储加密后的密码
Salt: salt, // 存储盐值(用于后续验证)
Email: email,
Nickname: nickname,
Avatar: avatar,
}
// 5. 插入数据库(使用之前定义的 o
_, err = o.Insert(user)
if err != nil {
return nil, fmt.Errorf("数据库插入失败: %v", err)
}
// 6. 返回新创建的用户对象
return user, nil
}
// UpdateUser 更新用户信息(模型层方法,支持租户模式)
func UpdateUser(id int, username, email, nickname, avatar string, tenantId int) (*User, error) {
// 1. 根据ID和租户ID查询用户是否存在确保只能更新自己租户下的用户
o := orm.NewOrm()
user := &User{}
err := o.Raw("SELECT * FROM yz_users WHERE id = ? AND tenant_id = ?", id, tenantId).QueryRow(user)
if err == orm.ErrNoRows {
return nil, fmt.Errorf("用户不存在或不属于该租户")
}
if err != nil {
return nil, fmt.Errorf("查询用户失败: %v", err)
}
// 2. 仅更新非空字段(避免覆盖原有值)
if username != "" {
// 若更新用户名,需检查同一租户下新用户名是否已被占用
existingUser, _ := GetUserByUsername(username, tenantId)
if existingUser != nil && existingUser.Id != id {
return nil, fmt.Errorf("该租户下用户名已被占用")
}
user.Username = username
}
if email != "" {
user.Email = email
}
if nickname != "" {
user.Nickname = nickname
}
if avatar != "" {
user.Avatar = avatar
}
// 3. 执行数据库更新
_, err = o.Update(user)
if err != nil {
return nil, fmt.Errorf("数据库更新失败: %v", err)
}
return user, nil
}
// DeleteUser 根据ID删除用户模型层方法支持租户模式
func DeleteUser(id int, tenantId int) error {
o := orm.NewOrm()
// 先查询用户是否存在且属于指定租户
user := &User{}
err := o.Raw("SELECT * FROM yz_users WHERE id = ? AND tenant_id = ?", id, tenantId).QueryRow(user)
if err == orm.ErrNoRows {
return fmt.Errorf("用户不存在或不属于该租户")
}
if err != nil {
return fmt.Errorf("查询用户失败: %v", err)
}
// 执行删除操作
_, err = o.Delete(user)
if err != nil {
return fmt.Errorf("数据库删除失败: %v", err)
}
return nil
}
// Init 初始化数据库
func Init() {
orm.RegisterModel(new(User))
orm.RegisterModel(new(Menu))
orm.RegisterModel(new(ProgramCategory))
orm.RegisterModel(new(ProgramInfo))
orm.RegisterModel(new(FileInfo))
orm.RegisterModel(new(Knowledge))
orm.RegisterModel(new(KnowledgeCategory))
orm.RegisterModel(new(KnowledgeTag))
ormConfig, err := beego.AppConfig.String("orm")
if err != nil {
panic("无法获取orm配置: " + err.Error())
}
if ormConfig == "mysql" {
user, err1 := beego.AppConfig.String("mysqluser")
pass, err2 := beego.AppConfig.String("mysqlpass")
urls, err3 := beego.AppConfig.String("mysqlurls")
db, err4 := beego.AppConfig.String("mysqldb")
if err1 != nil || err2 != nil || err3 != nil || err4 != nil {
panic("数据库配置错误")
}
// 构建连接字符串
dsn := user + ":" + pass + "@tcp(" + urls + ")/" + db + "?charset=utf8mb4&parseTime=True&loc=Local"
fmt.Println("数据库连接字符串:", dsn)
// 注册数据库
err = orm.RegisterDataBase("default", "mysql", dsn)
if err != nil {
panic("数据库连接失败: " + err.Error())
}
// 测试连接
// 注意Beego v2 中不需要显式调用 Using默认使用 "default"
fmt.Println("数据库连接成功!")
}
}