Mamba用于猫狗数据集训练(分类)

利用mamba来训练猫狗数据集---------分类


目录

一、总的代码架构

二、加载数据集

三、模型

四、训练

五、推理


一、总的代码架构

分为model、dataset、train、inference和weight。

二、加载数据集

'''
数据集的准备
'''
from torch.utils.data import Dataset
import os
from PIL import Image
from torchvision import transforms

transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

class MyDataset(Dataset):
    def __init__(self, train_path, test_path, train=True, transform=None):
        self.train_data_path = train_path
        self.test_data_path = test_path

        self.transform = transform

        data_path = self.train_data_path if train else self.test_data_path

        files = os.listdir(data_path)

        self.total_data_path = [os.path.join(data_path, file) for file in files]

    def __getitem__(self, idx):
        file_path = self.total_data_path[idx]

        # label and img
        img = Image.open(file_path).convert('RGB')
        if self.transform:
            img = self.transform(img)

        label_str = file_path.split('/')[-1].split('.')[0]
        label = 0 if label_str == 'dog' else 1

        return img, label


    def __len__(self):
        return len(self.total_data_path)



# def get_dataloder(train=True, transform=None):
#     my_data = MyDataset(train, transform=transform)
#     data_loder = DataLoader(my_data, batch_size=4, shuffle=True)
#     return data_loder


# if __name__ == '__main__':
#
#     for idx, (input, label) in enumerate(get_dataloder(train=True, transform=transform)):
#         print(idx)
#         print("input: ", input.shape)
#         print("label: ", label)
#         break

三、模型

from mamba_ssm import Mamba
import torch.nn as nn

class features(nn.Module):
    def __init__(self, in_c, out_c):
        super(features, self).__init__()
        self.fe = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
                                nn.BatchNorm2d(out_c),
                                nn.SiLU())
    def forward(self, x):
        return self.fe(x)

class cls_mamba(nn.Module):
    def __init__(self, channels=32, in_c=32, out_c=32, num_layers=5, num_class=2):
        super(cls_mamba, self).__init__()

        self.patch_embedding = nn.Sequential(nn.Conv2d(3, 32, kernel_size=1, stride=1, padding=0),
                                      nn.BatchNorm2d(32),
                                      nn.SiLU())

        self.layers = nn.ModuleList([features(in_c, out_c) for _ in range(num_layers)])

        self.mamba = Mamba(d_model=channels)

        self.classifier = nn.Sequential(
                                        nn.Linear(32 * 7 * 7, 1024),
                                        nn.ReLU(),
                                        nn.Linear(1024, 1024),
                                        nn.ReLU(),
                                        nn.Linear(1024, num_class)
                                        )
    def forward(self, x):
        x = self.patch_embedding(x)
        for layer in self.layers:
            x = layer(x)

        x = x.permute(0, 2, 3, 1).contiguous()

        B, H, W, C = x.shape  # B:32 H:7 W:7 C:32

        x_flat = x.view(1, -1, C)  # x_flat.shape----->(1, 1568, 32)
        x_flat = self.mamba(x_flat)
        x_ = x_flat.view(B, H, W, C)
        x_= x_.permute(0, 3, 1, 2).contiguous()
        x_ = x.view(x_.size(0), -1)

        return self.classifier(x_)

这里调用的是mamba模型,由于mamba的d_model是与输入x的channel有关,因此,这里取值为上一层的输出通道数。


注:这里x_flat的形状是(B, H*W, C),在进入mamba之后会转换成(B, C, H*W),也就是说mamba论文里面的L对应这里的H*W,D对应的C,最后输出y再转换成(B, H*W, C)的shape。

CV:(B,C,H,W )通过view操作===>(B,H*W,C)通过rearrange操作===>(B,C,H*W)

相当于

            y = selective_scan_fn(
                x,
                dt,
                A,
                B,
                C,
                self.D.float(),
                z=z,
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
                return_last_state=ssm_state is not None,
            )

这个代码在执行选择性扫描函数(selective_scan_fn)输入x的shape--->(B,D,L)。


 x.shape变换代码:

y.shape变换代码:

综上:首先进入mamba的x的shape是(B,H*W,C)也就是(B,L,D),其次通过rearrange操作将x.shape成(B,D,L),因此进入selective_scan_fn的shape是(B,C,H*W)也就是(B,D,L),然后输出y是(B,D,L),最后再经过rearrange操作将输出y.shape成(B,L,D)。正如上面的伪代码一样。

四、训练

import argparse
import torch
import torch.nn as nn

