from typing import Dict, Any, List
import logging
import random # Used for mock implementation
# Mock implementation that doesn't require torch
# import torch
# from transformers import AutoTokenizer, AutoModelForSequenceClassification
# import torch.nn.functional as F
logger = logging.getLogger(__name__)
class BERTDocumentClassifier:
"""Document classifier using BERT for document type classification."""
def __init__(self, model_name: str, threshold: float = 0.5):
self.model_name = model_name
self.threshold = threshold
self.last_confidence = 0.0
# Define document classes
self.classes = ["invoice", "contract", "email", "report", "other"]
logger.info(f"Initialized mock BERTDocumentClassifier with model: {model_name}")
def predict(self, text: str) -> str:
"""Predict the class of a document."""
scores = self.classify(text)
predicted_class = max(scores.items(), key=lambda x: x[1])[0]
return predicted_class
def classify(self, text: str) -> Dict[str, float]:
"""Classify a document into different categories using mock implementation."""
# Truncate text if too long
text = text[:512] # Simplified for demonstration
try:
# Mock implementation - use simple keyword matching
text_lower = text.lower()
# Initialize with low probabilities
result = {cls: 0.1 for cls in self.classes}
# Simple keyword-based classification
invoice_keywords = ["invoice", "bill", "payment", "amount due", "total", "qty", "quantity"]
contract_keywords = ["agreement", "contract", "terms", "parties", "clause", "signed", "obligations"]
email_keywords = ["from:", "to:", "subject:", "sent:", "received:", "dear", "regards", "sincerely"]
report_keywords = ["report", "analysis", "findings", "conclusion", "summary", "data", "results"]
# Count keyword matches
invoice_count = sum(1 for kw in invoice_keywords if kw in text_lower)
contract_count = sum(1 for kw in contract_keywords if kw in text_lower)
email_count = sum(1 for kw in email_keywords if kw in text_lower)
report_count = sum(1 for kw in report_keywords if kw in text_lower)
# Calculate probabilities based on keyword matches
total_matches = invoice_count + contract_count + email_count + report_count
if total_matches > 0:
result["invoice"] = 0.1 + (invoice_count / total_matches) * 0.8
result["contract"] = 0.1 + (contract_count / total_matches) * 0.8
result["email"] = 0.1 + (email_count / total_matches) * 0.8
result["report"] = 0.1 + (report_count / total_matches) * 0.8
result["other"] = 0.1
else:
# If no keywords match, assign to "other" with high probability
result["other"] = 0.8
# Add some randomness to simulate ML behavior
for cls in self.classes:
result[cls] += random.uniform(-0.05, 0.05)
result[cls] = max(0.01, min(0.99, result[cls])) # Keep between 0.01 and 0.99
# Normalize to make sure probabilities sum to 1
total_prob = sum(result.values())
result = {cls: prob / total_prob for cls, prob in result.items()}
# Set confidence as the highest probability
self.last_confidence = max(result.values())
return result
except Exception as e:
logger.error(f"Error in document classification: {str(e)}")
self.last_confidence = 0.0
return {cls: 0.0 for cls in self.classes}
def get_confidence(self) -> float:
"""Get the confidence score of the last prediction."""
return self.last_confidence