Swin Transformer 代码
先把代码跑起来吧(验证一个分类模型)
拿代码,配环境
按照Readme辛辛苦苦配了半天的环境,结果一直报错Numpy有问题。抱着绝望的心情随便试了一下之前项目的环境,一秒跑通,真的难以评价。跑通的环境配置为:
python 3.11.5
pytorch-cuda 12.1
NVIDIA GeForce RTX 4070(CUDA 11.8 + cudnn 8.9.7.29)
验证数据准备
我最开始是按照Readme去下载ImageNet数据集,但是这个数据集太大太劝退了,可以先按照Swin Transformer实战图像分类,去下载了猫狗大战数据集来试试手。数据准备过程也按照这篇博文来就可以。
之后,我去下载了ImageNet的验证集 val。验证集比较小只有6G。但是具体是哪个数据集不重要,只要先把程序跑起来这一步就成功了!这里先使用了standard folder dataset方式,不需要额外的标签文件,把图片按照不同的分类存放就可以了。
$ tree data
imagenet
├── train
│ ├── class1
│ │ ├── img1.jpeg
│ │ ├── img2.jpeg
│ │ └── …
│ ├── class2
│ │ ├── img3.jpeg
│ │ └── …
│ └── …
└── val
├── class1
│ ├── img4.jpeg
│ ├── img5.jpeg
│ └── …
├── class2
│ ├── img6.jpeg
│ └── …
└── …
不过这里有一个坑,我从官网直接download的 val 数据对应的 label 是这样的:(ILSVRC2012_validation_ground_truth.txt)
490
361
171
822
297
482
13
704
599
164
我按照这样的映射,将验证图片分类存放在对应的文件夹下,结果跑出来的正确率接近于零。果然ImageNet数据集简介里提到了,验证集的label需要进行重映射。我从这里ImageNet-1K直接下载了已经重新映射好的label文件,名为image1k_val_list.txt。
ILSVRC2012_val_00000001.JPEG 65
ILSVRC2012_val_00000002.JPEG 970
ILSVRC2012_val_00000003.JPEG 230
ILSVRC2012_val_00000004.JPEG 809
ILSVRC2012_val_00000005.JPEG 516
ILSVRC2012_val_00000006.JPEG 57
ILSVRC2012_val_00000007.JPEG 334
ILSVRC2012_val_00000008.JPEG 415
ILSVRC2012_val_00000009.JPEG 674
ILSVRC2012_val_00000010.JPEG 332
ILSVRC2012_val_00000011.JPEG 109
ILSVRC2012_val_00000012.JPEG 286
可以看到现在的label跟swin transformer的get_started.md里的演示就一致了。然后写一个程序来按照标签分类存放图片:
import os
import shutil
# 文件夹路径和标签文件路径
image_folder = "ILSVRC2012_img_val" # 替换为存放图像的文件夹路径
txt_file = "val_map.txt" # 替换为标签文件路径
output_folder = "output" # 输出文件夹路径,按标签分类的子文件夹会放在这里
# 获取图像文件列表并按文件名排序
image_files = sorted(os.listdir(image_folder))
# 读取标签文件
with open(txt_file, "r") as f:
labels = f.readlines()
# 检查图像文件数量和标签数量是否一致
if len(image_files) != len(labels):
print(f"Error: Number of images ({
len(image_files)}) and labels ({
len(labels)}) do not match.")
exit()
# 确保输出文件夹存在
os.makedirs(output_folder, exist_ok=True)
# 按标签将图像移动到对应的子文件夹
for image_file, label in zip(image_files, labels):
label = label.strip() # 去除标签中的换行符
label_parts = label.split() # 分割标签行
if len(label_parts) != 2:
print(f"Error: Invalid label format in line: {
label}")
continue
image_name, label_number = label_parts
label_number = int(label_number) # 将标签编号转换为整数
# 创建标签子文件夹(如果不存在),并在标签编号前添加前导零
label_folder = os.path.join(output_folder, str(label_number).zfill(3)) # 这里用 zfill(3) 补充前导零
os.makedirs(label_folder, exist_ok=True)
# 获取图像的路径
image_path = os.path.join(image_folder, image_name)
# 将图像文件移动到标签子文件夹
if os.path.exists(image_path):
shutil.move(image_path, os.path.join(label_folder, image_name))
print(f"Moved {
image_name} to {
label_folder}")
else:
print(f"Image {
image_name} not found in {
image_folder}. Skipping.")
print("All images have been categorized into their respective label folders.")
这里要将子文件夹的名字补零,确保是升序排列的(即"000"、“001”、…“999”)。不然ImageFolder自动映射的时候会出错。
可以在build.py里打印映射结果:
def build_dataset(is_train, config):
................
root = os.path.join(config.DATA.DATA_PATH, prefix)
print(f"Data path: {
root}")
dataset = datasets.ImageFolder(root, transform=transform)
print("Class to Index mapping:", dataset.class_to_idx)
................
验证集的映射结果:
Data path: data/imageNet\val
Class to Index mapping: {‘000’: 0, ‘001’: 1, ‘002’: 2, ‘003’: 3, ‘004’: 4, ‘005’: 5, ‘006’: 6, ‘007’: 7, ‘008’: 8, ‘009’: 9, ‘010’: 10, ‘011’: 11,…‘999’: 999}
映射是正确的。
下载预训练模型
略。
修改部分参数配置
学习率相关
把main.py的这几行预定义的关于学习率的代码注释掉,不然我们设置的新的学习率传递不进去:
# linear scale the learning rate according to total batch size, may not be optimal
linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * world_size / 512.0
linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * world_size / 512.0
linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * world_size / 512.0
# gradient accumulation also need to scale the learning rate
if config.TRAIN.ACCUMULATION_STEPS > 1:
linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
# Update config with new learning rates
config.defrost()
config.TRAIN.BASE_LR = linear_scaled_lr
config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
config.TRAIN.MIN_LR = linear_scaled_min_lr
config.freeze()
分布式训练相关
- 不需要分布式训练的话,所有涉及分布式训练的代码都改掉。main.py这里改成:
if __name__ == '__main__':
args, config = parse_option()

最低0.47元/天 解锁文章
311

被折叠的 条评论
为什么被折叠?



