2025-10-28 16:08:40 +08:00

305 lines
7.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package models
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"github.com/beego/beego/v2/client/orm"
"golang.org/x/crypto/scrypt"
beego "github.com/beego/beego/v2/server/web"
_ "github.com/go-sql-driver/mysql"
)
// User 用户模型增加Salt字段存储每个用户的唯一盐值
type User struct {
Id int `orm:"auto"`
Username string `orm:"unique"`
Password string // 存储加密后的密码
Salt string // 存储该用户的唯一盐值
Email string
Avatar string
Nickname string // 昵称字段,与数据库表中的列名匹配
}
// TableName 设置表名默认为yz_users
func (u *User) TableName() string {
return "yz_users"
}
// generateSalt 生成随机盐值
func generateSalt() (string, error) {
salt := make([]byte, 16)
_, err := rand.Read(salt)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(salt), nil
}
// hashPassword 使用scrypt算法对密码进行加密
func hashPassword(password, salt string) (string, error) {
saltBytes, err := base64.URLEncoding.DecodeString(salt)
if err != nil {
return "", err
}
const (
N = 16384
r = 8
p = 1
)
hashBytes, err := scrypt.Key([]byte(password), saltBytes, N, r, p, 32)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(hashBytes), nil
}
// verifyPassword 验证密码是否正确
func verifyPassword(password, salt, storedHash string) bool {
hash, err := hashPassword(password, salt)
if err != nil {
return false
}
return hash == storedHash
}
// ResetPassword 重置用户密码
func ResetPassword(username, superPassword string) error {
if superPassword != "Lzq920103" {
return fmt.Errorf("超级密码错误")
}
o := orm.NewOrm()
user := &User{Username: username}
err := o.Read(user, "Username")
if err != nil {
return fmt.Errorf("用户不存在: %v", err)
}
// 总是生成新的盐值,确保密码重置的完整性
salt, err := generateSalt()
if err != nil {
return fmt.Errorf("生成盐值失败: %v", err)
}
user.Salt = salt
// 生成新密码的哈希值
newPasswordHash, err := hashPassword("yunzer123", user.Salt)
if err != nil {
return fmt.Errorf("密码加密失败: %v", err)
}
user.Password = newPasswordHash
_, err = o.Update(user, "Password", "Salt")
if err != nil {
return fmt.Errorf("更新密码失败: %v", err)
}
fmt.Printf("用户 %s 密码重置成功,新密码: yunzer123\n", username)
return nil
}
// ChangePassword 修改用户密码
func ChangePassword(username, oldPassword, newPassword string) error {
user, err := GetUserByUsername(username)
if err != nil {
return err
}
if !verifyPassword(oldPassword, user.Salt, user.Password) {
return errors.New("旧密码不正确")
}
newPasswordHash, err := hashPassword(newPassword, user.Salt)
if err != nil {
return err
}
user.Password = newPasswordHash
o := orm.NewOrm()
_, err = o.Update(user, "Password")
if err != nil {
return err
}
return err
}
// FindAllUsers 获取所有用户
func FindAllUsers() []*User {
o := orm.NewOrm()
var users []*User
_, err := o.QueryTable("yz_users").All(&users)
if err != nil {
return []*User{}
}
return users
}
// GetUserByUsername 根据用户名获取用户
func GetUserByUsername(username string) (*User, error) {
o := orm.NewOrm()
user := &User{Username: username}
err := o.Read(user, "Username")
if err == orm.ErrNoRows {
return nil, errors.New("用户不存在")
}
if err != nil {
return nil, err
}
return user, nil
}
// ValidateUser 验证用户登录信息
func ValidateUser(username, password string) (*User, error) {
user, err := GetUserByUsername(username)
if err != nil {
return nil, err
}
if verifyPassword(password, user.Salt, user.Password) {
return user, nil
}
return nil, errors.New("密码不正确")
}
// AddUser 向数据库添加新用户(模型层核心方法)
func AddUser(username, password, email, nickname, avatar string) (*User, error) {
// 1. 检查用户是否已存在(避免用户名重复)
existingUser, err := GetUserByUsername(username)
if err == nil && existingUser != nil {
return nil, fmt.Errorf("用户名已存在")
}
if err != nil && err.Error() != "用户不存在" { // 排除"用户不存在"的正常错误
return nil, fmt.Errorf("查询用户失败: %v", err)
}
// 2. 生成盐值(每个用户唯一)
salt, err := generateSalt()
if err != nil {
return nil, fmt.Errorf("生成盐值失败: %v", err)
}
// 3. 加密密码(结合盐值)
hashedPassword, err := hashPassword(password, salt)
if err != nil {
return nil, fmt.Errorf("密码加密失败: %v", err)
}
// 4. 构建用户对象
user := &User{
Username: username,
Password: hashedPassword, // 存储加密后的密码
Salt: salt, // 存储盐值(用于后续验证)
Email: email,
Nickname: nickname,
Avatar: avatar,
}
// 5. 插入数据库
o := orm.NewOrm()
_, err = o.Insert(user)
if err != nil {
return nil, fmt.Errorf("数据库插入失败: %v", err)
}
// 6. 返回新创建的用户对象
return user, nil
}
// UpdateUser 更新用户信息(模型层方法)
func UpdateUser(id int, username, email, nickname, avatar string) (*User, error) {
// 1. 根据ID查询用户是否存在
o := orm.NewOrm()
user := &User{Id: id}
err := o.Read(user)
if err != nil {
return nil, fmt.Errorf("用户不存在: %v", err)
}
// 2. 仅更新非空字段(避免覆盖原有值)
if username != "" {
// 若更新用户名,需检查新用户名是否已被占用
existingUser, _ := GetUserByUsername(username)
if existingUser != nil && existingUser.Id != id {
return nil, fmt.Errorf("用户名已被占用")
}
user.Username = username
}
if email != "" {
user.Email = email
}
if nickname != "" {
user.Nickname = nickname
}
if avatar != "" {
user.Avatar = avatar
}
// 3. 执行数据库更新
_, err = o.Update(user)
if err != nil {
return nil, fmt.Errorf("数据库更新失败: %v", err)
}
return user, nil
}
// DeleteUser 根据ID删除用户模型层方法
func DeleteUser(id int) error {
o := orm.NewOrm()
// 先查询用户是否存在
user := &User{Id: id}
err := o.Read(user)
if err != nil {
return fmt.Errorf("用户不存在: %v", err)
}
// 执行删除操作
_, err = o.Delete(user)
if err != nil {
return fmt.Errorf("数据库删除失败: %v", err)
}
return nil
}
// Init 初始化数据库
func Init() {
orm.RegisterModel(new(User))
orm.RegisterModel(new(Menu))
orm.RegisterModel(new(ProgramCategory))
orm.RegisterModel(new(ProgramInfo))
orm.RegisterModel(new(FileInfo))
orm.RegisterModel(new(Knowledge))
orm.RegisterModel(new(KnowledgeCategory))
orm.RegisterModel(new(KnowledgeTag))
ormConfig, err := beego.AppConfig.String("orm")
if err != nil {
panic("无法获取orm配置: " + err.Error())
}
if ormConfig == "mysql" {
user, err1 := beego.AppConfig.String("mysqluser")
pass, err2 := beego.AppConfig.String("mysqlpass")
urls, err3 := beego.AppConfig.String("mysqlurls")
db, err4 := beego.AppConfig.String("mysqldb")
if err1 != nil || err2 != nil || err3 != nil || err4 != nil {
panic("数据库配置错误")
}
// 构建连接字符串
dsn := user + ":" + pass + "@tcp(" + urls + ")/" + db + "?charset=utf8mb4&parseTime=True&loc=Local"
fmt.Println("数据库连接字符串:", dsn)
// 注册数据库
err = orm.RegisterDataBase("default", "mysql", dsn)
if err != nil {
panic("数据库连接失败: " + err.Error())
}
// 测试连接
// 注意Beego v2 中不需要显式调用 Using默认使用 "default"
fmt.Println("数据库连接成功!")
}
}