利用mamba来训练猫狗数据集---------分类
目录
一、总的代码架构
分为model、dataset、train、inference和weight。
二、加载数据集
'''
数据集的准备
'''
from torch.utils.data import Dataset
import os
from PIL import Image
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
class MyDataset(Dataset):
def __init__(self, train_path, test_path, train=True, transform=None):
self.train_data_path = train_path
self.test_data_path = test_path
self.transform = transform
data_path = self.train_data_path if train else self.test_data_path
files = os.listdir(data_path)
self.total_data_path = [os.path.join(data_path, file) for file in files]
def __getitem__(self, idx):
file_path = self.total_data_path[idx]
# label and img
img = Image.open(file_path).convert('RGB')
if self.transform:
img = self.transform(img)
label_str = file_path.split('/')[-1].split('.')[0]
label = 0 if label_str == 'dog' else 1
return img, label
def __len__(self):
return len(self.total_data_path)
# def get_dataloder(train=True, transform=None):
# my_data = MyDataset(train, transform=transform)
# data_loder = DataLoader(my_data, batch_size=4, shuffle=True)
# return data_loder
# if __name__ == '__main__':
#
# for idx, (input, label) in enumerate(get_dataloder(train=True, transform=transform)):
# print(idx)
# print("input: ", input.shape)
# print("label: ", label)
# break
三、模型
from mamba_ssm import Mamba
import torch.nn as nn
class features(nn.Module):
def __init__(self, in_c, out_c):
super(features, self).__init__()
self.fe = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(out_c),
nn.SiLU())
def forward(self, x):
return self.fe(x)
class cls_mamba(nn.Module):
def __init__(self, channels=32, in_c=32, out_c=32, num_layers=5, num_class=2):
super(cls_mamba, self).__init__()
self.patch_embedding = nn.Sequential(nn.Conv2d(3, 32, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(32),
nn.SiLU())
self.layers = nn.ModuleList([features(in_c, out_c) for _ in range(num_layers)])
self.mamba = Mamba(d_model=channels)
self.classifier = nn.Sequential(
nn.Linear(32 * 7 * 7, 1024),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, num_class)
)
def forward(self, x):
x = self.patch_embedding(x)
for layer in self.layers:
x = layer(x)
x = x.permute(0, 2, 3, 1).contiguous()
B, H, W, C = x.shape # B:32 H:7 W:7 C:32
x_flat = x.view(1, -1, C) # x_flat.shape----->(1, 1568, 32)
x_flat = self.mamba(x_flat)
x_ = x_flat.view(B, H, W, C)
x_= x_.permute(0, 3, 1, 2).contiguous()
x_ = x.view(x_.size(0), -1)
return self.classifier(x_)
这里调用的是mamba模型,由于mamba的d_model是与输入x的channel有关,因此,这里取值为上一层的输出通道数。
注:这里x_flat的形状是(B, H*W, C),在进入mamba之后会转换成(B, C, H*W),也就是说mamba论文里面的L对应这里的H*W,D对应的C,最后输出y再转换成(B, H*W, C)的shape。
CV:(B,C,H,W )通过view操作===>(B,H*W,C)通过rearrange操作===>(B,C,H*W)
相当于
y = selective_scan_fn(
x,
dt,
A,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=ssm_state is not None,
)
这个代码在执行选择性扫描函数(selective_scan_fn)输入x的shape--->(B,D,L)。
x.shape变换代码:
y.shape变换代码:
综上:首先进入mamba的x的shape是(B,H*W,C)也就是(B,L,D),其次通过rearrange操作将x.shape成(B,D,L),因此进入selective_scan_fn的shape是(B,C,H*W)也就是(B,D,L),然后输出y是(B,D,L),最后再经过rearrange操作将输出y.shape成(B,L,D)。正如上面的伪代码一样。
四、训练
import argparse
import torch
import torch.nn as nn
from dataset import MyDataset, transform
from model.VGG import VGG16net
from model.my_mamba import cls_mamba
from model.pretrain_mamba import pre_cls_mamba
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def get_optimizer(model, opt):
if opt.optimizer == 'Adam':
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=opt.learning_rate)
elif opt.optimizer == 'SGD':
optimizer = torch.optim.SGD(model.classifier.parameters(), lr=opt.learning_rate, momentum=0.9, weight_decay=5e-4)
elif opt.optimizer == 'AdamW':
optimizer = torch.optim.AdamW(model.classifier.parameters(), lr=opt.learning_rate, weight_decay=5e-4)
else:
raise ValueError(f"Unsupported optimizer: {opt.optimizer}")
return optimizer
def train(opt):
model = cls_mamba().to(device)
# 数据集
train_dataset = MyDataset(train_path=opt.train_path, test_path=opt.train_path, train=True, transform=transform)
train_loder = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True)
# 优化器设置
criterion = nn.CrossEntropyLoss()
optimizer = get_optimizer(model, opt)
for epoch in range(opt.num_epochs):
total_step = len(train_loder)
train_epoch_loss = 0
for i, (images, labels) in enumerate(train_loder):
optimizer.zero_grad()
images = images.to(device)
labels = labels.to(device)
#前向传播
output = model(images)
loss = criterion(output, labels)
#反向传播
loss.backward()
optimizer.step()
train_epoch_loss += loss.item()
if (i + 1) % 2 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.5f}'
.format(epoch + 1, opt.num_epochs, i + 1, total_step, loss.item()))
# torch.save(obj=model.state_dict(), f='/home/sbliu/code/mycode/Code_template/weight/model_mamba.pth')
torch.save(obj=model.state_dict(), f='/home/sbliu/code/mycode/Code_template/weight/model_mamba.pth')
def parse_opt():
parser = argparse.ArgumentParser()
# Input Parameters
parser.add_argument('--batch_size', default=32)
parser.add_argument('--learning_rate', default=0.001)
parser.add_argument('--num_epochs', default=200)
parser.add_argument('--train_path', default='/home/sbliu/code/Kaggle_Dog_Cat/train')
parser.add_argument('--test_path', default='/home/sbliu/code/Kaggle_Dog_Cat/test1')
parser.add_argument('--optimizer', default= 'AdamW', choices=['Adam', 'SGD', 'AdamW'], help='optimizer')
return parser.parse_args()
if __name__ == "__main__":
opt = parse_opt()
train(opt)
五、推理
import torch
from model.VGG import VGG16net
from model.my_mamba import cls_mamba
from model.pretrain_mamba import pre_cls_mamba
from torchvision import transforms
from PIL import Image
import argparse
class_name = {'dog' : 0, 'cat' : 1}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
def predict(opt):
model = cls_mamba().to(device)
model.load_state_dict(torch.load(opt.weights))
with torch.no_grad():
img = Image.open(opt.data).convert('RGB')
img = transform(img).unsqueeze(0).to(device)
output = model(img)
_, pre = torch.max(output, 1)
pre_class = pre[0].item()
print(list(class_name.keys())[pre_class])
def parse_opt():
parser = argparse.ArgumentParser()
# Input Parameters
parser.add_argument('--data', type=str, default='/home/sbliu/code/Kaggle_Dog_Cat/test1/13.jpg')
parser.add_argument('--weights', type=str, default='/home/sbliu/code/mycode/Code_template/weight/model_mamba.pth')
opt = parser.parse_args()
return opt
if __name__ == '__main__':
opt = parse_opt()
predict(opt)
注:这里只是为了测试下mamba是否可以训练,因此不关注最终测试的准确率!!!
如果想要效果好点,可以结合VGG16进行训练和推理,模型如下。
from mamba_ssm import Mamba
import torch.nn as nn
from torchvision import models
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.required_grad = False
class pre_cls_mamba(nn.Module):
def __init__(self, channels=512, num_class=2):
super(pre_cls_mamba, self).__init__()
model = models.vgg16(pretrained=True)
self.features = model.features
self.mamba = Mamba(d_model=channels)
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 1024),
nn.ReLU(),
nn.Linear(1024, 1024),
nn.ReLU(),
nn.Linear(1024, num_class)
)
def forward(self, x):
x = self.features(x)
x = x.permute(0, 2, 3, 1).contiguous()
B, H, W, C = x.shape
x_flat = x.view(1, -1, C)
x_flat = self.mamba(x_flat)
x_ = x_flat.view(B, H, W, C)
x_= x_.permute(0, 3, 1, 2).contiguous()
x_ = x.view(x_.size(0), -1)
return self.classifier(x_)