# main_db.py
import json
import logging
from typing import Any, Dict, Optional, List, Tuple, Iterable
from datetime import datetime
import nest_asyncio
from mcp.server.fastmcp import FastMCP
from sqlalchemy import text, inspect, func
from sqlalchemy.orm import Query
from database import SessionLocal, engine, Base # uses your existing DB wiring
from models import User, Product, Order # uses your existing models
# ---------- server ----------
mcp = FastMCP("database") # name shows up to MCP clients
# ---------- logging ----------
logger = logging.getLogger("database_mcp")
logger.setLevel(logging.ERROR)
if not logger.handlers:
h = logging.StreamHandler()
h.setLevel(logging.ERROR)
logger.addHandler(h)
def _err(where: str, e: Exception) -> str:
logger.exception("%s failed: %s", where, e)
return f"An error occurred in {where}. See server logs."
# ---------- tiny helpers ----------
def _to_primitive(v: Any) -> Any:
if isinstance(v, (datetime,)):
return v.isoformat()
return v
def _row(obj) -> Dict[str, Any]:
"""Convert a SQLAlchemy model instance to a dict of column values."""
if obj is None:
return None
d: Dict[str, Any] = {}
for c in obj.__table__.columns:
d[c.name] = _to_primitive(getattr(obj, c.name))
return d
class DB:
def __enter__(self):
self.db = SessionLocal()
return self.db
def __exit__(self, exc_type, exc, tb):
try:
if exc:
self.db.rollback()
else:
self.db.commit()
finally:
self.db.close()
# Ensure metadata exists (safe if tables already created)
Base.metadata.create_all(bind=engine)
# ---------- generic utilities (introspection, pagination, sampling) ----------
def _paginate(q: Query, limit: int = 100, offset: int = 0,
order_by: Optional[str] = None, order_dir: str = "asc") -> Query:
"""Apply ordering, limit, and offset to a SQLAlchemy query safely."""
# Only allow ordering by valid column names (best-effort for mapped entity)
if order_by and hasattr(q.column_descriptions[0]['entity'], order_by):
col = getattr(q.column_descriptions[0]['entity'], order_by)
if order_dir.lower() == "desc":
q = q.order_by(col.desc())
else:
q = q.order_by(col.asc())
if offset:
q = q.offset(max(0, int(offset)))
if limit:
q = q.limit(max(0, min(int(limit), 1000))) # cap to 1000
return q
@mcp.tool()
def health_check() -> str:
"""
Quick connectivity test. Returns {"ok": true} if the DB connection works.
"""
try:
with DB() as db:
db.execute(text("SELECT 1"))
return json.dumps({"ok": True}, indent=2)
except Exception as e:
return _err("health_check", e)
@mcp.tool()
def list_tables() -> str:
"""
List table names in the connected database (via SQLAlchemy metadata).
"""
try:
names = [t.name for t in Base.metadata.sorted_tables]
return json.dumps(names, indent=2)
except Exception as e:
return _err("list_tables", e)
@mcp.tool()
def describe_table(table: str) -> str:
"""
Describe columns for a table: name, type, primary_key, nullable, default,
and foreign keys.
"""
try:
insp = inspect(engine)
cols = insp.get_columns(table)
pks = set(insp.get_pk_constraint(table).get("constrained_columns", []) or [])
fks_info = insp.get_foreign_keys(table)
fk_map = {}
for fk in fks_info:
for col in fk.get("constrained_columns", []):
fk_map[col] = {
"referred_table": fk.get("referred_table"),
"referred_columns": fk.get("referred_columns"),
}
out = []
for c in cols:
out.append({
"name": c["name"],
"type": str(c.get("type")),
"primary_key": c["name"] in pks,
"nullable": bool(c.get("nullable", True)),
"default": str(c.get("default")) if c.get("default") is not None else None,
"foreign_key": fk_map.get(c["name"]) or None,
})
return json.dumps(out, indent=2)
except Exception as e:
return _err("describe_table", e)
@mcp.tool()
def table_count(table: str) -> str:
"""
Return number of rows in a table.
"""
try:
with DB() as db:
insp = inspect(engine)
if table not in insp.get_table_names():
return f"Unknown table: {table}"
res = db.execute(text(f"SELECT COUNT(1) AS c FROM {table}"))
n = res.scalar() or 0
return json.dumps({"table": table, "count": int(n)}, indent=2)
except Exception as e:
return _err("table_count", e)
@mcp.tool()
def sample_rows(table: str, limit: int = 5) -> str:
"""
Return up to N example rows from a table (unordered).
"""
try:
with DB() as db:
r = db.execute(text(f"SELECT * FROM {table} LIMIT :n"), {"n": max(0, min(int(limit), 100))})
rows = [dict(x._mapping) for x in r]
return json.dumps(rows, indent=2, default=_to_primitive)
except Exception as e:
return _err("sample_rows", e)
@mcp.tool()
def run_sql_select(sql: str, max_rows: int = 1000) -> str:
"""
Run a read-only SELECT and return rows as JSON. Rejects mutations. Optional
max_rows caps the number of returned rows.
"""
try:
low = sql.strip().lower()
if not low.startswith("select"):
return "Only SELECT allowed here. Use the CRUD tools for mutations."
# Apply a LIMIT cap if caller didn't specify one
if " limit " not in low:
sql = f"{sql.rstrip()} LIMIT {max(1, min(int(max_rows), 10000))}"
with DB() as db:
rows = [dict(r) for r in db.execute(text(sql)).mappings().all()]
return json.dumps(rows, indent=2, default=_to_primitive)
except Exception as e:
return _err("run_sql_select", e)
# ---------- USER tools ----------
@mcp.tool()
def user_create(name: str, email: str, password: str) -> str:
"""
Create a user.
"""
try:
with DB() as db:
u = User(name=name, email=email, password=password)
db.add(u)
db.flush()
return json.dumps(_row(u), indent=2)
except Exception as e:
return _err("user_create", e)
@mcp.tool()
def user_get(id: Optional[int] = None, email: Optional[str] = None) -> str:
"""
Get a user by id or email.
"""
try:
with DB() as db:
q = db.query(User)
if id is not None:
q = q.filter(User.id == id)
elif email is not None:
q = q.filter(User.email == email)
else:
return "Provide id or email."
u = q.first()
return json.dumps(_row(u) if u else None, indent=2)
except Exception as e:
return _err("user_get", e)
@mcp.tool()
def user_list(limit: int = 100, offset: int = 0, q: Optional[str] = None) -> str:
"""
List users with optional fuzzy search across name and email.
"""
try:
with DB() as db:
query = db.query(User)
if q:
like = f"%{q}%"
query = query.filter((User.name.ilike(like)) | (User.email.ilike(like)))
query = _paginate(query, limit=limit, offset=offset, order_by="id", order_dir="asc")
rows = [ _row(u) for u in query.all() ]
return json.dumps(rows, indent=2)
except Exception as e:
return _err("user_list", e)
@mcp.tool()
def user_exists(email: str) -> str:
"""
Return {"exists": true|false} for a given email.
"""
try:
with DB() as db:
exists = db.query(User.id).filter(User.email == email).first() is not None
return json.dumps({"exists": exists}, indent=2)
except Exception as e:
return _err("user_exists", e)
@mcp.tool()
def user_update(id: int, updates: Dict[str, Any]) -> str:
"""
Update user fields. Example updates={"name":"Alice"}
"""
try:
with DB() as db:
u = db.query(User).filter(User.id == id).first()
if not u:
return "User not found."
for k, v in updates.items():
if hasattr(u, k):
setattr(u, k, v)
db.flush()
return json.dumps(_row(u), indent=2)
except Exception as e:
return _err("user_update", e)
@mcp.tool()
def user_delete(id: int) -> str:
"""
Delete user by id.
"""
try:
with DB() as db:
n = db.query(User).filter(User.id == id).delete()
return f"Deleted {n} user(s)."
except Exception as e:
return _err("user_delete", e)
# ---------- PRODUCT tools ----------
@mcp.tool()
def product_create(name: str, price: float, stock: int) -> str:
"""
Create a product.
"""
try:
with DB() as db:
p = Product(name=name, price=price, stock=stock)
db.add(p)
db.flush()
return json.dumps(_row(p), indent=2)
except Exception as e:
return _err("product_create", e)
@mcp.tool()
def product_get(
id: Optional[int] = None,
name: Optional[str] = None,
min_price: Optional[float] = None,
max_price: Optional[float] = None,
in_stock_only: bool = False,
limit: int = 100,
offset: int = 0,
order_by: str = "id",
order_dir: str = "asc",
) -> str:
"""
Get products. If no id/name provided, returns a paginated list.
Filters: min_price, max_price, in_stock_only. Sorting and pagination supported.
"""
try:
with DB() as db:
q = db.query(Product)
if id is not None:
p = q.filter(Product.id == id).first()
return json.dumps(_row(p) if p else None, indent=2)
if name is not None:
q = q.filter(Product.name == name)
if min_price is not None:
q = q.filter(Product.price >= float(min_price))
if max_price is not None:
q = q.filter(Product.price <= float(max_price))
if in_stock_only:
q = q.filter(Product.stock > 0)
q = _paginate(q, limit=limit, offset=offset, order_by=order_by, order_dir=order_dir)
rows = [_row(r) for r in q.all()]
return json.dumps(rows, indent=2)
except Exception as e:
return _err("product_get", e)
@mcp.tool()
def product_search(q: str, limit: int = 25, offset: int = 0) -> str:
"""
Fuzzy search products by name (ILIKE). Returns a paginated list.
"""
try:
with DB() as db:
like = f"%{q}%"
query = db.query(Product).filter(Product.name.ilike(like))
query = _paginate(query, limit=limit, offset=offset, order_by="id", order_dir="asc")
rows = [_row(p) for p in query.all()]
return json.dumps(rows, indent=2)
except Exception as e:
return _err("product_search", e)
@mcp.tool()
def product_update(name: str, field: str, new_value: str) -> str:
"""Update a single field of a product by name (price, stock, name)."""
db = SessionLocal()
try:
product = db.query(Product).filter(Product.name == name).first()
if not product:
return f"⚠️ No product found with name '{name}'"
if not hasattr(product, field):
return f"⚠️ Invalid field '{field}'. Allowed: name, price, stock"
# Convert data types automatically
if field == "price":
new_value = float(new_value)
if field == "stock":
new_value = int(new_value)
setattr(product, field, new_value)
db.commit()
db.refresh(product)
return f"✅ Updated '{field}' of '{product.name}' to {new_value}"
except Exception as e:
db.rollback()
return f"❌ Update failed: {e}"
finally:
db.close()
@mcp.tool()
def product_adjust_stock(name: str, delta: int) -> str:
"""
Increment (or decrement) product stock by delta using product name.
Returns the updated product.
"""
try:
with DB() as db:
p = (
db.query(Product)
.filter(Product.name == name)
.with_for_update()
.first()
)
if not p:
return "Product not found."
p.stock = int(p.stock) + int(delta)
db.flush()
db.refresh(p)
return json.dumps(_row(p), indent=2)
except Exception as e:
return _err("product_adjust_stock", e)
@mcp.tool()
def product_delete(name: str) -> str:
"""
Delete product(s) by name. Returns number of rows deleted.
"""
try:
with DB() as db:
n = db.query(Product).filter(Product.name == name).delete()
return f"Deleted {n} product(s)."
except Exception as e:
return _err("product_delete", e)
@mcp.tool()
def product_bulk_upsert(products: List[Dict[str, Any]]) -> str:
"""
Bulk upsert by (name): for each item with keys {name, price, stock},
update if name exists, else insert. Returns the resulting rows.
"""
try:
out = []
with DB() as db:
for itm in products:
nm = itm.get("name")
if not nm:
continue
p = db.query(Product).filter(Product.name == nm).first()
if p:
if "price" in itm:
p.price = itm["price"]
if "stock" in itm:
p.stock = int(itm["stock"])
else:
p = Product(name=nm, price=float(itm.get("price", 0.0)), stock=int(itm.get("stock", 0)))
db.add(p)
db.flush()
out.append(_row(p))
return json.dumps(out, indent=2)
except Exception as e:
return _err("product_bulk_upsert", e)
# ---------- ORDER tools (table: order_list) ----------
@mcp.tool()
def order_create(
user_id: int,
product_name: Optional[str] = None,
quantity: int = 1,
product_id: Optional[int] = None,
) -> str:
"""
Create an order (order_list).
Supply either product_name (preferred) or product_id for backwards compatibility.
"""
try:
with DB() as db:
# Resolve product_id if product_name is provided
if product_name and not product_id:
p = db.query(Product).filter(Product.name == product_name).first()
if not p:
return "Invalid product_name."
product_id = p.id
if not product_id:
return "Provide product_name or product_id."
# optional: basic FK sanity check
if not db.query(User).filter(User.id == user_id).first():
return "Invalid user_id."
if not db.query(Product).filter(Product.id == product_id).first():
return "Invalid product_id."
o = Order(user_id=user_id, product_id=product_id, quantity=quantity)
db.add(o)
db.flush()
return json.dumps(_row(o), indent=2)
except Exception as e:
return _err("order_create", e)
@mcp.tool()
def order_get(id: Optional[int] = None, user_id: Optional[int] = None) -> str:
"""
Get one order by id OR list orders for a user_id.
"""
try:
with DB() as db:
q = db.query(Order)
if id is not None:
o = q.filter(Order.id == id).first()
return json.dumps(_row(o) if o else None, indent=2)
if user_id is not None:
rows = q.filter(Order.user_id == user_id).all()
return json.dumps([_row(r) for r in rows], indent=2)
return "Provide id or user_id."
except Exception as e:
return _err("order_get", e)
@mcp.tool()
def order_list(
user_id: Optional[int] = None,
product_id: Optional[int] = None,
date_from: Optional[str] = None,
date_to: Optional[str] = None,
limit: int = 100,
offset: int = 0,
order_by: str = "id",
order_dir: str = "desc",
) -> str:
"""
List orders with filters. date_* should be ISO strings; if model has created_at,
they will be applied to that column when available.
"""
try:
with DB() as db:
q = db.query(Order)
if user_id is not None:
q = q.filter(Order.user_id == user_id)
if product_id is not None:
q = q.filter(Order.product_id == product_id)
# Optional created_at filter if the column exists
if hasattr(Order, "created_at"):
if date_from:
q = q.filter(getattr(Order, "created_at") >= datetime.fromisoformat(date_from))
if date_to:
q = q.filter(getattr(Order, "created_at") <= datetime.fromisoformat(date_to))
q = _paginate(q, limit=limit, offset=offset, order_by=order_by, order_dir=order_dir)
rows = [_row(o) for o in q.all()]
return json.dumps(rows, indent=2)
except Exception as e:
return _err("order_list", e)
@mcp.tool()
def order_update(id: int, updates: Dict[str, Any]) -> str:
"""
Update fields on an order (e.g., quantity or status if present).
"""
try:
with DB() as db:
o = db.query(Order).filter(Order.id == id).first()
if not o:
return "Order not found."
for k, v in updates.items():
if hasattr(o, k):
setattr(o, k, v)
db.flush()
return json.dumps(_row(o), indent=2)
except Exception as e:
return _err("order_update", e)
@mcp.tool()
def order_delete(id: int) -> str:
"""
Delete order by id.
"""
try:
with DB() as db:
n = db.query(Order).filter(Order.id == id).delete()
return f"Deleted {n} order(s)."
except Exception as e:
return _err("order_delete", e)
@mcp.tool()
def order_summary_for_user(user_id: int) -> str:
"""
Aggregate summary for a user: total_orders, total_quantity, and optional
total_value if Product.price exists.
"""
try:
with DB() as db:
total_orders = db.query(func.count(Order.id)).filter(Order.user_id == user_id).scalar() or 0
total_qty = db.query(func.coalesce(func.sum(Order.quantity), 0)).filter(Order.user_id == user_id).scalar() or 0
total_value = None
if hasattr(Product, "price"):
join_q = db.query(func.coalesce(func.sum(Order.quantity * Product.price), 0.0)) \
.join(Product, Product.id == Order.product_id) \
.filter(Order.user_id == user_id)
total_value = float(join_q.scalar() or 0.0)
return json.dumps({
"user_id": user_id,
"total_orders": int(total_orders),
"total_quantity": int(total_qty),
"total_value": total_value,
}, indent=2)
except Exception as e:
return _err("order_summary_for_user", e)
@mcp.tool()
def order_joined(limit: int = 50, offset: int = 0) -> str:
"""
Return orders joined with user and product fields for convenient display.
Keys are prefixed: order_, user_, product_.
"""
try:
with DB() as db:
q = db.query(Order, User, Product).join(User, User.id == Order.user_id).join(Product, Product.id == Order.product_id)
q = _paginate(q, limit=limit, offset=offset, order_by=None)
out: List[Dict[str, Any]] = []
for o, u, p in q.all():
row: Dict[str, Any] = {}
for k, v in _row(o).items():
row[f"order_{k}"] = v
for k, v in _row(u).items():
row[f"user_{k}"] = v
for k, v in _row(p).items():
row[f"product_{k}"] = v
out.append(row)
return json.dumps(out, indent=2)
except Exception as e:
return _err("order_joined", e)
# ---------- run ----------
if __name__ == "__main__":
import asyncio
nest_asyncio.apply()
asyncio.run(mcp.run_stdio_async())