1、文件说明
Model.py:构建模型
My_dataset.py:数据集处理
Predict.py:预测图片分类类别
Train.py:训练网络
Utils.py:
2、项目结构和函数设计
Model.py 的类
class DropPath(nn.Module)
def forward(self, x)
class PatchEmbed(nn.Module)
def forward(self, x)
class PatchMerging(nn.Module):
def forward(self, x, H, W)
class Mlp(nn.Module):
def forward(self, x):
class WindowAttention(nn.Module):
def forward(self, x, mask: Optional[torch.Tensor] = None):
class SwinTransformerBlock(nn.Module):
def forward(self, x, attn_mask):
class BasicLayer(nn.Module):
def create_mask(self, x, H, W):
def forward(self, x, H, W):
class SwinTransformer(nn.Module):
def _init_weights(self, m):
def forward(self, x)
Model.py 的函数
def drop_path_f(x, drop_prob: float = 0., training: bool = False)
def window_partition(x, window_size: int)
def window_reverse(windows, window_size: int, H: int, W: int)
def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):
def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs):
def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs):
def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs):
def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
My_dataset.py只有类
class MyDataSet(Dataset):
---def __len__(self):
---def __getitem__(self, item):
@staticmethod
---def collate_fn(batch):
Predict.py只有函数
def main():
if __name__ == '__main__':
main()
Train.py只有函数
def main(args):
if __name__ == '__main__':
。。。
main(opt)
Utils.py只有函数
def read_split_data(root: str, val_rate: float = 0.2):
def plot_data_loader_image(data_loader):
def write_pickle(list_info: list, file_name: str):
def read_pickle(file_name: str) -> list:
def train_one_epoch(model, optimizer, data_loader, device, epoch):
@torch.no_grad()
def evaluate(model, data_loader, device, epoch):
Swin-Transformer 论文代码介绍
1 开发环境
Python 3.6
torch 1.7.1
GPU
2 功能设计
实验数据集的说明:
数据来源
http://download.tensorflow.org/example_images/flower_photos.tgz
5类花的图片做分类:
3670 images were found in the dataset.
2939 images for training.
731 images for validation.
Daisy:菊花
Dandelion:蒲公英
Roses:玫瑰
Sunflowers:向日葵
Tulips:郁金香
3 、文件说明
Model.py:构建模型
My_dataset.py:数据集处理
Predict.py:预测图片分类类别
Train.py:训练网络
Utils.py:功能类函数
Model.py 的类
DropPath:设置各模块内的dropout率
PatchEmbed:对图片像素进行划分patch
PatchMerging:对图进行petch的拼接和线性映射
Mlp:SwinTransformerBlock后面一段的使用的
WindowAttention:window内部计算attention
SwinTransformerBlock:构建单个SwinTransformerBlock模型,该模型中含有W-MSA和SW-MSA两个模块
SwinTransformer:构建整个分类模型,这个类调用其他类,共同组成整个模型,从Patchpartion到LinearEmbedding(即类PatchEmbed),到四个SwinTransformerBlock,以及在SwinTransformerBlock中使用是否使用PatchMerging,经过四个阶段的SwinTransformerBlock之后输出展平的向量。
Model.py 的函数
window_partition:对特征图进行划分,划分成一个一个没有重叠的window
window_reverse:将window还原成特征图
定义各种模型,用于实例化模型
swin_tiny_patch4_window7_224
swin_small_patch4_window7_224
swin_base_patch4_window7_224
swin_base_patch4_window12_384
swin_base_patch4_window7_224_in22k
swin_base_patch4_window12_384_in22k
swin_large_patch4_window7_224_in22k
swin_large_patch4_window12_384_in22k
My_dataset.py只有类
MyDataSet(Dataset):构建获取数据集中元素和大小的方法
@staticmethod
collate_fn(batch):用于单独调用使用,将一个批次的图片转为向量并拼在一起
Predict.py只有函数
main(): 创建预测图片类别的函数,展示预测的图片以及被预测图片属于每个类别的概率
if name == ‘main’:
main()
开始预测
Train.py只有函数
main(args)
获取训练集和验证集,对图片进行处理,调整两个数据集中图片的大小,实例化模型,训练模型,保存模型。
自定义参数,解析参数,调用并执行main(args),训练分类模型
Utils.py只有函数
read_split_data:读取图片和图片的类别,划分训练集和验证集
train_one_epoch:
定义损失函数:torch.nn.CrossEntropyLoss()
进行一个epoch的训练,返回损失和精确率
Evaluate
4 流程
运行train.py训练模型,训练了个epoch,最高精确率可到96.6%
5 效果演示
运行predict.py对单独一张图片进行预测类别