server.py•6.11 kB
from typing import Any
from mcp.server.fastmcp import FastMCP
from starlette.applications import Starlette
from mcp.server.sse import SseServerTransport
from starlette.requests import Request
from starlette.routing import Mount, Route
from mcp.server import Server
from mcp.server.fastmcp.prompts import base
from psycopg_pool import ConnectionPool
import cocoindex
from numpy.typing import NDArray
import numpy as np
from pgvector.psycopg import register_vector
import os
from dotenv import load_dotenv
# Initialize FastMCP server for codebase search (SSE)
mcp = FastMCP("CodebaseRagMCP")
@cocoindex.transform_flow()
def code_to_embedding(
text: cocoindex.DataSlice[str],
) -> cocoindex.DataSlice[NDArray[np.float32]]:
"""
Embed the text using a SentenceTransformer model.
"""
# You can also switch to Voyage embedding model:
return text.transform(
cocoindex.functions.EmbedText(
api_type=cocoindex.LlmApiType.VOYAGE,
model="voyage-code-3",
)
)
def search(pool: ConnectionPool, query: str, top_k: int = 5) -> list[dict[str, Any]]:
# Get the table name from environment variable
table_name = os.getenv("EMBEDDING_TABLE")
# Evaluate the transform flow defined above with the input query, to get the embedding.
query_vector = code_to_embedding.eval(query)
# Run the query and get the results.
with pool.connection() as conn:
register_vector(conn)
with conn.cursor() as cur:
cur.execute(
f"""
SELECT filename, code, embedding <=> %s AS distance, start, "end"
FROM {table_name} ORDER BY distance LIMIT %s
""",
(query_vector, top_k),
)
return [
{
"filename": row[0],
"code": row[1],
"score": 1.0 - row[2],
"start": row[3],
"end": row[4],
}
for row in cur.fetchall()
]
def list_files(pool: ConnectionPool) -> list[dict[str, Any]]:
# Get the table name from environment variable
table_name = os.getenv("TRACKING_TABLE")
# Run the query and get the results.
with pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(
f"""
SELECT source_key
FROM {table_name}
"""
)
return [
{
"filename": row[0],
}
for row in cur.fetchall()
]
def query_file_content(pool: ConnectionPool, filename: str) -> list[dict[str, Any]]:
# Get the table name from environment variable
table_name = os.getenv("EMBEDDING_TABLE")
# Run the query and get the results.
with pool.connection() as conn:
with conn.cursor() as cur:
cur.execute(
f"""
SELECT filename, code, start, "end", location
FROM {table_name} WHERE filename = %s
ORDER BY lower(location)
""",
(filename,),
)
return [
{
"filename": row[0],
"code": row[1],
}
for row in cur.fetchall()
]
@mcp.tool()
async def search_codebase(query: str) -> str:
"""Search for related code from the complete codebase.
Args:
query: Search term to find relevant code snippets
"""
pool = ConnectionPool(os.getenv("DATABASE_URL"))
if query == "":
return "No query provided."
# Run the query function with the database connection pool and the query.
results = search(pool, query)
if not results:
return "No results found."
output_parts = []
for result in results:
output_parts.append(
f"[{result['score']:.3f}] {result['filename']} (L{result['start']['line']}-L{result['end']['line']})"
)
output_parts.append(f" {result['code']}")
output_parts.append("---")
return "\n".join(output_parts)
@mcp.tool()
async def get_files() -> str:
"""Get all the files in the codebase.
"""
pool = ConnectionPool(os.getenv("DATABASE_URL"))
results = list_files(pool)
if not results:
return "No results found."
output_parts = []
for result in results:
output_parts.append(f"{result['filename']}")
return "\n".join(output_parts)
@mcp.tool()
async def get_file_content(filename: str) -> str:
"""Get the complete content by filename.
Args:
filename: The filename to get the content for
"""
pool = ConnectionPool(os.getenv("DATABASE_URL"))
results = query_file_content(pool, filename)
if not results:
return "No results found."
output_parts = []
for result in results:
output_parts.append(f"{result['code']}")
return "\n".join(output_parts)
@mcp.prompt()
def get_initial_prompts() -> list[base.Message]:
return [
base.UserMessage(
"You are a helpful assistant that can help with codebase questions."),
]
def create_starlette_app(mcp_server: Server, *, debug: bool = False) -> Starlette:
"""Create a Starlette application that can server the provided mcp server with SSE."""
sse = SseServerTransport("/messages/")
async def handle_sse(request: Request) -> None:
async with sse.connect_sse(
request.scope,
request.receive,
request._send,
) as (read_stream, write_stream):
await mcp_server.run(
read_stream,
write_stream,
mcp_server.create_initialization_options(),
)
return Starlette(
debug=debug,
routes=[
Route("/sse", endpoint=handle_sse),
Mount("/messages/", app=sse.handle_post_message),
],
)
if __name__ == "__main__":
load_dotenv()
cocoindex.init()
mcp.run(transport="stdio")