from dataset import MyDataset, transform
from model.VGG import VGG16net
from model.my_mamba import cls_mamba
from model.pretrain_mamba import pre_cls_mamba

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_optimizer(model, opt):

    if opt.optimizer == 'Adam':
        optimizer = torch.optim.Adam(model.classifier.parameters(), lr=opt.learning_rate)
    elif opt.optimizer == 'SGD':
        optimizer = torch.optim.SGD(model.classifier.parameters(), lr=opt.learning_rate, momentum=0.9, weight_decay=5e-4)
    elif opt.optimizer == 'AdamW':
        optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=opt.learning_rate, weight_decay=5e-4)
    else:
        raise ValueError(f"Unsupported optimizer: {opt.optimizer}")
    return optimizer

def train(opt):
    model = cls_mamba().to(device)

    # 数据集
    train_dataset = MyDataset(train_path=opt.train_path, test_path=opt.train_path, train=True, transform=transform)
    train_loder = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True)

    # 优化器设置
    criterion = nn.CrossEntropyLoss()
    optimizer = get_optimizer(model, opt)

    for epoch in range(opt.num_epochs):

        total_step = len(train_loder)

        train_epoch_loss = 0

        for i, (images, labels) in enumerate(train_loder):

            optimizer.zero_grad()

            images = images.to(device)
            labels = labels.to(device)

            #前向传播
            output = model(images)
            loss = criterion(output, labels)

            #反向传播
            loss.backward()
            optimizer.step()

            train_epoch_loss += loss.item()

            if (i + 1) % 2 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.5f}'
                      .format(epoch + 1, opt.num_epochs, i + 1, total_step, loss.item()))

        # torch.save(obj=model.state_dict(), f='/home/sbliu/code/mycode/Code_template/weight/model_mamba.pth')

    torch.save(obj=model.state_dict(), f='/home/sbliu/code/mycode/Code_template/weight/model_mamba.pth')

def parse_opt():
    parser = argparse.ArgumentParser()

    # Input Parameters
    parser.add_argument('--batch_size', default=32)
    parser.add_argument('--learning_rate', default=0.001)
    parser.add_argument('--num_epochs', default=200)
    parser.add_argument('--train_path', default='/home/sbliu/code/Kaggle_Dog_Cat/train')
    parser.add_argument('--test_path', default='/home/sbliu/code/Kaggle_Dog_Cat/test1')
    parser.add_argument('--optimizer', default= 'AdamW',  choices=['Adam', 'SGD', 'AdamW'], help='optimizer')

    return parser.parse_args()

if __name__ == "__main__":
    opt = parse_opt()
    train(opt)

五、推理

import torch
from model.VGG import VGG16net
from model.my_mamba import cls_mamba
from model.pretrain_mamba import pre_cls_mamba
from torchvision import transforms
from PIL import Image
import argparse

class_name = {'dog' : 0, 'cat' : 1}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

def predict(opt):
    model = cls_mamba().to(device)
    model.load_state_dict(torch.load(opt.weights))

    with torch.no_grad():
        img = Image.open(opt.data).convert('RGB')
        img = transform(img).unsqueeze(0).to(device)

        output = model(img)
        _, pre = torch.max(output, 1)

        pre_class = pre[0].item()

    print(list(class_name.keys())[pre_class])

def parse_opt():
    parser = argparse.ArgumentParser()
    # Input Parameters
    parser.add_argument('--data', type=str, default='/home/sbliu/code/Kaggle_Dog_Cat/test1/13.jpg')
    parser.add_argument('--weights', type=str, default='/home/sbliu/code/mycode/Code_template/weight/model_mamba.pth')
    opt = parser.parse_args()
    return opt

if __name__ == '__main__':

    opt = parse_opt()
    predict(opt)

注:这里只是为了测试下mamba是否可以训练,因此不关注最终测试的准确率!!!

如果想要效果好点,可以结合VGG16进行训练和推理,模型如下。

from mamba_ssm import Mamba
import torch.nn as nn
from torchvision import models

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.required_grad = False

class pre_cls_mamba(nn.Module):
    def __init__(self, channels=512, num_class=2):
        super(pre_cls_mamba, self).__init__()

        model = models.vgg16(pretrained=True)
        self.features = model.features
        self.mamba = Mamba(d_model=channels)
        self.classifier = nn.Sequential(
                                        nn.Linear(512 * 7 * 7, 1024),
                                        nn.ReLU(),
                                        nn.Linear(1024, 1024),
                                        nn.ReLU(),
                                        nn.Linear(1024, num_class)
                                        )
    def forward(self, x):
        x = self.features(x)

        x = x.permute(0, 2, 3, 1).contiguous()

        B, H, W, C = x.shape

        x_flat = x.view(1, -1, C)
        x_flat = self.mamba(x_flat)
        x_ = x_flat.view(B, H, W, C)
        x_= x_.permute(0, 3, 1, 2).contiguous()
        x_ = x.view(x_.size(0), -1)

        return self.classifier(x_)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值