目录
一.nn.Module
pytorch中所有网络结构的一个父类
- 继承后可直接使用以下类
nn.Linear
nn.BatchNorm2d
nn.Conv2d
- nn.Module可嵌套
使用nn.Module的好处
1.提供很多现行的层
- Linear
- ReLU
- Sigmoid
- Conv2d
- ConvTransposed2d
- Dropout
- etc.
2.Container(Sequential类)
nn.Sequential()
可以调用pytorch的类,也可以调用自己写的类,只要这个类继承自nn.Module。这样只需要self.net(),就一次forward()完成所有类(比如152层)的forward()
3.parameters
optimizer = optim.SGD(net.parameters(),lr=le-3)
4.modules
5.to(device)
把类的操作转移到(网络结构搬到)CPU或者gpu
device = torch.device('cuda')#或者CPU
net = Net()
net = net.to(device)
6.save and load
加载和保存ckpt(),网络的一个中间状态;
device = torch.device('cuda')
net = Net()
net.to(device)
#模型开始时检查如果有ckpt,则加载
net.load_state_dict(torch.load('ckpt.mdl'))
# train ...
#把模型状态保存下来
torch.save(net.state_dict(),'ckpt.mdl')
7.train/test切换
#train
net.train()
...
#test
net.eval()
...
8.实现自己的类(继承自nn.Module)
只有class才能写到Sequential里面去
class Flatten(nn.Module):
def __init__(self):
super(Flatten,self).__init__()
def forward(self,input):
return input.view(input.size(0),-1)#把第一个维度保留,剩下维度打平
class TestNet(nn.Module):
def __init__(self):
super(TestNet,self).__init__()
self.net = nn.Sequential(nn.Conv2d(1,16,stride=1,padding=1,)
nn.MaxPool2d(2,2),
Flatten(),
nn.Linear(1*14*14,10))
def forward(self,x):
return self.net(x)
二.数据增强
The key to prevent Overfitting
Filp翻转
transforms.RandomHorizontalFlip(),#垂直
transforms.RandomVerticalFlip(),#水平
Rotate旋转
#-15°<0<15°旋转
transforms.RandomRotation(15),
#任选90°、180°、270°旋转
transforms.RandomRotation([90, 180, 270]),
Scale缩放
#把[28×28]改成[32×32]
transforms.Resize([32, 32]),
Crop Part裁剪/旋转部分
transforms.RandomCrop([28, 28]),
Noise
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(15),
transforms.RandomRotation([90, 180, 270]),
transforms.Resize([32, 32]),
transforms.RandomCrop([28, 28]),
transforms.ToTensor(),
# transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)