
1. 代码详细解释
1. 第一段代码
这段代码首先定义了一些参数,包括编码器个数、输入维度、句子长度、词嵌入维度等。然后它保存了这些超参数到指定路径。接着,它加载训练和验证数据集,并创建了对应的数据加载器。之后,它定义了一个模型,使用了一个叫做DSCTransformer的模型,以及交叉熵损失函数和Adam 优化器。最后,它将模型移动到可用的设备(如果有 GPU 则移动到 GPU,否则移动到 CPU)。
def train(model_save_path, train_result_path, val_result_path, hp_save_path, epochs=100):
#定义参数
N = 4 # 编码器个数
input_dim = 1024 # 输入维度
seq_len = 16 # 句子长度
d_model = 64 # 词嵌入维度
d_ff = 256 # 全连接层维度
head = 4 # 注意力头数
dropout = 0.1 # Dropout 比率
lr = 3E-5 # 学习率
batch_size = 64 # 批大小
# 保存超参数
hyper_parameters = {
'任务编码器堆叠数: ': '{}'.format(N),
'全连接层维度: ': '{}'.format(d_ff),
'任务注意力头数: ': '{}'.format(head),
'dropout: ': '{}'.format(dropout),
'学习率: ': '{}'.format(lr),
'batch_size: ': '{}'.format(batch_size)}
fs = open(hp_save_path, 'w') # 打开文件以保存超参数
fs.write(str(hyper_parameters)) # 将超参数写入文件
fs.close() # 关闭文件
# 加载数据
train_path = r'.\data\train\train.csv' # 训练数据路径
val_path = r'.\data\val\val.csv' # 验证数据路径
train_dataset = MyDataset(train_path, 'fd') # 加载训练数据集
val_dataset = MyDataset(val_path, 'fd') # 加载验证数据集
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) # 创建训练数据加载器
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=True, drop_last=True) # 创建验证数据加载器
# 定义模型
model = DSCTransformer(input_dim=input_dim, num_classes=10, dim=d_model, depth=N,
heads=head, mlp_dim=d_ff, dim_head=d_model, emb_dropout=dropout, dropout=dropout) # 初始化模型
criterion = nn.CrossEntropyLoss() # 定义损失函数
params = [p for p in model.parameters() if p.requires_grad] # 获取模型参数
optimizer = optim.Adam(params, lr=lr) # 定义优化器
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # 判断是否有可用的 GPU
print("using {} device.".format(device)) # 打印使用的设备
model.to(device) # 将模型移动到对应的设备(GPU 或 CPU)
2. 第二段代码
这段代码是一个训练循环,它用于在每个训练周期(epoch)中训练模型,并在每个周期结束后评估模型的性能。在每个训练周期中,代码首先使用模型在训练数据上进行训练,然后使用模型在验证数据上进行验证,并打印出每个周期的训练损失、训练准确率、验证损失和验证准确率。
best_acc_fd = 0.0 # 初始化最佳准确率为0
train_result = [] # 记录训练结果
result_train_loss = [] # 记录训练损失
result_train_acc = [] # 记录训练准确率
val_result = [] # 记录验证结果
result_val_loss = [] # 记录验证损失
result_val_acc = [] # 记录验证准确率
# 训练循环
for epoch in range

最低0.47元/天 解锁文章
1527

被折叠的 条评论
为什么被折叠?



