Pytorch学习(二十) ------ 可能有用的代码合集

总说

记录一些比较有用的pytorch代码(有些是自己写的, 有些是从网上看到的)
本文和Pytorch学习(十五)------ 杂项知识汇总估计会经常更新, 记录一些琐碎的东西.

目录

  • 提取网络特征(适用于sequential构建的网络)
  • 修改复杂Pretrained的网络(如ResNet等)
  • 预训练模型的简单更改(如ResNet等)
  • Sequential网络中插入自定义层
  • 高斯模糊
  • 自定义操作放入Transform中
  • 网络不同层用不同学习率的方法

网络不同层用不同学习率的方法

https://github.com/hytseng0509/CrossDomainFewShot/blob/master/methods/LFTNet.py

class LFTNet(nn.Module):
  def __init__(self, params, tf_path=None, change_way=True):
    super(LFTNet, self).__init__()
    #
    # Some code here
    #
    model_params, ft_params = self.split_model_parameters()
    self.model_optim = torch.optim.Adam(model_params)
    self.ft_optim = torch.optim.Adam(ft_params, weight_decay=1e-8, lr=1e-3)

    # total epochs
    self.total_epoch = params.stop_epoch

  # split the parameters of feature-wise transforamtion layers and others
  def split_model_parameters(self):
    model_params = []
    ft_params = []
    for n, p in self.model.named_parameters():
      n = n.split('.')
      if n[-1] == 'gamma' or n[-1] == 'beta':
        ft_params.append(p)
      else:
        model_params.append(p)
    return model_params, ft_params

这种方法,就是直接在定义网络时,将参数进行分离,并用独立的优化器。另一种方法可以参考:
Pytorch学习(二十三)---- 不同layer用不同学习率(高级版本)

提取网络特征

对于内部是sequence构建的网络

class VGG16FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        vgg16 = models.vgg16(pretrained=True)
        self.enc_1 = nn.Sequential(*vgg16.features[:5])
        self.enc_2 = nn.Sequential(*vgg16.features[5:10])
        self.enc_3 = nn.Sequential(*vgg16.features[10:17])

        # fix the encoder
        for i in range(3):
            # 这种写法挺好的啊!!!!!!!!!!!
            for param in getattr(self, 'enc_{:d}'.format(i + 1)).parameters():
                param.requires_grad = False

    def forward(self, image):
        results = [image]
        for i in range(3):
            func = getattr(self, 'enc_{:d}'.format(i + 1))
            results.append(func(results[-1]))
return results[1:]

对于内部不是单纯用sequential构建的网络

import torch
import torchvision.models as models
import torch.nn as nn


class ResNetFeat(nn.Module):
    def __init__(self)
  • 5
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值