Skip to main content
Glama

MemOS-MCP

by qinshu1109
hf.py9.01 kB
import torch from transformers import ( AutoModelForCausalLM, AutoTokenizer, DynamicCache, LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, ) from memos.configs.llm import HFLLMConfig from memos.llms.base import BaseLLM from memos.llms.utils import remove_thinking_tags from memos.log import get_logger from memos.types import MessageList logger = get_logger(__name__) class HFLLM(BaseLLM): """ HFLLM: Transformers LLM class supporting cache-augmented generation (CAG) and sampling. """ def __init__(self, config: HFLLMConfig): """ Initialize the HFLLM model and tokenizer, and set up logits processors for sampling. """ self.config = config # Default model if not specified if not self.config.model_name_or_path: self.config.model_name_or_path = "Qwen/Qwen3-1.7B" # Initialize hf model self.model = AutoModelForCausalLM.from_pretrained( self.config.model_name_or_path, torch_dtype="auto", device_map="auto" ) self.tokenizer = AutoTokenizer.from_pretrained( self.config.model_name_or_path, use_fast=True ) # Logits processors for sampling processors = [] if getattr(self.config, "temperature", 1.0) != 1.0: processors.append(TemperatureLogitsWarper(self.config.temperature)) if getattr(self.config, "top_k", 0) > 0: processors.append(TopKLogitsWarper(self.config.top_k)) if 0.0 < getattr(self.config, "top_p", 1.0) < 1.0: processors.append(TopPLogitsWarper(self.config.top_p)) self.logits_processors = LogitsProcessorList(processors) def generate(self, messages: MessageList, past_key_values: DynamicCache | None = None): """ Generate a response from the model. If past_key_values is provided, use cache-augmented generation. Args: messages (MessageList): Chat messages for prompt construction. past_key_values (DynamicCache | None): Optional KV cache for fast generation. Returns: str: Model response. """ prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=self.config.add_generation_prompt ) logger.info(f"HFLLM prompt: {prompt}") if past_key_values is None: return self._generate_full(prompt) else: return self._generate_with_cache(prompt, past_key_values) def _generate_full(self, prompt: str) -> str: """ Generate output from scratch using the full prompt. Args: prompt (str): The input prompt string. Returns: str: Model response. """ inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device) gen_kwargs = { "max_new_tokens": getattr(self.config, "max_tokens", 128), "do_sample": getattr(self.config, "do_sample", True), } if self.config.do_sample: gen_kwargs["temperature"] = self.config.temperature gen_kwargs["top_k"] = self.config.top_k gen_kwargs["top_p"] = self.config.top_p gen_ids = self.model.generate( **inputs, **gen_kwargs, ) new_ids = [ out_ids[len(src_ids) :] for src_ids, out_ids in zip(inputs.input_ids, gen_ids, strict=False) ] response = self.tokenizer.batch_decode(new_ids, skip_special_tokens=True)[0] logger.info(f"Full-gen raw response: {response}") return ( remove_thinking_tags(response) if getattr(self.config, "remove_think_prefix", False) else response ) def _generate_with_cache(self, query: str, kv: DynamicCache) -> str: """ Generate output incrementally using an existing KV cache. Args: query (str): The new user query string. kv (DynamicCache): The prefilled KV cache. Returns: str: Model response. """ query_ids = self.tokenizer( query, return_tensors="pt", add_special_tokens=False ).input_ids.to(self.model.device) logits, kv = self._prefill(query_ids, kv) next_token = self._select_next_token(logits) generated = [next_token] for _ in range(getattr(self.config, "max_tokens", 128) - 1): if self._should_stop(next_token): break logits, kv = self._prefill(next_token, kv) next_token = self._select_next_token(logits) generated.append(next_token) if generated: concat = torch.cat(generated, dim=-1) response = self.tokenizer.decode(concat[0], skip_special_tokens=True) else: response = "" logger.info(f"Cache-gen raw response: {response}") return ( remove_thinking_tags(response) if getattr(self.config, "remove_think_prefix", False) else response ) @torch.no_grad() def _prefill( self, input_ids: torch.Tensor, kv: DynamicCache ) -> tuple[torch.Tensor, DynamicCache]: """ Forward the model once, returning last-step logits and updated KV cache. Args: input_ids (torch.Tensor): Input token IDs. kv (DynamicCache): Existing KV cache. Returns: tuple[torch.Tensor, DynamicCache]: (last-step logits, updated KV cache) """ out = self.model( input_ids=input_ids, use_cache=True, past_key_values=kv, return_dict=True, ) return out.logits[:, -1, :], out.past_key_values def _select_next_token(self, logits: torch.Tensor) -> torch.Tensor: """ Select the next token from logits using sampling or argmax, depending on config. Args: logits (torch.Tensor): Logits for the next token. Returns: torch.Tensor: Selected token ID(s). """ if getattr(self.config, "do_sample", True): batch_size, _ = logits.size() dummy_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=logits.device) filtered = self.logits_processors(dummy_ids, logits) probs = torch.softmax(filtered, dim=-1) return torch.multinomial(probs, num_samples=1) return torch.argmax(logits, dim=-1, keepdim=True) def _should_stop(self, token: torch.Tensor) -> bool: """ Check if the given token is the EOS (end-of-sequence) token. Args: token (torch.Tensor): Token ID to check. Returns: bool: True if token is EOS, else False. """ eos_id = self.tokenizer.eos_token_id return eos_id is not None and token.item() == eos_id def build_kv_cache(self, messages) -> DynamicCache: """ Build a KV cache from chat messages via one forward pass. Supports the following input types: - str: Used as a system prompt. - list[str]: Concatenated and used as a system prompt. - list[dict]: Used directly as chat messages. The messages are always converted to a standard chat template. Raises: ValueError: If the resulting prompt is empty after template processing. Returns: DynamicCache: The constructed KV cache object. """ # Accept multiple input types and convert to standard chat messages if isinstance(messages, str): messages = [ { "role": "system", "content": f"Below is some information about the user.\n{messages}", } ] elif isinstance(messages, list) and messages and isinstance(messages[0], str): messages = [ { "role": "system", "content": f"Below is some information about the user.\n{' '.join(messages)}", } ] prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False ) inputs = self.tokenizer(prompt, return_tensors="pt") inputs["input_ids"] = inputs["input_ids"].to(self.model.device, dtype=torch.long) seq_len = inputs["input_ids"].size(-1) if seq_len == 0: raise ValueError( "Prompt after chat template is empty, cannot build KV cache. Check your messages input." ) kv = DynamicCache() with torch.no_grad(): self.model(**inputs, use_cache=True, past_key_values=kv) for i, (k, v) in enumerate(zip(kv.key_cache, kv.value_cache, strict=False)): kv.key_cache[i] = k[:, :, :seq_len, :] kv.value_cache[i] = v[:, :, :seq_len, :] return kv

MCP directory API

We provide all the information about MCP servers via our MCP API.

curl -X GET 'https://glama.ai/api/mcp/v1/servers/qinshu1109/memos-MCP'

If you have feedback or need assistance with the MCP directory API, please join our Discord server