302 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			302 lines
		
	
	
		
			7.4 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package models
 | ||
| 
 | ||
| import (
 | ||
| 	"crypto/rand"
 | ||
| 	"encoding/base64"
 | ||
| 	"errors"
 | ||
| 	"fmt"
 | ||
| 
 | ||
| 	"github.com/beego/beego/v2/client/orm"
 | ||
| 	"golang.org/x/crypto/scrypt"
 | ||
| 
 | ||
| 	beego "github.com/beego/beego/v2/server/web"
 | ||
| 	_ "github.com/go-sql-driver/mysql"
 | ||
| )
 | ||
| 
 | ||
| // User 用户模型,增加Salt字段存储每个用户的唯一盐值
 | ||
| type User struct {
 | ||
| 	Id       int    `orm:"auto"`
 | ||
| 	Username string `orm:"unique"`
 | ||
| 	Password string // 存储加密后的密码
 | ||
| 	Salt     string // 存储该用户的唯一盐值
 | ||
| 	Email    string
 | ||
| 	Avatar   string
 | ||
| 	Nickname string // 昵称字段,与数据库表中的列名匹配
 | ||
| }
 | ||
| 
 | ||
| // TableName 设置表名,默认为yz_users
 | ||
| func (u *User) TableName() string {
 | ||
| 	return "yz_users"
 | ||
| }
 | ||
| 
 | ||
| // generateSalt 生成随机盐值
 | ||
| func generateSalt() (string, error) {
 | ||
| 	salt := make([]byte, 16)
 | ||
| 	_, err := rand.Read(salt)
 | ||
| 	if err != nil {
 | ||
| 		return "", err
 | ||
| 	}
 | ||
| 	return base64.URLEncoding.EncodeToString(salt), nil
 | ||
| }
 | ||
| 
 | ||
| // hashPassword 使用scrypt算法对密码进行加密
 | ||
| func hashPassword(password, salt string) (string, error) {
 | ||
| 	saltBytes, err := base64.URLEncoding.DecodeString(salt)
 | ||
| 	if err != nil {
 | ||
| 		return "", err
 | ||
| 	}
 | ||
| 	const (
 | ||
| 		N = 16384
 | ||
| 		r = 8
 | ||
| 		p = 1
 | ||
| 	)
 | ||
| 	hashBytes, err := scrypt.Key([]byte(password), saltBytes, N, r, p, 32)
 | ||
| 	if err != nil {
 | ||
| 		return "", err
 | ||
| 	}
 | ||
| 	return base64.URLEncoding.EncodeToString(hashBytes), nil
 | ||
| }
 | ||
| 
 | ||
| // verifyPassword 验证密码是否正确
 | ||
| func verifyPassword(password, salt, storedHash string) bool {
 | ||
| 	hash, err := hashPassword(password, salt)
 | ||
| 	if err != nil {
 | ||
| 		return false
 | ||
| 	}
 | ||
| 	return hash == storedHash
 | ||
| }
 | ||
| 
 | ||
| // ResetPassword 重置用户密码
 | ||
