- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊 | 接辅导、项目定制
- 🚀 文章来源:K同学的学习圈子
前言
主要介绍DenseNet模型,它的基本思路与ResNet一致,但是它建立的是前面所有层与后面层的密集连接,通过特征图在channel上的连接来实现特征重用
一、模型介绍
相比ResNet, DenseNet提出了一个更为激进的密集连接机制: 即互相连接所有的层,具体来说就是每个层都会接受其前面所有层作为其额外的输入。
下图为ResNet网络的短路连接机制
下图为DenseNet网络的短路连接机制
二、网络结构
1.DenseNet网络结构
DenseNet网络中使用DenseBlock + Transition的结构,其中DenseBlock是包含很多层的模块,每个层的特征图大小相同,层与层之间采用密集连接方式。而Transition层是两个相邻的DenseBlock,并且通过pooling使特征图大小降低。
2. DenseBlock + Transition结构
3. DenseBlock 非线性结构
在DenseBlock中, 各个层的特征图大小一致, 可以在channel维度上连接,DenseBlock基本结构是BN + ReLU +(33)Conv的结构,如下图所示
由于后面层的输入会非常大, DenseBlock内部可以采用bottleneck层来减少计算量, 主要是原有的结构增加11的Conv, 即BN + ReLU + 11Conv + BN + ReLU +33Conv, 称为DenseBlock 结构
三、代码实现
1.导入相关包
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import datasets, transforms
import torch.nn.functional as F
from torchsummary import summary
import os,PIL,pathlib,warnings
from collections import OrderedDict
2. DenseBlock 内部结构
class DenseLayer(nn.Module):
def __init__(self, in_channels, growth_rate):
super(DenseLayer, self).__init__()
self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1, bias=False)
self.bn = nn.BatchNorm2d(growth_rate)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.relu(out)
out = torch.cat((x, out), dim=1)
return out
3.DenseBlock
class DenseBlock(nn.Module):
def __init__(self, in_channels, growth_rate, num_layers):
super(DenseBlock, self).__init__()
self.layers = nn.ModuleList([DenseLayer(in_channels + i * growth_rate, growth_rate) for i in range(num_layers)])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
4. Transition 层
class TransitionLayer(nn.Module):
def __init__(self, in_channels, out_channels):
super(TransitionLayer, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.relu(out)
out = self.avgpool(out)
return out
5.DenseNet网路
class DenseNet(nn.Module):
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=10):
super(DenseNet, self).__init__()
# initial convolutional layer
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
# dense blocks and transition layers
in_channels = 64
for i, num_layers in enumerate(block_config):
block = DenseBlock(in_channels, growth_rate, num_layers)
self.features.add_module(f'denseblock{i + 1}', block)
in_channels = in_channels + num_layers * growth_rate
if i != len(block_config) - 1:
trans = TransitionLayer(in_channels, in_channels // 2)
self.features.add_module(f'transition{i + 1}', trans)
in_channels = in_channels // 2
# final batch normalization and classification layer
self.features.add_module('bn', nn.BatchNorm2d(in_channels))
self.features.add_module('relu', nn.ReLU(inplace=True))
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(in_channels, num_classes)
def forward(self, x):
features = self.features(x)
out = self.avgpool(features)
out = out.view(features.size(0), -1)
out = self.classifier(out)
return out
6.网络模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = DenseNet().to(device)
summary(model, (3, 224, 224))
四、在CIFAR10上训练
1.加载CIFAR10
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 定义预处理转换
transform = transforms.Compose([
transforms.ToTensor(), # 将 PIL 图像转换为张量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
])
# 加载 CIFAR-10 数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
2.训练和测试函数
from tqdm import tqdm
def train (dataloader, model,loss_fn,optimizer):
size = len(dataloader.dataset)
num_batches = len(dataloader)
train_loss = 0
train_acc = 0
par = tqdm(dataloader)
for x,y in par:
x,y = x.to(device),y.to(device)
pred = model(x)
loss = loss_fn(pred,y)
train_loss += loss.item()
train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
par.set_description(desc= f'loss={loss.item():.4f}')
train_acc /= size
train_loss /= num_batches
return train_acc, train_loss
def test (dataloader, model,loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss = 0
test_acc = 0
with torch.no_grad():
par = tqdm(dataloader)
for x,y in par:
x,y = x.to(device),y.to(device)
pred = model(x)
test_loss += loss_fn(pred,y).item()
test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
par.set_description(desc= f'loss={test_loss:.4f}')
test_acc /= size
test_loss /= num_batches
return test_acc, test_loss
3.正式训练
loss_fn = nn.CrossEntropyLoss()
epochs = 50
train_loss = []
train_acc = []
test_loss = []
test_acc = []
best_acc = 0
learning_rate = 1e-3
optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
for epoch in range(1,epochs+1):
model.train()
epoch_train_acc,epoch_train_loss = train(train_loader,model,loss_fn,optimizer)
model.eval()
epoch_test_acc,epoch_test_loss = test(test_loader,model,loss_fn)
train_loss.append(epoch_train_loss)
train_acc.append(epoch_train_acc)
test_loss.append(epoch_test_loss)
test_acc.append(epoch_test_acc)
lr = optimizer.param_groups[0]['lr']
tmp = "Epoch:{},lr:{:.6f},train_loss:{:.4f},train_acc:{:.4f},test_loss:{:.4f},test_acc:{:.4f}"
print(tmp.format(epoch,lr,epoch_train_loss,epoch_train_acc,epoch_test_loss,epoch_test_acc))
五、结果可视化
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 600 #分辨率
epochs_range = range(epochs)
plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()