TransUnet复现过程(详细过程)

TransUNet
论文链接:https://arxiv.org/abs/2102.04306
GitHub链接:https://github.com/Beckschen/TransUNet

本文完整代码:https://download.csdn.net/download/qq_45806961/89222079(其中少了2Ddata文件夹,就是预处理第一阶段,生成的二维图像保存位置,可以自行创建即可)

一、处理数据集

提前准备好数据集,如下图所示:

 将数据转化为2D图像(根据自己的数据集进行更改下面的读取等操作,比较简单,并且添加了备注)

'''
coding:utf-8
@Software:PyCharm
@Time:2024/4/17 21:35
@Author:鹿长野
'''
# 如果更换数据集,就需要更改这里,根据数据集的相关名称进行更改
# 对nii格式的文件进行切片处理

import numpy as np
import nibabel as nib
import h5py
import os
from PIL import Image

data_path = "./predata"

def process_file(file_path):
    # 获取图像和标签
    img = nib.load(file_path)           # ./predata/MBAS_001_gt.nii.gz   --->   ./predata/MBAS_001_label.nii.gz
    label_path = file_path.replace('_gt.nii.gz', '_label.nii.gz')    # 找到标签的文件路径
    # print(label_path)
    label = nib.load(label_path)
    # print(label)

    img_data = img.get_fdata()                                # 获取像素数据
    label_data = label.get_fdata()
    # print(label_data)

    img_clipped = np.clip(img_data, -125, 275)
    img_normalised = (img_clipped-(-125))/(275-(-125))         # 图像做归一化,标签并没有

    # img_clipped = img_data
    # img_normalised = img_clipped

    for i in range(img_clipped.shape[2]):
        formatted_i ="{:04d}".format(i+1)                      # 将i+1四位数,不够的补零,我自己的理解,例如i是9,就是0010,所以加上1还是10,只是在前面补了两个零
        img_slice = img_normalised[:,:,i]
        label_slice = label_data[:,:,i]

        image = Image.fromarray(img_slice.astype(np.uint8))    # 将图像转化为PIL图像对象
        image = image.convert('L')                             # 转化为灰度图像

        label = Image.fromarray(label_slice.astype(np.uint8))
        label = label.convert('L')

        case_name = os.path.splitext(os.path.split(file_path)[1])[0]    # (从file_path开始入手)获取路径中的文件名,不包含扩展名  ----> MBAS_001_gt
        # print(case_name)       # 打印 MBAS_001_gt.nii
        case_name = case_name.replace("_gt.nii","")
        # print(case_name)       # 打印 MBAS_001
        case_number = "{:03d}".format(int(formatted_i))
        image.save(f"./2Ddata/{case_name}_{case_number}.png")
        label.save(f"./2Ddata/{case_name}_{case_number}_label.png")

for root,dirs,files in os.walk(data_path):
    for file in files:
        if file.endswith("gt.nii.gz"):              # 只读取图像,不读取标签文件
            file_path = os.path.join(root,file)
            process_file(file_path)

处理后的数据如下所示:(image+label放在了一起)

然后将image+label对应的两张图像叠加在一起,为了贴合作者的数据集而做的处理,代码如下:

'''
coding:utf-8
@Software:PyCharm
@Time:2024/4/18 10:12
@Author:鹿长野
'''
# 把img图像和对应的mask图像合并为一个npz文件
# 如果更换数据集,就需要更改这里,根据数据集的相关名称进行更改

import glob
import cv2
import numpy as np
from tqdm import tqdm

def npz():
    path = './2Ddata/*.png'                                     # 图像路径
    path2 = './data/train_npz/'                                      # 项目中存放训练所用的npz文件路径

    for i,img_path in tqdm(enumerate(glob.glob(path))):
        # print("i:",i)
        # print("img_path:",img_path)                             # 打印 ./2Ddata/MBAS_001_001.png
        if img_path.endswith('_label.png'):                     # 没有读取标签图片,提前给跳过
            continue

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        label_path = img_path.replace('.png', '_label.png')      # 读取标签图片
        # print(label_path)
        label = cv2.imread(label_path,flags=0)                   # flags 默认读取灰度图像

        filename = img_path.split('/')[-1].split('.')[0]         # img_path = ./2Ddata\MBAS_001_001.png,所以需要再次处理
        # print(filename)                                          # 2Ddata\MBAS_001_001,所以不得不根据实际问题进行结合

        filename = filename.replace('2Ddata\\', '')              # 两个\\才代表一个\
        # print(filename)
        np.savez(path2+filename+'.npz',image=image,label=label)

npz()

处理后的数据如下图所示:

上述代码将数据分成了训练集和测试集,如下图,根据地址自行更改:

同时也会生成TXT文件,代码中注明了文件保存位置,自行修改即可,

自此,数据集处理好啦

二、开始训练

下载预训练模型,存放地址如下,可以自行去网站下载

运行train.py

三、开始测试

需要修改的部分和训练过程中一样,训练后的结果会保存在test_save_dir中

测试结果

如有任何问题留言!

  • 15
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
要开始复现mmskeleton的过程,请按照以下步骤进行操作: 1. 确保你的系统已经安装了Python和pip。如果没有,请先安装它们。 2. 克隆mmskeleton的代码仓库。你可以在GitHub上找到mmskeleton的代码仓库,并使用以下命令克隆代码: ``` git clone https://github.com/open-mmlab/mmskeleton.git ``` 3. 进入克隆的代码仓库目录。 ``` cd mmskeleton ``` 4. 创建并激活一个虚拟环境(可选,但强烈推荐)。 5. 安装依赖项。可以使用以下命令安装必要的依赖项: ``` pip install -r requirements.txt ``` 6. 安装mmskeleton。可以使用以下命令安装mmskeleton: ``` python setup.py install ``` 7. 下载预训练模型。mmskeleton需要一些预训练模型来进行姿态估计等任务。你可以在mmskeleton的文档或代码仓库中找到相应的模型下载链接。下载并解压这些模型,并将它们放置在适当的目录中。 8. 准备数据。根据你的任务和数据集,准备好相应的数据。确保数据的路径与配置文件中指定的路径相匹配。 9. 配置文件设置。在`./configs/pose_estimation/`目录下,你可以找到一些已经配置好的示例配置文件。根据你的需求修改其中的一份配置文件,确保路径和参数设置正确。 10. 运行示例。使用以下命令来运行mmskeleton的示例: ``` python mmskl.py --config ./configs/pose_estimation/pose_demo.yaml ``` 这将使用指定的配置文件运行mmskeleton的姿态估计示例。根据你的配置文件和数据集,你可能需要进行相应的修改。 这些是复现mmskeleton的基本步骤。根据你的具体需求和任务,可能还需要进行其他设置和修改。请参考mmskeleton的文档和代码仓库以获取更详细的信息和指导。
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值