#!/usr/bin/env python3
"""
Azure Security Agent - LangChain + Groq + MCP
Uses LangGraph ReAct agent with ChatGroq and tools from the Azure Security MCP server.
"""
import os
import sys
import asyncio
import traceback
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence
# Unwrap anyio TaskGroup errors (BaseExceptionGroup from exceptiongroup or builtin)
def _format_exception(e: BaseException) -> str:
out = [f"{type(e).__name__}: {e}", ""]
if e.__traceback__:
out.append("".join(traceback.format_tb(e.__traceback__)))
if getattr(e, "exceptions", None):
out.append("Sub-exceptions:")
for i, sub in enumerate(e.exceptions):
out.append(f" [{i}] {type(sub).__name__}: {sub}")
if getattr(sub, "__traceback__", None):
out.append("".join(traceback.format_tb(sub.__traceback__)))
return "\n".join(out)
from langchain_core.tools import BaseTool, StructuredTool
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_groq import ChatGroq
from langgraph.prebuilt import create_react_agent
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from pydantic import BaseModel, create_model
from dotenv import load_dotenv
load_dotenv()
def _json_schema_to_pydantic(name: str, schema: Dict[str, Any]) -> type[BaseModel]:
"""Build a Pydantic model from MCP/JSON schema for tool arguments."""
if not isinstance(schema.get("properties"), dict):
return create_model(f"{name}_Args")
props = schema["properties"]
required = set(schema.get("required") or [])
fields = {}
for key, spec in props.items():
t = str if spec.get("type") == "string" else Any
default = ... if key in required else None
fields[key] = (Optional[t] if default is None else t, default)
return create_model(f"{name}_Args", **fields)
def mcp_tools_to_langchain(
mcp_tools: Sequence[Any],
execute_fn: Any,
) -> List[BaseTool]:
"""Convert MCP tool list to LangChain BaseTool list; execute_fn(tool_name, args_dict) -> str."""
tools = []
for t in mcp_tools:
name = t.name
description = t.description or ""
schema = t.inputSchema if hasattr(t, "inputSchema") else {}
try:
args_schema = _json_schema_to_pydantic(name, schema)
except Exception:
args_schema = None
def _make_runner(tool_name: str):
async def _run(**kwargs: Any) -> str:
return await execute_fn(tool_name, kwargs)
return _run
tool = StructuredTool.from_function(
name=name,
description=description,
func=None,
coroutine=_make_runner(name),
args_schema=args_schema,
)
tools.append(tool)
return tools
class AzureSecurityAgent:
"""Azure security analyst agent using LangChain (ChatGroq + LangGraph) and MCP tools."""
def __init__(self, groq_api_key: str):
self.groq_api_key = groq_api_key
self.model_name = "llama-3.3-70b-versatile"
self.session: Optional[ClientSession] = None
self._tools: List[BaseTool] = []
self._graph = None
async def connect_to_mcp_server(self, read_stream: Any, write_stream: Any) -> None:
"""Connect to the Azure Security MCP server and build LangChain tools."""
self.session = ClientSession(read_stream, write_stream)
await self.session.initialize()
tools_list = await self.session.list_tools()
mcp_tool_list = tools_list.tools
self._tools = mcp_tools_to_langchain(mcp_tool_list, self._execute_tool)
print(f"✓ Connected to MCP server with {len(self._tools)} tools", flush=True)
for t in self._tools:
desc = (t.description or "")[:70]
print(f" - {t.name}: {desc}", flush=True)
async def _execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
"""Execute a tool via MCP (used by LangChain tools)."""
if not self.session:
return "Error: not connected to MCP server."
try:
result = await self.session.call_tool(tool_name, arguments)
if result.content and len(result.content) > 0:
return result.content[0].text
return "No result returned"
except Exception as e:
return f"Error executing tool: {str(e)}"
def _build_agent(self, subscription_id: str):
"""Build LangGraph ReAct agent with system prompt and tools."""
llm = ChatGroq(
api_key=self.groq_api_key,
model=self.model_name,
temperature=0.1,
max_tokens=2000,
)
system_prompt = f"""You are an expert Azure cloud security analyst. Your job is to:
1. Discover and analyze Azure infrastructure for security issues
2. Use the available tools to gather information
3. Provide clear, actionable security recommendations
4. Follow CIS benchmarks and security best practices
Current subscription ID: {subscription_id}
Discovery workflow (use when user asks for "full scan", "discover", or does not provide resource names):
- First call azure_list_resource_groups to see what exists
- Then use azure_list_nsgs, azure_list_storage_accounts, azure_list_resources as needed
- Then run azure_check_nsg_rules on each NSG, azure_check_storage_security on each storage account, azure_list_public_ips for exposure
- For VMs found in resources, use azure_check_vm_security when relevant
When analyzing security:
- Check for open ports and permissive NSG rules
- Verify encryption settings
- Look for public access configurations
- Review identity and access management
- Provide severity ratings (CRITICAL, HIGH, MEDIUM, LOW)
Be thorough but concise. If doing a full discovery, analyze at least one resource group end-to-end and summarize findings."""
self._graph = create_react_agent(
llm,
self._tools,
prompt=system_prompt,
)
async def analyze_infrastructure(self, user_query: str, subscription_id: str) -> str:
"""Run the LangChain agent on the user query and return the final answer."""
self._build_agent(subscription_id)
print(f"\n🤖 Agent analyzing: {user_query}\n", flush=True)
inputs = {"messages": [HumanMessage(content=user_query)]}
try:
result = await self._graph.ainvoke(inputs)
messages = result.get("messages", [])
if not messages:
return "No response from agent."
last = messages[-1]
content = getattr(last, "content", None) or str(last)
if isinstance(content, list):
parts = [getattr(p, "text", str(p)) for p in content]
content = "\n".join(parts)
print(f"\n{'='*60}\nFINAL ANALYSIS\n{'='*60}\n", flush=True)
return content
except Exception as e:
return f"Agent error: {str(e)}"
async def run_interactive_demo(self, subscription_id: str) -> None:
"""Interactive menu to run security queries."""
print("\n" + "=" * 60, flush=True)
print("AZURE SECURITY AGENT (LangChain + Groq + MCP)", flush=True)
print("=" * 60, flush=True)
print(f"\nSubscription ID: {subscription_id}", flush=True)
print("\nAvailable commands:", flush=True)
print(" 1. Analyze a specific resource group", flush=True)
print(" 2. Check all NSG rules in a resource group", flush=True)
print(" 3. Audit storage accounts", flush=True)
print(" 4. List all public IPs", flush=True)
print(" 5. Custom query", flush=True)
print(" 6. Full discovery and security scan", flush=True)
print(" 'exit' to quit", flush=True)
print("\n" + "=" * 60 + "\n", flush=True)
while True:
try:
choice = input("Enter command (1-6 or 'exit'): ").strip()
if choice.lower() == "exit":
print("\nExiting agent...", flush=True)
break
query = None
if choice == "1":
rg = input("Enter resource group name: ").strip()
query = f"List and analyze all resources in the resource group '{rg}' for security issues"
elif choice == "2":
rg = input("Enter resource group name: ").strip()
nsg = input("Enter NSG name: ").strip()
query = f"Analyze the NSG '{nsg}' in resource group '{rg}' and report all security issues"
elif choice == "3":
rg = input("Enter resource group name: ").strip()
storage = input("Enter storage account name: ").strip()
query = f"Audit the security configuration of storage account '{storage}' in resource group '{rg}'"
elif choice == "4":
query = "List all public IP addresses in my subscription and identify any security concerns"
elif choice == "5":
query = input("Enter your security query: ").strip()
elif choice == "6":
query = (
"Perform a full discovery and security scan of my Azure subscription. "
"First list resource groups, then for each resource group (or the first one if many) "
"discover NSGs, storage accounts, and public IPs. Run security checks on each NSG and "
"storage account found. Summarize all security issues with severity and recommendations."
)
else:
print("Invalid choice. Please try again.\n", flush=True)
continue
if query:
result = await self.analyze_infrastructure(query, subscription_id)
print(result, flush=True)
print("\n" + "-" * 60 + "\n", flush=True)
except KeyboardInterrupt:
print("\n\nExiting agent...", flush=True)
break
except Exception as e:
print(f"\nError: {str(e)}\n", flush=True)
async def cleanup(self) -> None:
"""Release MCP session reference."""
self.session = None
async def main() -> None:
print("Azure Security Agent — starting (LangChain)...", flush=True)
groq_api_key = os.getenv("GROQ_API_KEY")
azure_subscription_id = os.getenv("AZURE_SUBSCRIPTION_ID")
if not groq_api_key:
print("Error: GROQ_API_KEY not found in environment variables")
print("Please create a .env file with: GROQ_API_KEY=your_api_key_here")
return
if not azure_subscription_id:
print("Warning: AZURE_SUBSCRIPTION_ID not found in .env")
azure_subscription_id = input("Enter your Azure Subscription ID: ").strip()
agent = AzureSecurityAgent(groq_api_key)
project_dir = Path(__file__).resolve().parent
server_script = project_dir / "server.py"
server_params = StdioServerParameters(
command=sys.executable,
args=[str(server_script)],
env=os.environ.copy(),
)
print("Starting MCP server (server.py)...", flush=True)
try:
async with stdio_client(server_params) as (read_stream, write_stream):
print("Connecting to MCP server...", flush=True)
await agent.connect_to_mcp_server(read_stream, write_stream)
await agent.run_interactive_demo(azure_subscription_id)
except BaseException as e:
msg = _format_exception(e)
print("\n" + msg, file=sys.stderr, flush=True)
traceback.print_exc(file=sys.stderr)
finally:
await agent.cleanup()
if __name__ == "__main__":
asyncio.run(main())