"""工具函数模块 - 提供文本处理、token 计算等基础功能"""
from config import TOKEN_ENCODING
import re
import sys
import subprocess
import tiktoken
from pathlib import Path
from typing import List, Dict
# 导入配置
sys.path.insert(0, str(Path(__file__).parent.parent))
# 使用配置的编码器
encoder = tiktoken.get_encoding(TOKEN_ENCODING)
def kill_port_process(port: int) -> bool:
"""
终止占用指定端口的进程
Args:
port: 端口号
Returns:
是否成功清理端口
"""
try:
# Windows 平台
if sys.platform == 'win32':
# 查找占用端口的进程
result = subprocess.run(
['netstat', '-ano'],
capture_output=True,
text=True,
encoding='gbk'
)
killed = False
for line in result.stdout.split('\n'):
if f':{port}' in line and 'LISTENING' in line:
parts = line.split()
if len(parts) >= 5:
pid = parts[-1]
try:
# 终止进程
subprocess.run(
['taskkill', '/F', '/PID', pid],
capture_output=True,
check=True
)
killed = True
except subprocess.CalledProcessError:
pass
return killed
else:
# Linux/Mac 平台
result = subprocess.run(
['lsof', '-ti', f':{port}'],
capture_output=True,
text=True
)
if result.stdout.strip():
pids = result.stdout.strip().split('\n')
for pid in pids:
try:
subprocess.run(['kill', '-9', pid], check=True)
except subprocess.CalledProcessError:
pass
return True
return False
except Exception:
# 如果清理失败,继续启动(让操作系统报错)
return False
def calculate_tokens(text: str) -> int:
"""
计算文本的 token 数量 (使用 tiktoken)
Args:
text: 输入文本
Returns:
token 数量
"""
if not text:
return 0
tokens = encoder.encode(text)
return len(tokens)
def normalize_chinese_punctuation(text: str) -> str:
"""
将中文标点符号转换为英文标点符号
Args:
text: 输入文本
Returns:
转换后的文本
"""
punctuation_map = {
'\u3002': '.', # 。
'\uFF01': '!', # !
'\uFF1F': '?', # ?
'\uFF0C': ',', # ,
'\u3001': ',', # 、
'\uFF1B': ';', # ;
'\uFF1A': ':', # :
'\u201C': '"', # "
'\u201D': '"', # "
'\u2018': "'", # '
'\u2019': "'", # '
'\uFF08': '(', # (
'\uFF09': ')', # )
'\u300A': '<', # 《
'\u300B': '>', # 》
'\u3010': '[', # 【
'\u3011': ']', # 】
'\u2026': '...', # …
'\u2014': '-', # —
'\uFF5E': '~' # ~
}
result = text
for chinese, english in punctuation_map.items():
result = result.replace(chinese, english)
return result
def split_by_sentence_endings(text: str) -> List[str]:
"""
根据句子结束标点符号分割文本
Args:
text: 输入文本
Returns:
分割后的句子数组
"""
# 句子结束标记
sentence_endings = re.compile(r'([.!?;]+)')
# 分割并保留标点符号
parts = sentence_endings.split(text)
sentences = []
for i in range(0, len(parts), 2):
content = parts[i]
punctuation = parts[i + 1] if i + 1 < len(parts) else ''
if content.strip():
sentences.append((content + punctuation).strip())
return sentences
def filter_meaningless(sentences: List[str]) -> List[str]:
"""
过滤无意义的内容
Args:
sentences: 句子数组
Returns:
过滤后的句子数组
"""
filtered = []
for sentence in sentences:
# 移除空字符串
if not sentence.strip():
continue
# 移除只包含标点符号的句子
if re.match(r'^[.,!?;:\s\-]+$', sentence):
continue
# 移除长度太短的句子(小于3个字符)
if len(sentence.strip()) < 3:
continue
filtered.append(sentence)
return filtered
def group_by_max_chars(sentences: List[str], max_chars: int) -> List[str]:
"""
根据最大字符数合并句子
Args:
sentences: 句子数组
max_chars: 最大字符数
Returns:
合并后的段落数组
"""
paragraphs = []
current_paragraph = ''
for sentence in sentences:
# 如果单个句子就超过最大长度,单独成段
if len(sentence) > max_chars:
if current_paragraph:
paragraphs.append(current_paragraph.strip())
current_paragraph = ''
paragraphs.append(sentence.strip())
continue
# 如果添加当前句子不会超过最大长度,则添加
if len(current_paragraph) + len(sentence) <= max_chars:
current_paragraph += sentence
else:
# 否则,保存当前段落,开始新段落
if current_paragraph:
paragraphs.append(current_paragraph.strip())
current_paragraph = sentence
# 添加最后一个段落
if current_paragraph:
paragraphs.append(current_paragraph.strip())
return paragraphs
def count_words(text: str) -> int:
"""
计算文本字数(移除空白字符)
Args:
text: 输入文本
Returns:
字数
"""
clean_text = text.replace('\n', '').replace(' ', '').replace('\t', '')
return len(clean_text)
def extract_chapter_number(filename: str) -> int:
"""
从文件名中提取章节号
文件名格式:章节号-字数改写.txt
Args:
filename: 文件名
Returns:
章节号,如果提取失败返回 999999
"""
try:
# 尝试提取 - 前面的数字
if '-' in filename:
return int(filename.split('-')[0])
# 如果没有 -,尝试提取开头的数字
match = re.match(r'^(\d+)', filename)
if match:
return int(match.group(1))
except (ValueError, AttributeError):
pass
# 如果提取失败,返回一个很大的数,让它排在后面
return 999999