解决windows版Vimamba运行main.py跑自己的数据集报错问题

根据论坛其他博主的博客,千辛万苦终于在Windows上安装好了vimamba,详情参考:http://t.csdnimg.cn/wncjPhttp://t.csdnimg.cn/w4mFx等几篇博客,当然,在安装过程中也遇到了其他问题,但一步步终归是解决了。

运行如下代码时显示如图就算成功安装了:

import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
print(y.shape)

运行vim:修改main.py文件如下几处地方(先下载timm和mlflow两个包):

修改batch_size等参数以适配自己的数据集大小

修改数据集路径为自己的数据集

例如我的:数据集格式如下:

 子目录下放图片,我是做图像分类,必须要按照train和val的数据分开放,不然会报错

在此之前,最好先对数据预处理一下,将图片格式变为224*224大小

from PIL import Image
import os


def resize_and_convert(input_path, output_path, target_size = (224, 224)):
    # Open the image
    image = Image.open(input_path)

    # Resize the image
    resized_image = image.resize(target_size)

    # Convert RGBA to RGB if the image has an alpha channel
    if resized_image.mode == 'RGBA':
        resized_image = resized_image.convert('RGB')

    # Save the resized image as JPEG
    resized_image.save(output_path)


# 替换为您的输入和输出路径
input_dir = 'path/to/input_dir'
output_dir = 'path/to/output_dir'

# 创建输出目录
os.makedirs(output_dir, exist_ok = True)

# 遍历输入目录中的所有文件
for filename in os.listdir(input_dir):
    if filename.endswith('.jpg'):
        # 构建输入和输出文件路径
        input_path = os.path.join(input_dir, filename)
        output_filename = os.path.splitext(filename)[0] + '_resized.jpg'
        output_path = os.path.join(output_dir, output_filename)

        # 调整大小并保存图像
        resize_and_convert(input_path, output_path)

print("Resizing and conversion complete.")

到这一步运行时会报错:

修改vim文件夹下engine.py

outputs = model(...)中
if_random_cls_token_position=args.if_random_cls_token_position, if_random_token_rank=args.if_random_token_rank删掉,即:
        with amp_autocast():
            outputs = model(samples)
            # outputs = model(samples)
然后就可以运行了

分析错误原因,可能是作者在定义VisionMamba块的时候,用的是Visiontransformer框架,而if_random_cls_token_positionif_random_token_rank 这两个参数是为vim定制的,VisionTransformerforward 方法并不接受这两个参数,因此在调用模型时,将这两个参数给删掉了(有大佬说一下这样会有什么坏处吗?)...

  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值