onnx_convert.py•5.72 kB
#!/usr/bin/env python3
"""
Simple ONNX conversion using local sentence-transformers model.
"""
import json
import torch
import onnx
from pathlib import Path
def convert_local_model():
"""Convert locally loaded sentence-transformers model to ONNX."""
try:
from sentence_transformers import SentenceTransformer
print("Loading sentence-transformers model locally...")
# Force CPU to avoid MPS issues
device = torch.device('cpu')
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', device='cpu')
# Get the transformer component
transformer = model[0]
pytorch_model = transformer.auto_model
tokenizer = transformer.tokenizer
# Move to CPU and set to evaluation mode
pytorch_model = pytorch_model.to(device)
pytorch_model.eval()
# Create dummy input for tracing
dummy_input_ids = torch.randint(0, 1000, (1, 128), dtype=torch.long)
dummy_attention_mask = torch.ones(1, 128, dtype=torch.long)
print("Converting to ONNX...")
kb_path = Path("/Users/thypon/kb")
onnx_path = kb_path / "sentence_model.onnx"
# Export to ONNX
torch.onnx.export(
pytorch_model,
(dummy_input_ids, dummy_attention_mask),
str(onnx_path),
input_names=['input_ids', 'attention_mask'],
output_names=['last_hidden_state'],
dynamic_axes={
'input_ids': {0: 'batch_size', 1: 'sequence'},
'attention_mask': {0: 'batch_size', 1: 'sequence'},
'last_hidden_state': {0: 'batch_size', 1: 'sequence'}
},
opset_version=14,
do_constant_folding=True
)
# Save tokenizer
tokenizer_path = kb_path / "tokenizer"
tokenizer_path.mkdir(exist_ok=True)
tokenizer.save_pretrained(str(tokenizer_path))
# Save config
config = {
"model_name": "all-MiniLM-L6-v2",
"max_seq_length": 256,
"embedding_dim": 384,
"tokenizer_path": str(tokenizer_path)
}
with open(kb_path / "model_config.json", 'w') as f:
json.dump(config, f, indent=2)
print("✅ ONNX conversion complete!")
print(f"Model: {onnx_path}")
print(f"Tokenizer: {tokenizer_path}")
print(f"Config: {kb_path / 'model_config.json'}")
# Verify the model
try:
onnx_model = onnx.load(str(onnx_path))
onnx.checker.check_model(onnx_model)
print("✅ ONNX model verification passed!")
except Exception as e:
print(f"⚠️ ONNX verification warning: {e}")
return True
except Exception as e:
print(f"Error during conversion: {e}")
return False
def test_onnx_inference():
"""Test ONNX model inference."""
try:
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np
kb_path = Path("/Users/thypon/kb")
model_path = kb_path / "sentence_model.onnx"
tokenizer_path = kb_path / "tokenizer"
if not model_path.exists():
print("ONNX model not found")
return False
print("Testing ONNX inference...")
# Load tokenizer and session
tokenizer = AutoTokenizer.from_pretrained(str(tokenizer_path))
session = ort.InferenceSession(str(model_path))
# Test sentences
texts = ["Hello world", "This is a test"]
for text in texts:
# Tokenize
inputs = tokenizer(text, return_tensors="np", padding=True, truncation=True, max_length=256)
# Run ONNX inference
outputs = session.run(None, {
'input_ids': inputs['input_ids'].astype(np.int64),
'attention_mask': inputs['attention_mask'].astype(np.int64)
})
# Mean pooling (same as sentence-transformers)
token_embeddings = outputs[0] # last_hidden_state
attention_mask = inputs['attention_mask']
# Apply mask and mean pool
input_mask_expanded = np.expand_dims(attention_mask, axis=-1)
masked_embeddings = token_embeddings * input_mask_expanded
sum_embeddings = np.sum(masked_embeddings, axis=1)
sum_mask = np.sum(attention_mask, axis=1, keepdims=True)
mean_embeddings = sum_embeddings / sum_mask
# Normalize
norm = np.linalg.norm(mean_embeddings, axis=1, keepdims=True)
normalized_embeddings = mean_embeddings / norm
print(f"Text: '{text}'")
print(f"Embedding shape: {normalized_embeddings.shape}")
print(f"First 5 values: {normalized_embeddings[0][:5]}")
print()
print("✅ ONNX inference test passed!")
return True
except Exception as e:
print(f"Error testing ONNX: {e}")
return False
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--convert", action="store_true")
parser.add_argument("--test", action="store_true")
args = parser.parse_args()
if args.convert:
success = convert_local_model()
if success and args.test:
test_onnx_inference()
elif args.test:
test_onnx_inference()
else:
print("Use --convert or --test")