package main import ( "bufio" "database/sql" "encoding/json" "fmt" "os" "strings" _ "github.com/go-sql-driver/mysql" ) // MCPRequest 表示 MCP 请求 type MCPRequest struct { JSONRPC string `json:"jsonrpc"` ID interface{} `json:"id"` Method string `json:"method"` Params json.RawMessage `json:"params"` } // MCPResponse 表示 MCP 响应 type MCPResponse struct { JSONRPC string `json:"jsonrpc"` ID interface{} `json:"id"` Result interface{} `json:"result,omitempty"` Error *MCPError `json:"error,omitempty"` } // MCPError 表示 MCP 错误 type MCPError struct { Code int `json:"code"` Message string `json:"message"` Data string `json:"data,omitempty"` } // QueryParams 表示查询参数 type QueryParams struct { SQL string `json:"sql"` Args []interface{} `json:"args"` } // ExecuteParams 表示执行参数 type ExecuteParams struct { SQL string `json:"sql"` Args []interface{} `json:"args"` } var db *sql.DB func main() { // 初始化数据库 if err := initDatabase(); err != nil { fmt.Fprintf(os.Stderr, "Failed to initialize database: %v\n", err) os.Exit(1) } defer db.Close() // 启动 MCP 服务器 reader := bufio.NewReader(os.Stdin) for { line, err := reader.ReadString('\n') if err != nil { break } line = strings.TrimSpace(line) if line == "" { continue } // 解析请求 var req MCPRequest if err := json.Unmarshal([]byte(line), &req); err != nil { sendError(nil, -32700, "Parse error", err.Error()) continue } // 处理请求 handleRequest(&req) } } func initDatabase() error { // 从环境变量或默认值读取配置 user := getEnv("MYSQL_USER", "gotest") pass := getEnv("MYSQL_PASS", "2nZhRdMPCNZrdzsd") urls := getEnv("MYSQL_URLS", "212.64.112.158:3388") dbName := getEnv("MYSQL_DB", "gotest") dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=10s&readTimeout=30s&writeTimeout=30s", user, pass, urls, dbName) var err error db, err = sql.Open("mysql", dsn) if err != nil { return fmt.Errorf("failed to open database: %w", err) } // 测试连接 if err := db.Ping(); err != nil { return fmt.Errorf("failed to ping database: %w", err) } // 配置连接池 db.SetMaxIdleConns(10) db.SetMaxOpenConns(100) fmt.Fprintf(os.Stderr, "Database connected successfully\n") return nil } func handleRequest(req *MCPRequest) { switch req.Method { case "initialize": handleInitialize(req) case "query": handleQuery(req) case "execute": handleExecute(req) case "get_tables": handleGetTables(req) case "get_table_schema": handleGetTableSchema(req) default: sendError(req.ID, -32601, "Method not found", fmt.Sprintf("Unknown method: %s", req.Method)) } } func handleInitialize(req *MCPRequest) { result := map[string]interface{}{ "protocolVersion": "1.0", "capabilities": map[string]interface{}{ "tools": []map[string]interface{}{ { "name": "query", "description": "Execute a SELECT query and return results", "inputSchema": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "sql": map[string]interface{}{ "type": "string", "description": "SQL SELECT query", }, "args": map[string]interface{}{ "type": "array", "description": "Query parameters (optional)", }, }, "required": []string{"sql"}, }, }, { "name": "execute", "description": "Execute an INSERT, UPDATE, or DELETE query", "inputSchema": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "sql": map[string]interface{}{ "type": "string", "description": "SQL INSERT/UPDATE/DELETE query", }, "args": map[string]interface{}{ "type": "array", "description": "Query parameters (optional)", }, }, "required": []string{"sql"}, }, }, { "name": "get_tables", "description": "Get all table names in the database", "inputSchema": map[string]interface{}{ "type": "object", }, }, { "name": "get_table_schema", "description": "Get the schema of a specific table", "inputSchema": map[string]interface{}{ "type": "object", "properties": map[string]interface{}{ "table": map[string]interface{}{ "type": "string", "description": "Table name", }, }, "required": []string{"table"}, }, }, }, }, "serverInfo": map[string]interface{}{ "name": "MySQL MCP Server", "version": "1.0.0", }, } sendResponse(req.ID, result) } func handleQuery(req *MCPRequest) { var params QueryParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { sendError(req.ID, -32602, "Invalid params", err.Error()) return } if params.SQL == "" { sendError(req.ID, -32602, "Invalid params", "SQL query is required") return } // 确保是 SELECT 查询 if !strings.HasPrefix(strings.ToUpper(strings.TrimSpace(params.SQL)), "SELECT") { sendError(req.ID, -32602, "Invalid query", "Only SELECT queries are allowed") return } rows, err := db.Query(params.SQL, params.Args...) if err != nil { sendError(req.ID, -32603, "Database error", err.Error()) return } defer rows.Close() // 获取列名 columns, err := rows.Columns() if err != nil { sendError(req.ID, -32603, "Database error", err.Error()) return } // 读取数据 var results []map[string]interface{} for rows.Next() { values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range columns { valuePtrs[i] = &values[i] } if err := rows.Scan(valuePtrs...); err != nil { sendError(req.ID, -32603, "Database error", err.Error()) return } entry := make(map[string]interface{}) for i, col := range columns { val := values[i] b, ok := val.([]byte) if ok { entry[col] = string(b) } else { entry[col] = val } } results = append(results, entry) } if err := rows.Err(); err != nil { sendError(req.ID, -32603, "Database error", err.Error()) return } sendResponse(req.ID, map[string]interface{}{ "rows": results, "count": len(results), }) } func handleExecute(req *MCPRequest) { var params ExecuteParams if err := json.Unmarshal(req.Params, ¶ms); err != nil { sendError(req.ID, -32602, "Invalid params", err.Error()) return } if params.SQL == "" { sendError(req.ID, -32602, "Invalid params", "SQL query is required") return } result, err := db.Exec(params.SQL, params.Args...) if err != nil { sendError(req.ID, -32603, "Database error", err.Error()) return } lastID, _ := result.LastInsertId() rowsAffected, _ := result.RowsAffected() sendResponse(req.ID, map[string]interface{}{ "lastInsertId": lastID, "rowsAffected": rowsAffected, }) } func handleGetTables(req *MCPRequest) { rows, err := db.Query("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = DATABASE()") if err != nil { sendError(req.ID, -32603, "Database error", err.Error()) return } defer rows.Close() var tables []string for rows.Next() { var tableName string if err := rows.Scan(&tableName); err != nil { sendError(req.ID, -32603, "Database error", err.Error()) return } tables = append(tables, tableName) } sendResponse(req.ID, map[string]interface{}{ "tables": tables, }) } func handleGetTableSchema(req *MCPRequest) { var params struct { Table string `json:"table"` } if err := json.Unmarshal(req.Params, ¶ms); err != nil { sendError(req.ID, -32602, "Invalid params", err.Error()) return } rows, err := db.Query(fmt.Sprintf("DESCRIBE %s", params.Table)) if err != nil { sendError(req.ID, -32603, "Database error", err.Error()) return } defer rows.Close() var schema []map[string]interface{} for rows.Next() { var field, typeStr, null, key, defaultVal, extra string if err := rows.Scan(&field, &typeStr, &null, &key, &defaultVal, &extra); err != nil { sendError(req.ID, -32603, "Database error", err.Error()) return } schema = append(schema, map[string]interface{}{ "field": field, "type": typeStr, "null": null, "key": key, "default": defaultVal, "extra": extra, }) } sendResponse(req.ID, schema) } func sendResponse(id interface{}, result interface{}) { response := MCPResponse{ JSONRPC: "2.0", ID: id, Result: result, } data, _ := json.Marshal(response) fmt.Println(string(data)) } func sendError(id interface{}, code int, message string, data string) { response := MCPResponse{ JSONRPC: "2.0", ID: id, Error: &MCPError{ Code: code, Message: message, Data: data, }, } jsonData, _ := json.Marshal(response) fmt.Println(string(jsonData)) } func getEnv(key, defaultValue string) string { if value := os.Getenv(key); value != "" { return value } return defaultValue }