from __future__ import annotations
import datetime
import os
import uuid
from typing import Any, Dict, List, Optional
from fastmcp import FastMCP # FastMCP v2+ uses this import path
from botocore.exceptions import ClientError
from aws_mcp_audit.auth.session import build_boto3_session
from aws_mcp_audit.collectors.regions import list_enabled_regions
from aws_mcp_audit.collectors.ec2 import collect_ec2_region
from aws_mcp_audit.collectors.elbv2 import collect_elbv2_region
from aws_mcp_audit.collectors.rds import collect_rds_region
from aws_mcp_audit.collectors.s3 import collect_s3
from aws_mcp_audit.collectors.telemetry import collect_cloudtrail, collect_cloudwatch_alarm_count
from aws_mcp_audit.collectors.cost_explorer import cost_explorer_summary, cost_explorer_by_dimension
from aws_mcp_audit.checks.exposure import (
check_sg_world_open,
check_public_instances,
check_unassociated_eips,
check_unattached_ebs,
)
from aws_mcp_audit.checks.health import check_unhealthy_targets
from aws_mcp_audit.checks.dataprotection import check_unencrypted_ebs, check_rds_public_or_low_backup
from aws_mcp_audit.checks.telemetry import check_cloudtrail_present, check_cloudwatch_alarm_signal
from aws_mcp_audit.storage.store import save_snapshot, save_findings, read_json, snapshot_dir, write_json
from aws_mcp_audit.reports.render import render_markdown, write_pdf_from_text
from aws_mcp_audit.utils.time import now_iso_utc
mcp = FastMCP("aws-mcp-audit")
DATA_DIR = os.path.join(os.getcwd(), "data")
def _snapshot_id() -> str:
return f"{datetime.datetime.utcnow().strftime('%Y-%m-%dT%H-%M-%SZ')}__{uuid.uuid4().hex[:6]}"
@mcp.tool
def aws_whoami(auth: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
session = build_boto3_session(auth)
sts = session.client("sts", region_name=session.region_name or "us-east-1")
ident = sts.get_caller_identity()
return {
"account": ident.get("Account"),
"arn": ident.get("Arn"),
"user_id": ident.get("UserId"),
}
@mcp.tool
def collect_snapshot(scope: Dict[str, Any], auth: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
session = build_boto3_session(auth)
sid = _snapshot_id()
who = aws_whoami(auth=auth)
requested = scope.get("regions")
if scope.get("all") is True or requested is None:
regions = list_enabled_regions(session)
else:
regions = list(requested)
ec2_by_region: Dict[str, Any] = {}
elbv2_by_region: Dict[str, Any] = {}
rds_by_region: Dict[str, Any] = {}
telemetry_cloudwatch: List[Dict[str, Any]] = []
# collectors (MVP: sequential; add concurrency later)
for r in regions:
ec2_by_region[r] = collect_ec2_region(session, r)
elbv2_by_region[r] = collect_elbv2_region(session, r)
rds_by_region[r] = collect_rds_region(session, r)
telemetry_cloudwatch.append(collect_cloudwatch_alarm_count(session, r))
# global-ish collectors
s3_blob = collect_s3(session)
# CloudTrail: use caller region or us-east-1; trails are global objects but API is regional
ct_region = session.region_name or "us-east-1"
cloudtrail_blob = collect_cloudtrail(session, ct_region)
total_alarm_count = 0
alarm_count_known = True
for row in telemetry_cloudwatch:
if row.get("alarm_count") is None:
alarm_count_known = False
else:
total_alarm_count += int(row.get("alarm_count"))
summary = {
"ec2_instances": sum(len(v.get("instances", [])) for v in ec2_by_region.values()),
"security_groups": sum(len(v.get("security_groups", [])) for v in ec2_by_region.values()),
"ebs_volumes": sum(len(v.get("volumes", [])) for v in ec2_by_region.values()),
"elastic_ips": sum(len(v.get("eips", [])) for v in ec2_by_region.values()),
"elbv2_load_balancers": sum(len(v.get("load_balancers", [])) for v in elbv2_by_region.values()),
"elbv2_target_groups": sum(len(v.get("target_groups", [])) for v in elbv2_by_region.values()),
"rds_instances": sum(len(v.get("instances", [])) for v in rds_by_region.values()),
"rds_clusters": sum(len(v.get("clusters", [])) for v in rds_by_region.values()),
"s3_buckets": len(s3_blob.get("buckets", [])),
}
snapshot: Dict[str, Any] = {
"meta": {
"snapshot_id": sid,
"account_id": who.get("account"),
"collected_at": now_iso_utc(),
"regions": regions,
},
"summary": summary,
"ec2_by_region": ec2_by_region,
"elbv2_by_region": elbv2_by_region,
"rds_by_region": rds_by_region,
"s3": s3_blob,
"telemetry": {
"cloudtrail": cloudtrail_blob,
"cloudwatch_alarms": {
"by_region": telemetry_cloudwatch,
"total_alarm_count": total_alarm_count if alarm_count_known else None,
},
},
}
save_snapshot(DATA_DIR, sid, snapshot)
return {"snapshot_id": sid, "regions": regions, "summary": summary}
@mcp.tool
def get_snapshot_summary(snapshot_id: str) -> Dict[str, Any]:
p = os.path.join(snapshot_dir(DATA_DIR, snapshot_id), "snapshot.json")
snap = read_json(p)
return {"meta": snap.get("meta", {}), "summary": snap.get("summary", {})}
@mcp.tool
def run_checks(snapshot_id: str) -> Dict[str, Any]:
p = os.path.join(snapshot_dir(DATA_DIR, snapshot_id), "snapshot.json")
snap = read_json(p)
findings: List[Dict[str, Any]] = []
# Exposure
findings.extend([f.__dict__ for f in check_sg_world_open(snap)])
findings.extend([f.__dict__ for f in check_public_instances(snap)])
findings.extend([f.__dict__ for f in check_unassociated_eips(snap)])
findings.extend([f.__dict__ for f in check_unattached_ebs(snap)])
# Telemetry signals
findings.extend([f.__dict__ for f in check_cloudtrail_present(snap)])
findings.extend([f.__dict__ for f in check_cloudwatch_alarm_signal(snap)])
# Data protection
findings.extend([f.__dict__ for f in check_unencrypted_ebs(snap)])
findings.extend([f.__dict__ for f in check_rds_public_or_low_backup(snap)])
# Health
findings.extend([f.__dict__ for f in check_unhealthy_targets(snap)])
save_findings(DATA_DIR, snapshot_id, findings)
# v1: 1:1 mapping of finding_set_id to snapshot_id
return {"finding_set_id": snapshot_id, "count": len(findings)}
@mcp.tool
def list_findings(finding_set_id: str, severity_min: str = "LOW") -> Dict[str, Any]:
from aws_mcp_audit.checks.findings import severity_at_least
p = os.path.join(snapshot_dir(DATA_DIR, finding_set_id), "findings.json")
findings = read_json(p)
filt = [f for f in findings if severity_at_least(str(f.get("severity", "LOW")), severity_min)]
return {"count": len(filt), "findings": filt}
@mcp.tool
def cost_signals(snapshot_id: str) -> Dict[str, Any]:
p = os.path.join(snapshot_dir(DATA_DIR, snapshot_id), "snapshot.json")
snap = read_json(p)
by_type: Dict[str, int] = {}
unattached_gb = 0
unassoc_eips = 0
stopped_instances = 0
for region, blob in snap.get("ec2_by_region", {}).items():
for inst in blob.get("instances", []):
t = inst.get("instance_type") or "unknown"
by_type[t] = by_type.get(t, 0) + 1
if inst.get("state") == "stopped":
stopped_instances += 1
for vol in blob.get("volumes", []):
if not vol.get("attached_instance_id"):
unattached_gb += int(vol.get("size_gb") or 0)
for e in blob.get("eips", []):
if not e.get("association_id") and not e.get("instance_id"):
unassoc_eips += 1
out = {
"ec2_instance_type_counts": dict(sorted(by_type.items(), key=lambda kv: kv[0])),
"stopped_instances": stopped_instances,
"unattached_ebs_gb": unattached_gb,
"unassociated_eips": unassoc_eips,
"note": "Tier-1 signals only (derived from inventory).",
}
write_json(os.path.join(snapshot_dir(DATA_DIR, snapshot_id), "cost.json"), out)
return out
@mcp.tool
def cost_explorer_summary_tool(days: int = 30, granularity: str = "DAILY", auth: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
session = build_boto3_session(auth)
out = cost_explorer_summary(session, days=days, granularity=granularity)
return out
@mcp.tool
def cost_explorer_by_service(days: int = 30, top_n: int = 10, auth: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
session = build_boto3_session(auth)
return cost_explorer_by_dimension(session, days=days, dimension="SERVICE", top_n=top_n)
@mcp.tool
def cost_explorer_by_region(days: int = 30, top_n: int = 10, auth: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
session = build_boto3_session(auth)
return cost_explorer_by_dimension(session, days=days, dimension="REGION", top_n=top_n)
@mcp.tool
def generate_report(snapshot_id: str, finding_set_id: Optional[str] = None, format: str = "md", auth: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
s_path = os.path.join(snapshot_dir(DATA_DIR, snapshot_id), "snapshot.json")
snap = read_json(s_path)
fsid = finding_set_id or snapshot_id
f_path = os.path.join(snapshot_dir(DATA_DIR, fsid), "findings.json")
findings = read_json(f_path) if os.path.exists(f_path) else []
cost_path = os.path.join(snapshot_dir(DATA_DIR, snapshot_id), "cost.json")
cost = read_json(cost_path) if os.path.exists(cost_path) else {}
# Optionally include cost explorer in report (best-effort)
cost_ce = {}
try:
if auth is not None:
session = build_boto3_session(auth)
cost_ce = cost_explorer_summary(session, days=30, granularity="DAILY")
except Exception as e:
cost_ce = {"error": str(e), "results": []}
md = render_markdown(snap, findings, cost, cost_ce)
out_dir = snapshot_dir(DATA_DIR, snapshot_id)
md_path = os.path.join(out_dir, "report.md")
with open(md_path, "w", encoding="utf-8") as f:
f.write(md)
if format.lower() == "pdf":
pdf_path = os.path.join(out_dir, "report.pdf")
write_pdf_from_text(pdf_path=pdf_path, title="AWS Audit Snapshot", text=md)
return {"report_md": md_path, "report_pdf": pdf_path}
return {"report_md": md_path}
def main() -> None:
# stdio transport by default (Claude Desktop compatible)
mcp.run()
if __name__ == "__main__":
main()