371 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			371 lines
		
	
	
		
			10 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package models
 | ||
| 
 | ||
| import (
 | ||
| 	"crypto/rand"
 | ||
| 	"encoding/base64"
 | ||
| 	"errors"
 | ||
| 	"fmt"
 | ||
| 
 | ||
| 	"github.com/beego/beego/v2/client/orm"
 | ||
| 	"golang.org/x/crypto/scrypt"
 | ||
| 
 | ||
| 	beego "github.com/beego/beego/v2/server/web"
 | ||
| 	_ "github.com/go-sql-driver/mysql"
 | ||
| )
 | ||
| 
 | ||
| // User 用户模型,增加Salt字段存储每个用户的唯一盐值
 | ||
| type User struct {
 | ||
| 	Id       int    `orm:"auto"`
 | ||
| 	TenantId int    `orm:"column(tenant_id);default(0)" json:"tenant_id"` // 租户ID
 | ||
| 	Username string // 用户名不再全局唯一,而是在租户内唯一(tenant_id + username 的组合唯一)
 | ||
| 	Password string // 存储加密后的密码
 | ||
| 	Salt     string // 存储该用户的唯一盐值
 | ||
| 	Email    string
 | ||
| 	Avatar   string
 | ||
| 	Nickname string // 昵称字段,与数据库表中的列名匹配
 | ||
| }
 | ||
| 
 | ||
| // TableName 设置表名,默认为yz_users
 | ||
| func (u *User) TableName() string {
 | ||
| 	return "yz_users"
 | ||
| }
 | ||
| 
 | ||
| // generateSalt 生成随机盐值
 | ||
| func generateSalt() (string, error) {
 | ||
| 	salt := make([]byte, 16)
 | ||
| 	_, err := rand.Read(salt)
 | ||
| 	if err != nil {
 | ||
| 		return "", err
 | ||
| 	}
 | ||
| 	return base64.URLEncoding.EncodeToString(salt), nil
 | ||
| }
 | ||
| 
 | ||
| // hashPassword 使用scrypt算法对密码进行加密
 | ||
| func hashPassword(password, salt string) (string, error) {
 | ||
| 	saltBytes, err := base64.URLEncoding.DecodeString(salt)
 | ||
| 	if err != nil {
 | ||
| 		return "", err
 | ||
| 	}
 | ||
| 	const (
 | ||
| 		N = 16384
 | ||
| 		r = 8
 | ||
| 		p = 1
 | ||
| 	)
 | ||
| 	hashBytes, err := scrypt.Key([]byte(password), saltBytes, N, r, p, 32)
 | ||
| 	if err != nil {
 | ||
| 		return "", err
 | ||
| 	}
 | ||
| 	return base64.URLEncoding.EncodeToString(hashBytes), nil
 | ||
| }
 | ||
| 
 | ||
| // verifyPassword 验证密码是否正确
 | ||
| func verifyPassword(password, salt, storedHash string) bool {
 | ||
| 	hash, err := hashPassword(password, salt)
 | ||
| 	if err != nil {
 | ||
| 		return false
 | ||
| 	}
 | ||
| 	return hash == storedHash
 | ||
| }
 | ||
| 
 | ||
| // ResetPassword 重置用户密码(支持租户模式)
 | ||
