go-platform/services/login_verify_code.go

247 lines
6.4 KiB
Go

package services
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"strings"
"sync"
"time"
"server/models"
)
type loginCodeItem struct {
Code string
Channel string
ExpiredAt time.Time
}
var loginCodeStore sync.Map
func codeKey(account, channel string) string {
return strings.ToLower(strings.TrimSpace(account)) + "|" + strings.TrimSpace(channel)
}
func SendPlatformLoginCode(account, channel string) error {
account = strings.TrimSpace(account)
channel = strings.TrimSpace(channel)
if account == "" {
return errors.New("账号不能为空")
}
if channel != "sms" && channel != "email" {
return errors.New("仅支持短信或邮箱验证码")
}
var u models.AdminUser
if err := models.Orm.QueryTable(new(models.AdminUser)).Filter("account", account).One(&u); err != nil {
return errors.New("用户不存在")
}
if u.Status == 0 {
return errors.New("账号已禁用")
}
if channel == "sms" && (u.Phone == nil || strings.TrimSpace(*u.Phone) == "") {
return errors.New("该账号未绑定手机号")
}
if channel == "email" && (u.Email == nil || strings.TrimSpace(*u.Email) == "") {
return errors.New("该账号未绑定邮箱")
}
rand.Seed(time.Now().UnixNano())
code := fmt.Sprintf("%06d", rand.Intn(1000000))
loginCodeStore.Store(codeKey(account, channel), loginCodeItem{
Code: code,
Channel: channel,
ExpiredAt: time.Now().Add(5 * time.Minute),
})
return nil
}
func VerifyPlatformLoginCode(account, channel, code string) error {
account = strings.TrimSpace(account)
channel = strings.TrimSpace(channel)
code = strings.TrimSpace(code)
if account == "" || code == "" {
return errors.New("验证码不能为空")
}
val, ok := loginCodeStore.Load(codeKey(account, channel))
if !ok {
return errors.New("验证码不存在或已失效")
}
item, ok := val.(loginCodeItem)
if !ok {
return errors.New("验证码状态异常")
}
if time.Now().After(item.ExpiredAt) {
loginCodeStore.Delete(codeKey(account, channel))
return errors.New("验证码已过期")
}
if item.Code != code {
return errors.New("验证码错误")
}
loginCodeStore.Delete(codeKey(account, channel))
return nil
}
func SendBackendLoginCode(tenantName, account, channel string) error {
tenantName = strings.TrimSpace(tenantName)
account = strings.TrimSpace(account)
channel = strings.TrimSpace(channel)
if tenantName == "" || account == "" {
return errors.New("租户名称和账号不能为空")
}
if channel != "sms" && channel != "email" {
return errors.New("仅支持短信或邮箱验证码")
}
var tenant models.SystemTenant
if err := models.Orm.QueryTable(new(models.SystemTenant)).Filter("tenant_name", tenantName).One(&tenant); err != nil {
return errors.New("租户不存在")
}
rand.Seed(time.Now().UnixNano())
code := fmt.Sprintf("%06d", rand.Intn(1000000))
switch channel {
case "sms":
phone := account
var user models.SystemTenantUser
if err := models.Orm.QueryTable(new(models.SystemTenantUser)).
Filter("tid", tenant.ID).
Filter("phone", phone).
One(&user); err != nil {
return errors.New("该手机号非当前企业绑定号码,请重试")
}
if user.Status == 0 {
return errors.New("账号已禁用")
}
if user.Phone == nil || strings.TrimSpace(*user.Phone) == "" {
return errors.New("该手机号非当前企业绑定号码,请重试")
}
content := "短信验证码:" + code
if err := enqueueSMSTaskForLogin(tenant.ID, phone, content, code); err != nil {
return errors.New("短信发送失败,请重试")
}
case "email":
email := account
var user models.SystemTenantUser
if err := models.Orm.QueryTable(new(models.SystemTenantUser)).
Filter("tid", tenant.ID).
Filter("email", email).
One(&user); err != nil {
return errors.New("该账号未绑定邮箱")
}
if user.Status == 0 {
return errors.New("账号已禁用")
}
if user.Email == nil || strings.TrimSpace(*user.Email) == "" {
return errors.New("该账号未绑定邮箱")
}
}
loginCodeStore.Store(codeKey(tenantName+"#"+account, channel), loginCodeItem{
Code: code,
Channel: channel,
ExpiredAt: time.Now().Add(5 * time.Minute),
})
return nil
}
func VerifyBackendLoginCode(tenantName, account, channel, code string) error {
return VerifyPlatformLoginCode(tenantName+"#"+account, channel, code)
}
func getDefaultSystemSMSConfig() (backendURL string, apiKey string, err error) {
var row models.SystemSMS
err = models.Orm.QueryTable(new(models.SystemSMS)).
Filter("is_default", 1).
Filter("status", 1).
OrderBy("-weight", "-id").
Limit(1).
One(&row)
if err != nil {
err2 := models.Orm.QueryTable(new(models.SystemSMS)).
Filter("config_code", "custom").
OrderBy("-id").
Limit(1).
One(&row)
if err2 != nil {
return "", "", err2
}
}
backendURL = strings.TrimSpace(row.ApiURL)
apiKey = strings.TrimSpace(row.ApiKey)
return backendURL, apiKey, nil
}
// enqueueSMSTaskForLogin 入队短信任务到网关,并写入 yz_system_sms_tasks
func enqueueSMSTaskForLogin(tid uint64, phone, content, code string) error {
backendURL, apiKey, err := getDefaultSystemSMSConfig()
if err != nil {
return err
}
if backendURL == "" || apiKey == "" {
return errors.New("短信网关未配置")
}
enqueueURL := strings.TrimRight(backendURL, "/") + "/api/v1/business/outbound-tasks"
payload := map[string]interface{}{
"phone": phone,
"content": content,
}
bs, _ := json.Marshal(payload)
client := &http.Client{Timeout: 10 * time.Second}
req, err := http.NewRequest("POST", enqueueURL, bytes.NewReader(bs))
if err != nil {
return err
}
req.Header.Set("X-Api-Key", apiKey)
req.Header.Set("Content-Type", "application/json; charset=utf-8")
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
bodyBytes, _ := io.ReadAll(resp.Body)
bodyStr := strings.TrimSpace(string(bodyBytes))
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return fmt.Errorf("gateway http status: %d, body: %s", resp.StatusCode, bodyStr)
}
now := time.Now()
tidCopy := tid
contentPtr := content
var reportPtr *string
if bodyStr != "" {
reportPtr = &bodyStr
}
task := &models.SystemSMSTask{
Tid: &tidCopy,
ApiKey: apiKey,
Phone: phone,
Content: &contentPtr,
Status: 3,
Code: code,
ReportRaw: reportPtr,
CreateTime: &now,
UpdateTime: &now,
}
_, insertErr := models.Orm.Insert(task)
if insertErr != nil {
return nil
}
return nil
}