Swin Unet 代码运行教程(总结了多个训练和测试的问题)

发现网上尽然没有 swin unet 的运行教程呜呜,那我来出一份吧

环境准备

克隆 Swin Unet 项目地址(https://github.com/HuCaoFighting/Swin-Unet.git),按照项目介绍的 python=3.7 版本安装项目依赖

pip install -r requirements.txt

训练

获取训练的数据集,接下来以 Synapse 为例

通过仓库给的链接
在这里插入图片描述
得到 project_TransUNet,根据 ./datasets/README.md 文件的信息,将文件夹按照格式放到项目当中的位置,在其中我们也会得到Synapse处理好后的数据集(注意 TransUNet 无需安装依赖)

根据官网给的例子开始训练模型

python train.py --dataset Synapse --cfg configs/swin_tiny_patch4_window7_224_lite.yaml --root_path ./data/Synapse --list_dir ./lists/Synapse --max_epochs 150 --output_dir ./model_out --img_size 224 --base_lr 0.05 --batch_size 24

your DATA_DIR:./data/Synapse
your OUT_DIR:./model_out

遇到的问题

一、文件缺少

在lists/Synapse 目录中缺少 train.txt 和 val.txt 文件

这个需要自己创建,对应Synapse数据集中test_vol_h5和train_npz的文件名

● train.txt:包含 train_npz 文件夹中所有训练文件的名称或路径。
● val.txt:包含 test_vol_h5 文件夹中所有验证文件的名称或路径。

可以叫GPT生成脚本,或者用下面的脚本创建

import os

# 定义数据目录路径
train_dir = './data/Synapse/train_npz'
val_dir = './data/Synapse/test_vol_h5'

# 定义输出的列表文件路径
output_dir = './lists/Synapse'
os.makedirs(output_dir, exist_ok=True)  # 如果目录不存在则创建

# 生成 train.txt 文件
train_files = os.listdir(train_dir)
with open(os.path.join(output_dir, 'train.txt'), 'w') as train_f:
    for file_name in train_files:
        if file_name.endswith('.npz'):
            train_f.write(f"{
     file_name}\n")

Swin Transformer UNET是一种结合了卷积神经网络(CNN)Transformer架构的深度学习模型,它通常用于图像分割任务。UNET(U形网络)原先是为医学图像处理设计的,而Swin Transformer则是基于 Swin Transformer模块,该模块通过划分空间并引入局部注意力机制来提高计算效率。 在编写Swin Transformer UNET代码时,你会看到以下几个关键部分: 1. **基础结构**:首先导入必要的库,如PyTorchSwin Transformer模块。你需要定义SwinTransformerBlock作为基本构建块,并搭建SwinTransformerEncoderDecoder。 ```python import torch.nn as nn from einops.layers.torch import Rearrange class SwinTransformerBlock(nn.Module): # ... class SwinTransformerEncoder(nn.Module): def __init__(self, num_layers): super().__init__() # ... class SwinTransformerDecoder(nn.Module): def __init__(self, num_layers): super().__init__() # ... ``` 2. **连接编码器解码器**:将SwinTransformerEncoderU-Net式的上采样层、下采样层以及跳跃连接结合起来。 ```python class SwinUNET(nn.Module): def __init__(self, encoder, decoder, in_channels, out_channels): super().__init__() self.encoder = encoder self.decoder = decoder # 其他连接细节... def forward(self, x): # 编码阶段... encoded = self.encoder(x) # 解码阶段... decoded = self.decoder(encoded) return decoded ``` 3. **实例化模型**:创建具体的SwinTransformer EncoderDecoder,然后组合成完整的模型。 ```python encoder_config = ... # 定义SwinTransformerEncoder配置 decoder_config = ... # 定义SwinTransformerDecoder配置 swin_unet = SwinUNET(SwinTransformerEncoder(encoder_config), SwinTransformerDecoder(decoder_config), ...) # 初始化权重设置其他训练选项 swin_unet.apply(weights_init) # 初始化权重函数 ```
评论 19
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值