根据论坛其他博主的博客,千辛万苦终于在Windows上安装好了vimamba,详情参考:http://t.csdnimg.cn/wncjP和http://t.csdnimg.cn/w4mFx等几篇博客,当然,在安装过程中也遇到了其他问题,但一步步终归是解决了。
运行如下代码时显示如图就算成功安装了:
import torch
from mamba_ssm import Mamba
batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=16, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
print(y.shape)
运行vim:修改main.py文件如下几处地方(先下载timm和mlflow两个包):
修改batch_size等参数以适配自己的数据集大小
修改数据集路径为自己的数据集:
例如我的:数据集格式如下:
子目录下放图片,我是做图像分类,必须要按照train和val的数据分开放,不然会报错
在此之前,最好先对数据预处理一下,将图片格式变为224*224大小
from PIL import Image
import os
def resize_and_convert(input_path, output_path, target_size = (224, 224)):
# Open the image
image = Image.open(input_path)
# Resize the image
resized_image = image.resize(target_size)
# Convert RGBA to RGB if the image has an alpha channel
if resized_image.mode == 'RGBA':
resized_image = resized_image.convert('RGB')
# Save the resized image as JPEG
resized_image.save(output_path)
# 替换为您的输入和输出路径
input_dir = 'path/to/input_dir'
output_dir = 'path/to/output_dir'
# 创建输出目录
os.makedirs(output_dir, exist_ok = True)
# 遍历输入目录中的所有文件
for filename in os.listdir(input_dir):
if filename.endswith('.jpg'):
# 构建输入和输出文件路径
input_path = os.path.join(input_dir, filename)
output_filename = os.path.splitext(filename)[0] + '_resized.jpg'
output_path = os.path.join(output_dir, output_filename)
# 调整大小并保存图像
resize_and_convert(input_path, output_path)
print("Resizing and conversion complete.")
到这一步运行时会报错:
修改vim文件夹下engine.py
将
outputs = model(...)中
if_random_cls_token_position=args.if_random_cls_token_position, if_random_token_rank=args.if_random_token_rank删掉,即:
with amp_autocast():
outputs = model(samples)
# outputs = model(samples)
然后就可以运行了
分析错误原因,可能是作者在定义VisionMamba块的时候,用的是Visiontransformer框架,而if_random_cls_token_position
和 if_random_token_rank
这两个参数是为vim定制的,VisionTransformer
的 forward
方法并不接受这两个参数,因此在调用模型时,将这两个参数给删掉了(有大佬说一下这样会有什么坏处吗?)...