main.py•2.67 kB
import duckdb
import torch
from common import duckdb_file, r_model, ja_tokens, v_model, v_tokenizer
from mcp.server.fastmcp import FastMCP
mcp = FastMCP("learning-react-mcp", stateless_http=True)
def fts_search(conn, query):
q_tokens = ja_tokens(query)
rows = conn.sql(f"""
SELECT id, fts_main_sora_doc.match_bm25(id, '{q_tokens}') AS score, content
FROM sora_doc
WHERE score IS NOT NULL
ORDER BY score DESC
""").fetchall()
return rows
def vss_search(conn, query):
with torch.inference_mode():
query_embedding = v_model.encode_query(query, v_tokenizer)
rows = conn.sql(
"""
SELECT id, array_cosine_distance(content_v, ?::FLOAT[2048]) as distance, content
FROM sora_doc
ORDER BY distance ASC
""",
params=[query_embedding.cpu().squeeze().numpy().tolist()],
).fetchall()
return rows
def reranking(query, vss_rows, fts_rows):
passages = {}
for row in vss_rows:
id, _, content = row
passages[content] = id
for row in fts_rows:
id, _, content = row
passages[content] = id
contents = list(passages.keys())
scores = r_model.predict([(query, content) for content in contents])
return sorted(
[
(passages[content], score, content)
for content, score in zip(contents, scores)
],
key=lambda x: x[1],
reverse=True,
)
def hybrid_search(query):
print("query:", query)
conn = duckdb.connect(duckdb_file)
conn.install_extension("vss")
conn.load_extension("vss")
conn.install_extension("fts")
conn.load_extension("fts")
# FTS
print("--- DuckDB-FTS + Lindera ---")
fts_rows = fts_search(conn, query)
for id, score, content in fts_rows:
print(f"ID: {id}, Score: {score:.4f}, Content: {content}")
# VSS
print("--- DuckDB-VSS + PLaMo ---")
vss_rows = vss_search(conn, query)
for id, score, content in vss_rows:
print(f"ID: {id}, Score: {score:.4f}, Content: {content}")
# Reranking
print("--- Reranking ---")
reranking_rows = reranking(query, vss_rows, fts_rows)
for id, score, content in reranking_rows:
print(f"ID: {id}, Score: {score:.4f}, Content: {content}")
return reranking_rows[0][2]
@mcp.tool()
def search(question: str) -> str:
"""React のドキュメントを参照して、回答を取得する"""
answer = hybrid_search(question)
return answer
def main():
print("MCP Server Start")
mcp.run()
print("MCP Server End")
if __name__ == "__main__":
main()