"""Microsoft Graph API client with automatic token refresh and rate limiting."""
import asyncio
import base64
import logging
from datetime import UTC, datetime, timedelta
from typing import Any
import httpx
from .auth import AuthManager
from .models import (
FileContent,
FileListResult,
FileMetadata,
Library,
ListItem,
SearchResult,
SharingLink,
Site,
UploadResult,
)
logger = logging.getLogger(__name__)
GRAPH_BASE = "https://graph.microsoft.com/v1.0"
MAX_FILE_SIZE = 10 * 1024 * 1024 # 10MB
class GraphClient:
"""Async client for Microsoft Graph API."""
def __init__(self, auth: AuthManager):
"""Initialize the Graph API client.
Args:
auth: Auth manager for token acquisition
"""
self.auth = auth
self._client: httpx.AsyncClient | None = None
async def _get_client(self) -> httpx.AsyncClient:
"""Get or create HTTP client with current auth token.
Returns:
Configured HTTP client
"""
token = await self.auth.get_access_token()
if self._client is None:
self._client = httpx.AsyncClient(
base_url=GRAPH_BASE,
timeout=30.0,
)
self._client.headers["Authorization"] = f"Bearer {token}"
return self._client
async def close(self) -> None:
"""Close the HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
async def _request(
self,
method: str,
path: str,
**kwargs: Any,
) -> dict[str, Any]:
"""Make authenticated request with retry logic.
Args:
method: HTTP method
path: API path
**kwargs: Additional arguments for httpx.request
Returns:
JSON response
Raises:
GraphAPIError: If request fails after retries
"""
client = await self._get_client()
for attempt in range(3):
try:
response = await client.request(method, path, **kwargs)
if response.status_code == 429:
# Rate limited - wait and retry
retry_after = int(response.headers.get("Retry-After", 5))
logger.warning(f"Rate limited, waiting {retry_after}s")
await asyncio.sleep(retry_after)
continue
if response.status_code == 401:
# Token expired - refresh and retry
logger.info("Token expired, refreshing")
await self.auth.ensure_authenticated()
client = await self._get_client()
continue
response.raise_for_status()
return response.json() if response.content else {}
except httpx.HTTPStatusError as e:
if attempt == 2: # Last attempt
raise GraphAPIError(
f"Request failed: {e.response.status_code} {e.response.text}"
)
logger.warning(f"Request failed (attempt {attempt + 1}/3): {e}")
await asyncio.sleep(1)
raise GraphAPIError("Max retries exceeded")
# ========== Sites ==========
async def get_sites(self, search: str | None = None) -> list[Site]:
"""List SharePoint sites.
Args:
search: Optional search query to filter sites
Returns:
List of sites
"""
if search:
data = await self._request("GET", f"/sites?search={search}")
else:
data = await self._request("GET", "/sites?search=*")
return [
Site(
id=s["id"],
name=s["displayName"],
url=s["webUrl"],
description=s.get("description"),
)
for s in data.get("value", [])
]
async def get_libraries(self, site_id: str) -> list[Library]:
"""List document libraries in a site.
Args:
site_id: Site identifier
Returns:
List of libraries
"""
data = await self._request("GET", f"/sites/{site_id}/drives")
return [
Library(
id=d["id"],
name=d["name"],
web_url=d["webUrl"],
item_count=d.get("quota", {}).get("used", 0),
)
for d in data.get("value", [])
]
# ========== Files ==========
async def list_files(
self,
site_id: str,
library_id: str,
folder_path: str = "",
cursor: str | None = None,
limit: int = 50,
) -> FileListResult:
"""List files in a library/folder with pagination.
Args:
site_id: Site identifier
library_id: Library/drive identifier
folder_path: Path within library (empty for root)
cursor: Pagination cursor from previous response
limit: Maximum items to return
Returns:
Paginated file list
"""
if cursor:
# Use the cursor URL directly (it's a full URL)
data = await self._request("GET", cursor.replace(GRAPH_BASE, ""))
else:
path = f"/sites/{site_id}/drives/{library_id}/root"
if folder_path:
# Ensure folder_path doesn't start with /
folder_path = folder_path.lstrip("/")
path += f":/{folder_path}:"
path += f"/children?$top={limit}"
data = await self._request("GET", path)
files = []
for item in data.get("value", []):
# Only include files, not folders
if "file" not in item:
continue
files.append(
FileMetadata(
id=item["id"],
name=item["name"],
size=item.get("size", 0),
mime_type=item.get("file", {}).get("mimeType"),
web_url=item["webUrl"],
created_at=self._parse_datetime(item["createdDateTime"]),
modified_at=self._parse_datetime(item["lastModifiedDateTime"]),
created_by=item.get("createdBy", {}).get("user", {}).get("displayName"),
modified_by=item.get("lastModifiedBy", {})
.get("user", {})
.get("displayName"),
)
)
return FileListResult(
files=files,
next_cursor=data.get("@odata.nextLink"),
total_count=data.get("@odata.count"),
)
async def download_file(self, site_id: str, file_id: str) -> FileContent:
"""Download file content.
Args:
site_id: Site identifier
file_id: File identifier
Returns:
File content (text or base64)
Raises:
GraphAPIError: If file is too large or download fails
"""
# Get metadata first
meta = await self.get_file_metadata(site_id, file_id)
if meta.size > MAX_FILE_SIZE:
raise GraphAPIError(f"File too large ({meta.size} bytes). Max: {MAX_FILE_SIZE}")
# Download content
client = await self._get_client()
response = await client.get(f"/sites/{site_id}/drive/items/{file_id}/content")
response.raise_for_status()
content = response.content
is_text = bool(meta.mime_type and meta.mime_type.startswith(("text/", "application/json")))
return FileContent(
name=meta.name,
mime_type=meta.mime_type or "application/octet-stream",
content=content.decode("utf-8") if is_text else base64.b64encode(content).decode(),
is_text=is_text,
)
async def get_file_metadata(self, site_id: str, file_id: str) -> FileMetadata:
"""Get file metadata.
Args:
site_id: Site identifier
file_id: File identifier
Returns:
File metadata
"""
data = await self._request("GET", f"/sites/{site_id}/drive/items/{file_id}")
return FileMetadata(
id=data["id"],
name=data["name"],
size=data.get("size", 0),
mime_type=data.get("file", {}).get("mimeType"),
web_url=data["webUrl"],
created_at=self._parse_datetime(data["createdDateTime"]),
modified_at=self._parse_datetime(data["lastModifiedDateTime"]),
created_by=data.get("createdBy", {}).get("user", {}).get("displayName"),
modified_by=data.get("lastModifiedBy", {}).get("user", {}).get("displayName"),
)
async def upload_file(
self,
site_id: str,
library_id: str,
file_name: str,
content: str,
folder_path: str = "",
is_base64: bool = False,
) -> UploadResult:
"""Upload file to SharePoint.
Args:
site_id: Site identifier
library_id: Library/drive identifier
file_name: Name for the uploaded file
content: File content (text or base64-encoded)
folder_path: Destination folder path (empty for library root)
is_base64: Whether content is base64-encoded
Returns:
Upload result
"""
file_content = base64.b64decode(content) if is_base64 else content.encode("utf-8")
path = f"/sites/{site_id}/drives/{library_id}/root"
if folder_path:
folder_path = folder_path.lstrip("/")
path += f":/{folder_path}/{file_name}:"
else:
path += f":/{file_name}:"
path += "/content"
client = await self._get_client()
response = await client.put(path, content=file_content)
response.raise_for_status()
data = response.json()
return UploadResult(
id=data["id"],
name=data["name"],
web_url=data["webUrl"],
size=data["size"],
)
async def create_sharing_link(
self,
site_id: str,
file_id: str,
link_type: str = "view",
expiration_days: int | None = None,
) -> SharingLink:
"""Create sharing link for a file.
Args:
site_id: Site identifier
file_id: File identifier
link_type: Type of link - "view" or "edit"
expiration_days: Optional link expiration in days
Returns:
Generated sharing link
"""
body: dict[str, Any] = {
"type": link_type,
"scope": "organization",
}
if expiration_days:
expiry = datetime.now(UTC) + timedelta(days=expiration_days)
body["expirationDateTime"] = expiry.isoformat()
data = await self._request(
"POST",
f"/sites/{site_id}/drive/items/{file_id}/createLink",
json=body,
)
expires_at = None
if "expirationDateTime" in data:
expires_at = self._parse_datetime(data["expirationDateTime"])
return SharingLink(
link_url=data["link"]["webUrl"],
link_type=link_type,
expires_at=expires_at,
)
async def search_files(
self,
query: str,
site_id: str | None = None,
) -> list[SearchResult]:
"""Search for files across SharePoint.
Args:
query: Search query (supports KQL syntax)
site_id: Optional site ID to limit search scope
Returns:
List of search results
"""
body = {
"requests": [
{
"entityTypes": ["driveItem"],
"query": {"queryString": query},
}
]
}
data = await self._request("POST", "/search/query", json=body)
results = []
for response_item in data.get("value", []):
for container in response_item.get("hitsContainers", []):
for hit in container.get("hits", []):
resource = hit.get("resource", {})
if not resource.get("id"):
continue
# Extract site name from parent reference
site_name = ""
if "parentReference" in resource and "siteId" in resource["parentReference"]:
# siteId format: "tenant,site-guid,web-guid"
site_parts = resource["parentReference"]["siteId"].split(",")
site_name = site_parts[1] if len(site_parts) > 1 else ""
results.append(
SearchResult(
id=resource.get("id", ""),
name=resource.get("name", ""),
site_name=site_name,
web_url=resource.get("webUrl", ""),
snippet=hit.get("summary"),
)
)
return results
# ========== Lists ==========
async def get_list_items(
self,
site_id: str,
list_id: str,
limit: int = 100,
) -> list[ListItem]:
"""Get items from a SharePoint list.
Args:
site_id: Site identifier
list_id: List identifier
limit: Maximum items to return
Returns:
List of items
"""
data = await self._request(
"GET",
f"/sites/{site_id}/lists/{list_id}/items?expand=fields&$top={limit}",
)
return [
ListItem(
id=item["id"],
fields=item.get("fields", {}),
created_at=self._parse_datetime(item["createdDateTime"]),
modified_at=self._parse_datetime(item["lastModifiedDateTime"]),
)
for item in data.get("value", [])
]
async def create_list_item(
self,
site_id: str,
list_id: str,
fields: dict[str, Any],
) -> ListItem:
"""Create item in a SharePoint list.
Args:
site_id: Site identifier
list_id: List identifier
fields: Field values for the new item
Returns:
Created list item
"""
data = await self._request(
"POST",
f"/sites/{site_id}/lists/{list_id}/items",
json={"fields": fields},
)
return ListItem(
id=data["id"],
fields=data.get("fields", {}),
created_at=self._parse_datetime(data["createdDateTime"]),
modified_at=self._parse_datetime(data["lastModifiedDateTime"]),
)
# ========== User ==========
async def get_current_user(self) -> dict[str, Any]:
"""Get current authenticated user info.
Returns:
User information
"""
return await self._request("GET", "/me")
# ========== Utilities ==========
@staticmethod
def _parse_datetime(dt_str: str) -> datetime:
"""Parse ISO datetime string, removing timezone info.
Args:
dt_str: ISO datetime string
Returns:
Timezone-naive datetime
"""
# Remove 'Z' suffix if present
dt_str = dt_str.rstrip("Z")
# Parse and remove timezone info
dt = datetime.fromisoformat(dt_str)
return dt.replace(tzinfo=None)
class GraphAPIError(Exception):
"""Graph API errors."""
pass