server.py•10.9 kB
from typing import Dict, List
import requests
from mcp.server.fastmcp import FastMCP
from config import ConfigManager
from observable import ObservableType
from services import (
EnrichmentService,
VirusTotalIP,
VirusTotalDomain,
AlienVaultIP,
AlienVaultDomain,
ShodanIP,
ShodanDomain,
HybridAnalysisIP,
HybridAnalysisDomain,
UrlscanIP,
UrlscanDomain,
UrlscanUrl,
AbuseIPDBIP,
HaveIBeenPwnedEmail,
)
# loads the config from either our config file or env variables
config = ConfigManager().load()
# Initialise FastMCP server
mcp = FastMCP("enrichment-mcp")
ALL_SERVICES: Dict[ObservableType, List[EnrichmentService]] = {
ObservableType.IPV4: [
AbuseIPDBIP,
AlienVaultIP,
HybridAnalysisIP,
ShodanIP,
UrlscanIP,
VirusTotalIP
],
ObservableType.DOMAIN: [
AlienVaultDomain,
HybridAnalysisDomain,
ShodanDomain,
UrlscanDomain,
VirusTotalDomain,
],
ObservableType.URL: [
UrlscanUrl,
],
ObservableType.EMAIL: [
HaveIBeenPwnedEmail,
]
}
@mcp.prompt(
name="enrichment-prompt",
description="Standard prompt string to support enrichment of provided observable(s)."
)
def setup_prompt(observable: str) -> str:
return f"""
As a security analyst, detection engineer and network security engineer you are responsible for making a risk level determination of one or more provided observables.
Using your knowledge from these diverse fields, networking constructs, detection (security) reasoning, and responses from third-party enrichment services.
Carefully consider the output from these services along with historical knowledge both internal and external from an organization to make a determination of the risk of a provided
observable. Make a determination based on all these factors on whether the observable is benign, suspicious, malicious, unknown. If unknown provide suggestions for other relative context
that may be needed in order to make the determination.
Your objective is to assist with the threat determination of a given observable. The observable is {observable}
"""
@mcp.tool(
name="lookup-observable",
description="A generic tool which takes any observable and passes it the correct tool."
)
async def lookup(value: str) -> str:
ioc_type = ObservableType.from_observable_value(value)
if ioc_type == ObservableType.IPV4:
return await lookup_ipaddress(value)
elif ioc_type == ObservableType.DOMAIN:
return await lookup_domain(value)
elif ioc_type == ObservableType.URL:
return await lookup_url(value)
elif ioc_type == ObservableType.EMAIL:
return await lookup_email(value)
return ""
@mcp.tool(
name="lookup-ipaddress",
description="Performs third-party enrichment lookup for the provided ipv4 address."
)
async def lookup_ipaddress(ipaddress: str) -> Dict[str, str]:
if not ObservableType.from_observable_value(ipaddress) == ObservableType.IPV4:
return ""
responses: Dict[str, str] = {}
for source in config.enrichments.lookups.ipaddress:
error_string: str = f"error occurred looking up ip {ipaddress} in {source.name}"
if source.name == "virustotal":
virustotal = VirusTotalIP()
req = virustotal.get(ipaddress=ipaddress)
req.headers["x-apikey"] = source.apikey
resp = requests.Session().send(request=req)
if resp.ok:
responses[source.name] = virustotal.parse_response(resp, source.template)
else:
responses[source.name] = error_string
elif source.name == "alienvault":
alienvault = AlienVaultIP()
req = alienvault.get(ipaddress=ipaddress)
req.headers["X-OTX-API-KEY"] = source.apikey
resp = requests.Session().send(request=req)
if resp.ok:
responses[source.name] = alienvault.parse_response(resp, source.template)
else:
responses[source.name] = error_string
elif source.name == "shodan":
shodan = ShodanIP()
req = shodan.get(ipaddress=ipaddress)
req.url = f"{req.url}{source.apikey}"
resp = requests.Session().send(request=req)
if resp.ok:
responses[source.name] = shodan.parse_response(resp, source.template)
else:
responses[source.name] = error_string
elif source.name == "hybridanalysis":
hybrid = HybridAnalysisIP()
req = hybrid.get(ipaddress=ipaddress)
req.headers["api_key"] = source.apikey
req.body = {
"host": ipaddress
}
resp = requests.Session().send(request=req)
if resp.ok:
responses[source.name] = hybrid.parse_response(resp, source.template)
else:
responses[source.name] = error_string
elif source.name == "urlscan":
urlscan = UrlscanIP()
req = urlscan.get(ipaddress=ipaddress)
req.headers["API-Key"] = source.apikey
resp = requests.Session().send(request=req)
if resp.ok:
responses[source.name] = urlscan.parse_response(resp, source.template)
else:
responses[source.name] = error_string
elif source.name == "abuseipdb":
abuse = AbuseIPDBIP()
req = abuse.get(ipaddress=ipaddress)
req.headers["Key"] = source.apikey
resp = requests.Session().send(request=req)
if resp.ok:
responses[source.name] = abuse.parse_response(resp, source.template)
else:
responses[source.name] = error_string
return responses
@mcp.tool(
name="lookup-domain",
description="Performs third-party enrichment lookup for the provided domain."
)
async def lookup_domain(domain: str) -> requests.Response:
if not ObservableType.from_observable_value(domain) == ObservableType.DOMAIN:
return ""
responses: Dict[str, List[Dict[str, str]]] = {}
for source in config.enrichments.lookups.domain:
error_string: str = f"error occurred looking up domain {domain} in {source.name}"
if source.name == "virustotal":
virustotal = VirusTotalDomain()
req = virustotal.get(domain=domain)
req.headers["x-apikey"] = source.apikey
resp = requests.Session().send(request=req)
if resp.ok:
responses[source.name] = virustotal.parse_response(resp, source.template)
else:
responses[source.name] = error_string
elif source.name == "alienvault":
alienvault = AlienVaultDomain()
req = alienvault.get(domain=domain)
req.headers["x-apikey"] = source.apikey
resp = requests.Session().send(request=req)
if resp.ok:
responses[source.name] = alienvault.parse_response(resp, source.template)
else:
responses[source.name] = error_string
elif source.name == "shodan":
shodan = ShodanDomain()
req = shodan.get(domain=domain)
req.headers["x-apikey"] = source.apikey
resp = requests.Session().send(request=req)
if resp.ok:
responses[source.name] = shodan.parse_response(resp, source.template)
else:
responses[source.name] = error_string
elif source.name == "hybridanalysis":
hybrid = HybridAnalysisDomain()
req = hybrid.get(domain=domain)
req.headers["api_key"] = source.apikey
req.body = {
"domain": domain
}
resp = requests.Session().send(request=req)
if resp.ok:
responses[source.name] = hybrid.parse_response(resp, source.template)
else:
responses[source.name] = error_string
elif source.name == "urlscan":
urlscan = UrlscanDomain()
req = urlscan.get(domain=domain)
req.headers["API-Key"] = source.apikey
resp = requests.Session().send(request=req)
if resp.ok:
responses[source.name] = urlscan.parse_response(resp, source.template)
else:
responses[source.name] = error_string
return responses
@mcp.tool(
name="lookup-url",
description="Performs third-party enrichment lookup for the provided URL."
)
async def lookup_url(url: str) -> requests.Response:
if not ObservableType.from_observable_value(url) == ObservableType.URL:
return ""
responses: Dict[str, List[Dict[str, str]]] = {}
for source in config.enrichments.lookups.url:
error_string: str = f"error occurred looking up url {url} in {source.name}"
if source.name == "urlscan":
urlscan = UrlscanUrl()
req = urlscan.get(url=url)
req.headers["API-Key"] = source.apikey
resp = requests.Session().send(request=req)
if resp.ok:
responses[source.name] = urlscan.parse_response(resp, source.template)
else:
responses[source.name] = error_string
return responses
@mcp.tool(
name="lookup-email",
description="Performs third-party enrichment lookup for the provided email address."
)
async def lookup_email(email: str) -> requests.Response:
if not ObservableType.from_observable_value(email) == ObservableType.EMAIL:
return ""
responses: Dict[str, List[Dict[str, str]]] = {}
for source in config.enrichments.lookups.email:
error_string: str = f"error occurred looking up email {email} in {source.name}"
if source.name == "hibp":
hibp = HaveIBeenPwnedEmail()
req = hibp.get(email=email)
req.headers["hibp-api-key"] = source.apikey
resp = requests.Session().send(request=req)
if resp.ok:
responses[source.name] = hibp.parse_response(resp, source.template)
else:
responses[source.name] = error_string
return responses
if __name__ == "__main__":
# mcp.run(transport='stdio')
import asyncio
loop = asyncio.get_event_loop()
result = asyncio.run(lookup("91.195.240.94"))
print(f"result: {result}")
# virustotal = VirusTotalIP()
# req = virustotal.get(ipaddress="91.195.240.94")
# req.headers["x-apikey"] = "ea012164d5433d485347bea0bd0437518d2994bcef58fb1dd3eb489fdf68bf8a"
# resp = requests.Session().send(request=req)
# print(f"resp: {resp}")
# print(f"ok: {resp.ok}")
# print(f"text: {resp.json()}")
# res = virustotal.parse_response(resp)
# print(res)