yunzer_go/server/mcp-server/main.go

388 lines
8.9 KiB
Go

package main
import (
"bufio"
"database/sql"
"encoding/json"
"fmt"
"os"
"strings"
_ "github.com/go-sql-driver/mysql"
)
// MCPRequest 表示 MCP 请求
type MCPRequest struct {
JSONRPC string `json:"jsonrpc"`
ID interface{} `json:"id"`
Method string `json:"method"`
Params json.RawMessage `json:"params"`
}
// MCPResponse 表示 MCP 响应
type MCPResponse struct {
JSONRPC string `json:"jsonrpc"`
ID interface{} `json:"id"`
Result interface{} `json:"result,omitempty"`
Error *MCPError `json:"error,omitempty"`
}
// MCPError 表示 MCP 错误
type MCPError struct {
Code int `json:"code"`
Message string `json:"message"`
Data string `json:"data,omitempty"`
}
// QueryParams 表示查询参数
type QueryParams struct {
SQL string `json:"sql"`
Args []interface{} `json:"args"`
}
// ExecuteParams 表示执行参数
type ExecuteParams struct {
SQL string `json:"sql"`
Args []interface{} `json:"args"`
}
var db *sql.DB
func main() {
// 初始化数据库
if err := initDatabase(); err != nil {
fmt.Fprintf(os.Stderr, "Failed to initialize database: %v\n", err)
os.Exit(1)
}
defer db.Close()
// 启动 MCP 服务器
reader := bufio.NewReader(os.Stdin)
for {
line, err := reader.ReadString('\n')
if err != nil {
break
}
line = strings.TrimSpace(line)
if line == "" {
continue
}
// 解析请求
var req MCPRequest
if err := json.Unmarshal([]byte(line), &req); err != nil {
sendError(nil, -32700, "Parse error", err.Error())
continue
}
// 处理请求
handleRequest(&req)
}
}
func initDatabase() error {
// 从环境变量或默认值读取配置
user := getEnv("MYSQL_USER", "gotest")
pass := getEnv("MYSQL_PASS", "2nZhRdMPCNZrdzsd")
urls := getEnv("MYSQL_URLS", "212.64.112.158:3388")
dbName := getEnv("MYSQL_DB", "gotest")
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=10s&readTimeout=30s&writeTimeout=30s",
user, pass, urls, dbName)
var err error
db, err = sql.Open("mysql", dsn)
if err != nil {
return fmt.Errorf("failed to open database: %w", err)
}
// 测试连接
if err := db.Ping(); err != nil {
return fmt.Errorf("failed to ping database: %w", err)
}
// 配置连接池
db.SetMaxIdleConns(10)
db.SetMaxOpenConns(100)
fmt.Fprintf(os.Stderr, "Database connected successfully\n")
return nil
}
func handleRequest(req *MCPRequest) {
switch req.Method {
case "initialize":
handleInitialize(req)
case "query":
handleQuery(req)
case "execute":
handleExecute(req)
case "get_tables":
handleGetTables(req)
case "get_table_schema":
handleGetTableSchema(req)
default:
sendError(req.ID, -32601, "Method not found", fmt.Sprintf("Unknown method: %s", req.Method))
}
}
func handleInitialize(req *MCPRequest) {
result := map[string]interface{}{
"protocolVersion": "1.0",
"capabilities": map[string]interface{}{
"tools": []map[string]interface{}{
{
"name": "query",
"description": "Execute a SELECT query and return results",
"inputSchema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"sql": map[string]interface{}{
"type": "string",
"description": "SQL SELECT query",
},
"args": map[string]interface{}{
"type": "array",
"description": "Query parameters (optional)",
},
},
"required": []string{"sql"},
},
},
{
"name": "execute",
"description": "Execute an INSERT, UPDATE, or DELETE query",
"inputSchema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"sql": map[string]interface{}{
"type": "string",
"description": "SQL INSERT/UPDATE/DELETE query",
},
"args": map[string]interface{}{
"type": "array",
"description": "Query parameters (optional)",
},
},
"required": []string{"sql"},
},
},
{
"name": "get_tables",
"description": "Get all table names in the database",
"inputSchema": map[string]interface{}{
"type": "object",
},
},
{
"name": "get_table_schema",
"description": "Get the schema of a specific table",
"inputSchema": map[string]interface{}{
"type": "object",
"properties": map[string]interface{}{
"table": map[string]interface{}{
"type": "string",
"description": "Table name",
},
},
"required": []string{"table"},
},
},
},
},
"serverInfo": map[string]interface{}{
"name": "MySQL MCP Server",
"version": "1.0.0",
},
}
sendResponse(req.ID, result)
}
func handleQuery(req *MCPRequest) {
var params QueryParams
if err := json.Unmarshal(req.Params, &params); err != nil {
sendError(req.ID, -32602, "Invalid params", err.Error())
return
}
if params.SQL == "" {
sendError(req.ID, -32602, "Invalid params", "SQL query is required")
return
}
// 确保是 SELECT 查询
if !strings.HasPrefix(strings.ToUpper(strings.TrimSpace(params.SQL)), "SELECT") {
sendError(req.ID, -32602, "Invalid query", "Only SELECT queries are allowed")
return
}
rows, err := db.Query(params.SQL, params.Args...)
if err != nil {
sendError(req.ID, -32603, "Database error", err.Error())
return
}
defer rows.Close()
// 获取列名
columns, err := rows.Columns()
if err != nil {
sendError(req.ID, -32603, "Database error", err.Error())
return
}
// 读取数据
var results []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range columns {
valuePtrs[i] = &values[i]
}
if err := rows.Scan(valuePtrs...); err != nil {
sendError(req.ID, -32603, "Database error", err.Error())
return
}
entry := make(map[string]interface{})
for i, col := range columns {
val := values[i]
b, ok := val.([]byte)
if ok {
entry[col] = string(b)
} else {
entry[col] = val
}
}
results = append(results, entry)
}
if err := rows.Err(); err != nil {
sendError(req.ID, -32603, "Database error", err.Error())
return
}
sendResponse(req.ID, map[string]interface{}{
"rows": results,
"count": len(results),
})
}
func handleExecute(req *MCPRequest) {
var params ExecuteParams
if err := json.Unmarshal(req.Params, &params); err != nil {
sendError(req.ID, -32602, "Invalid params", err.Error())
return
}
if params.SQL == "" {
sendError(req.ID, -32602, "Invalid params", "SQL query is required")
return
}
result, err := db.Exec(params.SQL, params.Args...)
if err != nil {
sendError(req.ID, -32603, "Database error", err.Error())
return
}
lastID, _ := result.LastInsertId()
rowsAffected, _ := result.RowsAffected()
sendResponse(req.ID, map[string]interface{}{
"lastInsertId": lastID,
"rowsAffected": rowsAffected,
})
}
func handleGetTables(req *MCPRequest) {
rows, err := db.Query("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = DATABASE()")
if err != nil {
sendError(req.ID, -32603, "Database error", err.Error())
return
}
defer rows.Close()
var tables []string
for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
sendError(req.ID, -32603, "Database error", err.Error())
return
}
tables = append(tables, tableName)
}
sendResponse(req.ID, map[string]interface{}{
"tables": tables,
})
}
func handleGetTableSchema(req *MCPRequest) {
var params struct {
Table string `json:"table"`
}
if err := json.Unmarshal(req.Params, &params); err != nil {
sendError(req.ID, -32602, "Invalid params", err.Error())
return
}
rows, err := db.Query(fmt.Sprintf("DESCRIBE %s", params.Table))
if err != nil {
sendError(req.ID, -32603, "Database error", err.Error())
return
}
defer rows.Close()
var schema []map[string]interface{}
for rows.Next() {
var field, typeStr, null, key, defaultVal, extra string
if err := rows.Scan(&field, &typeStr, &null, &key, &defaultVal, &extra); err != nil {
sendError(req.ID, -32603, "Database error", err.Error())
return
}
schema = append(schema, map[string]interface{}{
"field": field,
"type": typeStr,
"null": null,
"key": key,
"default": defaultVal,
"extra": extra,
})
}
sendResponse(req.ID, schema)
}
func sendResponse(id interface{}, result interface{}) {
response := MCPResponse{
JSONRPC: "2.0",
ID: id,
Result: result,
}
data, _ := json.Marshal(response)
fmt.Println(string(data))
}
func sendError(id interface{}, code int, message string, data string) {
response := MCPResponse{
JSONRPC: "2.0",
ID: id,
Error: &MCPError{
Code: code,
Message: message,
Data: data,
},
}
jsonData, _ := json.Marshal(response)
fmt.Println(string(jsonData))
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}