『pytorch』pytorch 代码积累


持续更新 …


定义自己的数据集 Dataset, Dataloader

class Dataset_name(Dataset):
    def __init__(self, flag='train'):
        assert flag in ['train', 'test', 'valid']
        self.flag = flag
        self.__load_data__()

    def __getitem__(self, index):
        pass
    def __len__(self):
        pass

    def __load_data__(self, csv_paths: list):
        pass
        print("train_X.shape:{}\ntrain_Y.shape:{}\n" \
        	  "valid_X.shape:{}\nvalid_Y.shape:{}\n".format(
        	  self.train_X.shape, self.train_Y.shape, self.valid_X.shape, self.valid_Y.shape))

train_dataset = Dataset_name(flag='train')
train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
valid_dataset = Dataset_name(flag='valid')
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=64, shuffle=True)

开始训练并调整学习率

for epoch in range(args.epochs):
    Your_model.train()
    train_epoch_loss = []
    for idx,(data_x,data_y) in enumerate(train_dataloader, 0):
        data_x = data_x.to(torch.float32).to(args.device)
        data_y = data_y.to(torch.float32).to(args.device)
        outputs = Your_model(data_x)
        optimizer.zero_grad()
        loss = criterion(data_y,outputs)
        loss.backward()
        optimizer.step()
        train_epoch_loss.append(loss.item())
        train_loss.append(loss.item())
        if idx % (len(train_dataloader)//2) == 0:
            print("epoch={}/{},{}/{}of train, loss={}".format(
                  epoch, args.epochs, idx, len(train_dataloader),loss.item()))
                  
    train_epochs_loss.append(np.average(train_epoch_loss))

    # =====================valid============================
    Your_model.eval()
    valid_epoch_loss = []
    for idx,(data_x,data_y) in enumerate(valid_dataloader, 0):
        data_x = data_x.to(torch.float32).to(args.device)
        data_y = data_y.to(torch.float32).to(args.device)
        outputs = Your_model(data_x)
        loss = criterion(outputs,data_y)
        valid_epoch_loss.append(loss.item())
        valid_loss.append(loss.item())
    valid_epochs_loss.append(np.average(valid_epoch_loss))
    
    #==================early stopping======================
    early_stopping(valid_epochs_loss[-1], 
    			   model=Your_model,path=r'c:\\your_model_to_save')
    			   
    if early_stopping.early_stop:
        print("Early stopping")
        break
    #====================adjust lr========================
    lr_adjust = {
            2: 5e-5, 
            4: 1e-5, 
            6: 5e-6, 
            8: 1e-6,
            10: 5e-7, 
            15: 1e-7
        }
    if epoch in lr_adjust.keys():
        lr = lr_adjust[epoch]
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        print('Updating learning rate to {}'.format(lr))

load_model

模型的参数中,key值不同,多了 module

state_dict = torch.load('checkpoint.pt')
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module
    new_state_dict[name] = v 
model.load_state_dict(new_state_dict) # 从新加载这个模型

编写预测类

以下代码用于加载一个预先训练好的模型,进行模型预测,大都可以使用,不排除特定情况!

from ceevee.base import AbstractPredictor

class MySuperPredictor(AbstractPredictor):
	def __init__(self, weights_path: str, ):
       	super().__init__()
       	self.model = self._load_model(weights_path=weights_path)
    def process(self, x, *kw):
    	with torch.no_grad():
    		res = self.model(x)
    		return res
    
    @staticmethod
    def _load_model(weights_path):
    	model = ModelClass()
    	weights = torch.load(weights_path, map_location='cpu')
    	model.load_state_dict(weights)
    	return model
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

libo-coder

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值