topic_drift_detector.py•10.2 kB
#!/usr/bin/env python3
"""
轻量版主题漂移检测器
使用简单滑窗TF-IDF或余弦相似度检测对话主题变化,
当漂移阈值>0.5时清空候选集,确保AI不再引用无关记忆。
避免复杂的LLM分析,采用轻量级算法。
"""
import re
import math
from collections import defaultdict, deque
from typing import List, Dict, Set, Tuple, Optional
from datetime import datetime
class TopicDriftDetector:
"""轻量级主题漂移检测器"""
def __init__(self, window_size: int = 5, drift_threshold: float = 0.5,
min_similarity: float = 0.3):
"""
初始化主题漂移检测器
Args:
window_size: 滑动窗口大小(保留最近N个查询)
drift_threshold: 主题漂移阈值,超过此值认为主题发生漂移
min_similarity: 最小相似度阈值,低于此值认为完全不相关
"""
self.window_size = window_size
self.drift_threshold = drift_threshold
self.min_similarity = min_similarity
# 滑动窗口存储最近的查询
self.query_window = deque(maxlen=window_size)
# 词汇表和文档频率
self.vocabulary: Set[str] = set()
self.doc_freq: Dict[str, int] = defaultdict(int)
# 统计信息
self.drift_count = 0
self.total_queries = 0
print(f"🔄 主题漂移检测器初始化完成")
print(f" 窗口大小: {window_size}, 漂移阈值: {drift_threshold}, 最小相似度: {min_similarity}")
def _preprocess_text(self, text: str) -> List[str]:
"""文本预处理:分词、去停用词、小写化"""
# 简单的中英文分词
# 中文按字符分,英文按单词分
text = text.lower().strip()
# 移除标点符号和特殊字符
text = re.sub(r'[^\w\s\u4e00-\u9fff]', ' ', text)
# 分词
words = []
current_word = ""
for char in text:
if '\u4e00' <= char <= '\u9fff': # 中文字符
if current_word:
words.append(current_word)
current_word = ""
words.append(char)
elif char.isalnum(): # 英文字母数字
current_word += char
else: # 空格等分隔符
if current_word:
words.append(current_word)
current_word = ""
if current_word:
words.append(current_word)
# 过滤停用词和短词
stop_words = {'的', '了', '在', '是', '我', '你', '他', '她', '它', '们',
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at',
'to', 'for', 'of', 'with', 'by', 'i', 'you', 'he', 'she', 'it'}
words = [word for word in words if len(word) > 1 and word not in stop_words]
return words
def _calculate_tf_idf(self, words: List[str], all_docs: List[List[str]]) -> Dict[str, float]:
"""计算TF-IDF向量"""
# 计算词频 (TF)
tf = defaultdict(float)
total_words = len(words)
if total_words == 0:
return {}
for word in words:
tf[word] += 1.0 / total_words
# 计算TF-IDF
tf_idf = {}
total_docs = len(all_docs)
for word, tf_score in tf.items():
# 计算文档频率 (DF)
df = sum(1 for doc in all_docs if word in doc)
# 计算IDF
if df > 0:
idf = math.log(total_docs / df)
tf_idf[word] = tf_score * idf
else:
tf_idf[word] = 0.0
return tf_idf
def _cosine_similarity(self, vec1: Dict[str, float], vec2: Dict[str, float]) -> float:
"""计算两个TF-IDF向量的余弦相似度"""
if not vec1 or not vec2:
return 0.0
# 计算点积
dot_product = 0.0
for word in vec1:
if word in vec2:
dot_product += vec1[word] * vec2[word]
# 计算向量长度
norm1 = math.sqrt(sum(score ** 2 for score in vec1.values()))
norm2 = math.sqrt(sum(score ** 2 for score in vec2.values()))
if norm1 == 0.0 or norm2 == 0.0:
return 0.0
return dot_product / (norm1 * norm2)
def detect_drift(self, query: str) -> Tuple[bool, float, Dict]:
"""
检测主题漂移
Args:
query: 当前查询文本
Returns:
Tuple[bool, float, Dict]: (是否发生漂移, 平均相似度, 详细信息)
"""
self.total_queries += 1
# 预处理当前查询
current_words = self._preprocess_text(query)
if not current_words:
return False, 1.0, {"reason": "空查询", "similarities": []}
# 如果窗口还没满,直接添加
if len(self.query_window) < 2:
self.query_window.append(current_words)
return False, 1.0, {"reason": "窗口未满", "window_size": len(self.query_window)}
# 计算与窗口中所有查询的相似度
all_docs = list(self.query_window) + [current_words]
similarities = []
# 计算当前查询的TF-IDF向量
current_tfidf = self._calculate_tf_idf(current_words, all_docs)
for i, window_words in enumerate(self.query_window):
# 计算窗口中查询的TF-IDF向量
window_tfidf = self._calculate_tf_idf(window_words, all_docs)
# 计算余弦相似度
similarity = self._cosine_similarity(current_tfidf, window_tfidf)
similarities.append(similarity)
# 计算平均相似度
avg_similarity = sum(similarities) / len(similarities) if similarities else 0.0
# 判断是否发生主题漂移
is_drift = False
drift_reason = ""
if avg_similarity < self.min_similarity:
is_drift = True
drift_reason = f"平均相似度过低: {avg_similarity:.3f} < {self.min_similarity}"
elif (1.0 - avg_similarity) > self.drift_threshold:
is_drift = True
drift_reason = f"主题漂移超过阈值: {1.0 - avg_similarity:.3f} > {self.drift_threshold}"
if is_drift:
self.drift_count += 1
print(f"🔄 检测到主题漂移: {drift_reason}")
print(f" 当前查询: {query[:50]}...")
print(f" 相似度分布: {[f'{s:.3f}' for s in similarities]}")
# 添加当前查询到窗口
self.query_window.append(current_words)
# 返回检测结果
detail_info = {
"similarities": similarities,
"avg_similarity": avg_similarity,
"drift_reason": drift_reason if is_drift else "无漂移",
"window_queries": len(self.query_window),
"current_words": len(current_words)
}
return is_drift, avg_similarity, detail_info
def should_clear_candidates(self, query: str) -> Tuple[bool, str]:
"""
判断是否应该清空候选集
Args:
query: 当前查询
Returns:
Tuple[bool, str]: (是否清空, 原因)
"""
is_drift, avg_similarity, detail_info = self.detect_drift(query)
if is_drift:
reason = f"主题漂移检测: {detail_info['drift_reason']}"
return True, reason
return False, "主题连续,保持候选集"
def get_statistics(self) -> Dict:
"""获取检测统计信息"""
drift_rate = (self.drift_count / self.total_queries * 100) if self.total_queries > 0 else 0.0
return {
"total_queries": self.total_queries,
"drift_count": self.drift_count,
"drift_rate": f"{drift_rate:.1f}%",
"window_size": self.window_size,
"current_window_length": len(self.query_window),
"drift_threshold": self.drift_threshold,
"min_similarity": self.min_similarity
}
def reset(self):
"""重置检测器状态"""
self.query_window.clear()
self.vocabulary.clear()
self.doc_freq.clear()
self.drift_count = 0
self.total_queries = 0
print("🔄 主题漂移检测器已重置")
def test_topic_drift_detector():
"""测试主题漂移检测器"""
print("🧪 主题漂移检测器测试")
print("=" * 40)
detector = TopicDriftDetector(window_size=3, drift_threshold=0.5, min_similarity=0.3)
# 测试查询序列
test_queries = [
"Python编程基础知识",
"Python函数和类的使用",
"Python数据结构和算法",
"机器学习算法原理", # 主题开始漂移
"深度学习神经网络",
"自然语言处理技术",
"今天天气怎么样", # 完全不相关的主题
"明天的计划安排",
"Python web开发", # 回到编程主题
"Django框架使用"
]
print("\n📝 测试查询序列:")
for i, query in enumerate(test_queries, 1):
print(f" {i}. {query}")
print("\n🔍 主题漂移检测结果:")
print("-" * 50)
for i, query in enumerate(test_queries, 1):
should_clear, reason = detector.should_clear_candidates(query)
status = "🚨 清空候选集" if should_clear else "✅ 保持候选集"
print(f"{i:2d}. {query}")
print(f" {status}: {reason}")
print()
print("📊 检测统计:")
stats = detector.get_statistics()
for key, value in stats.items():
print(f" {key}: {value}")
return True
if __name__ == "__main__":
test_topic_drift_detector()