"""Model management helpers for the DPS Coach local LLM."""
from __future__ import annotations
import hashlib
import os
from pathlib import Path
from typing import Any, Dict, Optional, Sequence
from .constants import (
DEFAULT_MODEL_MIN_SIZE_MB,
DEFAULT_MODEL_PATH,
DEFAULT_MODEL_SHA256,
RUNTIME_MODELS_DIR,
)
try: # pragma: no cover - optional dependency warning
from llama_cpp import Llama # type: ignore
except ImportError as exc: # pragma: no cover - surfaced to UI
Llama = None # type: ignore
_LLAMA_IMPORT_ERROR = exc
else:
_LLAMA_IMPORT_ERROR = None
LlamaType = Any
class ModelManagerError(RuntimeError):
"""Raised when the bundled model fails validation or inference."""
def validate_model_file(
path: str | Path,
*,
min_size_mb: float = DEFAULT_MODEL_MIN_SIZE_MB,
expected_sha256: str | None = DEFAULT_MODEL_SHA256,
) -> None:
"""Validate that the GGUF file exists, has the right size, magic, and SHA."""
resolved = Path(path)
if not resolved.exists():
raise FileNotFoundError(f"Model file not found at {resolved}.")
size_mb = resolved.stat().st_size / (1024 * 1024)
if size_mb <= min_size_mb:
raise ValueError(
f"Model file is too small ({size_mb:.2f} MB). Expected > {min_size_mb} MB. File: {resolved}"
)
with resolved.open("rb") as handle:
header = handle.read(4)
if header != b"GGUF":
raise ValueError(f"Model file header mismatch (expected GGUF magic bytes). File: {resolved}")
if expected_sha256:
sha = hashlib.sha256()
with resolved.open("rb") as handle:
for chunk in iter(lambda: handle.read(1024 * 1024), b""):
sha.update(chunk)
digest = sha.hexdigest()
if digest.lower() != expected_sha256.lower():
raise ValueError(
f"Model SHA256 mismatch. Expected {expected_sha256.upper()}, "
f"found {digest.upper()}. File: {resolved}"
)
class LocalModelManager:
"""Lightweight wrapper around llama-cpp that keeps a shared model handle."""
def __init__(
self,
model_path: str | Path = DEFAULT_MODEL_PATH,
*,
min_size_mb: float = DEFAULT_MODEL_MIN_SIZE_MB,
expected_sha256: str | None = DEFAULT_MODEL_SHA256,
n_ctx: int = 4096,
) -> None:
self._model_path = Path(model_path)
self._min_size_mb = min_size_mb
self._expected_sha256 = expected_sha256
self._n_ctx = n_ctx
self._llm: Optional[LlamaType] = None
# Ensure runtime models directory exists
RUNTIME_MODELS_DIR.mkdir(parents=True, exist_ok=True)
@property
def model_path(self) -> Path:
return self._model_path
@property
def min_size_mb(self) -> float:
return self._min_size_mb
@property
def expected_sha256(self) -> str | None:
return self._expected_sha256
def set_model_path(self, new_path: str | Path) -> None:
new_path = Path(new_path)
if new_path == self._model_path:
return
# Drop cached handle so llama reloads using the new file.
self._llm = None
self._model_path = new_path
def ensure_model(self) -> LlamaType:
if Llama is None:
raise ModelManagerError("llama-cpp-python is not installed.") from _LLAMA_IMPORT_ERROR
validate_model_file(
self._model_path,
min_size_mb=self._min_size_mb,
expected_sha256=self._expected_sha256,
)
if self._llm is None:
self._llm = Llama(
model_path=str(self._model_path),
n_ctx=self._n_ctx,
logits_all=False,
embedding=False,
verbose=False,
)
return self._llm
def generate(
self,
messages: Sequence[Dict[str, str]],
*,
max_tokens: int,
temperature: float,
top_p: float,
repeat_penalty: float,
stop: Optional[Sequence[str]] = None,
) -> str:
llm = self.ensure_model()
completion = llm.create_chat_completion(
messages=list(messages),
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
repeat_penalty=repeat_penalty,
stop=stop,
)
content = completion["choices"][0]["message"]["content"]
return (content or "").strip()
def run_self_test(self) -> None:
messages = [
{"role": "system", "content": "You are a terse test helper."},
{"role": "user", "content": "Say OK."},
]
output = self.generate(
messages,
max_tokens=4,
temperature=0.0,
top_p=1.0,
repeat_penalty=1.0,
stop=["\n"],
)
normalized = self._normalize_self_test_output(output)
if normalized != "OK":
raise ModelManagerError(f"Model self-test mismatch: {output!r}")
@staticmethod
def _normalize_self_test_output(output: str) -> str:
cleaned = output.strip()
# Repeatedly strip quotes and punctuation until stable
prev = None
while prev != cleaned:
prev = cleaned
cleaned = cleaned.strip("\"'")
cleaned = cleaned.strip()
cleaned = cleaned.rstrip(".!?\"'")
cleaned = cleaned.strip()
return cleaned.upper()
__all__ = [
"LocalModelManager",
"ModelManagerError",
"validate_model_file",
"_LLAMA_IMPORT_ERROR",
]