pytorch6--nn.Module、数据增强

目录

一.nn.Module

使用nn.Module的好处

1.提供很多现行的层

2.Container(Sequential类)

3.parameters

4.modules

5.to(device)

6.save and load

7.train/test切换

8.实现自己的类(继承自nn.Module)

二.数据增强

Filp翻转

Rotate旋转

Scale缩放

Crop Part裁剪/旋转部分

Noise


一.nn.Module

pytorch中所有网络结构的一个父类

  • 继承后可直接使用以下类

nn.Linear

nn.BatchNorm2d

nn.Conv2d

  • nn.Module可嵌套

使用nn.Module的好处

1.提供很多现行的层

  • Linear
  • ReLU
  • Sigmoid
  • Conv2d
  • ConvTransposed2d
  • Dropout
  • etc.

2.Container(Sequential类)

nn.Sequential()

可以调用pytorch的类,也可以调用自己写的类,只要这个类继承自nn.Module。这样只需要self.net(),就一次forward()完成所有类(比如152层)的forward()

3.parameters

optimizer = optim.SGD(net.parameters(),lr=le-3)

4.modules

5.to(device)

把类的操作转移到(网络结构搬到)CPU或者gpu

  • 对于一个tensor a, 需要写成 a = a.to(device)
  • 对于一个网络net, 可以直接写成 net.to(device)
   device = torch.device('cuda')#或者CPU
   net = Net()
   net = net.to(device)

6.save and load

加载和保存ckpt(),网络的一个中间状态;

device = torch.device('cuda')
net = Net()
net.to(device)

#模型开始时检查如果有ckpt,则加载
net.load_state_dict(torch.load('ckpt.mdl'))

# train ...

#把模型状态保存下来
torch.save(net.state_dict(),'ckpt.mdl')

7.train/test切换

#train
net.train()
...

#test
net.eval()
...

8.实现自己的类(继承自nn.Module)

只有class才能写到Sequential里面去

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten,self).__init__()

    def forward(self,input):
        return input.view(input.size(0),-1)#把第一个维度保留,剩下维度打平
        
class TestNet(nn.Module):
    
    def __init__(self):
        super(TestNet,self).__init__()
        self.net = nn.Sequential(nn.Conv2d(1,16,stride=1,padding=1,)
                   nn.MaxPool2d(2,2),
                   Flatten(),
                   nn.Linear(1*14*14,10))

    def forward(self,x):
        return self.net(x)

二.数据增强

The key to prevent Overfitting

Filp翻转

transforms.RandomHorizontalFlip(),#垂直
transforms.RandomVerticalFlip(),#水平

Rotate旋转

#-15°<0<15°旋转
transforms.RandomRotation(15),
#任选90°、180°、270°旋转
transforms.RandomRotation([90, 180, 270]),

Scale缩放

#把[28×28]改成[32×32]
transforms.Resize([32, 32]),

Crop Part裁剪/旋转部分

transforms.RandomCrop([28, 28]),

Noise

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.RandomHorizontalFlip(),
                       transforms.RandomVerticalFlip(),
                       transforms.RandomRotation(15),
                       transforms.RandomRotation([90, 180, 270]),
                       transforms.Resize([32, 32]),
                       transforms.RandomCrop([28, 28]),
                       transforms.ToTensor(),
                       # transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值