package middleware import ( "bytes" "fmt" "io" "server/models" "server/services" "strconv" "strings" "time" "github.com/beego/beego/v2/server/web/context" ) // OperationLogMiddleware 操作日志中间件 - 记录所有接口的调用记录 func OperationLogMiddleware(ctx *context.Context) { // 跳过静态资源和内部路由 url := ctx.Input.URL() if shouldSkipLogging(url) { return } method := ctx.Input.Method() // 获取用户信息和租户信息(由 JWT 中间件设置在 Input.Data 中) userId := 0 tenantId := 0 username := "" userType := "" // 用户类型:user(平台用户) 或 employee(租户员工) if v := ctx.Input.GetData("userId"); v != nil { if id, ok := v.(int); ok { userId = id } } if v := ctx.Input.GetData("tenantId"); v != nil { if id, ok := v.(int); ok { tenantId = id } } if v := ctx.Input.GetData("username"); v != nil { if s, ok := v.(string); ok { username = s } } if v := ctx.Input.GetData("userType"); v != nil { if s, ok := v.(string); ok { userType = s } } // 用户信息补全 if username == "" { username = "anonymous" } // 读取请求体(对于有请求体的方法) var requestBody string if method == "POST" || method == "PUT" || method == "PATCH" { body, err := io.ReadAll(ctx.Request.Body) if err == nil && len(body) > 0 { requestBody = string(body) // 重置请求体,使其可以被后续处理 ctx.Request.Body = io.NopCloser(bytes.NewBuffer(body)) } } startTime := time.Now() ipAddress := ctx.Input.IP() userAgent := ctx.Input.Header("User-Agent") queryString := ctx.Request.URL.RawQuery // 使用延迟函数来记录操作 defer func() { duration := time.Since(startTime) // 解析操作相关信息 operation := parseOperationType(method, url) module := parseModule(url) resourceType := parseResourceType(url) resourceId := parseResourceId(url) // 为所有接口都记录日志 log := &models.OperationLog{ TenantId: tenantId, UserId: userId, Username: username, Module: module, ResourceType: resourceType, Operation: operation, IpAddress: ipAddress, UserAgent: userAgent, RequestMethod: method, RequestUrl: url, Status: 1, // 默认成功 Duration: int(duration.Milliseconds()), CreateTime: time.Now(), } // 设置资源ID if resourceId > 0 { log.ResourceId = &resourceId } // 记录请求信息到Description var description strings.Builder if requestBody != "" { description.WriteString("Request: " + truncateString(requestBody, 500)) } if queryString != "" { if description.Len() > 0 { description.WriteString(" | ") } description.WriteString("Query: " + queryString) } log.Description = description.String() // 如果有请求体,作为NewValue保存 if requestBody != "" { log.NewValue = requestBody } // 添加用户类型信息到Description if userType != "" { if log.Description != "" { log.Description += " | " } log.Description += "UserType: " + userType } // 调用服务层保存日志 if err := services.AddOperationLog(log); err != nil { fmt.Printf("Failed to save operation log: %v\n", err) } }() } // parseOperationType 根据HTTP方法解析操作类型 func parseOperationType(method, url string) string { switch method { case "POST": // 检查URL是否包含特定的操作关键字 if strings.Contains(url, "login") { return "LOGIN" } if strings.Contains(url, "logout") { return "LOGOUT" } if strings.Contains(url, "add") || strings.Contains(url, "create") { return "CREATE" } return "CREATE" case "PUT", "PATCH": return "UPDATE" case "DELETE": return "DELETE" default: return "READ" } } // parseResourceType 根据URL解析资源类型 func parseResourceType(url string) string { parts := strings.Split(strings.TrimPrefix(url, "/api/"), "/") if len(parts) > 0 { // 移除复数形式的s resourceType := strings.TrimSuffix(parts[0], "s") return resourceType } return "unknown" } // parseResourceId 从URL中提取资源ID func parseResourceId(url string) int { parts := strings.Split(strings.TrimPrefix(url, "/api/"), "/") if len(parts) >= 2 { // 尝试解析第二个部分为ID if id, err := strconv.Atoi(parts[1]); err == nil { return id } } return 0 } // parseModule 根据URL解析模块名称 func parseModule(url string) string { // 返回与 sys_operation_log.module 字段匹配的短code(例如 dict、user 等) parts := strings.Split(strings.TrimPrefix(url, "/api/"), "/") if len(parts) > 0 { return strings.ToLower(parts[0]) } return "unknown" } // shouldSkipLogging 判断是否需要跳过日志记录 func shouldSkipLogging(url string) bool { // 跳过静态资源、健康检查等 skipPatterns := []string{ "/static/", "/uploads/", "/favicon.ico", "/health", "/ping", } for _, pattern := range skipPatterns { if strings.HasPrefix(url, pattern) { return true } } return false } // truncateString 截断字符串到指定长度 func truncateString(s string, maxLen int) string { if len(s) <= maxLen { return s } return s[:maxLen] + "..." }