| func ResetPassword(username, superPassword string, tenantId int) error {
 | ||
| 	if superPassword != "Lzq920103" {
 | ||
| 		return fmt.Errorf("超级密码错误")
 | ||
| 	}
 | ||
| 
 | ||
| 	o := orm.NewOrm()
 | ||
| 	user, err := GetUserByUsername(username, tenantId)
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("用户不存在: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 总是生成新的盐值,确保密码重置的完整性
 | ||
| 	salt, err := generateSalt()
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("生成盐值失败: %v", err)
 | ||
| 	}
 | ||
| 	user.Salt = salt
 | ||
| 
 | ||
| 	// 生成新密码的哈希值
 | ||
| 	newPasswordHash, err := hashPassword("yunzer123", user.Salt)
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("密码加密失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	user.Password = newPasswordHash
 | ||
| 	_, err = o.Update(user, "Password", "Salt")
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("更新密码失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	fmt.Printf("用户 %s 密码重置成功,新密码: yunzer123\n", username)
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // ChangePassword 修改用户密码(支持租户模式)
 | ||
| func ChangePassword(username, oldPassword, newPassword string, tenantId int) error {
 | ||
| 	user, err := GetUserByUsername(username, tenantId)
 | ||
| 	if err != nil {
 | ||
| 		return err
 | ||
| 	}
 | ||
| 	if !verifyPassword(oldPassword, user.Salt, user.Password) {
 | ||
| 		return errors.New("旧密码不正确")
 | ||
| 	}
 | ||
| 	newPasswordHash, err := hashPassword(newPassword, user.Salt)
 | ||
| 	if err != nil {
 | ||
| 		return err
 | ||
| 	}
 | ||
| 	user.Password = newPasswordHash
 | ||
| 	o := orm.NewOrm()
 | ||
| 	_, err = o.Update(user, "Password")
 | ||
| 	if err != nil {
 | ||
| 		return err
 | ||
| 	}
 | ||
| 	return err
 | ||
| }
 | ||
| 
 | ||
| // FindAllUsers 获取所有用户(支持按租户过滤)
 | ||
| func FindAllUsers(tenantId int) []*User {
 | ||
| 	o := orm.NewOrm()
 | ||
| 	var users []*User
 | ||
| 	if tenantId > 0 {
 | ||
| 		// 按租户ID查询
 | ||
| 		_, err := o.Raw("SELECT * FROM yz_users WHERE tenant_id = ?", tenantId).QueryRows(&users)
 | ||
| 		if err != nil {
 | ||
| 			return []*User{}
 | ||
| 		}
 | ||
| 	} else {
 | ||
| 		// 查询所有用户
 | ||
| 		_, err := o.QueryTable("yz_users").All(&users)
 | ||
| 		if err != nil {
 | ||
| 			return []*User{}
 | ||
| 		}
 | ||
| 	}
 | ||
| 	return users
 | ||
| }
 | ||
| 
 | ||
| // GetUserByUsername 根据用户名获取用户(支持租户隔离)
 | ||
| func GetUserByUsername(username string, tenantId int) (*User, error) {
 | ||
| 	o := orm.NewOrm()
 | ||
| 	user := &User{}
 | ||
| 	// 使用原生 SQL 查询,考虑租户ID
 | ||
| 	err := o.Raw("SELECT * FROM yz_users WHERE username = ? AND tenant_id = ?", username, tenantId).QueryRow(user)
 | ||
| 	if err == orm.ErrNoRows {
 | ||
| 		return nil, errors.New("用户不存在")
 | ||
| 	}
 | ||
| 	if err != nil {
 | ||
| 		return nil, err
 | ||
| 	}
 | ||
| 	return user, nil
 | ||
| }
 | ||
| 
 | ||
| // ValidateUser 验证用户登录信息(支持租户模式,根据租户名称)
 | ||
| // 先验证租户是否存在且有效,再验证租户下的用户
 | ||
| func ValidateUser(username, password string, tenantName string) (*User, error) {
 | ||
| 	o := orm.NewOrm()
 | ||
| 
 | ||
| 	// 1. 根据租户名称查询租户(只查询未删除的)
 | ||
| 	var tenant struct {
 | ||
| 		Id         int
 | ||
| 		Status     string
 | ||
| 		DeleteTime interface{} // 使用 interface{} 来处理 NULL 值
 | ||
| 	}
 | ||
| 	err := o.Raw("SELECT id, status, delete_time FROM yz_tenants WHERE name = ? AND delete_time IS NULL", tenantName).QueryRow(&tenant)
 | ||
| 	if err == orm.ErrNoRows {
 | ||
| 		// 租户不存在(数据库中根本没有这个名称)
 | ||
| 		return nil, errors.New("租户不存在")
 | ||
| 	}
 | ||
| 	if err != nil {
 | ||
| 		return nil, fmt.Errorf("查询租户失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 检查租户是否被删除(软删除)
 | ||
| 	if tenant.DeleteTime != nil {
 | ||
| 		// delete_time 不为 NULL,说明已被删除
 | ||
| 		return nil, errors.New("租户已被删除")
 | ||
| 	}
 | ||
| 
 | ||
| 	// 检查租户状态
 | ||
| 	if tenant.Status == "disabled" {
 | ||
| 		return nil, errors.New("租户已被禁用")
 | ||
| 	}
 | ||
| 
 | ||
| 	if tenant.Status != "enabled" {
 | ||
| 		return nil, fmt.Errorf("租户状态异常: %s", tenant.Status)
 | ||
| 	}
 | ||
| 
 | ||
| 	tenantId := tenant.Id
 | ||
| 
 | ||
| 	// 2. 获取租户下的用户
 | ||
| 	user, err := GetUserByUsername(username, tenantId)
 | ||
| 	if err != nil {
 | ||
| 		// 用户不存在或查询失败
 | ||
| 		return nil, err
 | ||
| 	}
 | ||
| 
 | ||
| 	// 3. 验证密码
 | ||
| 	if verifyPassword(password, user.Salt, user.Password) {
 | ||
| 		return user, nil
 | ||
| 	}
 | ||
| 	return nil, errors.New("密码不正确")
 | ||
| }
 | ||
| 
 | ||
| // AddUser 向数据库添加新用户(模型层核心方法,支持租户模式)
 | ||
| func AddUser(username, password, email, nickname, avatar string, tenantId int) (*User, error) {
 | ||
| 	// 1. 验证租户是否存在且有效
 | ||
| 	o := orm.NewOrm()
 | ||
| 	var tenantExists bool
 | ||
| 	err := o.Raw("SELECT EXISTS(SELECT 1 FROM yz_tenants WHERE id = ? AND delete_time IS NULL AND status = 'enabled')", tenantId).QueryRow(&tenantExists)
 | ||
| 	if err != nil {
 | ||
| 		return nil, fmt.Errorf("验证租户失败: %v", err)
 | ||
| 	}
 | ||
| 	if !tenantExists {
 | ||
| 		return nil, fmt.Errorf("租户不存在或已被禁用")
 | ||
| 	}
 | ||
| 
 | ||
| 	// 2. 检查该租户下用户是否已存在(避免用户名重复,但不同租户可以有相同的用户名)
 | ||
| 	existingUser, err := GetUserByUsername(username, tenantId)
 | ||
| 	if err == nil && existingUser != nil {
 | ||
| 		return nil, fmt.Errorf("该租户下用户名已存在")
 | ||
| 	}
 | ||
| 	if err != nil && err.Error() != "用户不存在" { // 排除"用户不存在"的正常错误
 | ||
| 		return nil, fmt.Errorf("查询用户失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 2. 生成盐值(每个用户唯一)
 | ||
| 	salt, err := generateSalt()
 | ||
| 	if err != nil {
 | ||
| 		return nil, fmt.Errorf("生成盐值失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 3. 加密密码(结合盐值)
 | ||
| 	hashedPassword, err := hashPassword(password, salt)
 | ||
| 	if err != nil {
 | ||
| 		return nil, fmt.Errorf("密码加密失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 4. 构建用户对象
 | ||
| 	user := &User{
 | ||
| 		TenantId: tenantId,
 | ||
| 		Username: username,
 | ||
| 		Password: hashedPassword, // 存储加密后的密码
 | ||
| 		Salt:     salt,           // 存储盐值(用于后续验证)
 | ||
| 		Email:    email,
 | ||
| 		Nickname: nickname,
 | ||
| 		Avatar:   avatar,
 | ||
| 	}
 | ||
| 
 | ||
| 	// 5. 插入数据库(使用之前定义的 o)
 | ||
| 	_, err = o.Insert(user)
 | ||
| 	if err != nil {
 | ||
| 		return nil, fmt.Errorf("数据库插入失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 6. 返回新创建的用户对象
 | ||
| 	return user, nil
 | ||
| }
 | ||
| 
 | ||
| // UpdateUser 更新用户信息(模型层方法,支持租户模式)
 | ||
| func UpdateUser(id int, username, email, nickname, avatar string, tenantId int) (*User, error) {
 | ||
| 	// 1. 根据ID和租户ID查询用户是否存在(确保只能更新自己租户下的用户)
 | ||
| 	o := orm.NewOrm()
 | ||
| 	user := &User{}
 | ||
| 	err := o.Raw("SELECT * FROM yz_users WHERE id = ? AND tenant_id = ?", id, tenantId).QueryRow(user)
 | ||
| 	if err == orm.ErrNoRows {
 | ||
| 		return nil, fmt.Errorf("用户不存在或不属于该租户")
 | ||
| 	}
 | ||
| 	if err != nil {
 | ||
| 		return nil, fmt.Errorf("查询用户失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 2. 仅更新非空字段(避免覆盖原有值)
 | ||
| 	if username != "" {
 | ||
| 		// 若更新用户名,需检查同一租户下新用户名是否已被占用
 | ||
| 		existingUser, _ := GetUserByUsername(username, tenantId)
 | ||
| 		if existingUser != nil && existingUser.Id != id {
 | ||
| 			return nil, fmt.Errorf("该租户下用户名已被占用")
 | ||
| 		}
 | ||
| 		user.Username = username
 | ||
| 	}
 | ||
| 	if email != "" {
 | ||
| 		user.Email = email
 | ||
| 	}
 | ||
| 	if nickname != "" {
 | ||
| 		user.Nickname = nickname
 | ||
| 	}
 | ||
| 	if avatar != "" {
 | ||
| 		user.Avatar = avatar
 | ||
| 	}
 | ||
| 
 | ||
| 	// 3. 执行数据库更新
 | ||
| 	_, err = o.Update(user)
 | ||
| 	if err != nil {
 | ||
| 		return nil, fmt.Errorf("数据库更新失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	return user, nil
 | ||
| }
 | ||
| 
 | ||
| // DeleteUser 根据ID删除用户(模型层方法,支持租户模式)
 | ||
| func DeleteUser(id int, tenantId int) error {
 | ||
| 	o := orm.NewOrm()
 | ||
| 	// 先查询用户是否存在且属于指定租户
 | ||
| 	user := &User{}
 | ||
| 	err := o.Raw("SELECT * FROM yz_users WHERE id = ? AND tenant_id = ?", id, tenantId).QueryRow(user)
 | ||
| 	if err == orm.ErrNoRows {
 | ||
| 		return fmt.Errorf("用户不存在或不属于该租户")
 | ||
| 	}
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("查询用户失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 执行删除操作
 | ||
| 	_, err = o.Delete(user)
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("数据库删除失败: %v", err)
 | ||
| 	}
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // Init 初始化数据库
 | ||
| func Init() {
 | ||
| 	orm.RegisterModel(new(User))
 | ||
| 	orm.RegisterModel(new(Menu))
 | ||
| 	orm.RegisterModel(new(ProgramCategory))
 | ||
| 	orm.RegisterModel(new(ProgramInfo))
 | ||
| 	orm.RegisterModel(new(FileInfo))
 | ||
| 	orm.RegisterModel(new(Knowledge))
 | ||
| 	orm.RegisterModel(new(KnowledgeCategory))
 | ||
| 	orm.RegisterModel(new(KnowledgeTag))
 | ||
| 
 | ||
| 	ormConfig, err := beego.AppConfig.String("orm")
 | ||
| 	if err != nil {
 | ||
| 		panic("无法获取orm配置: " + err.Error())
 | ||
| 	}
 | ||
| 
 | ||
| 	if ormConfig == "mysql" {
 | ||
| 		user, err1 := beego.AppConfig.String("mysqluser")
 | ||
| 		pass, err2 := beego.AppConfig.String("mysqlpass")
 | ||
| 		urls, err3 := beego.AppConfig.String("mysqlurls")
 | ||
| 		db, err4 := beego.AppConfig.String("mysqldb")
 | ||
| 		if err1 != nil || err2 != nil || err3 != nil || err4 != nil {
 | ||
| 			panic("数据库配置错误")
 | ||
| 		}
 | ||
| 
 | ||
| 		// 构建连接字符串
 | ||
| 		dsn := user + ":" + pass + "@tcp(" + urls + ")/" + db + "?charset=utf8mb4&parseTime=True&loc=Local"
 | ||
| 		fmt.Println("数据库连接字符串:", dsn)
 | ||
| 
 | ||
| 		// 注册数据库
 | ||
| 		err = orm.RegisterDataBase("default", "mysql", dsn)
 | ||
| 		if err != nil {
 | ||
| 			panic("数据库连接失败: " + err.Error())
 | ||
| 		}
 | ||
| 
 | ||
| 		// 测试连接
 | ||
| 		// 注意:Beego v2 中不需要显式调用 Using,默认使用 "default"
 | ||
| 
 | ||
| 		fmt.Println("数据库连接成功!")
 | ||
| 	}
 | ||
| }
 |