1.train.py
import torch
import torchvision.datasets
from torch.utils.data import DataLoader
from torch import nn
from torch.nn import Sequential,Conv2d,MaxPool2d,Linear,Flatten
import time
from torch.utils.tensorboard import SummaryWriter
#GPU
device=torch.device("cuda") if torch.cuda.is_available() else "cpu"
print("---{}---".format(device))
#tensorboard
writer=SummaryWriter("logs_song")
#todo 1 准备数据
train_data=torchvision.datasets.CIFAR10(root="data_set",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data=torchvision.datasets.CIFAR10(root="data_set",train=False,transform=torchvision.transforms.ToTensor(),download=True)
print("训练集长度为:{}, 测试集长度为:{}".format(len(train_data),len(test_data)))
#todo 2 加载数据
train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)
#todo 3 创建模型
class Song(nn.Module):
def __init__(self) :
super(Song,self).__init__()
self.module=Sequential(
Conv2d(in_channels=3,out_channels=32,kernel_size=5,stride=1,padding=2),
MaxPool2d(2),
Conv2d(in_channels=32,out_channels=32,kernel_size=5,stride=1,padding=2),
MaxPool2d(2),
Conv2d(32,64,5,stride=1,padding=2),
MaxPool2d(2),
Flatten(),
Linear(64*16,64),
Linear(64,10)
)
def forward(self, x):
x=self.module(x)
return x
song=Song()
song.to(device)
#todo 4 定义损失函数
loss_fn=nn.CrossEntropyLoss()
loss_fn.to(device)
#todo 5 定义优化器
learning_rate=1e-2 # 或0.01
optimizer=torch.optim.SGD(song.parameters(),lr=learning_rate)
#todo 6 设置训练网络需要的参数
total_train_step=0
total_test_step=0
epoch=10
start_time=time.time()
#todo 7 开始训练
for i in range(epoch):
print("第{}轮训练开始:".format(i+1))
song.train()
for data in train_dataloader:
imgs,targets=data
imgs=imgs.to(device)
targets=targets.to(device)
outputs=song(imgs)
Loss=loss_fn(outputs,targets)
#todo 8 开始优化
optimizer.zero_grad()
Loss.backward()
optimizer.step()
total_train_step+=1
writer.add_images("train_data",imgs,total_train_step) #可视化训练数据
if total_train_step%100==0:
end_time=time.time()
print(end_time-start_time)
print("训练次数:{},训练Loss:{}".format(total_train_step,Loss.item()))
writer.add_scalar("train_loss",Loss.item(),total_train_step) #可视化训练损失
#todo 9 开始测试
song.eval()
total_test_loss=0
total_accuracy=0
with torch.no_grad():
for data in test_dataloader:
imgs,targets=data
imgs=imgs.to(device)
targets=targets.to(device)
outputs=song(imgs)
loss=loss_fn(outputs,targets)
accuracy=(outputs.argmax(1)==targets).sum() #正确率
total_test_loss+=loss.item()
total_accuracy+=accuracy
print("整个测试集的loss为:{},准确率为:{}".format(total_test_loss,(total_accuracy/len(test_data))))
total_test_step+=1
writer.add_images("test_data",imgs,total_test_step) #可视化测试数据
writer.add_scalar("test_loss",total_test_loss,total_test_step) #可视化测试损失
#torch.save(song,"song1_{}.pth".format(i+1)) #保存的2种方式
torch.save(song.state_dict(),"song2_{}.pth".format(i+1))
print("模型已保存!")
2.test.py
from PIL import Image
import torchvision.transforms
import torch
from torch import nn
from torch.nn import Sequential,Conv2d,MaxPool2d,Linear,Flatten
img_path="/home/slam/deeplearning/CIFAR10/R.jpg"
image=Image.open(img_path)
image=image.convert("RGB") #转3通道
print(image)
print(image.size) #没有totensor之前不能shape ,转之后是张量
channels = image.getbands()
print(channels) #通道数
transform=torchvision.transforms.Compose([torchvision.transforms.Resize((32,32)), torchvision.transforms.ToTensor()])
image=transform(image)
print(image.shape)
class Song(nn.Module):
def __init__(self) :
super(Song,self).__init__()
self.module=Sequential(
Conv2d(in_channels=3,out_channels=32,kernel_size=5,stride=1,padding=2),
MaxPool2d(2),
Conv2d(in_channels=32,out_channels=32,kernel_size=5,stride=1,padding=2),
MaxPool2d(2),
Conv2d(32,64,5,stride=1,padding=2),
MaxPool2d(2),
Flatten(),
Linear(64*16,64),
Linear(64,10)
)
def forward(self, x):
x=self.module(x)
return x
#model=torch.load("song1_1.pth",map_location="cpu") #加载模型方式 1
#model=torch.load("song2_1.pth",map_location="cpu") #这样加载方式2 的模型 只有参数
model=Song() #所以需要先创建模型实例,再加载参数
model.load_state_dict(torch.load("song2_1.pth"))
print(model)
image=torch.reshape(image,(1,3,32,32)) #形状调整
model.eval() #开启评估
with torch.no_grad():
output=model(image)
print(output)
print(output.argmax(1))