388 lines
8.9 KiB
Go
388 lines
8.9 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"os"
|
|
"strings"
|
|
|
|
_ "github.com/go-sql-driver/mysql"
|
|
)
|
|
|
|
// MCPRequest 表示 MCP 请求
|
|
type MCPRequest struct {
|
|
JSONRPC string `json:"jsonrpc"`
|
|
ID interface{} `json:"id"`
|
|
Method string `json:"method"`
|
|
Params json.RawMessage `json:"params"`
|
|
}
|
|
|
|
// MCPResponse 表示 MCP 响应
|
|
type MCPResponse struct {
|
|
JSONRPC string `json:"jsonrpc"`
|
|
ID interface{} `json:"id"`
|
|
Result interface{} `json:"result,omitempty"`
|
|
Error *MCPError `json:"error,omitempty"`
|
|
}
|
|
|
|
// MCPError 表示 MCP 错误
|
|
type MCPError struct {
|
|
Code int `json:"code"`
|
|
Message string `json:"message"`
|
|
Data string `json:"data,omitempty"`
|
|
}
|
|
|
|
// QueryParams 表示查询参数
|
|
type QueryParams struct {
|
|
SQL string `json:"sql"`
|
|
Args []interface{} `json:"args"`
|
|
}
|
|
|
|
// ExecuteParams 表示执行参数
|
|
type ExecuteParams struct {
|
|
SQL string `json:"sql"`
|
|
Args []interface{} `json:"args"`
|
|
}
|
|
|
|
var db *sql.DB
|
|
|
|
func main() {
|
|
// 初始化数据库
|
|
if err := initDatabase(); err != nil {
|
|
fmt.Fprintf(os.Stderr, "Failed to initialize database: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
defer db.Close()
|
|
|
|
// 启动 MCP 服务器
|
|
reader := bufio.NewReader(os.Stdin)
|
|
for {
|
|
line, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
break
|
|
}
|
|
|
|
line = strings.TrimSpace(line)
|
|
if line == "" {
|
|
continue
|
|
}
|
|
|
|
// 解析请求
|
|
var req MCPRequest
|
|
if err := json.Unmarshal([]byte(line), &req); err != nil {
|
|
sendError(nil, -32700, "Parse error", err.Error())
|
|
continue
|
|
}
|
|
|
|
// 处理请求
|
|
handleRequest(&req)
|
|
}
|
|
}
|
|
|
|
func initDatabase() error {
|
|
// 从环境变量或默认值读取配置
|
|
user := getEnv("MYSQL_USER", "gotest")
|
|
pass := getEnv("MYSQL_PASS", "2nZhRdMPCNZrdzsd")
|
|
urls := getEnv("MYSQL_URLS", "212.64.112.158:3388")
|
|
dbName := getEnv("MYSQL_DB", "gotest")
|
|
|
|
dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&timeout=10s&readTimeout=30s&writeTimeout=30s",
|
|
user, pass, urls, dbName)
|
|
|
|
var err error
|
|
db, err = sql.Open("mysql", dsn)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open database: %w", err)
|
|
}
|
|
|
|
// 测试连接
|
|
if err := db.Ping(); err != nil {
|
|
return fmt.Errorf("failed to ping database: %w", err)
|
|
}
|
|
|
|
// 配置连接池
|
|
db.SetMaxIdleConns(10)
|
|
db.SetMaxOpenConns(100)
|
|
|
|
fmt.Fprintf(os.Stderr, "Database connected successfully\n")
|
|
return nil
|
|
}
|
|
|
|
func handleRequest(req *MCPRequest) {
|
|
switch req.Method {
|
|
case "initialize":
|
|
handleInitialize(req)
|
|
case "query":
|
|
handleQuery(req)
|
|
case "execute":
|
|
handleExecute(req)
|
|
case "get_tables":
|
|
handleGetTables(req)
|
|
case "get_table_schema":
|
|
handleGetTableSchema(req)
|
|
default:
|
|
sendError(req.ID, -32601, "Method not found", fmt.Sprintf("Unknown method: %s", req.Method))
|
|
}
|
|
}
|
|
|
|
func handleInitialize(req *MCPRequest) {
|
|
result := map[string]interface{}{
|
|
"protocolVersion": "1.0",
|
|
"capabilities": map[string]interface{}{
|
|
"tools": []map[string]interface{}{
|
|
{
|
|
"name": "query",
|
|
"description": "Execute a SELECT query and return results",
|
|
"inputSchema": map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{
|
|
"sql": map[string]interface{}{
|
|
"type": "string",
|
|
"description": "SQL SELECT query",
|
|
},
|
|
"args": map[string]interface{}{
|
|
"type": "array",
|
|
"description": "Query parameters (optional)",
|
|
},
|
|
},
|
|
"required": []string{"sql"},
|
|
},
|
|
},
|
|
{
|
|
"name": "execute",
|
|
"description": "Execute an INSERT, UPDATE, or DELETE query",
|
|
"inputSchema": map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{
|
|
"sql": map[string]interface{}{
|
|
"type": "string",
|
|
"description": "SQL INSERT/UPDATE/DELETE query",
|
|
},
|
|
"args": map[string]interface{}{
|
|
"type": "array",
|
|
"description": "Query parameters (optional)",
|
|
},
|
|
},
|
|
"required": []string{"sql"},
|
|
},
|
|
},
|
|
{
|
|
"name": "get_tables",
|
|
"description": "Get all table names in the database",
|
|
"inputSchema": map[string]interface{}{
|
|
"type": "object",
|
|
},
|
|
},
|
|
{
|
|
"name": "get_table_schema",
|
|
"description": "Get the schema of a specific table",
|
|
"inputSchema": map[string]interface{}{
|
|
"type": "object",
|
|
"properties": map[string]interface{}{
|
|
"table": map[string]interface{}{
|
|
"type": "string",
|
|
"description": "Table name",
|
|
},
|
|
},
|
|
"required": []string{"table"},
|
|
},
|
|
},
|
|
},
|
|
},
|
|
"serverInfo": map[string]interface{}{
|
|
"name": "MySQL MCP Server",
|
|
"version": "1.0.0",
|
|
},
|
|
}
|
|
sendResponse(req.ID, result)
|
|
}
|
|
|
|
func handleQuery(req *MCPRequest) {
|
|
var params QueryParams
|
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
|
sendError(req.ID, -32602, "Invalid params", err.Error())
|
|
return
|
|
}
|
|
|
|
if params.SQL == "" {
|
|
sendError(req.ID, -32602, "Invalid params", "SQL query is required")
|
|
return
|
|
}
|
|
|
|
// 确保是 SELECT 查询
|
|
if !strings.HasPrefix(strings.ToUpper(strings.TrimSpace(params.SQL)), "SELECT") {
|
|
sendError(req.ID, -32602, "Invalid query", "Only SELECT queries are allowed")
|
|
return
|
|
}
|
|
|
|
rows, err := db.Query(params.SQL, params.Args...)
|
|
if err != nil {
|
|
sendError(req.ID, -32603, "Database error", err.Error())
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
// 获取列名
|
|
columns, err := rows.Columns()
|
|
if err != nil {
|
|
sendError(req.ID, -32603, "Database error", err.Error())
|
|
return
|
|
}
|
|
|
|
// 读取数据
|
|
var results []map[string]interface{}
|
|
for rows.Next() {
|
|
values := make([]interface{}, len(columns))
|
|
valuePtrs := make([]interface{}, len(columns))
|
|
for i := range columns {
|
|
valuePtrs[i] = &values[i]
|
|
}
|
|
|
|
if err := rows.Scan(valuePtrs...); err != nil {
|
|
sendError(req.ID, -32603, "Database error", err.Error())
|
|
return
|
|
}
|
|
|
|
entry := make(map[string]interface{})
|
|
for i, col := range columns {
|
|
val := values[i]
|
|
b, ok := val.([]byte)
|
|
if ok {
|
|
entry[col] = string(b)
|
|
} else {
|
|
entry[col] = val
|
|
}
|
|
}
|
|
results = append(results, entry)
|
|
}
|
|
|
|
if err := rows.Err(); err != nil {
|
|
sendError(req.ID, -32603, "Database error", err.Error())
|
|
return
|
|
}
|
|
|
|
sendResponse(req.ID, map[string]interface{}{
|
|
"rows": results,
|
|
"count": len(results),
|
|
})
|
|
}
|
|
|
|
func handleExecute(req *MCPRequest) {
|
|
var params ExecuteParams
|
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
|
sendError(req.ID, -32602, "Invalid params", err.Error())
|
|
return
|
|
}
|
|
|
|
if params.SQL == "" {
|
|
sendError(req.ID, -32602, "Invalid params", "SQL query is required")
|
|
return
|
|
}
|
|
|
|
result, err := db.Exec(params.SQL, params.Args...)
|
|
if err != nil {
|
|
sendError(req.ID, -32603, "Database error", err.Error())
|
|
return
|
|
}
|
|
|
|
lastID, _ := result.LastInsertId()
|
|
rowsAffected, _ := result.RowsAffected()
|
|
|
|
sendResponse(req.ID, map[string]interface{}{
|
|
"lastInsertId": lastID,
|
|
"rowsAffected": rowsAffected,
|
|
})
|
|
}
|
|
|
|
func handleGetTables(req *MCPRequest) {
|
|
rows, err := db.Query("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = DATABASE()")
|
|
if err != nil {
|
|
sendError(req.ID, -32603, "Database error", err.Error())
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var tables []string
|
|
for rows.Next() {
|
|
var tableName string
|
|
if err := rows.Scan(&tableName); err != nil {
|
|
sendError(req.ID, -32603, "Database error", err.Error())
|
|
return
|
|
}
|
|
tables = append(tables, tableName)
|
|
}
|
|
|
|
sendResponse(req.ID, map[string]interface{}{
|
|
"tables": tables,
|
|
})
|
|
}
|
|
|
|
func handleGetTableSchema(req *MCPRequest) {
|
|
var params struct {
|
|
Table string `json:"table"`
|
|
}
|
|
if err := json.Unmarshal(req.Params, ¶ms); err != nil {
|
|
sendError(req.ID, -32602, "Invalid params", err.Error())
|
|
return
|
|
}
|
|
|
|
rows, err := db.Query(fmt.Sprintf("DESCRIBE %s", params.Table))
|
|
if err != nil {
|
|
sendError(req.ID, -32603, "Database error", err.Error())
|
|
return
|
|
}
|
|
defer rows.Close()
|
|
|
|
var schema []map[string]interface{}
|
|
for rows.Next() {
|
|
var field, typeStr, null, key, defaultVal, extra string
|
|
if err := rows.Scan(&field, &typeStr, &null, &key, &defaultVal, &extra); err != nil {
|
|
sendError(req.ID, -32603, "Database error", err.Error())
|
|
return
|
|
}
|
|
schema = append(schema, map[string]interface{}{
|
|
"field": field,
|
|
"type": typeStr,
|
|
"null": null,
|
|
"key": key,
|
|
"default": defaultVal,
|
|
"extra": extra,
|
|
})
|
|
}
|
|
|
|
sendResponse(req.ID, schema)
|
|
}
|
|
|
|
func sendResponse(id interface{}, result interface{}) {
|
|
response := MCPResponse{
|
|
JSONRPC: "2.0",
|
|
ID: id,
|
|
Result: result,
|
|
}
|
|
data, _ := json.Marshal(response)
|
|
fmt.Println(string(data))
|
|
}
|
|
|
|
func sendError(id interface{}, code int, message string, data string) {
|
|
response := MCPResponse{
|
|
JSONRPC: "2.0",
|
|
ID: id,
|
|
Error: &MCPError{
|
|
Code: code,
|
|
Message: message,
|
|
Data: data,
|
|
},
|
|
}
|
|
jsonData, _ := json.Marshal(response)
|
|
fmt.Println(string(jsonData))
|
|
}
|
|
|
|
func getEnv(key, defaultValue string) string {
|
|
if value := os.Getenv(key); value != "" {
|
|
return value
|
|
}
|
|
return defaultValue
|
|
}
|