记录 pytorch 快速搭建模型

记录 pytorch 快速搭建模型


快速搭建pytorch模型模板(仅供参考)


前言

一个模型的搭建主要分为参数定义、网络模型定义、再到训练步骤,验证步骤,测试步骤。其主要大致分为四个文件:

1.hparams.py

2.datasets.py

3.model.py

4.run.py


主要内容:
1.以类的方式定义参数
2.定义自己的模型
3.定义early_stop类(可选)
4.定义数据集处理类
5.设置 loss 优化器
6.开始训练
7.绘图(可选)
8.预测

一、参数定义

import argparse
class Hparams:
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch_size', default=6, type=int,help="帮助解释")
#############################################################################
#调用参数
from hparams import Hparams
hparams = Hparams()
parser = hparams.parser
hp = parser.parse_args()
batch_size = hp.batch_szie  

二、模型定义

class Mymodel(nn.Module):
    def __init__(self):
        super(Mymodel, self).__init__()
        pass

    def forward(self,x):
        pass
        return x

三、Early_stop

防止过拟合,早停法

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):
		print("val_loss={}".format(val_loss))
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

四、数据集处理 Dataset, DataLoader


class Batch(TensorWrapper):
    """A wrapper of batch.
    """

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def __repr__(self):
        return "Batch with attributes: {}".format(", ".join(self.__dict__))

    @property
    def size(self):
        return self.seq_lens.size(0)

    @property
    def step(self):
        return self.seq_lens.max().item()

class Dataset(torch.utils.data.Dataset):
    def __init__(self,flag):
    	assert flag in ['train', 'test', 'valid']
        self.flag = flag
        self.texts = get_data(flag) #get the samples
        super().__init__()

    def __len__(self):
        # return the number of samples
        return len(self.texts)

    def __getitem__(self, i):
    	'''每次传入一个样例,返回处理后的样例
        '''
        pass
        return self.texts[i]

    def collate(self, batch_examples: List[dict]): 
    '''如何取样本的,定义自己的函数function来准确地实现想要的功能
        '''
        batch = function(batch_examples)  #function return [input_ids,labels]
        return Batch(**batch) #批次包装


train_data = Dataset('train')
train_loader = DataLoader(train_data, batch_size, shuffle=True, collate_fn=train_data.collate)

五、设置 loss 优化器

import numpy as np
import random
import torch
#设置随机种子
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)


# 设置 cuda 设备
cuda_condition = torch.cuda.is_available() and with_cuda
device = torch.device("cuda:0" if cuda_condition else "cpu")


model = Mymodel().to(device)

# 如果 CUDA 可以检测到 1 个以上的 GPU,则进行分布式 GPU 训练
if with_cuda and torch.cuda.device_count() > 1:
	print("Using %d GPUS" % torch.cuda.device_count())
	model = nn.DataParallel(model, device_ids=cuda_devices)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=hp.learning_rate)
print("Total Parameters:", sum([p.nelement() for p in model.parameters()]))

train_loss = []
valid_loss = []
train_epochs_loss = []
valid_epochs_loss = []

early_stopping = EarlyStopping(patience=hp.patience,verbose=True)

六、开始训练

for epoch in range(hp.epochs):
	model.train()
	train_epoch_loss = []
    batchs = tqdm(train_loader, leave=True)
    for batch in batchs:
        # initialize calculated gradients (from prev step)
        optim.zero_grad()
        # pull all tensor batches required for training
        batch.to(device)
        input_ids = batch.input_ids 
        labels = batch.labels 
        outputs = model(input_ids)
        loss = criterion(outputs, labels)
        loss.backward()
        optim.step()  # 参数更新
        batchs.set_description(f'Epoch {epoch}')
        batchs.set_postfix(loss=loss.item())
        
		train_epoch_loss.append(loss.item())
        train_loss.append(loss.item())
    train_epochs_loss.append(np.average(train_epoch_loss))
#=====================valid============================
    model.eval()
    valid_epoch_loss = []
    batchs = tqdm(valid_loader, leave=True)
    for batch in batchs:
        batch.to(device)
        outputs = Your_model(batch.input_ids)
        loss = criterion(outputs,batch.labels)
        valid_epoch_loss.append(loss.item())
        valid_loss.append(loss.item())
    valid_epochs_loss.append(np.average(valid_epoch_loss))
#==================early stopping======================
    early_stopping(valid_epochs_loss[-1],model=Mymodel,path=hp.path)
    if early_stopping.early_stop:
        print("Early stopping")
        break

七、绘图

plt.figure(figsize=(12,4))
plt.subplot(121)
plt.plot(train_loss[:])
plt.title("train_loss")
plt.subplot(122)
plt.plot(train_epochs_loss[1:],'-o',label="train_loss")
plt.plot(valid_epochs_loss[1:],'-o',label="valid_loss")
plt.title("epochs_loss")
plt.legend()
plt.show()

八、预测

# 此处可定义一个预测集的Dataloader。也可以直接将你的预测数据reshape,添加batch_size=1
Your_model.eval()
predict = Mymodel(Dataloader)
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值