| func ResetPassword(username, superPassword string) error {
 | ||
| 	if superPassword != "Lzq920103" {
 | ||
| 		return fmt.Errorf("超级密码错误")
 | ||
| 	}
 | ||
| 
 | ||
| 	o := orm.NewOrm()
 | ||
| 	user := &User{Username: username}
 | ||
| 	err := o.Read(user, "Username")
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("用户不存在: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 总是生成新的盐值,确保密码重置的完整性
 | ||
| 	salt, err := generateSalt()
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("生成盐值失败: %v", err)
 | ||
| 	}
 | ||
| 	user.Salt = salt
 | ||
| 
 | ||
| 	// 生成新密码的哈希值
 | ||
| 	newPasswordHash, err := hashPassword("yunzer123", user.Salt)
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("密码加密失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	user.Password = newPasswordHash
 | ||
| 	_, err = o.Update(user, "Password", "Salt")
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("更新密码失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	fmt.Printf("用户 %s 密码重置成功,新密码: yunzer123\n", username)
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // ChangePassword 修改用户密码
 | ||
| func ChangePassword(username, oldPassword, newPassword string) error {
 | ||
| 	user, err := GetUserByUsername(username)
 | ||
| 	if err != nil {
 | ||
| 		return err
 | ||
| 	}
 | ||
| 	if !verifyPassword(oldPassword, user.Salt, user.Password) {
 | ||
| 		return errors.New("旧密码不正确")
 | ||
| 	}
 | ||
| 	newPasswordHash, err := hashPassword(newPassword, user.Salt)
 | ||
| 	if err != nil {
 | ||
| 		return err
 | ||
| 	}
 | ||
| 	user.Password = newPasswordHash
 | ||
| 	o := orm.NewOrm()
 | ||
| 	_, err = o.Update(user, "Password")
 | ||
| 	if err != nil {
 | ||
| 		return err
 | ||
| 	}
 | ||
| 	return err
 | ||
| }
 | ||
| 
 | ||
| // FindAllUsers 获取所有用户
 | ||
| func FindAllUsers() []*User {
 | ||
| 	o := orm.NewOrm()
 | ||
| 	var users []*User
 | ||
| 	_, err := o.QueryTable("yz_users").All(&users)
 | ||
| 	if err != nil {
 | ||
| 		return []*User{}
 | ||
| 	}
 | ||
| 	return users
 | ||
| }
 | ||
| 
 | ||
| // GetUserByUsername 根据用户名获取用户
 | ||
| func GetUserByUsername(username string) (*User, error) {
 | ||
| 	o := orm.NewOrm()
 | ||
| 	user := &User{Username: username}
 | ||
| 	err := o.Read(user, "Username")
 | ||
| 	if err == orm.ErrNoRows {
 | ||
| 		return nil, errors.New("用户不存在")
 | ||
| 	}
 | ||
| 	if err != nil {
 | ||
| 		return nil, err
 | ||
| 	}
 | ||
| 	return user, nil
 | ||
| }
 | ||
| 
 | ||
| // ValidateUser 验证用户登录信息
 | ||
| func ValidateUser(username, password string) (*User, error) {
 | ||
| 	user, err := GetUserByUsername(username)
 | ||
| 	if err != nil {
 | ||
| 		return nil, err
 | ||
| 	}
 | ||
| 	if verifyPassword(password, user.Salt, user.Password) {
 | ||
| 		return user, nil
 | ||
| 	}
 | ||
| 	return nil, errors.New("密码不正确")
 | ||
| }
 | ||
| 
 | ||
| // AddUser 向数据库添加新用户(模型层核心方法)
 | ||
| func AddUser(username, password, email, nickname, avatar string) (*User, error) {
 | ||
| 	// 1. 检查用户是否已存在(避免用户名重复)
 | ||
| 	existingUser, err := GetUserByUsername(username)
 | ||
| 	if err == nil && existingUser != nil {
 | ||
| 		return nil, fmt.Errorf("用户名已存在")
 | ||
| 	}
 | ||
| 	if err != nil && err.Error() != "用户不存在" { // 排除"用户不存在"的正常错误
 | ||
| 		return nil, fmt.Errorf("查询用户失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 2. 生成盐值(每个用户唯一)
 | ||
| 	salt, err := generateSalt()
 | ||
| 	if err != nil {
 | ||
| 		return nil, fmt.Errorf("生成盐值失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 3. 加密密码(结合盐值)
 | ||
| 	hashedPassword, err := hashPassword(password, salt)
 | ||
| 	if err != nil {
 | ||
| 		return nil, fmt.Errorf("密码加密失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 4. 构建用户对象
 | ||
| 	user := &User{
 | ||
| 		Username: username,
 | ||
| 		Password: hashedPassword, // 存储加密后的密码
 | ||
| 		Salt:     salt,           // 存储盐值(用于后续验证)
 | ||
| 		Email:    email,
 | ||
| 		Nickname: nickname,
 | ||
| 		Avatar:   avatar,
 | ||
| 	}
 | ||
| 
 | ||
| 	// 5. 插入数据库
 | ||
| 	o := orm.NewOrm()
 | ||
| 	_, err = o.Insert(user)
 | ||
| 	if err != nil {
 | ||
| 		return nil, fmt.Errorf("数据库插入失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 6. 返回新创建的用户对象
 | ||
| 	return user, nil
 | ||
| }
 | ||
| 
 | ||
| // UpdateUser 更新用户信息(模型层方法)
 | ||
| func UpdateUser(id int, username, email, nickname, avatar string) (*User, error) {
 | ||
| 	// 1. 根据ID查询用户是否存在
 | ||
| 	o := orm.NewOrm()
 | ||
| 	user := &User{Id: id}
 | ||
| 	err := o.Read(user)
 | ||
| 	if err != nil {
 | ||
| 		return nil, fmt.Errorf("用户不存在: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 2. 仅更新非空字段(避免覆盖原有值)
 | ||
| 	if username != "" {
 | ||
| 		// 若更新用户名,需检查新用户名是否已被占用
 | ||
| 		existingUser, _ := GetUserByUsername(username)
 | ||
| 		if existingUser != nil && existingUser.Id != id {
 | ||
| 			return nil, fmt.Errorf("用户名已被占用")
 | ||
| 		}
 | ||
| 		user.Username = username
 | ||
| 	}
 | ||
| 	if email != "" {
 | ||
| 		user.Email = email
 | ||
| 	}
 | ||
| 	if nickname != "" {
 | ||
| 		user.Nickname = nickname
 | ||
| 	}
 | ||
| 	if avatar != "" {
 | ||
| 		user.Avatar = avatar
 | ||
| 	}
 | ||
| 
 | ||
| 	// 3. 执行数据库更新
 | ||
| 	_, err = o.Update(user)
 | ||
| 	if err != nil {
 | ||
| 		return nil, fmt.Errorf("数据库更新失败: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	return user, nil
 | ||
| }
 | ||
| 
 | ||
| // DeleteUser 根据ID删除用户(模型层方法)
 | ||
| func DeleteUser(id int) error {
 | ||
| 	o := orm.NewOrm()
 | ||
| 	// 先查询用户是否存在
 | ||
| 	user := &User{Id: id}
 | ||
| 	err := o.Read(user)
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("用户不存在: %v", err)
 | ||
| 	}
 | ||
| 
 | ||
| 	// 执行删除操作
 | ||
| 	_, err = o.Delete(user)
 | ||
| 	if err != nil {
 | ||
| 		return fmt.Errorf("数据库删除失败: %v", err)
 | ||
| 	}
 | ||
| 	return nil
 | ||
| }
 | ||
| 
 | ||
| // Init 初始化数据库
 | ||
| func Init() {
 | ||
| 	orm.RegisterModel(new(User))
 | ||
| 	orm.RegisterModel(new(Menu))
 | ||
| 	orm.RegisterModel(new(ProgramCategory))
 | ||
| 	orm.RegisterModel(new(ProgramInfo))
 | ||
| 	orm.RegisterModel(new(FileInfo))
 | ||
| 
 | ||
| 	ormConfig, err := beego.AppConfig.String("orm")
 | ||
| 	if err != nil {
 | ||
| 		panic("无法获取orm配置: " + err.Error())
 | ||
| 	}
 | ||
| 
 | ||
| 	if ormConfig == "mysql" {
 | ||
| 		user, err1 := beego.AppConfig.String("mysqluser")
 | ||
| 		pass, err2 := beego.AppConfig.String("mysqlpass")
 | ||
| 		urls, err3 := beego.AppConfig.String("mysqlurls")
 | ||
| 		db, err4 := beego.AppConfig.String("mysqldb")
 | ||
| 		if err1 != nil || err2 != nil || err3 != nil || err4 != nil {
 | ||
| 			panic("数据库配置错误")
 | ||
| 		}
 | ||
| 
 | ||
| 		// 构建连接字符串
 | ||
| 		dsn := user + ":" + pass + "@tcp(" + urls + ")/" + db + "?charset=utf8mb4&parseTime=True&loc=Local"
 | ||
| 		fmt.Println("数据库连接字符串:", dsn)
 | ||
| 
 | ||
| 		// 注册数据库
 | ||
| 		err = orm.RegisterDataBase("default", "mysql", dsn)
 | ||
| 		if err != nil {
 | ||
| 			panic("数据库连接失败: " + err.Error())
 | ||
| 		}
 | ||
| 
 | ||
| 		// 测试连接
 | ||
| 		// 注意:Beego v2 中不需要显式调用 Using,默认使用 "default"
 | ||
| 
 | ||
| 		fmt.Println("数据库连接成功!")
 | ||
| 	}
 | ||
| }
 |