hf.py•9.01 kB
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DynamicCache,
LogitsProcessorList,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
from memos.configs.llm import HFLLMConfig
from memos.llms.base import BaseLLM
from memos.llms.utils import remove_thinking_tags
from memos.log import get_logger
from memos.types import MessageList
logger = get_logger(__name__)
class HFLLM(BaseLLM):
"""
HFLLM: Transformers LLM class supporting cache-augmented generation (CAG) and sampling.
"""
def __init__(self, config: HFLLMConfig):
"""
Initialize the HFLLM model and tokenizer, and set up logits processors for sampling.
"""
self.config = config
# Default model if not specified
if not self.config.model_name_or_path:
self.config.model_name_or_path = "Qwen/Qwen3-1.7B"
# Initialize hf model
self.model = AutoModelForCausalLM.from_pretrained(
self.config.model_name_or_path, torch_dtype="auto", device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.model_name_or_path, use_fast=True
)
# Logits processors for sampling
processors = []
if getattr(self.config, "temperature", 1.0) != 1.0:
processors.append(TemperatureLogitsWarper(self.config.temperature))
if getattr(self.config, "top_k", 0) > 0:
processors.append(TopKLogitsWarper(self.config.top_k))
if 0.0 < getattr(self.config, "top_p", 1.0) < 1.0:
processors.append(TopPLogitsWarper(self.config.top_p))
self.logits_processors = LogitsProcessorList(processors)
def generate(self, messages: MessageList, past_key_values: DynamicCache | None = None):
"""
Generate a response from the model. If past_key_values is provided, use cache-augmented generation.
Args:
messages (MessageList): Chat messages for prompt construction.
past_key_values (DynamicCache | None): Optional KV cache for fast generation.
Returns:
str: Model response.
"""
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=self.config.add_generation_prompt
)
logger.info(f"HFLLM prompt: {prompt}")
if past_key_values is None:
return self._generate_full(prompt)
else:
return self._generate_with_cache(prompt, past_key_values)
def _generate_full(self, prompt: str) -> str:
"""
Generate output from scratch using the full prompt.
Args:
prompt (str): The input prompt string.
Returns:
str: Model response.
"""
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device)
gen_kwargs = {
"max_new_tokens": getattr(self.config, "max_tokens", 128),
"do_sample": getattr(self.config, "do_sample", True),
}
if self.config.do_sample:
gen_kwargs["temperature"] = self.config.temperature
gen_kwargs["top_k"] = self.config.top_k
gen_kwargs["top_p"] = self.config.top_p
gen_ids = self.model.generate(
**inputs,
**gen_kwargs,
)
new_ids = [
out_ids[len(src_ids) :]
for src_ids, out_ids in zip(inputs.input_ids, gen_ids, strict=False)
]
response = self.tokenizer.batch_decode(new_ids, skip_special_tokens=True)[0]
logger.info(f"Full-gen raw response: {response}")
return (
remove_thinking_tags(response)
if getattr(self.config, "remove_think_prefix", False)
else response
)
def _generate_with_cache(self, query: str, kv: DynamicCache) -> str:
"""
Generate output incrementally using an existing KV cache.
Args:
query (str): The new user query string.
kv (DynamicCache): The prefilled KV cache.
Returns:
str: Model response.
"""
query_ids = self.tokenizer(
query, return_tensors="pt", add_special_tokens=False
).input_ids.to(self.model.device)
logits, kv = self._prefill(query_ids, kv)
next_token = self._select_next_token(logits)
generated = [next_token]
for _ in range(getattr(self.config, "max_tokens", 128) - 1):
if self._should_stop(next_token):
break
logits, kv = self._prefill(next_token, kv)
next_token = self._select_next_token(logits)
generated.append(next_token)
if generated:
concat = torch.cat(generated, dim=-1)
response = self.tokenizer.decode(concat[0], skip_special_tokens=True)
else:
response = ""
logger.info(f"Cache-gen raw response: {response}")
return (
remove_thinking_tags(response)
if getattr(self.config, "remove_think_prefix", False)
else response
)
@torch.no_grad()
def _prefill(
self, input_ids: torch.Tensor, kv: DynamicCache
) -> tuple[torch.Tensor, DynamicCache]:
"""
Forward the model once, returning last-step logits and updated KV cache.
Args:
input_ids (torch.Tensor): Input token IDs.
kv (DynamicCache): Existing KV cache.
Returns:
tuple[torch.Tensor, DynamicCache]: (last-step logits, updated KV cache)
"""
out = self.model(
input_ids=input_ids,
use_cache=True,
past_key_values=kv,
return_dict=True,
)
return out.logits[:, -1, :], out.past_key_values
def _select_next_token(self, logits: torch.Tensor) -> torch.Tensor:
"""
Select the next token from logits using sampling or argmax, depending on config.
Args:
logits (torch.Tensor): Logits for the next token.
Returns:
torch.Tensor: Selected token ID(s).
"""
if getattr(self.config, "do_sample", True):
batch_size, _ = logits.size()
dummy_ids = torch.zeros((batch_size, 1), dtype=torch.long, device=logits.device)
filtered = self.logits_processors(dummy_ids, logits)
probs = torch.softmax(filtered, dim=-1)
return torch.multinomial(probs, num_samples=1)
return torch.argmax(logits, dim=-1, keepdim=True)
def _should_stop(self, token: torch.Tensor) -> bool:
"""
Check if the given token is the EOS (end-of-sequence) token.
Args:
token (torch.Tensor): Token ID to check.
Returns:
bool: True if token is EOS, else False.
"""
eos_id = self.tokenizer.eos_token_id
return eos_id is not None and token.item() == eos_id
def build_kv_cache(self, messages) -> DynamicCache:
"""
Build a KV cache from chat messages via one forward pass.
Supports the following input types:
- str: Used as a system prompt.
- list[str]: Concatenated and used as a system prompt.
- list[dict]: Used directly as chat messages.
The messages are always converted to a standard chat template.
Raises:
ValueError: If the resulting prompt is empty after template processing.
Returns:
DynamicCache: The constructed KV cache object.
"""
# Accept multiple input types and convert to standard chat messages
if isinstance(messages, str):
messages = [
{
"role": "system",
"content": f"Below is some information about the user.\n{messages}",
}
]
elif isinstance(messages, list) and messages and isinstance(messages[0], str):
messages = [
{
"role": "system",
"content": f"Below is some information about the user.\n{' '.join(messages)}",
}
]
prompt = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
inputs = self.tokenizer(prompt, return_tensors="pt")
inputs["input_ids"] = inputs["input_ids"].to(self.model.device, dtype=torch.long)
seq_len = inputs["input_ids"].size(-1)
if seq_len == 0:
raise ValueError(
"Prompt after chat template is empty, cannot build KV cache. Check your messages input."
)
kv = DynamicCache()
with torch.no_grad():
self.model(**inputs, use_cache=True, past_key_values=kv)
for i, (k, v) in enumerate(zip(kv.key_cache, kv.value_cache, strict=False)):
kv.key_cache[i] = k[:, :, :seq_len, :]
kv.value_cache[i] = v[:, :, :seq_len, :]
return kv