全连接网络
logistic问题其实就是一个很小的全连接网络,全连接网络简单的说就是每个输出节点都与下层的每个节点相连。
下面的是一个简单三层全连接网络和它逐渐添加了优化学习的代码。
import torch
from torch import nn, optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
class simpleNet(torch.nn.Module):#简单的全连接网络
def __init__(self, in_dims, n_hidden1, n_hidden2, out_dims):
super(simpleNet, self).__init__()
self.layer1 = torch.nn.Linear(in_dims, n_hidden1)
self.layer2 = torch.nn.Linear(n_hidden1, n_hidden2)
self.layer3 = torch.nn.Linear(n_hidden2, out_dims)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
class Activation_Net(torch.nn.Module):#加了激活函数的全连接网络
def __init__(self, in_dims, n_hidden1, n_hidden2, out_dims):
super(Activation_Net, self).__init__()
self.layer1 = torch.nn.Sequential(#进行非线性变换
torch.nn.Linear(in_dims, n_hidden1), torch.nn.ReLU(True)
)
self.layer2 = torch.nn.Sequential( # 进行非线性变换
torch.nn.Linear(n_hidden1, n_hidden2), torch.nn.ReLU(True)
)
self.layer3 = torch.nn.Sequential(
torch.nn.Linear(n_hidden2, out_dims)
)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
class Batch_Net(torch.nn.Module):#批标准化
def __init__(self, in_dims, n_hidden1, n_hidden2, out_dims):
super(Batch_Net, self).__init__()
self.layer1 = torch.nn.Sequential(
torch.nn.Linear(in_dims, n_hidden1),
torch.nn.BatchNorm1d(n_hidden1), torch.nn.ReLU(True)
)
self.layer2 = torch.nn.Sequential(
torch.nn.Linear(n_hidden1, n_hidden2),
torch.nn.BatchNorm1d(n_hidden2), torch.nn.ReLU(True)
)
self.layer3 = torch.nn.Sequential(
torch.nn.Linear(n_hidden2, out_dims)
)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
data_tf = transforms.Compose(#预处理操作
[transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])]
)
#数据集下载
train_dataset = datasets.MNIST(
root="./data", train = True, transform = data_tf, download=True
)
test_dataset = datasets.MNIST(
root='./data', train=False, transform=data_tf
)
#数据读取
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
model = net.simpleNet(28 * 28, 300, 100, 10)
criterion = nn.CrossEntropyLoss()#损失函数 损失函数交叉熵
optimizer = optim.SGD(model.parameters(), lr = 0.01)#随机梯度下降
for epoch in range(10):#训练网络
for i, (inputs, labels) in enumerate(train_loader, 0):
inputs = inputs.view(inputs.size(0), -1)#将张量转化为一维
inputs = Variable(inputs)
labels = Variable(labels)
out = model(inputs)
loss = criterion(out,labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch)
卷积网络
我这里就不对卷积网络进行讲解了,网上书上都写的很好,我只记一下搭建卷积层所需要的函数。
-
nn.Conv2d()
此函数有五个参数
in_channel:输入数据体的通道数
out_channel: 输出数据体的通道数
kernel_size: 卷积核的大小
padding: 填充数量
stride: 滑动步长 -
nn.MaxPool2d()
此函数为最大值池化,一般我们都使用最大池化,特殊情况可能会使用均值池化。
kernel_size:卷积核大小
stride: 滑动步长
padding: 填充数量
return_indices:是否返回最大值坐标
import torch
from torch import nn
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
layer1 = nn.Sequential()
layer1.add_module('conv1', nn.Conv2d(3, 32, 3, 1, padding=1))
layer1.add_module('relu1', nn.ReLU(True))
layer1.add_module('pool1', nn.MaxPool2d(2,2))
self.layer1 = layer1
layer2 = nn.Sequential()
layer2.add_module('conv2', nn.Conv2d(32, 64, 3, 1, padding=1))
layer2.add_module('relu2', nn.ReLU(True))
layer2.add_module('pool2', nn.MaxPool2d(2, 2))
self.layer2 = layer2
layer3 = nn.Sequential()
layer3.add_module('conv3', nn.Conv2d(64, 128, 3, 1, padding=1))
layer3.add_module('relu3', nn.ReLU(True))
layer3.add_module('pool3', nn.MaxPool2d(2, 2))
self.layer3 = layer3
#全连接层
layer4 = nn.Sequential()
layer4.add_module('fc1', nn.Linear(2048, 512))
layer4.add_module('fc_relu', nn.ReLU(True))
layer4.add_module('fc2', nn.Linear(512, 64))
layer4.add_module('fc_relu', nn.ReLU(True))
layer4.add_module('fc3', nn.Linear(64, 10))
self.layer4 = layer4
def forward(self, x):
conv1 = self.layer1(x)
conv2 = self.layer2(conv1)
conv3 = self.layer3(conv2)
fc_input = conv3.view(conv3.size(0), -1)#调整为一维
fc_out = self.layer4(fc_input)
return fc_out
if __name__ == '__main__':
model = SimpleCNN()
print(model)
好久没写了,书看的好慢啊。。。