EasyTemporalPointProcess-main / examples /train_robot_thp_with_features.py
Abigail99216's picture
Upload folder using huggingface_hub
f43af3c verified
"""
训练RobotTHP模型(带语义特征)
展示如何在EasyTPP框架中使用RobotTHP模型,并加载语义特征、偏差特征等
"""
import torch
from torch.utils.data import DataLoader
from easy_tpp.config_factory import DataSpecConfig
from easy_tpp.model import TorchRobotTHP
from easy_tpp.preprocess.robert_dataset import RobertTPPDataset
from easy_tpp.preprocess.robert_tokenizer import RobertEventTokenizer
from easy_tpp.preprocess.data_collator import TPPDataCollator
def prepare_robert_data():
"""
准备评论罗伯特数据(示例)
实际使用时,应该从JSON文件加载并处理
"""
# 示例数据
time_seqs = [
[0.0, 10.5, 25.3, 45.2],
[0.0, 5.2, 12.8]
]
type_seqs = [
[0, 1, 2, 1], # post, bot_comment, user_comment, user_comment
[0, 1, 2]
]
time_delta_seqs = [
[0.0, 10.5, 14.8, 19.9],
[0.0, 5.2, 7.6]
]
# 语义向量(示例:768维BERT向量)
semantic_vectors = [
[[0.1] * 768, [0.2] * 768, [0.3] * 768, [0.4] * 768],
[[0.1] * 768, [0.2] * 768, [0.3] * 768]
]
# 偏差特征(示例:3维 [语境偏差, 情感偏差, 困惑度])
deviation_features = [
[[0.0, 0.0, 0.0], [0.7, 0.5, 0.3], [0.2, 0.1, 0.1], [0.3, 0.2, 0.1]],
[[0.0, 0.0, 0.0], [0.6, 0.4, 0.2], [0.1, 0.1, 0.1]]
]
# 自发/被@标记(-1=不适用, 0=被@, 1=自发)
is_spontaneous = [
[-1.0, 1.0, -1.0, -1.0], # 原帖不适用, 罗伯特自发, 用户评论不适用
[-1.0, 0.0, -1.0] # 原帖不适用, 罗伯特被@, 用户评论不适用
]
return {
'time_seqs': time_seqs,
'type_seqs': type_seqs,
'time_delta_seqs': time_delta_seqs,
'semantic_vectors': semantic_vectors,
'deviation_features': deviation_features,
'is_spontaneous': is_spontaneous
}
def create_data_loader(data_dict, config, use_semantic=True, use_deviation=True):
"""
创建数据加载器
Args:
data_dict: 数据字典
config: 数据配置
use_semantic: 是否使用语义特征
use_deviation: 是否使用偏差特征
Returns:
DataLoader: 数据加载器
"""
# 创建数据集
dataset = RobertTPPDataset(data_dict)
# 创建分词器
tokenizer = RobertEventTokenizer(
config,
use_semantic=use_semantic,
use_deviation=use_deviation,
semantic_dim=768
)
# 创建数据整理器
padding = True if tokenizer.padding_strategy is None else tokenizer.padding_strategy
truncation = False if tokenizer.truncation_strategy is None else tokenizer.truncation_strategy
data_collator = TPPDataCollator(
tokenizer=tokenizer,
return_tensors='pt',
max_length=tokenizer.model_max_length,
padding=padding,
truncation=truncation
)
# 创建数据加载器
data_loader = DataLoader(
dataset,
collate_fn=data_collator,
batch_size=config.batch_size,
shuffle=True
)
return data_loader
def main():
"""主函数"""
print("=" * 60)
print("训练RobotTHP模型(带语义特征)")
print("=" * 60)
# 1. 准备数据
print("\n1. 准备数据...")
data_dict = prepare_robert_data()
print(f" 序列数: {len(data_dict['time_seqs'])}")
# 2. 创建配置
print("\n2. 创建配置...")
config = DataSpecConfig.parse_from_yaml_config({
'num_event_types': 4,
'batch_size': 2,
'pad_token_id': 4
})
# 3. 创建数据加载器
print("\n3. 创建数据加载器...")
data_loader = create_data_loader(
data_dict,
config,
use_semantic=True,
use_deviation=True
)
# 4. 创建模型配置
print("\n4. 创建模型...")
from easy_tpp.config_factory import ModelConfig
model_config = ModelConfig.parse_from_yaml_config({
'hidden_size': 128,
'num_layers': 3,
'num_heads': 6,
'dropout_rate': 0.1,
'num_event_types': 4,
'num_event_types_pad': 5,
'pad_token_id': 4,
'semantic_dim': 768,
'use_semantic': True,
'use_deviation': True,
'use_structure_mask': False,
'loss_integral_num_sample_per_step': 20,
'use_mc_samples': True,
'gpu': -1
})
model = TorchRobotTHP(model_config)
print(f" 模型参数数量: {sum(p.numel() for p in model.parameters()):,}")
# 5. 测试一个批次
print("\n5. 测试数据加载...")
for batch in data_loader:
# batch是BatchEncoding对象,需要转换为tuple/list
batch_values = batch.values()
print(f" 批次大小: {len(batch_values[0])}")
print(f" 序列长度: {batch_values[0].shape[1]}")
print(f" 时间序列形状: {batch_values[0].shape}")
print(f" 事件类型形状: {batch_values[2].shape}")
if len(batch_values) > 5:
print(f" 语义向量形状: {batch_values[5].shape if batch_values[5] is not None else 'None'}")
if len(batch_values) > 6:
print(f" 偏差特征形状: {batch_values[6].shape if batch_values[6] is not None else 'None'}")
if len(batch_values) > 7:
print(f" 自发标记形状: {batch_values[7].shape if batch_values[7] is not None else 'None'}")
# 6. 测试前向传播
print("\n6. 测试前向传播...")
model.eval()
with torch.no_grad():
loss, num_events = model.loglike_loss(batch_values)
print(f" 损失值: {loss.item():.4f}")
print(f" 事件数: {num_events}")
break
print("\n✅ 测试完成!")
print("\n使用说明:")
print("1. 将你的JSON数据转换为上述格式")
print("2. 使用RobertTPPDataset和RobertEventTokenizer加载数据")
print("3. 在EasyTPP配置文件中设置model_id为RobotTHP")
print("4. 运行训练即可")
if __name__ == '__main__':
main()