"""NetBox API client with Vault token integration."""
import logging
import time
from typing import Any, Dict, List, Optional
import pynetbox
from pynetbox.core.response import Record
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from src.vault_client import VaultClient
logger = logging.getLogger(__name__)
class NetBoxClient:
"""Client for NetBox API with Vault-based authentication."""
def __init__(
self, netbox_url: str, vault_client: VaultClient, vault_path: str
):
"""Initialize NetBox client.
Args:
netbox_url: NetBox API URL
vault_client: VaultClient instance for token minting
vault_path: Vault path for NetBox tokens
"""
self.netbox_url = netbox_url
self.vault_client = vault_client
self.vault_path = vault_path
self._api: Optional[pynetbox.api] = None
self._token: Optional[str] = None
# Performance optimizations - like good Soviet engineering!
self._cache = {} # Simple in-memory cache
self._cache_ttl = 300 # 5 minutes cache TTL
self._last_cache_cleanup = time.time()
self._session_config = self._create_session_config()
def _create_session_config(self) -> Dict[str, Any]:
"""Create optimized session configuration for better performance.
Returns:
Session configuration dictionary
"""
# Retry strategy - like good Soviet persistence!
retry_strategy = Retry(
total=3,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=[
"HEAD",
"GET",
"POST",
"PUT",
"DELETE",
"OPTIONS",
"TRACE",
],
)
# HTTP adapter with connection pooling
adapter = HTTPAdapter(
max_retries=retry_strategy,
pool_connections=10, # Number of connection pools
pool_maxsize=20, # Maximum connections per pool
pool_block=False, # Don't block when pool is full
)
return {
"adapter": adapter,
"timeout": 30, # 30 second timeout
"verify": True, # SSL verification
}
def _clean_cache(self) -> None:
"""Clean expired cache entries - like good housekeeping!"""
current_time = time.time()
if current_time - self._last_cache_cleanup > 60: # Clean every minute
expired_keys = [
key
for key, (_, timestamp) in self._cache.items()
if current_time - timestamp > self._cache_ttl
]
for key in expired_keys:
del self._cache[key]
self._last_cache_cleanup = current_time
def _get_cache_key(self, method: str, **kwargs) -> str:
"""Generate cache key for method call."""
# Sort kwargs for consistent cache keys
sorted_kwargs = sorted(kwargs.items())
return f"{method}:{str(sorted_kwargs)}"
def _ensure_token(self) -> bool:
"""Ensure we have a valid NetBox API token.
Returns:
True if token is available, False otherwise
"""
token = self.vault_client.mint_netbox_token(self.vault_path)
if token:
self._token = token
return True
return False
def _get_api(self) -> Optional[pynetbox.api]:
"""Get NetBox API client instance with optimized session.
Returns:
pynetbox API instance or None if token unavailable
"""
if not self._token and not self._ensure_token():
logger.error("Failed to obtain NetBox API token")
return None
if not self._api or self._api.token != self._token:
self._api = pynetbox.api(self.netbox_url, token=self._token)
# Apply optimized session configuration - like tuning skates!
session = self._api.http_session
session.mount("http://", self._session_config["adapter"])
session.mount("https://", self._session_config["adapter"])
session.timeout = self._session_config["timeout"]
session.verify = self._session_config["verify"]
return self._api
def _record_to_dict(self, record: Record) -> Dict[str, Any]:
"""Convert NetBox Record to dictionary.
Args:
record: NetBox Record object
Returns:
Dictionary representation of the record
"""
if not record:
return {}
data = {}
for key, value in record.items():
if isinstance(value, Record):
# Handle nested records (e.g., device.rack)
data[key] = {
"id": value.id,
"display": str(value),
"url": getattr(value, "url", None),
}
elif isinstance(value, list):
# Handle lists of records
data[key] = [
(
{
"id": item.id,
"display": str(item),
"url": getattr(item, "url", None),
}
if isinstance(item, Record)
else item
)
for item in value
]
else:
data[key] = value
return data
def _get_cached_result(
self, method: str, **kwargs
) -> Optional[List[Dict[str, Any]]]:
"""Get cached result if available and not expired."""
self._clean_cache()
cache_key = self._get_cache_key(method, **kwargs)
if cache_key in self._cache:
result, timestamp = self._cache[cache_key]
if time.time() - timestamp < self._cache_ttl:
logger.debug(f"Cache hit for {method}")
return result
else:
del self._cache[cache_key]
return None
def _cache_result(
self, method: str, result: List[Dict[str, Any]], **kwargs
) -> None:
"""Cache result for future use."""
cache_key = self._get_cache_key(method, **kwargs)
self._cache[cache_key] = (result, time.time())
# Device (Host) Methods
def list_devices(
self,
name: Optional[str] = None,
primary_ip: Optional[str] = None,
role: Optional[str] = None,
limit: int = 100,
) -> List[Dict[str, Any]]:
"""List devices (hosts) from NetBox with caching.
Args:
name: Filter by device name (partial match)
primary_ip: Filter by primary IP address
role: Filter by device role
limit: Maximum number of results
Returns:
List of device dictionaries
"""
# Check cache first - like checking if skates are already sharpened!
cached_result = self._get_cached_result(
"list_devices",
name=name,
primary_ip=primary_ip,
role=role,
limit=limit,
)
if cached_result is not None:
return cached_result
api = self._get_api()
if not api:
return []
try:
devices = api.dcim.devices.all()
filters = {}
if name:
filters["name__ic"] = name
if primary_ip:
filters["primary_ip4__address"] = primary_ip
if role:
filters["device_role__slug"] = role
if filters:
devices = devices.filter(**filters)
results = []
for device in devices[:limit]:
results.append(self._record_to_dict(device))
# Cache the result - like storing sharpened skates for next time!
self._cache_result(
"list_devices",
results,
name=name,
primary_ip=primary_ip,
role=role,
limit=limit,
)
return results
except Exception as e:
logger.error(f"Error listing devices: {e}")
return []
def get_device(self, name: str) -> Optional[Dict[str, Any]]:
"""Get a specific device by name.
Args:
name: Device name
Returns:
Device dictionary or None if not found
"""
api = self._get_api()
if not api:
return None
try:
device = api.dcim.devices.get(name=name)
if device:
return self._record_to_dict(device)
return None
except Exception as e:
logger.error(f"Error getting device {name}: {e}")
return None
def search_devices(
self, query: str, limit: int = 50
) -> List[Dict[str, Any]]:
"""Search devices by name or IP address.
Args:
query: Search query
limit: Maximum number of results
Returns:
List of matching device dictionaries
"""
api = self._get_api()
if not api:
return []
try:
# Try searching by name first
devices = api.dcim.devices.filter(name__ic=query)
results = []
seen_ids = set()
for device in devices[:limit]:
if device.id not in seen_ids:
results.append(self._record_to_dict(device))
seen_ids.add(device.id)
# Also search by IP if query looks like an IP
if "." in query or ":" in query:
ip_addresses = api.ipam.ip_addresses.filter(address=query)
for ip_addr in ip_addresses:
if ip_addr.assigned_object_type == "dcim.interface":
interface = ip_addr.assigned_object
if (
interface.device
and interface.device.id not in seen_ids
):
results.append(
self._record_to_dict(interface.device)
)
seen_ids.add(interface.device.id)
return results[:limit]
except Exception as e:
logger.error(f"Error searching devices: {e}")
return []
# Virtual Machine Methods
def list_virtual_machines(
self,
name: Optional[str] = None,
role: Optional[str] = None,
primary_ip: Optional[str] = None,
limit: int = 100,
) -> List[Dict[str, Any]]:
"""List virtual machines from NetBox.
Args:
name: Filter by VM name (partial match)
role: Filter by VM role
primary_ip: Filter by primary IP address
limit: Maximum number of results
Returns:
List of VM dictionaries
"""
api = self._get_api()
if not api:
return []
try:
vms = api.virtualization.virtual_machines.all()
filters = {}
if name:
filters["name__ic"] = name
if role:
filters["role__slug"] = role
if primary_ip:
filters["primary_ip4__address"] = primary_ip
if filters:
vms = vms.filter(**filters)
results = []
for vm in vms[:limit]:
results.append(self._record_to_dict(vm))
return results
except Exception as e:
logger.error(f"Error listing virtual machines: {e}")
return []
def get_virtual_machine(self, name: str) -> Optional[Dict[str, Any]]:
"""Get a specific virtual machine by name.
Args:
name: VM name
Returns:
VM dictionary or None if not found
"""
api = self._get_api()
if not api:
return None
try:
vm = api.virtualization.virtual_machines.get(name=name)
if vm:
return self._record_to_dict(vm)
return None
except Exception as e:
logger.error(f"Error getting VM {name}: {e}")
return None
def list_vm_interfaces(self, vm_name: str) -> List[Dict[str, Any]]:
"""List network interfaces for a VM.
Args:
vm_name: Virtual machine name
Returns:
List of interface dictionaries
"""
api = self._get_api()
if not api:
return []
try:
vm = api.virtualization.virtual_machines.get(name=vm_name)
if not vm:
return []
interfaces = []
for interface in vm.interfaces.all():
interface_data = self._record_to_dict(interface)
# Get IP addresses assigned to this interface
ip_addresses = []
for ip_addr in interface.ip_addresses.all():
ip_addresses.append(self._record_to_dict(ip_addr))
interface_data["ip_addresses"] = ip_addresses
interfaces.append(interface_data)
return interfaces
except Exception as e:
logger.error(f"Error listing VM interfaces for {vm_name}: {e}")
return []
# IP Address Methods
def list_ip_addresses(
self,
address: Optional[str] = None,
device: Optional[str] = None,
limit: int = 100,
) -> List[Dict[str, Any]]:
"""List IP addresses from NetBox.
Args:
address: Filter by IP address
device: Filter by device name
limit: Maximum number of results
Returns:
List of IP address dictionaries
"""
api = self._get_api()
if not api:
return []
try:
ip_addresses = api.ipam.ip_addresses.all()
filters = {}
if address:
filters["address"] = address
if device:
filters["device"] = device
if filters:
ip_addresses = ip_addresses.filter(**filters)
results = []
for ip_addr in ip_addresses[:limit]:
results.append(self._record_to_dict(ip_addr))
return results
except Exception as e:
logger.error(f"Error listing IP addresses: {e}")
return []
def get_ip_address(self, address: str) -> Optional[Dict[str, Any]]:
"""Get a specific IP address.
Args:
address: IP address (with CIDR notation if needed)
Returns:
IP address dictionary or None if not found
"""
api = self._get_api()
if not api:
return None
try:
ip_addr = api.ipam.ip_addresses.get(address=address)
if ip_addr:
return self._record_to_dict(ip_addr)
return None
except Exception as e:
logger.error(f"Error getting IP address {address}: {e}")
return None
def search_ip_addresses(
self, query: str, limit: int = 50
) -> List[Dict[str, Any]]:
"""Search IP addresses by address or hostname.
Args:
query: Search query (IP address or hostname)
limit: Maximum number of results
Returns:
List of matching IP address dictionaries
"""
api = self._get_api()
if not api:
return []
try:
ip_addresses = api.ipam.ip_addresses.filter(address__ic=query)
results = []
for ip_addr in ip_addresses[:limit]:
results.append(self._record_to_dict(ip_addr))
return results
except Exception as e:
logger.error(f"Error searching IP addresses: {e}")
return []
# VLAN Methods
def list_vlans(
self,
vid: Optional[int] = None,
name: Optional[str] = None,
site: Optional[str] = None,
limit: int = 100,
) -> List[Dict[str, Any]]:
"""List VLANs from NetBox.
Args:
vid: Filter by VLAN ID
name: Filter by VLAN name
site: Filter by site slug
limit: Maximum number of results
Returns:
List of VLAN dictionaries
"""
api = self._get_api()
if not api:
return []
try:
vlans = api.ipam.vlans.all()
filters = {}
if vid:
filters["vid"] = vid
if name:
filters["name__ic"] = name
if site:
filters["site__slug"] = site
if filters:
vlans = vlans.filter(**filters)
results = []
for vlan in vlans[:limit]:
results.append(self._record_to_dict(vlan))
return results
except Exception as e:
logger.error(f"Error listing VLANs: {e}")
return []
def get_vlan(self, vid: int) -> Optional[Dict[str, Any]]:
"""Get a specific VLAN by ID.
Args:
vid: VLAN ID
Returns:
VLAN dictionary or None if not found
"""
api = self._get_api()
if not api:
return None
try:
vlan = api.ipam.vlans.get(vid=vid)
if vlan:
return self._record_to_dict(vlan)
return None
except Exception as e:
logger.error(f"Error getting VLAN {vid}: {e}")
return None
def list_vlan_ip_addresses(self, vid: int) -> List[Dict[str, Any]]:
"""List IP addresses assigned to a VLAN.
Args:
vid: VLAN ID
Returns:
List of IP address dictionaries
"""
api = self._get_api()
if not api:
return []
try:
vlan = api.ipam.vlans.get(vid=vid)
if not vlan:
return []
ip_addresses = []
for ip_addr in vlan.ip_addresses.all():
ip_addresses.append(self._record_to_dict(ip_addr))
return ip_addresses
except Exception as e:
logger.error(f"Error listing IP addresses for VLAN {vid}: {e}")
return []