enhanced_clustering_trainer.pyā¢16.6 kB
"""
Enhanced Clustering Trainer with Sentence Transformers
Clusters datasets into GOOD, LIMITATIONS, and HARMFUL categories
"""
import asyncio
import json
import logging
from pathlib import Path
from typing import Dict, List, Any, Tuple
from dataclasses import dataclass, asdict
from datetime import datetime
import pickle
import numpy as np
from sklearn.cluster import KMeans, DBSCAN
from sklearn.metrics import silhouette_score, davies_bouldin_score
from sklearn.preprocessing import StandardScaler
from collections import Counter
import re
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Try to import sentence transformers
try:
from sentence_transformers import SentenceTransformer
HAS_TRANSFORMERS = True
except ImportError:
logger.warning("sentence-transformers not installed. Install with: uv pip install sentence-transformers")
HAS_TRANSFORMERS = False
@dataclass
class ClusterResult:
"""Result of clustering analysis"""
cluster_id: int
size: int
# Cluster composition
category_distribution: Dict[str, float] # {good: 0.2, limitations: 0.8}
domain_distribution: Dict[str, int] # {mathematics: 50, medicine: 30}
# Quality metrics
purity: float # Homogeneity of cluster (0-1)
is_dangerous: bool # True if >70% limitations or harmful
# Representative examples
examples: List[str]
# Pattern description
pattern_description: str
detection_heuristic: str # Rule for detecting this pattern
# Top keywords
keywords: List[str]
@dataclass
class TrainingResult:
"""Complete training results"""
timestamp: str
model_type: str # "kmeans", "dbscan"
embedding_model: str # "all-MiniLM-L6-v2"
# Metrics
n_clusters: int
silhouette_score: float
davies_bouldin_score: float
# Clusters
clusters: List[ClusterResult]
dangerous_clusters: List[ClusterResult] # For ToGMAL tools
# Paths
model_path: str
embeddings_path: str
class EnhancedClusteringTrainer:
"""
Clustering trainer using sentence transformers
Goal: Separate GOOD, LIMITATIONS, and HARMFUL clusters clearly
"""
def __init__(
self,
embedding_model: str = "all-MiniLM-L6-v2",
output_dir: Path = Path("./models/clustering")
):
self.embedding_model_name = embedding_model
self.output_dir = output_dir
self.output_dir.mkdir(parents=True, exist_ok=True)
if HAS_TRANSFORMERS:
logger.info(f"Loading sentence transformer: {embedding_model}")
self.embedder = SentenceTransformer(embedding_model)
else:
logger.error("sentence-transformers not available!")
self.embedder = None
self.dangerous_threshold = 0.7 # >70% limitations/harmful = dangerous
async def train_clustering(
self,
dataset_entries: List[Dict[str, Any]],
n_clusters: int = 3,
method: str = "kmeans"
) -> TrainingResult:
"""
Train clustering model
Args:
dataset_entries: List of {text, cluster_category, domain, source}
n_clusters: Number of clusters (3 = good, limitations, harmful)
method: "kmeans" or "dbscan"
Returns:
TrainingResult with clusters and metrics
"""
if not self.embedder:
raise RuntimeError("Sentence transformers not available")
logger.info(f"\n{'='*60}")
logger.info(f"Training {method.upper()} Clustering")
logger.info(f"{'='*60}")
# Extract texts and labels
texts = [entry['text'] for entry in dataset_entries]
true_categories = [entry['cluster_category'] for entry in dataset_entries]
domains = [entry['domain'] for entry in dataset_entries]
logger.info(f"Total samples: {len(texts)}")
logger.info(f"Categories: {Counter(true_categories)}")
logger.info(f"Domains: {Counter(domains)}")
# Generate embeddings
logger.info("\n[1/4] Generating embeddings with sentence transformers...")
embeddings = await self._generate_embeddings(texts)
# Standardize
logger.info("[2/4] Standardizing embeddings...")
scaler = StandardScaler()
embeddings_scaled = scaler.fit_transform(embeddings)
# Perform clustering
logger.info(f"[3/4] Clustering with {method}...")
if method == "kmeans":
model, labels = self._cluster_kmeans(embeddings_scaled, n_clusters)
else: # dbscan
model, labels = self._cluster_dbscan(embeddings_scaled)
n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
# Calculate metrics
logger.info("[4/4] Analyzing clusters...")
silhouette = silhouette_score(embeddings_scaled, labels) if len(set(labels)) > 1 else 0.0
davies_bouldin = davies_bouldin_score(embeddings_scaled, labels) if len(set(labels)) > 1 else 999.0
logger.info(f"\nMetrics:")
logger.info(f" Clusters: {n_clusters}")
logger.info(f" Silhouette Score: {silhouette:.4f}")
logger.info(f" Davies-Bouldin Score: {davies_bouldin:.4f}")
# Analyze clusters
clusters = self._analyze_clusters(
labels, texts, true_categories, domains, dataset_entries
)
# Identify dangerous clusters
dangerous_clusters = [c for c in clusters if c.is_dangerous]
logger.info(f"\nDangerous clusters: {len(dangerous_clusters)}/{n_clusters}")
# Save model
model_path = self.output_dir / f"{method}_model.pkl"
self._save_model(model, scaler, model_path, clusters)
# Save embeddings
embeddings_path = self.output_dir / "embeddings.npy"
np.save(embeddings_path, embeddings)
return TrainingResult(
timestamp=datetime.now().isoformat(),
model_type=method,
embedding_model=self.embedding_model_name,
n_clusters=n_clusters,
silhouette_score=silhouette,
davies_bouldin_score=davies_bouldin,
clusters=clusters,
dangerous_clusters=dangerous_clusters,
model_path=str(model_path),
embeddings_path=str(embeddings_path)
)
async def _generate_embeddings(self, texts: List[str]) -> np.ndarray:
"""Generate embeddings using sentence transformers"""
embeddings = self.embedder.encode(
texts,
batch_size=32,
show_progress_bar=True,
convert_to_numpy=True,
normalize_embeddings=True # Important for cosine similarity
)
logger.info(f"Generated embeddings: {embeddings.shape}")
return embeddings
def _cluster_kmeans(
self, embeddings: np.ndarray, n_clusters: int
) -> Tuple[KMeans, np.ndarray]:
"""Perform K-Means clustering"""
model = KMeans(
n_clusters=n_clusters,
random_state=42,
n_init=20, # More initializations for better results
max_iter=500
)
labels = model.fit_predict(embeddings)
return model, labels
def _cluster_dbscan(
self, embeddings: np.ndarray, eps: float = 0.5, min_samples: int = 10
) -> Tuple[DBSCAN, np.ndarray]:
"""Perform DBSCAN clustering"""
model = DBSCAN(
eps=eps,
min_samples=min_samples,
metric='cosine',
n_jobs=-1
)
labels = model.fit_predict(embeddings)
n_noise = np.sum(labels == -1)
logger.info(f" DBSCAN noise points: {n_noise}")
return model, labels
def _analyze_clusters(
self,
labels: np.ndarray,
texts: List[str],
true_categories: List[str],
domains: List[str],
entries: List[Dict[str, Any]]
) -> List[ClusterResult]:
"""Analyze cluster composition and identify patterns"""
clusters = []
for cluster_id in set(labels):
if cluster_id == -1: # Skip noise in DBSCAN
continue
# Get cluster members
mask = labels == cluster_id
cluster_texts = [t for t, m in zip(texts, mask) if m]
cluster_categories = [c for c, m in zip(true_categories, mask) if m]
cluster_domains = [d for d, m in zip(domains, mask) if m]
# Category distribution
category_counts = Counter(cluster_categories)
total = len(cluster_categories)
category_dist = {cat: count/total for cat, count in category_counts.items()}
# Domain distribution
domain_dist = dict(Counter(cluster_domains))
# Calculate purity (max category %)
purity = max(category_dist.values()) if category_dist else 0.0
# Is this dangerous? (>70% limitations or harmful)
limitations_harmful_pct = (
category_dist.get('limitations', 0.0) +
category_dist.get('harmful', 0.0)
)
is_dangerous = limitations_harmful_pct > self.dangerous_threshold
# Extract keywords
keywords = self._extract_keywords(cluster_texts)
# Generate pattern description
primary_category = max(category_dist, key=category_dist.get)
primary_domain = max(domain_dist, key=domain_dist.get)
pattern_desc = f"{primary_category.upper()} cluster: {primary_domain}"
if is_dangerous:
pattern_desc += f" (DANGEROUS: {limitations_harmful_pct:.1%} limitations/harmful)"
# Generate detection heuristic
heuristic = self._generate_heuristic(
primary_category, primary_domain, keywords
)
# Representative examples
examples = cluster_texts[:5]
cluster_result = ClusterResult(
cluster_id=int(cluster_id),
size=len(cluster_texts),
category_distribution=category_dist,
domain_distribution=domain_dist,
purity=float(purity),
is_dangerous=is_dangerous,
examples=examples,
pattern_description=pattern_desc,
detection_heuristic=heuristic,
keywords=keywords
)
clusters.append(cluster_result)
# Log cluster info
logger.info(f"\nCluster {cluster_id}:")
logger.info(f" Size: {len(cluster_texts)}")
logger.info(f" Purity: {purity:.1%}")
logger.info(f" Categories: {category_dist}")
logger.info(f" Dangerous: {is_dangerous}")
logger.info(f" Pattern: {pattern_desc}")
return clusters
def _extract_keywords(self, texts: List[str], top_n: int = 10) -> List[str]:
"""Extract common keywords from cluster texts"""
all_text = " ".join(texts).lower()
words = re.findall(r'\b[a-z]{4,}\b', all_text)
# Remove common words
stopwords = {'this', 'that', 'with', 'from', 'have', 'what', 'which', 'would', 'could', 'should'}
words = [w for w in words if w not in stopwords]
word_counts = Counter(words)
return [word for word, count in word_counts.most_common(top_n)]
def _generate_heuristic(
self, category: str, domain: str, keywords: List[str]
) -> str:
"""Generate detection heuristic for this cluster"""
if category == "limitations":
if domain == "mathematics":
return "keyword_match: [integral, proof, theorem, equation] OR complexity_score > 0.7"
elif domain == "medicine":
return f"keyword_match: {keywords[:5]} AND domain=medicine"
elif domain == "coding":
return "contains_code AND (has_vulnerability OR cyclomatic_complexity > 10)"
else:
return f"keyword_match: {keywords[:5]}"
elif category == "harmful":
return f"safety_filter_trigger OR keyword_match: {keywords[:5]}"
else: # good
return f"domain={domain} AND low_complexity"
def _save_model(
self, model: Any, scaler: StandardScaler, path: Path, clusters: List[ClusterResult]
):
"""Save model with metadata"""
model_data = {
'model': model,
'scaler': scaler,
'clusters': [asdict(c) for c in clusters],
'dangerous_clusters': [c.cluster_id for c in clusters if c.is_dangerous],
'timestamp': datetime.now().isoformat(),
'embedding_model': self.embedding_model_name
}
with open(path, 'wb') as f:
pickle.dump(model_data, f)
logger.info(f"\nā Model saved to {path}")
async def main():
"""Main training pipeline"""
# Load dataset
dataset_path = Path("./data/datasets/combined_dataset.json")
if not dataset_path.exists():
logger.error(f"Dataset not found: {dataset_path}")
logger.info("Run enhanced_dataset_fetcher.py first!")
return
logger.info(f"Loading dataset from {dataset_path}")
with open(dataset_path, 'r') as f:
data = json.load(f)
# Flatten all entries
all_entries = []
for category, entries in data['categories'].items():
all_entries.extend(entries)
logger.info(f"Loaded {len(all_entries)} entries")
logger.info(f"Summary: {data['summary']}")
# Train clustering
trainer = EnhancedClusteringTrainer()
# Try K-Means with 3 clusters (good, limitations, harmful)
result = await trainer.train_clustering(all_entries, n_clusters=3, method="kmeans")
# Save results
results_path = Path("./data/training_results.json")
with open(results_path, 'w') as f:
json.dump({
**asdict(result),
'clusters': [asdict(c) for c in result.clusters],
'dangerous_clusters': [asdict(c) for c in result.dangerous_clusters]
}, f, indent=2)
logger.info(f"\nā Results saved to {results_path}")
# Export to ToGMAL ML tools cache
await export_to_ml_tools_cache(result)
async def export_to_ml_tools_cache(result: TrainingResult):
"""Export dangerous clusters to ToGMAL ML tools cache"""
patterns = []
for cluster in result.dangerous_clusters:
# Extract primary domain
primary_domain = max(cluster.domain_distribution, key=cluster.domain_distribution.get)
pattern = {
"id": f"cluster_{cluster.cluster_id}",
"domain": primary_domain,
"description": cluster.pattern_description,
"confidence": float(cluster.purity),
"heuristic": cluster.detection_heuristic,
"examples": cluster.examples[:3],
"keywords": cluster.keywords,
"metadata": {
"cluster_size": cluster.size,
"category_distribution": cluster.category_distribution,
"discovered_at": result.timestamp
}
}
patterns.append(pattern)
# Save to ML tools cache
ml_tools_cache = {
"updated_at": result.timestamp,
"patterns": patterns,
"metadata": {
"embedding_model": result.embedding_model,
"silhouette_score": result.silhouette_score,
"n_clusters": result.n_clusters,
"total_patterns": len(patterns)
}
}
cache_path = Path("./data/ml_discovered_tools.json")
cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(cache_path, 'w') as f:
json.dump(ml_tools_cache, f, indent=2)
logger.info(f"\nā Exported {len(patterns)} patterns to {cache_path}")
logger.info("\nDangerous patterns discovered:")
for pattern in patterns:
logger.info(f" - {pattern['domain']}: {pattern['description']}")
if __name__ == "__main__":
asyncio.run(main())