记录 pytorch 快速搭建模型
快速搭建pytorch模型模板(仅供参考)
文章目录
前言
一个模型的搭建主要分为参数定义、网络模型定义、再到训练步骤,验证步骤,测试步骤。其主要大致分为四个文件:
1.hparams.py
2.datasets.py
3.model.py
4.run.py
主要内容:
1.以类的方式定义参数
2.定义自己的模型
3.定义early_stop类(可选)
4.定义数据集处理类
5.设置 loss 优化器
6.开始训练
7.绘图(可选)
8.预测
一、参数定义
import argparse
class Hparams:
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=6, type=int,help="帮助解释")
#############################################################################
#调用参数
from hparams import Hparams
hparams = Hparams()
parser = hparams.parser
hp = parser.parse_args()
batch_size = hp.batch_szie
二、模型定义
class Mymodel(nn.Module):
def __init__(self):
super(Mymodel, self).__init__()
pass
def forward(self,x):
pass
return x
三、Early_stop
防止过拟合,早停法
class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
path (str): Path for the checkpoint to be saved to.
Default: 'checkpoint.pt'
trace_func (function): trace print function.
Default: print
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
self.path = path
self.trace_func = trace_func
def __call__(self, val_loss, model):
print("val_loss={}".format(val_loss))
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
'''Saves model when validation loss decrease.'''
if self.verbose:
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), self.path)
self.val_loss_min = val_loss
四、数据集处理 Dataset, DataLoader
class Batch(TensorWrapper):
"""A wrapper of batch.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def __repr__(self):
return "Batch with attributes: {}".format(", ".join(self.__dict__))
@property
def size(self):
return self.seq_lens.size(0)
@property
def step(self):
return self.seq_lens.max().item()
class Dataset(torch.utils.data.Dataset):
def __init__(self,flag):
assert flag in ['train', 'test', 'valid']
self.flag = flag
self.texts = get_data(flag) #get the samples
super().__init__()
def __len__(self):
# return the number of samples
return len(self.texts)
def __getitem__(self, i):
'''每次传入一个样例,返回处理后的样例
'''
pass
return self.texts[i]
def collate(self, batch_examples: List[dict]):
'''如何取样本的,定义自己的函数function来准确地实现想要的功能
'''
batch = function(batch_examples) #function return [input_ids,labels]
return Batch(**batch) #批次包装
train_data = Dataset('train')
train_loader = DataLoader(train_data, batch_size, shuffle=True, collate_fn=train_data.collate)
五、设置 loss 优化器
import numpy as np
import random
import torch
#设置随机种子
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
# 设置 cuda 设备
cuda_condition = torch.cuda.is_available() and with_cuda
device = torch.device("cuda:0" if cuda_condition else "cpu")
model = Mymodel().to(device)
# 如果 CUDA 可以检测到 1 个以上的 GPU,则进行分布式 GPU 训练
if with_cuda and torch.cuda.device_count() > 1:
print("Using %d GPUS" % torch.cuda.device_count())
model = nn.DataParallel(model, device_ids=cuda_devices)
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),lr=hp.learning_rate)
print("Total Parameters:", sum([p.nelement() for p in model.parameters()]))
train_loss = []
valid_loss = []
train_epochs_loss = []
valid_epochs_loss = []
early_stopping = EarlyStopping(patience=hp.patience,verbose=True)
六、开始训练
for epoch in range(hp.epochs):
model.train()
train_epoch_loss = []
batchs = tqdm(train_loader, leave=True)
for batch in batchs:
# initialize calculated gradients (from prev step)
optim.zero_grad()
# pull all tensor batches required for training
batch.to(device)
input_ids = batch.input_ids
labels = batch.labels
outputs = model(input_ids)
loss = criterion(outputs, labels)
loss.backward()
optim.step() # 参数更新
batchs.set_description(f'Epoch {epoch}')
batchs.set_postfix(loss=loss.item())
train_epoch_loss.append(loss.item())
train_loss.append(loss.item())
train_epochs_loss.append(np.average(train_epoch_loss))
#=====================valid============================
model.eval()
valid_epoch_loss = []
batchs = tqdm(valid_loader, leave=True)
for batch in batchs:
batch.to(device)
outputs = Your_model(batch.input_ids)
loss = criterion(outputs,batch.labels)
valid_epoch_loss.append(loss.item())
valid_loss.append(loss.item())
valid_epochs_loss.append(np.average(valid_epoch_loss))
#==================early stopping======================
early_stopping(valid_epochs_loss[-1],model=Mymodel,path=hp.path)
if early_stopping.early_stop:
print("Early stopping")
break
七、绘图
plt.figure(figsize=(12,4))
plt.subplot(121)
plt.plot(train_loss[:])
plt.title("train_loss")
plt.subplot(122)
plt.plot(train_epochs_loss[1:],'-o',label="train_loss")
plt.plot(valid_epochs_loss[1:],'-o',label="valid_loss")
plt.title("epochs_loss")
plt.legend()
plt.show()
八、预测
# 此处可定义一个预测集的Dataloader。也可以直接将你的预测数据reshape,添加batch_size=1
Your_model.eval()
predict = Mymodel(Dataloader)