Dataset.py
import torch
from torch.utils.data import Dataset
clas Dataset(Dataset):
def __init__(self):
def __len__(self): return len
def __getitem(self)__:
def process_meta(self, filename):处理train.txt、val.txt给getitem方法提供一些所需的数据。
def reprocess(self, data, idxs):处理一batch数据,利用pad_2D、pad_1D将一batch数据统一长度并捆成到一个batch里。
def collate_fn(self, data):指派数据集中每一条数据的索引,并将数据集中所有数据分配为len(data)%batch_size + batch_size*n,然后在for batch in idx_array:调用reprocess方法处理每一batch数据。
在 PyTorch 中,Dataset 是一个抽象类,它规定了 PyTorch 中数据集的基本特征和行为。用户可以通过继承 Dataset 类,实现自己的数据集类,并根据需要定义自己的 getitem() 和 len() 方法,来处理数据集中的每个样本以及整个数据集的长度。通过继承 Dataset 类并实现自己的数据集类,我们就能够很方便地将数据集对象传递给 DataLoader,并调用其加载数据的方法,实现对数据集的方便管理和使用。
train.py
def main(args, configs):
preprocess_config, model_config, train_config = configs
#实例化自定义的Dataset类
dataset = Dataset(filename, preprocess_config, train_config, sort=True, drop_last=True)
batch_size = train_config["optimizer"]["batch_size"]
group_size = 4
assert batch_size * group_size < len(dataset)
#实例化DataLoader类
loader = DataLoader(dataset, batch_size=batch_size*group_size, shuffle=True, collate_fn=dataset.collate_fn,)
#定义模型
model, optimizer = get_model(args, configs, device, train=True)
#将模型放到服务器上并最大化利用gpu
model = nn.DataParallel(model)
#加载损失函数
Loss = 损失类名(preprocess_config, model_config).to(device)
省略一些关于超参数的定义
#传输真实数据了
for batchs in loader:
for batch in batchs:
batch = to_device(batch, device)
#forward
output = model(*(batch[2:]))
#计算损失
losses = Loss(batch, output)
#自定义损失类返回的是一个元组,第一个元素为total_loss
total_loss = loss[0]
#backward
# 因为grad_acc_step代表梯度累积步数(虚拟批量大小)
total_loss = total_loss / grad_acc_step
total_loss.backward()
if step % grad_acc_step == 0:
如果有超出阈值grad_clip_thresh的梯度就将所有梯度都按比例缩小直到没有超出阈值的了。
#更新参数和学习率 然后清空梯度
optimizer._step_and_updata()
optimizer.zero()
#如果到达打印损失信息步数就输出当前步数的损失信息
#到达保存模型步数:
if step%save_step == 0:
torch.save(
{
"model": model.module.state_dict(),
"optimizer": optimizer._optimizer.state_dict(),
},
os.path.join(
train_config["path"]["ckpt_path"],
"{}.pth.tar".format(step),
),
)
#到达总步数
if step == total_step:
quit()
step += 1
epoch += 1