# --- Remote Tokenizer Implementation ---
import json
from typing import List
import httpx
from indexer.abstract_tokenizer import AbstractTokenizer
def escape(text) :
return json.dumps(text)[1:-1]
class RemoteTeiTokenizer(AbstractTokenizer):
"""
A tokenizer that uses a remote TEI API to tokenize text.
"""
def __init__(self,
api_url: str,
timeout: int = 60,
add_special_tokens: bool = True):
"""
Initializes the RemoteTeiTokenizer.
Args:
api_url: The URL of the tokenization API endpoint.
timeout: The timeout in seconds for the HTTP request.
add_special_tokens: Flag to include special tokens, as per the API.
This is used in constructing the request payload.
"""
self.api_url = api_url
self.timeout = timeout
self.add_special_tokens = add_special_tokens
self.headers = {'Content-Type': 'application/json'}
def tokenize(self, text: str) -> List[float]:
"""
Tokenizes a given string by calling the remote tokenization API.
Args:
text: The input string to be tokenized.
Returns:
A list of floats
Raises:
RuntimeError: If the API request fails due to network or HTTP errors.
ValueError: If the API response is not in the expected format or
cannot be parsed.
"""
# Safely escape the input text for embedding within a JSON string value
# json.dumps(text) creates a fully quoted and escaped JSON string (e.g., "\"hello \\\"world\\\"\"").
# We need the inner content (e.g., "hello \\\"world\\\"") for the %s formatting.
json_escaped_inner_text = json.dumps(text)[1:-1]
# Construct the payload string exactly as specified in the behavior,
# using the safely escaped text.
# The API expects a JSON string like: '{"inputs": ["some text"], "add_special_tokens": true}'
payload_string = (
f'{{"inputs": ["{json_escaped_inner_text}"], '
f'"add_special_tokens": {str(self.add_special_tokens).lower()}}}'
)
try:
with httpx.Client() as client:
response = client.post(
self.api_url,
headers=self.headers,
content=payload_string,
timeout=self.timeout,
)
response.raise_for_status() # Raises HTTPStatusError for 4xx/5xx responses
return response.json()[0]
except httpx.HTTPStatusError as e:
error_text = e.response.text
raise RuntimeError(
f"API request failed with status {e.response.status_code}: {error_text}"
) from e
except httpx.RequestError as e:
# Covers network errors, timeouts (excluding read timeouts if handled by status), etc.
raise RuntimeError(f"Network error during tokenization: {e}") from e
except json.JSONDecodeError as e:
# If response.json() fails
raise ValueError(f"Invalid JSON response from API: {e}") from e
# ValueError can also be raised by our parsing logic
def tokenize_multiple(self, text_list: List[str]) -> List[List[str]]:
"""
Tokenizes a given string list by calling the remote tokenization API.
Args:
text_list: The input string to be tokenized.
Returns:
A list of list of floats
Raises:
RuntimeError: If the API request fails due to network or HTTP errors.
ValueError: If the API response is not in the expected format or
cannot be parsed.
"""
# Safely escape the input text for embedding within a JSON string value
# json.dumps(text) creates a fully quoted and escaped JSON string (e.g., "\"hello \\\"world\\\"\"").
# We need the inner content (e.g., "hello \\\"world\\\"") for the %s formatting.
#json_escaped_inner_text = json.dumps(text)[1:-1]
json_escaped_inner_text_list = list(map(escape, text_list))
# Construct the payload string exactly as specified in the behavior,
# using the safely escaped text.
# The API expects a JSON string like: '{"inputs": ["some text"], "add_special_tokens": true}'
payload_string = (
f'{{"inputs": {json.dumps(json_escaped_inner_text_list)}, '
f'"add_special_tokens": {str(self.add_special_tokens).lower()}}}'
)
try:
with httpx.Client() as client:
response = client.post(
self.api_url,
headers=self.headers,
content=payload_string,
timeout=self.timeout,
)
response.raise_for_status() # Raises HTTPStatusError for 4xx/5xx responses
return response.json()
except httpx.HTTPStatusError as e:
error_text = e.response.text
raise RuntimeError(
f"API request failed with status {e.response.status_code}: {error_text}"
) from e
except httpx.RequestError as e:
# Covers network errors, timeouts (excluding read timeouts if handled by status), etc.
raise RuntimeError(f"Network error during tokenization: {e}") from e
except json.JSONDecodeError as e:
# If response.json() fails
raise ValueError(f"Invalid JSON response from API: {e}") from e
# ValueError can also be raised by our parsing logic