import torch
from torch import nn
import torchvision
from torchvision import datasets
from torch.nn import functional as F
class Lenet5(nn.Module):
def __init__(self):
super(Lenet5, self).__init__()
self.conv_unit = nn.Sequential(
# x:[b,3,32,32] === >[b,6, , ]
nn.Conv2d(3,6,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
nn.Conv2d(6,16,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
)
# flatten
# fc unit
self.fc_unit = nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
# 使用交叉熵损失 softmax和loss操作统一给这个函数了
# self.crition = nn.CrossEntropyLoss()
# x:[b,3,32,32]
# tmp = torch.randn(2,3,32,32)
# out = self.conv_unit(tmp)
# print(out.shape)
def forward(self,x):
# [b,3,32,32,] => [b,16,5,5]
batchsz = x.size(0)
x = self.conv_unit(x)
# [b,16,5,5] =>[b,16 * 5 * 5]
x= x.view(batchsz,-1)
# [b, 16 * 5 * 5] => [b,10]
logits = self.fc_unit(x)
return logits
from torchvision import transforms
from torch.utils.data import DataLoader
def main():
batchsz = 32
cifar_train = datasets.CIFAR10('cifar',True,transform=transforms.Compose(
[
transforms.Resize((32,32)),
transforms.ToTensor()
]
),download=True,)
cifar_train = DataLoader(cifar_train,batch_size=batchsz,shuffle=True)
cifar_test = datasets.CIFAR10('cifar',False,transform=transforms.Compose(
[
transforms.Resize((32,32)),
transforms.ToTensor()
]
),download=True,)
cifar_test = DataLoader(cifar_test,batch_size=batchsz,shuffle=False)
x,label = iter(cifar_train).next()
print("x",x.shape," label",label.shape)
device = torch.device("cuda")
net = Lenet5()
net.to(device)
# 损失函数 包含softmax和 loss操作
crition = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(),lr=1e-3)
for epoch in range(1000):
net.train()
for batchidx,(x,label) in enumerate(cifar_train):
x,label = x.to(device),label.to(device)
# x: [b,10] label:[b]
x = net(x)
loss = crition(x,label)
# 梯度清零
optimizer.zero_grad()
# 计算梯度值
loss.backward()
# 将梯度值代入损失函数进行计算,并往下迭代一步 更新到weight
optimizer.step()
# 转化为numpy
print(epoch,loss.item())
# test
net.eval()
with torch.no_grad():
total_correct = 0
total_num = 0
for x,label in cifar_test:
x,label = x.to(device),label.to(device)
# [b,10]
logits = net(x)
# [b]
pred = logits.argmax(dim=1)
total_correct+=torch.eq(pred,label).float().sum().item()
total_num+=x.size(0)
print("epoch ",epoch," acc ",total_correct/total_num)
if __name__ == '__main__':
main()
# net = Lenet5()
# tmp = torch.randn(2,3,32,32)
# out = net(tmp)
# # [2,10]
# print(out.shape)
cifar10实战图像分类
最新推荐文章于 2024-09-23 14:07:18 发布