305 lines
7.5 KiB
Go
305 lines
7.5 KiB
Go
package models
|
||
|
||
import (
|
||
"crypto/rand"
|
||
"encoding/base64"
|
||
"errors"
|
||
"fmt"
|
||
|
||
"github.com/beego/beego/v2/client/orm"
|
||
"golang.org/x/crypto/scrypt"
|
||
|
||
beego "github.com/beego/beego/v2/server/web"
|
||
_ "github.com/go-sql-driver/mysql"
|
||
)
|
||
|
||
// User 用户模型,增加Salt字段存储每个用户的唯一盐值
|
||
type User struct {
|
||
Id int `orm:"auto"`
|
||
Username string `orm:"unique"`
|
||
Password string // 存储加密后的密码
|
||
Salt string // 存储该用户的唯一盐值
|
||
Email string
|
||
Avatar string
|
||
Nickname string // 昵称字段,与数据库表中的列名匹配
|
||
}
|
||
|
||
// TableName 设置表名,默认为yz_users
|
||
func (u *User) TableName() string {
|
||
return "yz_users"
|
||
}
|
||
|
||
// generateSalt 生成随机盐值
|
||
func generateSalt() (string, error) {
|
||
salt := make([]byte, 16)
|
||
_, err := rand.Read(salt)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return base64.URLEncoding.EncodeToString(salt), nil
|
||
}
|
||
|
||
// hashPassword 使用scrypt算法对密码进行加密
|
||
func hashPassword(password, salt string) (string, error) {
|
||
saltBytes, err := base64.URLEncoding.DecodeString(salt)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
const (
|
||
N = 16384
|
||
r = 8
|
||
p = 1
|
||
)
|
||
hashBytes, err := scrypt.Key([]byte(password), saltBytes, N, r, p, 32)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return base64.URLEncoding.EncodeToString(hashBytes), nil
|
||
}
|
||
|
||
// verifyPassword 验证密码是否正确
|
||
func verifyPassword(password, salt, storedHash string) bool {
|
||
hash, err := hashPassword(password, salt)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
return hash == storedHash
|
||
}
|
||
|
||
// ResetPassword 重置用户密码
|
||
func ResetPassword(username, superPassword string) error {
|
||
if superPassword != "Lzq920103" {
|
||
return fmt.Errorf("超级密码错误")
|
||
}
|
||
|
||
o := orm.NewOrm()
|
||
user := &User{Username: username}
|
||
err := o.Read(user, "Username")
|
||
if err != nil {
|
||
return fmt.Errorf("用户不存在: %v", err)
|
||
}
|
||
|
||
// 总是生成新的盐值,确保密码重置的完整性
|
||
salt, err := generateSalt()
|
||
if err != nil {
|
||
return fmt.Errorf("生成盐值失败: %v", err)
|
||
}
|
||
user.Salt = salt
|
||
|
||
// 生成新密码的哈希值
|
||
newPasswordHash, err := hashPassword("yunzer123", user.Salt)
|
||
if err != nil {
|
||
return fmt.Errorf("密码加密失败: %v", err)
|
||
}
|
||
|
||
user.Password = newPasswordHash
|
||
_, err = o.Update(user, "Password", "Salt")
|
||
if err != nil {
|
||
return fmt.Errorf("更新密码失败: %v", err)
|
||
}
|
||
|
||
fmt.Printf("用户 %s 密码重置成功,新密码: yunzer123\n", username)
|
||
return nil
|
||
}
|
||
|
||
// ChangePassword 修改用户密码
|
||
func ChangePassword(username, oldPassword, newPassword string) error {
|
||
user, err := GetUserByUsername(username)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if !verifyPassword(oldPassword, user.Salt, user.Password) {
|
||
return errors.New("旧密码不正确")
|
||
}
|
||
newPasswordHash, err := hashPassword(newPassword, user.Salt)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
user.Password = newPasswordHash
|
||
o := orm.NewOrm()
|
||
_, err = o.Update(user, "Password")
|
||
if err != nil {
|
||
return err
|
||
}
|
||
return err
|
||
}
|
||
|
||
// FindAllUsers 获取所有用户
|
||
func FindAllUsers() []*User {
|
||
o := orm.NewOrm()
|
||
var users []*User
|
||
_, err := o.QueryTable("yz_users").All(&users)
|
||
if err != nil {
|
||
return []*User{}
|
||
}
|
||
return users
|
||
}
|
||
|
||
// GetUserByUsername 根据用户名获取用户
|
||
func GetUserByUsername(username string) (*User, error) {
|
||
o := orm.NewOrm()
|
||
user := &User{Username: username}
|
||
err := o.Read(user, "Username")
|
||
if err == orm.ErrNoRows {
|
||
return nil, errors.New("用户不存在")
|
||
}
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return user, nil
|
||
}
|
||
|
||
// ValidateUser 验证用户登录信息
|
||
func ValidateUser(username, password string) (*User, error) {
|
||
user, err := GetUserByUsername(username)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if verifyPassword(password, user.Salt, user.Password) {
|
||
return user, nil
|
||
}
|
||
return nil, errors.New("密码不正确")
|
||
}
|
||
|
||
// AddUser 向数据库添加新用户(模型层核心方法)
|
||
func AddUser(username, password, email, nickname, avatar string) (*User, error) {
|
||
// 1. 检查用户是否已存在(避免用户名重复)
|
||
existingUser, err := GetUserByUsername(username)
|
||
if err == nil && existingUser != nil {
|
||
return nil, fmt.Errorf("用户名已存在")
|
||
}
|
||
if err != nil && err.Error() != "用户不存在" { // 排除"用户不存在"的正常错误
|
||
return nil, fmt.Errorf("查询用户失败: %v", err)
|
||
}
|
||
|
||
// 2. 生成盐值(每个用户唯一)
|
||
salt, err := generateSalt()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("生成盐值失败: %v", err)
|
||
}
|
||
|
||
// 3. 加密密码(结合盐值)
|
||
hashedPassword, err := hashPassword(password, salt)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("密码加密失败: %v", err)
|
||
}
|
||
|
||
// 4. 构建用户对象
|
||
user := &User{
|
||
Username: username,
|
||
Password: hashedPassword, // 存储加密后的密码
|
||
Salt: salt, // 存储盐值(用于后续验证)
|
||
Email: email,
|
||
Nickname: nickname,
|
||
Avatar: avatar,
|
||
}
|
||
|
||
// 5. 插入数据库
|
||
o := orm.NewOrm()
|
||
_, err = o.Insert(user)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("数据库插入失败: %v", err)
|
||
}
|
||
|
||
// 6. 返回新创建的用户对象
|
||
return user, nil
|
||
}
|
||
|
||
// UpdateUser 更新用户信息(模型层方法)
|
||
func UpdateUser(id int, username, email, nickname, avatar string) (*User, error) {
|
||
// 1. 根据ID查询用户是否存在
|
||
o := orm.NewOrm()
|
||
user := &User{Id: id}
|
||
err := o.Read(user)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("用户不存在: %v", err)
|
||
}
|
||
|
||
// 2. 仅更新非空字段(避免覆盖原有值)
|
||
if username != "" {
|
||
// 若更新用户名,需检查新用户名是否已被占用
|
||
existingUser, _ := GetUserByUsername(username)
|
||
if existingUser != nil && existingUser.Id != id {
|
||
return nil, fmt.Errorf("用户名已被占用")
|
||
}
|
||
user.Username = username
|
||
}
|
||
if email != "" {
|
||
user.Email = email
|
||
}
|
||
if nickname != "" {
|
||
user.Nickname = nickname
|
||
}
|
||
if avatar != "" {
|
||
user.Avatar = avatar
|
||
}
|
||
|
||
// 3. 执行数据库更新
|
||
_, err = o.Update(user)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("数据库更新失败: %v", err)
|
||
}
|
||
|
||
return user, nil
|
||
}
|
||
|
||
// DeleteUser 根据ID删除用户(模型层方法)
|
||
func DeleteUser(id int) error {
|
||
o := orm.NewOrm()
|
||
// 先查询用户是否存在
|
||
user := &User{Id: id}
|
||
err := o.Read(user)
|
||
if err != nil {
|
||
return fmt.Errorf("用户不存在: %v", err)
|
||
}
|
||
|
||
// 执行删除操作
|
||
_, err = o.Delete(user)
|
||
if err != nil {
|
||
return fmt.Errorf("数据库删除失败: %v", err)
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// Init 初始化数据库
|
||
func Init() {
|
||
orm.RegisterModel(new(User))
|
||
orm.RegisterModel(new(Menu))
|
||
orm.RegisterModel(new(ProgramCategory))
|
||
orm.RegisterModel(new(ProgramInfo))
|
||
orm.RegisterModel(new(FileInfo))
|
||
orm.RegisterModel(new(Knowledge))
|
||
orm.RegisterModel(new(KnowledgeCategory))
|
||
orm.RegisterModel(new(KnowledgeTag))
|
||
|
||
ormConfig, err := beego.AppConfig.String("orm")
|
||
if err != nil {
|
||
panic("无法获取orm配置: " + err.Error())
|
||
}
|
||
|
||
if ormConfig == "mysql" {
|
||
user, err1 := beego.AppConfig.String("mysqluser")
|
||
pass, err2 := beego.AppConfig.String("mysqlpass")
|
||
urls, err3 := beego.AppConfig.String("mysqlurls")
|
||
db, err4 := beego.AppConfig.String("mysqldb")
|
||
if err1 != nil || err2 != nil || err3 != nil || err4 != nil {
|
||
panic("数据库配置错误")
|
||
}
|
||
|
||
// 构建连接字符串
|
||
dsn := user + ":" + pass + "@tcp(" + urls + ")/" + db + "?charset=utf8mb4&parseTime=True&loc=Local"
|
||
fmt.Println("数据库连接字符串:", dsn)
|
||
|
||
// 注册数据库
|
||
err = orm.RegisterDataBase("default", "mysql", dsn)
|
||
if err != nil {
|
||
panic("数据库连接失败: " + err.Error())
|
||
}
|
||
|
||
// 测试连接
|
||
// 注意:Beego v2 中不需要显式调用 Using,默认使用 "default"
|
||
|
||
fmt.Println("数据库连接成功!")
|
||
}
|
||
}
|