Swintransformer详细设计文档

1、文件说明

Model.py:构建模型
My_dataset.py:数据集处理
Predict.py:预测图片分类类别
Train.py:训练网络
Utils.py:

2、项目结构和函数设计

Model.py 的类

class DropPath(nn.Module)
	def forward(self, x)
class PatchEmbed(nn.Module)
	def forward(self, x)
class PatchMerging(nn.Module):
	def forward(self, x, H, W)
class Mlp(nn.Module):
	def forward(self, x):
class WindowAttention(nn.Module):
	def forward(self, x, mask: Optional[torch.Tensor] = None):
class SwinTransformerBlock(nn.Module):
	def forward(self, x, attn_mask):
class BasicLayer(nn.Module):
	def create_mask(self, x, H, W):
	def forward(self, x, H, W):
class SwinTransformer(nn.Module):
	def _init_weights(self, m):
	def forward(self, x)

Model.py 的函数

def drop_path_f(x, drop_prob: float = 0., training: bool = False)
def window_partition(x, window_size: int)
def window_reverse(windows, window_size: int, H: int, W: int)
def swin_tiny_patch4_window7_224(num_classes: int = 1000, **kwargs):
def swin_small_patch4_window7_224(num_classes: int = 1000, **kwargs):
def swin_base_patch4_window7_224(num_classes: int = 1000, **kwargs):
def swin_base_patch4_window12_384(num_classes: int = 1000, **kwargs):
def swin_base_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
def swin_base_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):
def swin_large_patch4_window7_224_in22k(num_classes: int = 21841, **kwargs):
def swin_large_patch4_window12_384_in22k(num_classes: int = 21841, **kwargs):

My_dataset.py只有类

class MyDataSet(Dataset):
---def __len__(self):
---def __getitem__(self, item):
   @staticmethod
---def collate_fn(batch):

Predict.py只有函数

def main():
if __name__ == '__main__':
    main()

Train.py只有函数

def main(args):
if __name__ == '__main__':
	。。。
	main(opt)

Utils.py只有函数

def read_split_data(root: str, val_rate: float = 0.2):
def plot_data_loader_image(data_loader):
def write_pickle(list_info: list, file_name: str):
def read_pickle(file_name: str) -> list:
def train_one_epoch(model, optimizer, data_loader, device, epoch):
@torch.no_grad()
def evaluate(model, data_loader, device, epoch):

Swin-Transformer 论文代码介绍

1 开发环境

 Python 3.6
 torch 1.7.1
 GPU

2 功能设计

实验数据集的说明:
数据来源
http://download.tensorflow.org/example_images/flower_photos.tgz
5类花的图片做分类:
3670 images were found in the dataset.
2939 images for training.
731 images for validation.

Daisy:菊花
Dandelion:蒲公英
Roses:玫瑰
Sunflowers:向日葵
Tulips:郁金香

3 、文件说明

Model.py:构建模型
My_dataset.py:数据集处理
Predict.py:预测图片分类类别
Train.py:训练网络
Utils.py:功能类函数
Model.py 的类
DropPath:设置各模块内的dropout率
PatchEmbed:对图片像素进行划分patch
PatchMerging:对图进行petch的拼接和线性映射
Mlp:SwinTransformerBlock后面一段的使用的
WindowAttention:window内部计算attention
SwinTransformerBlock:构建单个SwinTransformerBlock模型,该模型中含有W-MSA和SW-MSA两个模块
SwinTransformer:构建整个分类模型,这个类调用其他类,共同组成整个模型,从Patchpartion到LinearEmbedding(即类PatchEmbed),到四个SwinTransformerBlock,以及在SwinTransformerBlock中使用是否使用PatchMerging,经过四个阶段的SwinTransformerBlock之后输出展平的向量。
Model.py 的函数
window_partition:对特征图进行划分,划分成一个一个没有重叠的window
window_reverse:将window还原成特征图
定义各种模型,用于实例化模型
swin_tiny_patch4_window7_224
swin_small_patch4_window7_224
swin_base_patch4_window7_224
swin_base_patch4_window12_384
swin_base_patch4_window7_224_in22k
swin_base_patch4_window12_384_in22k
swin_large_patch4_window7_224_in22k
swin_large_patch4_window12_384_in22k
My_dataset.py只有类
MyDataSet(Dataset):构建获取数据集中元素和大小的方法
@staticmethod
collate_fn(batch):用于单独调用使用,将一个批次的图片转为向量并拼在一起
Predict.py只有函数
main(): 创建预测图片类别的函数,展示预测的图片以及被预测图片属于每个类别的概率
if name == ‘main’:
main()
开始预测
Train.py只有函数
main(args)
获取训练集和验证集,对图片进行处理,调整两个数据集中图片的大小,实例化模型,训练模型,保存模型。
自定义参数,解析参数,调用并执行main(args),训练分类模型
Utils.py只有函数
read_split_data:读取图片和图片的类别,划分训练集和验证集
train_one_epoch:
定义损失函数:torch.nn.CrossEntropyLoss()
进行一个epoch的训练,返回损失和精确率
Evaluate

4 流程

运行train.py训练模型,训练了个epoch,最高精确率可到96.6%
在这里插入图片描述

5 效果演示

运行predict.py对单独一张图片进行预测类别
在这里插入图片描述

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

饿了就干饭

你的鼓励将是我创作的最大动力~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值