FictionAgent / core /text_processor.py
gdwind's picture
Upload folder using huggingface_hub
a226682 verified
import re
from typing import List, Dict, Tuple # 确保有 Dict
from tqdm import tqdm
from utils.text_utils import TextUtils
from config import Config
class TextProcessor:
"""大规模文本处理器"""
def __init__(self):
self.text_utils = TextUtils()
def chunk_text(self, text: str, chunk_size: int = None,
overlap: int = None) -> List[Dict]:
"""将长文本分块,保持语义完整性
Args:
text: 输入文本
chunk_size: 每块的最大字符数
overlap: 块之间的重叠字符数
Returns:
分块结果列表,每个元素包含 text, start, end, chunk_id
"""
chunk_size = chunk_size or Config.MAX_CHUNK_SIZE
overlap = overlap or Config.CHUNK_OVERLAP
# 先按段落分割
paragraphs = text.split('\n\n')
chunks = []
current_chunk = ""
current_start = 0
total_processed = 0
print(f"开始分块处理 (块大小: {chunk_size}, 重叠: {overlap})...")
for para in tqdm(paragraphs, desc="分块进度"):
para = para.strip()
if not para:
continue
# 如果当前块加上新段落超过限制
if len(current_chunk) + len(para) + 2 > chunk_size: # +2 for \n\n
if current_chunk:
# 保存当前块
chunks.append({
'text': current_chunk.strip(),
'start': current_start,
'end': current_start + len(current_chunk),
'chunk_id': len(chunks)
})
# 计算重叠部分
if len(current_chunk) > overlap:
# 从当前块末尾取重叠部分
overlap_text = current_chunk[-overlap:]
# 尝试在句子边界处分割
sentences = self.text_utils.split_into_sentences(overlap_text)
if sentences:
overlap_text = sentences[-1] if len(sentences) == 1 else ' '.join(sentences[-2:])
else:
overlap_text = current_chunk
# 更新起始位置
total_processed += len(current_chunk) - len(overlap_text)
current_start = total_processed
# 开始新块
current_chunk = overlap_text + "\n\n" + para
else:
# 当前块为空,直接使用新段落
current_chunk = para
current_start = total_processed
else:
# 添加到当前块
if current_chunk:
current_chunk += "\n\n" + para
else:
current_chunk = para
# 添加最后一块
if current_chunk:
chunks.append({
'text': current_chunk.strip(),
'start': current_start,
'end': current_start + len(current_chunk),
'chunk_id': len(chunks)
})
print(f"✓ 文本分块完成: 总共 {len(chunks)} 块")
return chunks
def chunk_text_by_tokens(self, text: str, max_tokens: int = 1500,
overlap_tokens: int = 150) -> List[Dict]:
"""按 token 数量分块(更精确但较慢)
Args:
text: 输入文本
max_tokens: 每块的最大 token 数
overlap_tokens: 重叠的 token 数
Returns:
分块结果列表
"""
sentences = self.text_utils.split_into_sentences(text)
chunks = []
current_chunk = []
current_tokens = 0
current_start = 0
print(f"按 token 分块处理 (最大: {max_tokens} tokens)...")
for sentence in tqdm(sentences, desc="处理句子"):
sentence_tokens = self.text_utils.count_tokens(sentence)
if current_tokens + sentence_tokens > max_tokens and current_chunk:
# 保存当前块
chunk_text = ' '.join(current_chunk)
chunks.append({
'text': chunk_text,
'start': current_start,
'end': current_start + len(chunk_text),
'chunk_id': len(chunks),
'token_count': current_tokens
})
# 处理重叠
overlap_chunk = []
overlap_tokens_count = 0
for s in reversed(current_chunk):
s_tokens = self.text_utils.count_tokens(s)
if overlap_tokens_count + s_tokens <= overlap_tokens:
overlap_chunk.insert(0, s)
overlap_tokens_count += s_tokens
else:
break
current_chunk = overlap_chunk + [sentence]
current_tokens = overlap_tokens_count + sentence_tokens
current_start += len(chunk_text) - len(' '.join(overlap_chunk))
else:
current_chunk.append(sentence)
current_tokens += sentence_tokens
# 添加最后一块
if current_chunk:
chunk_text = ' '.join(current_chunk)
chunks.append({
'text': chunk_text,
'start': current_start,
'end': current_start + len(chunk_text),
'chunk_id': len(chunks),
'token_count': current_tokens
})
print(f"✓ Token 分块完成: 总共 {len(chunks)} 块")
return chunks
def extract_dialogues(self, text: str) -> List[Dict]:
"""提取对话片段
Args:
text: 输入文本
Returns:
对话列表,每个元素包含 content, attribution, position
"""
# 检测语言
language = self.text_utils.detect_language(text)
dialogues = []
if language == "zh":
# 中文对话模式
patterns = [
(r'"([^"]+)"[,,]?\s*([^说道讲告诉问答叫喊]*(?:说|道|讲|告诉|问|答|叫|喊))', 'chinese_quote'),
(r'「([^」]+)」[,,]?\s*([^说道讲]*(?:说|道|讲))', 'chinese_bracket'),
(r'"([^"]+)"', 'simple_quote'),
]
else:
# 英文对话模式
patterns = [
(r'"([^"]+)",?\s+([A-Z][a-z]+\s+(?:said|asked|replied|shouted|whispered|muttered|exclaimed))', 'english_quote_said'),
(r'"([^"]+)"', 'simple_quote'),
(r"'([^']+)',?\s+([A-Z][a-z]+\s+said)", 'english_single_quote'),
]
for pattern, pattern_type in patterns:
matches = re.finditer(pattern, text, re.IGNORECASE)
for match in matches:
dialogue = {
'content': match.group(1).strip(),
'attribution': match.group(2).strip() if len(match.groups()) > 1 else '',
'position': match.start(),
'type': pattern_type
}
# 过滤太短的对话
if len(dialogue['content']) > 5:
dialogues.append(dialogue)
# 按位置排序
dialogues.sort(key=lambda x: x['position'])
return dialogues
def split_by_chapters(self, text: str) -> List[Dict]:
"""按章节分割文本
Args:
text: 输入文本
Returns:
章节列表,每个元素包含 title, content, chapter_num
"""
# 检测章节标记模式
chapter_patterns = [
r'Chapter\s+(\d+)[:\s]*([^\n]*)', # English: Chapter 1: Title
r'第([一二三四五六七八九十百千零\d]+)章[:\s]*([^\n]*)', # Chinese: 第一章:标题
r'CHAPTER\s+([IVXLCDM]+)[:\s]*([^\n]*)', # Roman numerals
]
chapters = []
last_pos = 0
for pattern in chapter_patterns:
matches = list(re.finditer(pattern, text, re.IGNORECASE | re.MULTILINE))
if matches:
for i, match in enumerate(matches):
start = match.start()
end = matches[i + 1].start() if i + 1 < len(matches) else len(text)
chapters.append({
'chapter_num': match.group(1),
'title': match.group(2).strip() if len(match.groups()) > 1 else '',
'content': text[start:end].strip(),
'start': start,
'end': end
})
break # 找到匹配的模式就停止
# 如果没找到章节,返回整个文本作为一章
if not chapters:
chapters.append({
'chapter_num': '1',
'title': 'Full Text',
'content': text,
'start': 0,
'end': len(text)
})
return chapters
def get_statistics(self, text: str) -> Dict:
"""获取文本统计信息
Args:
text: 输入文本
Returns:
统计信息字典
"""
# 基本统计
total_length = len(text)
total_tokens = self.text_utils.count_tokens(text)
# 段落统计
paragraphs = [p for p in text.split('\n\n') if p.strip()]
paragraph_count = len(paragraphs)
# 句子统计
sentences = self.text_utils.split_into_sentences(text)
sentence_count = len(sentences)
# 单词/字符统计
words = re.findall(r'\b\w+\b', text)
word_count = len(words)
# 语言检测
language = self.text_utils.detect_language(text)
# 对话统计
dialogues = self.extract_dialogues(text[:10000]) # 只检查前10000字符
dialogue_count = len(dialogues)
# 章节检测
chapters = self.split_by_chapters(text)
chapter_count = len(chapters)
return {
'total_length': total_length,
'total_tokens': total_tokens,
'paragraphs': paragraph_count,
'sentences': sentence_count,
'words': word_count,
'language': language,
'dialogues': dialogue_count,
'chapters': chapter_count,
'avg_paragraph_length': total_length // paragraph_count if paragraph_count > 0 else 0,
'avg_sentence_length': total_length // sentence_count if sentence_count > 0 else 0,
}
def clean_text(self, text: str,
remove_extra_whitespace: bool = True,
normalize_quotes: bool = True) -> str:
"""清理文本
Args:
text: 输入文本
remove_extra_whitespace: 是否移除多余空白
normalize_quotes: 是否标准化引号
Returns:
清理后的文本
"""
cleaned = text
# 移除多余空白
if remove_extra_whitespace:
# 移除行首行尾空白
cleaned = '\n'.join(line.strip() for line in cleaned.split('\n'))
# 合并多个空行为一个
cleaned = re.sub(r'\n{3,}', '\n\n', cleaned)
# 移除制表符
cleaned = cleaned.replace('\t', ' ')
# 合并多个空格
cleaned = re.sub(r' {2,}', ' ', cleaned)
# 标准化引号
if normalize_quotes:
# 中文引号统一为 ""
cleaned = cleaned.replace('『', '"').replace('』', '"')
cleaned = cleaned.replace('「', '"').replace('」', '"')
# 英文引号统一为 ""
cleaned = cleaned.replace('"', '"').replace('"', '"')
cleaned = cleaned.replace(''', "'").replace(''', "'")
return cleaned
def extract_metadata(self, text: str) -> Dict:
"""提取文本元数据(标题、作者等)
Args:
text: 输入文本
Returns:
元数据字典
"""
metadata = {
'title': None,
'author': None,
'year': None,
}
# 尝试从文本开头提取标题和作者
lines = text.split('\n')[:20] # 只看前20行
for line in lines:
line = line.strip()
# 尝试匹配标题
if not metadata['title'] and len(line) > 5 and len(line) < 100:
# 如果是全大写或标题格式
if line.isupper() or line.istitle():
metadata['title'] = line
# 尝试匹配作者
author_patterns = [
r'by\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)',
r'作者[::]\s*(.+)',
r'Author[:\s]+(.+)',
]
for pattern in author_patterns:
match = re.search(pattern, line, re.IGNORECASE)
if match:
metadata['author'] = match.group(1).strip()
break
# 尝试匹配年份
year_match = re.search(r'\b(19|20)\d{2}\b', line)
if year_match:
metadata['year'] = year_match.group(0)
return metadata
def sample_text(self, text: str, sample_size: int = 1000,
strategy: str = 'random') -> str:
"""从文本中采样
Args:
text: 输入文本
sample_size: 采样大小(字符数)
strategy: 采样策略 ('start', 'random', 'distributed')
Returns:
采样的文本
"""
if len(text) <= sample_size:
return text
if strategy == 'start':
# 从开头采样
return text[:sample_size]
elif strategy == 'random':
# 随机位置采样
import random
start = random.randint(0, len(text) - sample_size)
return text[start:start + sample_size]
elif strategy == 'distributed':
# 分布式采样(从文本的不同部分采样)
num_samples = 3
sample_per_part = sample_size // num_samples
samples = []
for i in range(num_samples):
start = (len(text) // num_samples) * i
end = min(start + sample_per_part, len(text))
samples.append(text[start:end])
return '\n...\n'.join(samples)
else:
return text[:sample_size]