fs2_ori笔记

该文描述了在PyTorch中如何自定义数据集类,通过继承`Dataset`并实现`__init__`,`__len__`,`__getitem__`方法。接着,文章展示了如何实例化`DataLoader`进行数据加载,以及如何定义和使用模型、损失函数、优化器进行训练。训练过程中包含了梯度累积和梯度裁剪的策略,并定期保存模型。
摘要由CSDN通过智能技术生成

Dataset.py

import torch
from torch.utils.data import Dataset

clas Dataset(Dataset):
    def __init__(self):
    def __len__(self): return len
    def __getitem(self)__:
    def process_meta(self, filename):处理train.txt、val.txt给getitem方法提供一些所需的数据。
    def reprocess(self, data, idxs):处理一batch数据,利用pad_2D、pad_1D将一batch数据统一长度并捆成到一个batch里。
    def collate_fn(self, data):指派数据集中每一条数据的索引,并将数据集中所有数据分配为len(data)%batch_size + batch_size*n,然后在for batch in idx_array:调用reprocess方法处理每一batch数据。
    

在 PyTorch 中,Dataset 是一个抽象类,它规定了 PyTorch 中数据集的基本特征和行为。用户可以通过继承 Dataset 类,实现自己的数据集类,并根据需要定义自己的 getitem() 和 len() 方法,来处理数据集中的每个样本以及整个数据集的长度。通过继承 Dataset 类并实现自己的数据集类,我们就能够很方便地将数据集对象传递给 DataLoader,并调用其加载数据的方法,实现对数据集的方便管理和使用。

train.py

def main(args, configs):
    preprocess_config, model_config, train_config = configs
    
    #实例化自定义的Dataset类
    dataset = Dataset(filename, preprocess_config, train_config, sort=True, drop_last=True)
    
    batch_size = train_config["optimizer"]["batch_size"]
    group_size = 4
    assert batch_size * group_size < len(dataset)
    
    #实例化DataLoader类
    loader = DataLoader(dataset, batch_size=batch_size*group_size, shuffle=True, collate_fn=dataset.collate_fn,)
    
    #定义模型
    model, optimizer = get_model(args, configs, device, train=True)
    
    #将模型放到服务器上并最大化利用gpu
    model = nn.DataParallel(model)
    
    #加载损失函数
    Loss = 损失类名(preprocess_config, model_config).to(device)
    省略一些关于超参数的定义
    
    #传输真实数据了
    for batchs in loader:
        for batch in batchs:
            batch = to_device(batch, device)
            
            #forward
            output = model(*(batch[2:]))
            
            #计算损失
            losses = Loss(batch, output)
            #自定义损失类返回的是一个元组,第一个元素为total_loss
            total_loss = loss[0]
            
            #backward
            # 因为grad_acc_step代表梯度累积步数(虚拟批量大小)
            total_loss = total_loss / grad_acc_step
            total_loss.backward()
            
            if step % grad_acc_step == 0:
                如果有超出阈值grad_clip_thresh的梯度就将所有梯度都按比例缩小直到没有超出阈值的了。
                
            #更新参数和学习率 然后清空梯度
            optimizer._step_and_updata()
            optimizer.zero()
            
            #如果到达打印损失信息步数就输出当前步数的损失信息
            
            #到达保存模型步数:
            if step%save_step == 0:
                torch.save(
                {
                    "model": model.module.state_dict(),
                    "optimizer": optimizer._optimizer.state_dict(),
                },
                os.path.join(
                    train_config["path"]["ckpt_path"],
                    "{}.pth.tar".format(step),
                ),
                )
                
            #到达总步数
            if step == total_step:
                quit()
            step += 1
        epoch += 1

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值