一、前期工作
本文将采用CNN实现多云、下雨、晴、日出四种天气状态的识别。
%%time
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasets
import os, PIL, pathlib,random
1. 设置GPU
%%time
device = torch.device('cpu')
device
2. 导入数据
%%time
data_dir = './weather_photos/'
data_dir = pathlib.Path(data_dir)
data_paths = list(data_dir.glob('*'))
data_paths
classNames = [str(path).split('/')[1] for path in data_paths]
classNames
%%time
total_datadir = './weather_photos/'
train_transforms = transforms.Compose([
transforms.Resize([224,224]), #resize 图片
transforms.ToTensor(), #转换为Tensor并归一化
transforms.Normalize( #正态化
mean = [0.485, 0.456, 0.406],
std = [0.229, 0.224, 0.225] #这里是从数据集中抽样得到的,How?
)
])
total_data = datasets.ImageFolder(total_datadir,transform=train_transforms)
total_data
3. 划分数据集
%%time
train_size = int(0.8*len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data,[train_size,test_size])
train_dataset, test_dataset
train_size, test_size
数据集一共分为cloudy
、rain
、shine
、sunrise
四类,分别存放于weather_photos
文件夹中以各自名字命名的子文件夹中。
二、数据预处理
1. 加载数据
使用image_dataset_from_directory
方法将磁盘中的数据加载到tf.data.Dataset
中
我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。
2. 可视化数据
plt.figure(figsize=(20, 10))
for images, labels in train_ds.take(1):
for i in range(20):
ax = plt.subplot(5, 10, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
3. 再次检查数据
Image_batch
是形状的张量(32,180,180,3)。这是一批形状180x180x3的32张图片(最后一维指的是彩色通道RGB)。Label_batch
是形状(32,)的张量,这些标签对应32张图片
4. 配置数据集
%%time
batch_size = 32
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,shuffle=True,num_workers=1)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,shuffle=True,num_workers=1)
%%time
for X,y in test_dataloader:
print('Shape of X [N,C,H,W]',X.shape)
print('Shape of y:',y.shape,y.dtype)
break
三、构建CNN网络
特征提取网络
分类网络
torch.nn.conv2d()
torch.nn.MaxPool2d()
torch.nn.Liner()
torch.nn.Flatten()
注意bn层的作用
%%time
import torch.nn.functional as F
class Network_bn(nn.Module):
def __init__(self):
super(Network_bn, self).__init__()
self.conv1=nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0)
self.bn1=nn.BatchNorm2d(12)
self.conv2=nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)
self.bn2=nn.BatchNorm2d(12)
self.pool = nn.MaxPool2d(2,2)
self.conv3=nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0)
self.bn3=nn.BatchNorm2d(24)
self.conv4=nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)
self.bn4=nn.BatchNorm2d(24)
self.fc1=nn.Linear(24*50*50, len(classNames))
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = self.pool(x)
x = F.relu(self.bn3(self.conv3(x)))
x = F.relu(self.bn4(self.conv4(x)))
x = self.pool(x)
x = x.view(-1, 24*50*50)
x = self.fc1(x)
return x
四、编译
在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:
- 损失函数(loss):用于衡量模型在训练期间的准确率。
- 优化器(optimizer):决定模型如何根据其看到的数据和自身的损失函数进行更新。
- 指标(metrics):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。
device = torch.device('cpu')
%%time
model = Network_bn().to(device)
model
output:
Network_bn(
(conv1): Conv2d(3, 12, kernel_size=(5, 5), stride=(1, 1))
(bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(12, 12, kernel_size=(5, 5), stride=(1, 1))
(bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv3): Conv2d(12, 24, kernel_size=(5, 5), stride=(1, 1))
(bn3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv4): Conv2d(24, 24, kernel_size=(5, 5), stride=(1, 1))
(bn4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(fc1): Linear(in_features=60000, out_features=5, bias=True)
)
五、训练模型
1. 设置超参数
loss_fn = nn.CrossEntropyLoss()
learn_rate = 1e-4
opt = torch.optim.SGD(model.parameters(),lr=learn_rate)
2. 编写训练函数
def train(dataloader, model, loss_fn, optimizer):
size = len(dataloader.dataset)
num_batches = len(dataloader)
train_loss, train_acc = 0,0
for X,y in dataloader:
X,y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
train_loss += loss.item()
train_acc /=size
train_loss /=num_batches
return train_acc, train_loss
3. 编写测试函数
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, test_acc = 0,0
with torch.no_grad():
for imgs, target in dataloader:
imgs, target = imgs.to(device), target.to(device)
target_pred = model(imgs)
loss = loss_fn(target_pred, target)
test_loss += loss.item()
test_acc += (target_pred.argmax(1)==target).type(torch.float).sum().item()
test_acc /= size
test_loss /= num_batches
return test_acc, test_loss
4. 正式训练
%%time
epochs = 20
train_loss =[]
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
model.train()
epoch_train_acc, epoch_train_loss = train(train_dataloader,model,loss_fn, opt)
model.eval()
epoch_test_acc, epoch_test_loss = test(test_dataloader,model,loss_fn)
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
template = ('Epoch:{:2d},Train_acc:{:1f}%,Train_loss:{:3f},Test_acc:{:1f}%,Test_loss:{:3f}')
print(template.format(epoch+1,epoch_train_acc*100,epoch_train_loss,epoch_test_acc*100,epoch_test_loss))
print('Done!')
六、结果可视化
%%time
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
plt.rcParams['font.sans-serif'] = ['PingFang HK']
#plt.rcParams['font.sans-serif'] = ['Hiragino Sans GB']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['figure.dpi'] = 100
epochs_range = range(epochs)
plt.figure(figsize = (12, 3))
plt.subplot(1,2,1)
plt.plot(epochs_range, train_acc, label='训练集正确率')
plt.plot(epochs_range, test_acc, label='测试集正确率')
plt.legend(loc='lower right')
plt.title('训练和测试正确率')
plt.subplot(1,2,2)
plt.plot(epochs_range, train_loss, label='训练集损失率')
plt.plot(epochs_range, test_loss, label = '测试集损失率')
plt.legend(loc = 'upper right')
plt.title('训练和测试损失率')
出了个问题, 实际测试集正确率太低,待查找原因。
七、保存训练结果
%%time
PATH = './p3.pth'
torch.save(model.state_dict(),PATH)
model.load_state_dict(torch.load(PATH, map_location = device))