from __future__ import annotations
import hashlib
import io
import json
import secrets
import threading
import time
import uuid
import logging
from dataclasses import dataclass
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from typing import Any, Callable, Optional
from .stata_client import StataClient
from .config import (
DEFAULT_HOST,
DEFAULT_PORT,
MAX_ARROW_LIMIT,
MAX_CHARS,
MAX_LIMIT,
MAX_REQUEST_BYTES,
MAX_VARS,
TOKEN_TTL_S,
VIEW_TTL_S,
)
logger = logging.getLogger("mcp_stata")
try:
from .native_ops import argsort_numeric as _native_argsort_numeric
from .native_ops import argsort_mixed as _native_argsort_mixed
except Exception:
_native_argsort_numeric = None
_native_argsort_mixed = None
def _try_native_argsort(
table: Any,
sort_cols: list[str],
descending: list[bool],
nulls_last: list[bool],
) -> list[int] | None:
if _native_argsort_numeric is None and _native_argsort_mixed is None:
return None
try:
import pyarrow as pa
import numpy as np
is_string: list[bool] = []
cols: list[object] = []
for col in sort_cols:
arr = table.column(col).combine_chunks()
if pa.types.is_string(arr.type) or pa.types.is_large_string(arr.type):
is_string.append(True)
cols.append(arr.to_pylist())
continue
if not (pa.types.is_floating(arr.type) or pa.types.is_integer(arr.type)):
return None
is_string.append(False)
np_arr = arr.to_numpy(zero_copy_only=False)
if np_arr.dtype != np.float64:
np_arr = np_arr.astype(np.float64, copy=False)
# Normalize Stata missing values for numeric columns
np_arr = np.where(np_arr > 8.0e307, np.nan, np_arr)
cols.append(np_arr)
if "_n" not in table.column_names:
return None
obs = table.column("_n").to_numpy(zero_copy_only=False).astype(np.int64, copy=False)
if all(not flag for flag in is_string) and _native_argsort_numeric is not None:
idx = _native_argsort_numeric(cols, descending, nulls_last)
elif _native_argsort_mixed is not None:
idx = _native_argsort_mixed(cols, is_string, descending, nulls_last)
else:
return None
return [int(x) for x in (obs[idx] - 1).tolist()]
except Exception:
return None
def _get_sorted_indices_polars(
table: Any,
sort_cols: list[str],
descending: list[bool],
nulls_last: list[bool],
) -> list[int]:
import polars as pl
df = pl.from_arrow(table)
# Normalize Stata missing values for numeric columns
exprs = []
for col, dtype in zip(df.columns, df.dtypes):
if col == "_n":
exprs.append(pl.col(col))
continue
if dtype in (pl.Float32, pl.Float64):
exprs.append(
pl.when(pl.col(col) > 8.0e307)
.then(None)
.otherwise(pl.col(col))
.alias(col)
)
else:
exprs.append(pl.col(col))
df = df.select(exprs)
try:
# Use expressions for arithmetic to avoid eager Series-scalar conversion issues
# that have been observed in some environments with Int64 dtypes.
res = df.select(
idx=pl.arg_sort_by(
[pl.col(c) for c in sort_cols],
descending=descending,
nulls_last=nulls_last,
),
zero_based_n=pl.col("_n") - 1
)
return res["zero_based_n"].take(res["idx"]).to_list()
except Exception:
# Fallback to eager sort if arg_sort_by fails or has issues
return (
df.sort(by=sort_cols, descending=descending, nulls_last=nulls_last)
.select(pl.col("_n") - 1)
.to_series()
.to_list()
)
def _stable_hash(payload: dict[str, Any]) -> str:
return hashlib.sha1(json.dumps(payload, sort_keys=True).encode("utf-8")).hexdigest()
@dataclass
class UIChannelInfo:
base_url: str
token: str
expires_at: int
@dataclass
class ViewHandle:
view_id: str
dataset_id: str
frame: str
filter_expr: str
obs_indices: list[int]
filtered_n: int
created_at: float
last_access: float
class UIChannelManager:
def __init__(
self,
client: StataClient,
*,
host: str = DEFAULT_HOST,
port: int = DEFAULT_PORT,
token_ttl_s: int = TOKEN_TTL_S,
view_ttl_s: int = VIEW_TTL_S,
max_limit: int = MAX_LIMIT,
max_vars: int = MAX_VARS,
max_chars: int = MAX_CHARS,
max_request_bytes: int = MAX_REQUEST_BYTES,
max_arrow_limit: int = MAX_ARROW_LIMIT,
):
self._client = client
self._host = host
self._port = port
self._token_ttl_s = token_ttl_s
self._view_ttl_s = view_ttl_s
self._max_limit = max_limit
self._max_vars = max_vars
self._max_chars = max_chars
self._max_request_bytes = max_request_bytes
self._max_arrow_limit = max_arrow_limit
self._lock = threading.Lock()
self._httpd: ThreadingHTTPServer | None = None
self._thread: threading.Thread | None = None
self._token: str | None = None
self._expires_at: int = 0
self._dataset_version: int = 0
self._dataset_id_cache: str | None = None
self._dataset_id_cache_at_version: int = -1
self._views: dict[str, dict[str, ViewHandle]] = {} # session_id -> {view_id -> ViewHandle}
self._sort_index_cache: dict[tuple[str, str, tuple[str, ...]], list[int]] = {} # (session_id, dataset_id, sort_spec)
self._sort_cache_order: list[tuple[str, str, tuple[str, ...]]] = []
self._sort_cache_max_entries: int = 10
self._sort_table_cache: dict[tuple[str, str, tuple[str, ...]], Any] = {} # (session_id, dataset_id, sort_cols)
self._sort_table_order: list[tuple[str, str, tuple[str, ...]]] = []
self._sort_table_max_entries: int = 4
self._dataset_id_caches: dict[str, tuple[str, int]] = {} # session_id -> (digest, version)
def notify_potential_dataset_change(self, session_id: str = "default") -> None:
with self._lock:
self._dataset_id_caches.pop(session_id, None)
if session_id in self._views:
self._views[session_id].clear()
# Clear caches for this session
self._sort_cache_order = [k for k in self._sort_cache_order if k[0] != session_id]
self._sort_index_cache = {k: v for k, v in self._sort_index_cache.items() if k[0] != session_id}
self._sort_table_order = [k for k in self._sort_table_order if k[0] != session_id]
self._sort_table_cache = {k: v for k, v in self._sort_table_cache.items() if k[0] != session_id}
@staticmethod
def _normalize_sort_spec(sort_spec: list[str]) -> tuple[str, ...]:
normalized: list[str] = []
for spec in sort_spec:
if not isinstance(spec, str) or not spec:
raise ValueError(f"Invalid sort specification: {spec!r}")
raw = spec.strip()
if not raw:
raise ValueError(f"Invalid sort specification: {spec!r}")
sign = "-" if raw.startswith("-") else "+"
varname = raw.lstrip("+-")
if not varname:
raise ValueError(f"Invalid sort specification: {spec!r}")
normalized.append(f"{sign}{varname}")
return tuple(normalized)
def _get_cached_sort_indices(
self, session_id: str, dataset_id: str, sort_spec: tuple[str, ...]
) -> list[int] | None:
key = (session_id, dataset_id, sort_spec)
with self._lock:
cached = self._sort_index_cache.get(key)
if cached is None:
return None
if key in self._sort_cache_order:
self._sort_cache_order.remove(key)
self._sort_cache_order.append(key)
return cached
def _set_cached_sort_indices(
self, session_id: str, dataset_id: str, sort_spec: tuple[str, ...], indices: list[int]
) -> None:
key = (session_id, dataset_id, sort_spec)
with self._lock:
if key in self._sort_index_cache:
self._sort_cache_order.remove(key)
self._sort_index_cache[key] = indices
self._sort_cache_order.append(key)
while len(self._sort_cache_order) > self._sort_cache_max_entries:
evict = self._sort_cache_order.pop(0)
self._sort_index_cache.pop(evict, None)
def _get_cached_sort_table(
self, session_id: str, dataset_id: str, sort_cols: tuple[str, ...]
) -> Any | None:
key = (session_id, dataset_id, sort_cols)
with self._lock:
cached = self._sort_table_cache.get(key)
if cached is None:
return None
if key in self._sort_table_order:
self._sort_table_order.remove(key)
self._sort_table_order.append(key)
return cached
def _set_cached_sort_table(
self, session_id: str, dataset_id: str, sort_cols: tuple[str, ...], table: Any
) -> None:
key = (session_id, dataset_id, sort_cols)
with self._lock:
if key in self._sort_table_cache:
self._sort_table_order.remove(key)
self._sort_table_cache[key] = table
self._sort_table_order.append(key)
while len(self._sort_table_order) > self._sort_table_max_entries:
evict = self._sort_table_order.pop(0)
self._sort_table_cache.pop(evict, None)
def _get_sort_table(self, session_id: str, dataset_id: str, sort_cols: list[str]) -> Any:
sort_cols_key = tuple(sort_cols)
cached = self._get_cached_sort_table(session_id, dataset_id, sort_cols_key)
if cached is not None:
return cached
# Use an appropriate client for the session
proxy = self._get_proxy_for_session(session_id)
state = proxy.get_dataset_state()
n = int(state.get("n", 0) or 0)
if n <= 0:
return None
# Pull full columns once via Arrow stream (Stata -> Arrow), then sort in Polars.
arrow_bytes = proxy.get_arrow_stream(
offset=0,
limit=n,
vars=sort_cols,
include_obs_no=True,
obs_indices=None,
)
import pyarrow as pa
with pa.ipc.open_stream(io.BytesIO(arrow_bytes)) as reader:
table = reader.read_all()
self._set_cached_sort_table(session_id, dataset_id, sort_cols_key, table)
return table
def get_channel(self) -> UIChannelInfo:
self._ensure_http_server()
with self._lock:
self._ensure_token()
assert self._httpd is not None
port = self._httpd.server_address[1]
base_url = f"http://{self._host}:{port}"
return UIChannelInfo(base_url=base_url, token=self._token or "", expires_at=self._expires_at)
def capabilities(self) -> dict[str, bool]:
return {"dataBrowser": True, "filtering": True, "sorting": True, "arrowStream": True}
def _get_proxy_for_session(self, session_id: str) -> StataClient:
# Prefer the injected client when present (used by unit tests and single-session setups).
client = getattr(self, "_client", None)
if client is not None:
from .server import StataClientProxy
if isinstance(client, StataClientProxy):
return StataClientProxy(session_id=session_id or "default")
return client
from .server import StataClientProxy
return StataClientProxy(session_id=session_id or "default")
def current_dataset_id(self, session_id: str = "default") -> str:
with self._lock:
cached = self._dataset_id_caches.get(session_id)
if cached:
digest, version = cached
if version == self._dataset_version:
return digest
proxy = self._get_proxy_for_session(session_id)
state = proxy.get_dataset_state()
payload = {
"version": self._dataset_version,
"frame": state.get("frame"),
"n": state.get("n"),
"k": state.get("k"),
"sortlist": state.get("sortlist"),
}
digest = _stable_hash(payload)
with self._lock:
self._dataset_id_caches[session_id] = (digest, self._dataset_version)
return digest
def get_view(self, session_id: str, view_id: str) -> Optional[ViewHandle]:
now = time.time()
with self._lock:
self._evict_expired_locked(now)
session_views = self._views.get(session_id)
if session_views is None:
return None
view = session_views.get(view_id)
if view is None:
return None
view.last_access = now
return view
def create_view(self, *, session_id: str, dataset_id: str, frame: str, filter_expr: str) -> ViewHandle:
current_id = self.current_dataset_id(session_id)
if dataset_id != current_id:
raise DatasetChangedError(current_id)
proxy = self._get_proxy_for_session(session_id)
try:
obs_indices = proxy.compute_view_indices(filter_expr)
except ValueError as e:
raise InvalidFilterError(str(e))
except RuntimeError as e:
msg = str(e) or "No data in memory"
if "no data" in msg.lower():
raise NoDataInMemoryError(msg)
raise
now = time.time()
view_id = f"view_{uuid.uuid4().hex}"
view = ViewHandle(
view_id=view_id,
dataset_id=current_id,
frame=frame,
filter_expr=filter_expr,
obs_indices=obs_indices,
filtered_n=len(obs_indices),
created_at=now,
last_access=now,
)
with self._lock:
self._evict_expired_locked(now)
if session_id not in self._views:
self._views[session_id] = {}
self._views[session_id][view_id] = view
return view
def delete_view(self, session_id: str, view_id: str) -> bool:
with self._lock:
session_views = self._views.get(session_id)
if session_views is None:
return False
return session_views.pop(view_id, None) is not None
def validate_token(self, header_value: str | None) -> bool:
if not header_value:
return False
if not header_value.startswith("Bearer "):
return False
token = header_value[len("Bearer ") :].strip()
with self._lock:
self._ensure_token()
if self._token is None:
return False
if time.time() * 1000 >= self._expires_at:
return False
return secrets.compare_digest(token, self._token)
def limits(self) -> tuple[int, int, int, int]:
return self._max_limit, self._max_vars, self._max_chars, self._max_request_bytes
def _ensure_token(self) -> None:
now_ms = int(time.time() * 1000)
if self._token is None or now_ms >= self._expires_at:
self._token = secrets.token_urlsafe(32)
self._expires_at = int((time.time() + self._token_ttl_s) * 1000)
def _evict_expired_locked(self, now: float) -> None:
for session_id in list(self._views.keys()):
session_views = self._views[session_id]
expired: list[str] = []
for key, view in session_views.items():
if now - view.last_access >= self._view_ttl_s:
expired.append(key)
for key in expired:
session_views.pop(key, None)
if not session_views:
self._views.pop(session_id, None)
def _ensure_http_server(self) -> None:
with self._lock:
if self._httpd is not None:
return
manager = self
class Handler(BaseHTTPRequestHandler):
def _send_json(self, status: int, payload: dict[str, Any]) -> None:
data = json.dumps(payload).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)
def _send_binary(self, status: int, data: bytes, content_type: str) -> None:
self.send_response(status)
self.send_header("Content-Type", content_type)
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)
def _error(self, status: int, code: str, message: str, *, stata_rc: int | None = None) -> None:
if status >= 500 or code == "internal_error":
logger.error("UI HTTP error %s: %s", code, message)
message = "Internal server error"
body: dict[str, Any] = {"error": {"code": code, "message": message}}
if stata_rc is not None:
body["error"]["stataRc"] = stata_rc
self._send_json(status, body)
def _require_auth(self) -> bool:
if manager.validate_token(self.headers.get("Authorization")):
return True
self._error(401, "auth_failed", "Unauthorized")
return False
def _read_json(self) -> dict[str, Any] | None:
max_limit, max_vars, max_chars, max_bytes = manager.limits()
_ = (max_limit, max_vars, max_chars)
length = int(self.headers.get("Content-Length", "0") or "0")
if length <= 0:
return {}
if length > max_bytes:
self._error(400, "request_too_large", "Request too large")
return None
raw = self.rfile.read(length)
try:
parsed = json.loads(raw.decode("utf-8"))
except Exception:
self._error(400, "invalid_request", "Invalid JSON")
return None
if not isinstance(parsed, dict):
self._error(400, "invalid_request", "Expected JSON object")
return None
return parsed
def do_GET(self) -> None:
if not self._require_auth():
return
if self.path.startswith("/v1/dataset"):
from urllib.parse import urlparse, parse_qs
parsed_url = urlparse(self.path)
params = parse_qs(parsed_url.query)
session_id = params.get("sessionId", ["default"])[0]
try:
proxy = manager._get_proxy_for_session(session_id)
state = proxy.get_dataset_state()
dataset_id = manager.current_dataset_id(session_id)
self._send_json(
200,
{
"dataset": {
"id": dataset_id,
"frame": state.get("frame"),
"n": state.get("n"),
"k": state.get("k"),
"changed": state.get("changed"),
}
},
)
return
except NoDataInMemoryError as e:
self._error(400, "no_data_in_memory", str(e), stata_rc=e.stata_rc)
return
except Exception as e:
self._error(500, "internal_error", str(e))
return
if self.path.startswith("/v1/vars"):
from urllib.parse import urlparse, parse_qs
parsed_url = urlparse(self.path)
params = parse_qs(parsed_url.query)
session_id = params.get("sessionId", ["default"])[0]
try:
proxy = manager._get_proxy_for_session(session_id)
state = proxy.get_dataset_state()
dataset_id = manager.current_dataset_id(session_id)
variables = proxy.list_variables_rich()
self._send_json(
200,
{
"dataset": {"id": dataset_id, "frame": state.get("frame")},
"variables": variables,
},
)
return
except NoDataInMemoryError as e:
self._error(400, "no_data_in_memory", str(e), stata_rc=e.stata_rc)
return
except Exception as e:
self._error(500, "internal_error", str(e))
return
self._error(404, "not_found", "Not found")
def do_POST(self) -> None:
if not self._require_auth():
return
if self.path == "/v1/arrow":
body = self._read_json()
if body is None:
return
try:
resp_bytes = handle_arrow_request(manager, body, view_id=None)
self._send_binary(200, resp_bytes, "application/vnd.apache.arrow.stream")
return
except HTTPError as e:
self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
return
except Exception as e:
self._error(500, "internal_error", str(e))
return
if self.path == "/v1/page":
body = self._read_json()
if body is None:
return
try:
resp = handle_page_request(manager, body, view_id=None)
self._send_json(200, resp)
return
except HTTPError as e:
self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
return
except Exception as e:
self._error(500, "internal_error", str(e))
return
if self.path == "/v1/views":
body = self._read_json()
if body is None:
return
dataset_id = str(body.get("datasetId", ""))
frame = str(body.get("frame", "default"))
filter_expr = str(body.get("filterExpr", ""))
session_id = str(body.get("sessionId", "default"))
if not dataset_id or not filter_expr:
self._error(400, "invalid_request", "datasetId and filterExpr are required")
return
try:
view = manager.create_view(session_id=session_id, dataset_id=dataset_id, frame=frame, filter_expr=filter_expr)
self._send_json(
200,
{
"dataset": {"id": view.dataset_id, "frame": view.frame},
"view": {"id": view.view_id, "filteredN": view.filtered_n},
},
)
return
except DatasetChangedError as e:
self._error(409, "dataset_changed", f"Dataset changed for session {session_id}")
return
except ValueError as e:
self._error(400, "invalid_filter", str(e))
return
except RuntimeError as e:
msg = str(e) or "No data in memory"
if "no data" in msg.lower():
self._error(400, "no_data_in_memory", msg)
return
self._error(500, "internal_error", msg)
return
except Exception as e:
self._error(500, "internal_error", str(e))
return
if self.path.startswith("/v1/views/") and self.path.endswith("/page"):
parts = self.path.split("/")
if len(parts) != 5:
self._error(404, "not_found", "Not found")
return
view_id = parts[3]
body = self._read_json()
if body is None:
return
try:
resp = handle_page_request(manager, body, view_id=view_id)
self._send_json(200, resp)
return
except HTTPError as e:
self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
return
except Exception as e:
self._error(500, "internal_error", str(e))
return
if self.path.startswith("/v1/views/") and self.path.endswith("/arrow"):
parts = self.path.split("/")
if len(parts) != 5:
self._error(404, "not_found", "Not found")
return
view_id = parts[3]
body = self._read_json()
if body is None:
return
try:
resp_bytes = handle_arrow_request(manager, body, view_id=view_id)
self._send_binary(200, resp_bytes, "application/vnd.apache.arrow.stream")
return
except HTTPError as e:
self._error(e.status, e.code, e.message, stata_rc=e.stata_rc)
return
except Exception as e:
self._error(500, "internal_error", str(e))
return
if self.path == "/v1/filters/validate":
body = self._read_json()
if body is None:
return
filter_expr = str(body.get("filterExpr", ""))
session_id = str(body.get("sessionId", "default"))
if not filter_expr:
self._error(400, "invalid_request", "filterExpr is required")
return
try:
proxy = manager._get_proxy_for_session(session_id)
proxy.validate_filter_expr(filter_expr)
self._send_json(200, {"ok": True})
return
except ValueError as e:
self._error(400, "invalid_filter", str(e))
return
except RuntimeError as e:
msg = str(e) or "No data in memory"
if "no data" in msg.lower():
self._error(400, "no_data_in_memory", msg)
return
self._error(500, "internal_error", msg)
return
except Exception as e:
self._error(500, "internal_error", str(e))
return
self._error(404, "not_found", "Not found")
def do_DELETE(self) -> None:
if not self._require_auth():
return
if self.path.startswith("/v1/views/"):
parts = self.path.split("/")
if len(parts) != 4:
self._error(404, "not_found", "Not found")
return
from urllib.parse import urlparse, parse_qs
parsed_url = urlparse(self.path)
params = parse_qs(parsed_url.query)
session_id = params.get("sessionId", ["default"])[0]
view_id = parts[3]
if manager.delete_view(session_id, view_id):
self._send_json(200, {"ok": True})
else:
self._error(404, "not_found", f"View {view_id} not found in session {session_id}")
return
self._error(404, "not_found", "Not found")
def log_message(self, format: str, *args: Any) -> None:
return
httpd = ThreadingHTTPServer((self._host, self._port), Handler)
t = threading.Thread(target=httpd.serve_forever, daemon=True)
t.start()
self._httpd = httpd
self._thread = t
class HTTPError(Exception):
def __init__(self, status: int, code: str, message: str, *, stata_rc: int | None = None):
super().__init__(message)
self.status = status
self.code = code
self.message = message
self.stata_rc = stata_rc
class DatasetChangedError(Exception):
def __init__(self, current_dataset_id: str):
super().__init__("dataset_changed")
self.current_dataset_id = current_dataset_id
class NoDataInMemoryError(Exception):
def __init__(self, message: str = "No data in memory", *, stata_rc: int | None = None):
super().__init__(message)
self.stata_rc = stata_rc
class InvalidFilterError(Exception):
def __init__(self, message: str, *, stata_rc: int | None = None):
super().__init__(message)
self.message = message
self.stata_rc = stata_rc
def _resolve_proxy(manager: UIChannelManager, session_id: str) -> StataClient:
"""Resolve the Stata client proxy, preferring injected clients for tests."""
proxy = getattr(manager, "_client", None)
if proxy is not None:
from .server import StataClientProxy
if not isinstance(proxy, StataClientProxy):
return proxy
return manager._get_proxy_for_session(session_id)
def handle_page_request(manager: UIChannelManager, body: dict[str, Any], *, view_id: str | None) -> dict[str, Any]:
max_limit, max_vars, max_chars, _ = manager.limits()
session_id = str(body.get("sessionId", "default"))
if view_id is None:
dataset_id = str(body.get("datasetId", ""))
frame = str(body.get("frame", "default"))
else:
view = manager.get_view(session_id, view_id)
if view is None:
raise HTTPError(404, "not_found", f"View {view_id} not found in session {session_id}")
dataset_id = view.dataset_id
frame = view.frame
try:
offset = int(body.get("offset") or 0)
except (ValueError, TypeError):
raise HTTPError(400, "invalid_request", "offset must be a valid integer")
limit_raw = body.get("limit")
if limit_raw is None:
raise HTTPError(400, "invalid_request", "limit is required")
try:
limit = int(limit_raw)
except (ValueError, TypeError):
raise HTTPError(400, "invalid_request", "limit must be a valid integer")
vars_req = body.get("vars", [])
include_obs_no = bool(body.get("includeObsNo", False))
sort_by = body.get("sortBy", [])
max_chars_raw = body.get("maxChars", max_chars)
try:
max_chars_req = int(max_chars_raw or max_chars)
except (ValueError, TypeError):
raise HTTPError(400, "invalid_request", "maxChars must be a valid integer")
if offset < 0:
raise HTTPError(400, "invalid_request", "offset must be >= 0")
if limit <= 0:
raise HTTPError(400, "invalid_request", f"limit must be > 0 (got: {limit})")
if limit > max_limit:
raise HTTPError(400, "request_too_large", f"limit must be <= {max_limit}")
if max_chars_req <= 0:
raise HTTPError(400, "invalid_request", "maxChars must be > 0")
if max_chars_req > max_chars:
raise HTTPError(400, "request_too_large", f"maxChars must be <= {max_chars}")
if not isinstance(vars_req, list) or not all(isinstance(v, str) for v in vars_req):
raise HTTPError(400, "invalid_request", "vars must be a list of strings")
if len(vars_req) > max_vars:
raise HTTPError(400, "request_too_large", f"vars length must be <= {max_vars}")
if sort_by and (not isinstance(sort_by, list) or not all(isinstance(s, str) for s in sort_by)):
raise HTTPError(400, "invalid_request", "sortBy must be an array of strings")
current_id = manager.current_dataset_id(session_id)
if dataset_id != current_id:
raise HTTPError(409, "dataset_changed", f"Dataset changed for session {session_id}")
if view_id is None:
obs_indices = None
filtered_n = None
else:
assert view is not None
obs_indices = view.obs_indices
filtered_n = view.filtered_n
try:
if sort_by:
sort_spec = manager._normalize_sort_spec(sort_by)
obs_indices_sorted = manager._get_cached_sort_indices(session_id, dataset_id, sort_spec)
if obs_indices_sorted is None:
sort_cols = [s.lstrip("+-") for s in sort_spec]
descending = [s.startswith("-") for s in sort_spec]
nulls_last = [False] * len(sort_spec)
table = manager._get_sort_table(session_id, dataset_id, sort_cols)
if table is not None:
obs_indices_sorted = _try_native_argsort(table, sort_cols, descending, nulls_last)
if obs_indices_sorted is None:
obs_indices_sorted = _get_sorted_indices_polars(table, sort_cols, descending, nulls_last)
manager._set_cached_sort_indices(session_id, dataset_id, sort_spec, obs_indices_sorted)
if obs_indices_sorted:
if obs_indices:
filter_set = set(obs_indices)
obs_indices = [idx for idx in obs_indices_sorted if idx in filter_set]
else:
obs_indices = obs_indices_sorted
proxy = _resolve_proxy(manager, session_id)
dataset_state = proxy.get_dataset_state()
page = proxy.get_page(
offset=offset,
limit=limit,
vars=vars_req,
include_obs_no=include_obs_no,
max_chars=max_chars_req,
obs_indices=obs_indices,
)
view_obj: dict[str, Any] = {
"offset": offset,
"limit": limit,
"returned": page["returned"],
"filteredN": filtered_n,
}
if view_id is not None:
view_obj["viewId"] = view_id
return {
"dataset": {
"id": current_id,
"frame": dataset_state.get("frame"),
"n": dataset_state.get("n"),
"k": dataset_state.get("k"),
},
"view": view_obj,
"vars": page["vars"],
"rows": page["rows"],
"display": {
"maxChars": max_chars_req,
"truncatedCells": page["truncated_cells"],
"missing": ".",
},
}
except HTTPError:
raise
except RuntimeError as e:
msg = str(e) or "No data in memory"
if "invalid variable" in msg.lower():
raise HTTPError(400, "invalid_variable", msg)
if "no data" in msg.lower():
raise HTTPError(400, "no_data_in_memory", msg)
raise HTTPError(500, "internal_error", msg)
except ValueError as e:
msg = str(e)
if msg.lower().startswith("invalid variable"):
raise HTTPError(400, "invalid_variable", msg)
raise HTTPError(400, "invalid_request", msg)
except Exception as e:
raise HTTPError(500, "internal_error", str(e))
def handle_arrow_request(manager: UIChannelManager, body: dict[str, Any], *, view_id: str | None) -> bytes:
max_limit, max_vars, _, _ = manager.limits()
chunk_limit = getattr(manager, "_max_arrow_limit", 1_000_000)
session_id = str(body.get("sessionId", "default"))
if view_id is None:
dataset_id = str(body.get("datasetId", ""))
frame = str(body.get("frame", "default"))
else:
view = manager.get_view(session_id, view_id)
if view is None:
raise HTTPError(404, "not_found", f"View {view_id} not found in session {session_id}")
dataset_id = view.dataset_id
frame = view.frame
try:
offset = int(body.get("offset") or 0)
except (ValueError, TypeError):
raise HTTPError(400, "invalid_request", "offset must be a valid integer")
limit_raw = body.get("limit")
if limit_raw is None:
raise HTTPError(400, "invalid_request", "limit is required")
try:
limit = int(limit_raw)
except (ValueError, TypeError):
raise HTTPError(400, "invalid_request", "limit must be a valid integer")
vars_req = body.get("vars", [])
include_obs_no = bool(body.get("includeObsNo", False))
sort_by = body.get("sortBy", [])
if offset < 0:
raise HTTPError(400, "invalid_request", "offset must be >= 0")
if limit <= 0:
raise HTTPError(400, "invalid_request", f"limit must be > 0 (got: {limit})")
if limit > chunk_limit:
raise HTTPError(400, "request_too_large", f"limit must be <= {chunk_limit}")
if not isinstance(vars_req, list) or not all(isinstance(v, str) for v in vars_req):
raise HTTPError(400, "invalid_request", "vars must be a list of strings")
if len(vars_req) > max_vars:
raise HTTPError(400, "request_too_large", f"vars length must be <= {max_vars}")
current_id = manager.current_dataset_id(session_id)
if dataset_id != current_id:
raise HTTPError(409, "dataset_changed", f"Dataset changed for session {session_id}")
if view_id is None:
obs_indices = None
else:
assert view is not None
obs_indices = view.obs_indices
try:
if sort_by:
if not isinstance(sort_by, list) or not all(isinstance(s, str) for s in sort_by):
raise HTTPError(400, "invalid_request", "sortBy must be an array of strings")
sort_spec = manager._normalize_sort_spec(sort_by)
obs_indices_sorted = manager._get_cached_sort_indices(session_id, dataset_id, sort_spec)
if obs_indices_sorted is None:
sort_cols = [s.lstrip("+-") for s in sort_spec]
descending = [s.startswith("-") for s in sort_spec]
nulls_last = [False] * len(sort_spec)
table = manager._get_sort_table(session_id, dataset_id, sort_cols)
if table is not None:
obs_indices_sorted = _try_native_argsort(table, sort_cols, descending, nulls_last)
if obs_indices_sorted is None:
obs_indices_sorted = _get_sorted_indices_polars(table, sort_cols, descending, nulls_last)
manager._set_cached_sort_indices(session_id, dataset_id, sort_spec, obs_indices_sorted)
if obs_indices_sorted:
if obs_indices:
filter_set = set(obs_indices)
obs_indices = [idx for idx in obs_indices_sorted if idx in filter_set]
else:
obs_indices = obs_indices_sorted
proxy = _resolve_proxy(manager, session_id)
return proxy.get_arrow_stream(
offset=offset,
limit=limit,
vars=vars_req,
include_obs_no=include_obs_no,
obs_indices=obs_indices,
)
except RuntimeError as e:
msg = str(e) or "No data in memory"
if "invalid variable" in msg.lower():
raise HTTPError(400, "invalid_variable", msg)
if "no data" in msg.lower():
raise HTTPError(400, "no_data_in_memory", msg)
raise HTTPError(500, "internal_error", msg)
except ValueError as e:
msg = str(e)
if "invalid variable" in msg.lower():
raise HTTPError(400, "invalid_variable", msg)
raise HTTPError(400, "invalid_request", msg)
except Exception as e:
raise HTTPError(500, "internal_error", str(e))