XiYan MCP Server
Official
- src
- xiyan_mcp_server
import logging
import os
import yaml # 添加yaml库导入
from mysql.connector import connect, Error
from mcp.server import FastMCP
from mcp.types import TextContent
from .utils.db_config import DBConfig
from .database_env import DataBaseEnv
from .utils.db_source import HITLSQLDatabase
from .utils.db_util import init_db_conn
from .utils.file_util import extract_sql_from_qwen
from .utils.llm_util import call_dashscope
mcp = FastMCP("xiyan")
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("xiyan_mcp_server")
def get_yml_config():
config_path = os.getenv("YML", os.path.join(os.path.dirname(__file__), "config_demo.yml"))
try:
with open(config_path, 'r') as file:
config = yaml.safe_load(file)
return config
except FileNotFoundError:
logger.error(f"Configuration file {config_path} not found.")
raise
except yaml.YAMLError as exc:
logger.error(f"Error parsing configuration file {config_path}: {exc}")
raise
def get_xiyan_config(db_config):
xiyan_db_config = DBConfig(dialect='mysql',db_name=db_config['database'], user_name=db_config['user'], db_pwd=db_config['password'], db_host=db_config['host'], port=db_config['port'])
return xiyan_db_config
global_config = get_yml_config()
#print(global_config)
model_config = global_config['model']
global_db_config = global_config['database']
global_xiyan_db_config = get_xiyan_config(global_db_config)
@mcp.resource('mysql://'+global_db_config['database'])
async def read_resource() -> str:
db_engine = init_db_conn(global_xiyan_db_config)
db_source = HITLSQLDatabase(db_engine)
return db_source.mschema.to_mschema()
@mcp.resource("mysql://{table_name}")
async def read_resource(table_name) -> str:
"""Read table contents."""
config = global_db_config
try:
with connect(**config) as conn:
with conn.cursor() as cursor:
cursor.execute(f"SELECT * FROM {table_name} LIMIT 100")
columns = [desc[0] for desc in cursor.description]
rows = cursor.fetchall()
result = [",".join(map(str, row)) for row in rows]
return "\n".join([",".join(columns)] + result)
except Error as e:
raise RuntimeError(f"Database error: {str(e)}")
def sql_gen_and_execute(db_env, query: str):
"""
Transfers the input natural language question to sql query (known as Text-to-sql) and executes it on the database.
Args:
query: natural language to query the database. e.g. 查询在2024年每个月,卡宴的各经销商销量分别是多少
"""
#db_env = context_variables.get('db_env', None)
prompt = f"""你现在是一名{db_env.dialect}数据分析专家,你的任务是根据参考的数据库schema和用户的问题,编写正确的SQL来回答用户的问题,生成的SQL用``sql 和```包围起来。
【数据库schema】
{db_env.mschema_str}
【问题】
{query}
"""
#logger.info(f"SQL generation prompt: {prompt}")
messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": f"用户的问题是: {query}"}
]
param = {"model": model_config['name'], "messages": messages,"key":model_config['key'],"url":model_config['url']}
try:
response = call_dashscope(**param)
content = response.choices[0].message.content
sql_query = extract_sql_from_qwen(content)
status, res = db_env.database.fetch(sql_query)
if not status:
for idx in range(3):
sql_query = sql_fix(db_env.dialect, db_env.mschema_str, query, sql_query, res)
status, res = db_env.database.fetch(sql_query)
if status:
break
sql_res = db_env.database.fetch_truncated(sql_query,max_rows=100)
markdown_res = db_env.database.trunc_result_to_markdown(sql_res)
logger.info(f"SQL query: {sql_query}\nSQL result: {sql_res}")
return markdown_res.strip()
except Exception as e:
return str(e)
def sql_fix(dialect: str, mschema: str, query: str, sql_query: str, error_info: str):
system_prompt = '''现在你是一个{dialect}数据分析专家,需要阅读一个客户的问题,参考的数据库schema,该问题对应的待检查SQL,以及执行该SQL时数据库返回的语法错误,请你仅针对其中的语法错误进行修复,输出修复后的SQL。
注意:
1、仅修复语法错误,不允许改变SQL的逻辑。
2、生成的SQL用```sql 和```包围起来。
【数据库schema】
{schema}
'''.format(dialect=dialect, schema=mschema)
user_prompt = '''【问题】
{question}
【待检查SQL】
{sql}
【错误信息】
{sql_res}'''.format(question=query, sql=sql_query, sql_res=error_info)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
param = {"model": model_config['name'], "messages": messages,"key":model_config['key'],'url':model_config['url']}
response = call_dashscope(**param)
content = response.choices[0].message.content
sql_query = extract_sql_from_qwen(content)
return sql_query
def call_xiyan(query: str)-> str:
"""Fetch the data from database through a natural language query
Args:
query: The query in natual language
"""
#db_config = global_db_config
#xiyan_config = get_xiyan_config(db_config)
logger.info(f"Calling tool with arguments: {query}")
try:
db_engine = init_db_conn(global_xiyan_db_config)
db_source = HITLSQLDatabase(db_engine)
except Exception as e:
return "数据库连接失败"+str(e)
logger.info(f"Calling xiyan")
env = DataBaseEnv(db_source)
res = sql_gen_and_execute(env,query)
return str(res)
@mcp.tool()
def get_data_via_natural_language(query: str)-> list[TextContent]:
"""Fetch the data from database through a natural language query
Args:
query: The query in natual language
"""
res=call_xiyan(query)
return [TextContent(type="text", text=res)]
def main():
mcp.run()
if __name__ == "__main__":
main()