UniRepLKNet:大核卷积的领先性能

摘要:本文介绍UniRepLKNet论文及测试UniRepLKNet的性能,用到农业病害识别数据集做图像分类测试。

一、UniRepLKNet网络结构

Lark块包括膨胀修复块、SE块、FFN和批处理归一化层。Smak Block和LarK块之间的唯一区别是前者使用深度方向的3x3转换层代替后者的膨胀修复块。

这个模型的创新性在于重新思考了传统模型中使用多个小卷积核的设计,并提出了一种基于大卷积核的架构。该模型通过将一个3×3的卷积层添加到小卷积核的网络中,同时实现了三个效果:1)扩大感受野;2)增加空间模式的抽象层次(例如从角度和纹理到物体形状);3)通过增加深度、引入更多可学习参数和非线性来提高模型的表示能力。与此相反,作者认为在大卷积核的架构中,这三个效果应该解耦,因为模型应该利用大卷积核的强大能力——能够广泛地观察而不需要增加深度。由于增加卷积核大小比堆叠更多层在扩大有效感受野方面更有效,可以用少量的大卷积核层来构建足够大的有效感受野,以便将计算资源节省下来用于其他更有效增加空间模式抽象层次或增加深度的结构。例如,当目标是从低级局部空间模式中提取更高级的局部空间模式时,3×3的卷积层可能比大卷积核层更适合。原因是后者需要更多计算量,并且可能导致模式不再局限于较小的局部区域,这在特定场景下可能是不可取的。

具体来说,该模型提出了四个大卷积核卷积网络的架构准则:1)使用高效的结构(如SE块[24])增加深度;2)使用提出的扩张重参数块对大卷积核卷积层进行重新参数化,以提高性能而不增加推理成本;3)根据下游任务决定卷积核大小,并通常仅在中高层使用大卷积核;4)在扩展模型深度时,使用3×3的卷积层而不是更多的大卷积核层。根据这些准则构建的卷积网络将前面提到的三个效果分别实现:使用适量的大卷积核来保证大的有效感受野,使用小卷积核更高效地提取更复杂的空间模式,并使用多个轻量级块进一步增加深度以增强表示能力。

扩展重建块:

使用与大核平行的扩展卷积层,并将它们的输出加起来。

 二、测试

我使用一个病害识别的数据集来测试UniRepLKNet

相关细节:


数据集:Plantvillage数据集,"Apple___Black_rot", "Apple___Cedar_apple_rust", "Apple___healthy"

模型:采用官方的UniRepLKNet-T版本,代码地址

训练代码:

import json
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from timm.utils import accuracy, AverageMeter, ModelEma
from sklearn.metrics import classification_report
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
from unireplknet import unireplknet_t
from torch.autograd import Variable
from torchvision import datasets
torch.backends.cudnn.benchmark = False
import warnings
warnings.filterwarnings("ignore")
os.environ['CUDA_VISIBLE_DEVICES']="0,1"


# 定义训练过程
def train(model, device, train_loader, optimizer, epoch,model_ema):
    model.train()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    total_num = len(train_loader.dataset)
    print(total_num, len(train_loader))
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device, non_blocking=True), Variable(target).to(device,non_blocking=True)
        samples, targets = mixup_fn(data, target)
        output = model(data)
        optimizer.zero_grad()
        if use_amp:
            with torch.cuda.amp.autocast():
                loss = torch.nan_to_num(criterion_train(output, targets))
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss = criterion_train(output, targets)
            loss.backward()
            optimizer.step()

        if model_ema is not None:
            model_ema.update(model)
        torch.cuda.synchronize()
        lr = optimizer.state_dict()['param_groups'][0]['lr']
        loss_meter.update(loss.item(), target.size(0))
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
        if (batch_idx + 1) % 10 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}'.format(
                epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),
                       100. * (batch_idx + 1) / len(train_loader), loss.item(), lr))
    ave_loss =loss_meter.avg
    acc = acc1_meter.avg
    print('epoch:{}\tloss:{:.2f}\tacc:{:.2f}'.format(epoch, ave_loss, acc))
    return ave_loss, acc


# 验证过程
@torch.no_grad()
def val(model, device, test_loader):
    global Best_ACC
    model.eval()
    loss_meter = AverageMeter()
    acc1_meter = AverageMeter()
    total_num = len(test_loader.dataset)
    print(total_num, len(test_loader))
    val_list = []
    pred_list = []

    for data, target in test_loader:
        for t in target:
            val_list.append(t.data.item())
        data, target = data.to(device,non_blocking=True), target.to(device,non_blocking=True)
        output = model(data)
        loss = criterion_val(output, target)
        _, pred = torch.max(output.data, 1)
        for p in pred:
            pred_list.append(p.data.item())
        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        loss_meter.update(loss.item(), target.size(0))
        acc1_meter.update(acc1.item(), target.size(0))
    acc = acc1_meter.avg
    print('\nVal set: Average loss: {:.4f}\tAcc1:{:.3f}%\n'.format(
        loss_meter.avg,  acc))

    if acc > Best_ACC:
        if isinstance(model, torch.nn.DataParallel):
            torch.save(model.module, file_dir + '/' + 'best.pth')
        else:
            torch.save(model, file_dir + '/' + 'best.pth',pickle_protocol=0)
        Best_ACC = acc
    if isinstance(model, torch.nn.DataParallel):
        state = {

            'epoch': epoch,
            'state_dict': model.module.state_dict(),
            'Best_ACC':Best_ACC
        }
        if use_ema:
            state['state_dict_ema']=model.module.state_dict()
        torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')
    else:
        state = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'Best_ACC': Best_ACC
        }
        if use_ema:
            state['state_dict_ema']=model.state_dict()
        torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')
    return val_list, pred_list, loss_meter.avg, acc

def seed_everything(seed=0):
    os.environ['PYHTONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


if __name__ == '__main__':
    file_dir = 'checkpoints/uniRepLKNet/'
    if os.path.exists(file_dir):
        print('true')
        os.makedirs(file_dir,exist_ok=True)
    else:
        os.makedirs(file_dir)

    model_lr = 1e-4
    BATCH_SIZE = 8
    EPOCHS = 200
    DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    use_amp = True
    use_dp = True
    classes = 3
    resume =None
    CLIP_GRAD = 5.0
    Best_ACC = 0
    use_ema=True
    model_ema_decay=0.9998
    start_epoch=1
    seed=0
    seed_everything(seed)
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std= [0.5, 0.5, 0.5])

    ])
    transform_test = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std= [0.5, 0.5, 0.5])
    ])
    mixup_fn = Mixup(
        mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
        prob=0.1, switch_prob=0.5, mode='batch',
        label_smoothing=0.1, num_classes=classes)

    dataset_train = datasets.ImageFolder('data/train', transform=transform)
    dataset_test = datasets.ImageFolder("data/val", transform=transform_test)
    with open('class.txt', 'w') as file:
        file.write(str(dataset_train.class_to_idx))
    with open('class.json', 'w', encoding='utf-8') as file:
        file.write(json.dumps(dataset_train.class_to_idx))
    train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True,drop_last=True)
    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

    criterion_train = SoftTargetCrossEntropy()
    criterion_val = torch.nn.CrossEntropyLoss()
    model_ft = unireplknet_t()
    num_fr=model_ft.head.in_features
    model_ft.head =nn.Linear(num_fr,classes)
    print(model_ft)
    if resume:
        model=torch.load(resume)
        print(model['state_dict'].keys())
        model_ft.load_state_dict(model['state_dict'],strict = False)
        Best_ACC=model['Best_ACC']
        start_epoch=model['epoch']+1
    model_ft.to(DEVICE)
    optimizer = optim.AdamW(model_ft.parameters(),lr=model_lr)
    cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=30, eta_min=1e-9)
    if use_amp:
        scaler = torch.cuda.amp.GradScaler()
    if torch.cuda.device_count() > 1 and use_dp:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model_ft = torch.nn.DataParallel(model_ft)
    if use_ema:
        model_ema = ModelEma(
            model_ft,
            decay=model_ema_decay,
            device=DEVICE,
            resume=resume)
    else:
        model_ema=None

    # 训练与验证
    is_set_lr = False
    log_dir = {}
    train_loss_list, val_loss_list, train_acc_list, val_acc_list, epoch_list = [], [], [], [], []
    epoch_info = []
    if resume and os.path.isfile(file_dir+"result.json"):
        with open(file_dir+'result.json', 'r', encoding='utf-8') as file:
            logs = json.load(file)
            train_acc_list = logs['train_acc']
            train_loss_list = logs['train_loss']
            val_acc_list = logs['val_acc']
            val_loss_list = logs['val_loss']
            epoch_list = logs['epoch_list']
    for epoch in range(start_epoch, EPOCHS + 1):
        epoch_list.append(epoch)
        log_dir['epoch_list'] = epoch_list
        train_loss, train_acc = train(model_ft, DEVICE, train_loader, optimizer, epoch,model_ema)
        train_loss_list.append(train_loss)
        train_acc_list.append(train_acc)
        log_dir['train_acc'] = train_acc_list
        log_dir['train_loss'] = train_loss_list
        if use_ema:
            val_list, pred_list, val_loss, val_acc = val(model_ema.ema, DEVICE, test_loader)
        else:
            val_list, pred_list, val_loss, val_acc = val(model_ft, DEVICE, test_loader)
        val_loss_list.append(val_loss)
        val_acc_list.append(val_acc)
        log_dir['val_acc'] = val_acc_list
        log_dir['val_loss'] = val_loss_list
        log_dir['best_acc'] = Best_ACC
        with open(file_dir + '/result.json', 'w', encoding='utf-8') as file:
            file.write(json.dumps(log_dir))
        print(classification_report(val_list, pred_list, target_names=dataset_train.class_to_idx))
        epoch_info.append({
            'epoch': epoch,
            'train_loss': train_loss,
            'train_acc': train_acc,
            'val_loss': val_loss,
            'val_acc': val_acc
        })
        with open('epoch_info.txt', 'w') as f:
            for epoch_data in epoch_info:
                f.write(f"Epoch: {epoch_data['epoch']}\n")
                f.write(f"Train Loss: {epoch_data['train_loss']}\n")
                f.write(f"Train Acc: {epoch_data['train_acc']}\n")
                f.write(f"Val Loss: {epoch_data['val_loss']}\n")
                f.write(f"Val Acc: {epoch_data['val_acc']}\n")
                f.write("\n")
        if epoch < 600:
            cosine_schedule.step()
        else:
            if not is_set_lr:
                for param_group in optimizer.param_groups:
                    param_group["lr"] = 1e-6
                    is_set_lr = True
        fig = plt.figure(1)
        plt.plot(epoch_list, train_loss_list, 'r-', label=u'Train Loss')
        # 显示图例
        plt.plot(epoch_list, val_loss_list, 'b-', label=u'Val Loss')
        plt.legend(["Train Loss", "Val Loss"], loc="upper right")
        plt.xlabel(u'epoch')
        plt.ylabel(u'loss')
        plt.title('Model Loss ')
        plt.savefig(file_dir + "/loss.png")
        plt.close(1)
        fig2 = plt.figure(2)
        plt.plot(epoch_list, train_acc_list, 'g-', label=u'Train Acc')
        plt.plot(epoch_list, val_acc_list, 'y-', label=u'Val Acc')
        plt.legend(["Train Acc", "Val Acc"], loc="lower right")
        plt.title("Model Acc")
        plt.ylabel("acc")
        plt.xlabel("epoch")
        plt.savefig(file_dir + "/acc.png")
        plt.close(2)

测试结果:代码在100轮时acc为95.88%,200轮时达到97.89%,loss图和acc图如下

混淆矩阵:

from torchvision import transforms
import numpy as np
from prettytable import PrettyTable
import matplotlib.pyplot as plt
from data.dataset import SeedlingData
import torch
from unireplknet import unireplknet_t
import torch.nn as nn
class ConfusionMatrix(object):

    def __init__(self, num_classes: int, labels: list):
        self.matrix = np.zeros((num_classes, num_classes))
        self.num_classes = num_classes
        self.labels = labels

    def update(self, preds, labels):
        for p, t in zip(preds, labels):
            self.matrix[p, t] += 1

    def summary(self):
        # calculate accuracy
        sum_TP = 0
        for i in range(self.num_classes):
            sum_TP += self.matrix[i, i]
        acc = sum_TP / np.sum(self.matrix)
        print("the model accuracy is ", acc)

        # precision, recall, specificity
        table = PrettyTable()
        table.field_names = ["","accuracy","Precision", "Recall", "Specificity","F1 Score"]#
        for i in range(self.num_classes):
            TP = self.matrix[i, i]
            FP = np.sum(self.matrix[i, :]) - TP
            FN = np.sum(self.matrix[:, i]) - TP
            TN = np.sum(self.matrix) - TP - FP - FN
            Accuracy = round((TP + TN)/ (TP + TN + FP + FN), 3) if (TP + TN + FP + FN) != 0 else 0.
            Precision = round(TP / (TP + FP), 3) if TP + FP != 0 else 0.
            Recall = round(TP / (TP + FN), 3) if TP + FN != 0 else 0.
            Specificity = round(TN / (TN + FP), 3) if TN + FP != 0 else 0.
            F1 = round((2*TP) / (2*TP + FN + FP), 3) if (2*TP + FN + FP) != 0 else 0.#
            table.add_row([self.labels[i],Accuracy, Precision, Recall, Specificity,F1])#
        print(table)

    def plot(self):
        matrix = self.matrix
        print(matrix)
        plt.imshow(matrix, cmap=plt.cm.Blues)

        # 设置x轴坐标label
        plt.xticks(range(self.num_classes), self.labels,fontsize=8)
        # 设置y轴坐标labe
        plt.yticks(range(self.num_classes), self.labels,fontsize=10)
        # 显示colorbar
        plt.colorbar()
        plt.xlabel('True Labels',fontsize=15)
        plt.ylabel('Predicted Labels',fontsize=15)
        plt.title('Confusion matrix',fontsize=15)

        # 在图中标注数量/概率信息
        thresh = matrix.max() / 2
        for x in range(self.num_classes):#
            for y in range(self.num_classes):
                # 注意这里的matrix[y, x]不是matrix[x, y]
                info = int(matrix[y, x])
                plt.text(x, y, info,
                         verticalalignment='center',
                         horizontalalignment='center',
                         color="white" if info > thresh else "black",fontsize=20)
        plt.tight_layout()
        plt.show()

num_classes = 3  # 类别数量
batch_size = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
transform_test = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
dataset_test = SeedlingData("data/val", transforms=transform_test, train=False)
labels = ['Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy']  # 类别标签列表
confusion_matrix = ConfusionMatrix(num_classes, labels)
model_ft = unireplknet_t()
num_fr = model_ft.head.in_features
model_ft.head = nn.Linear(num_fr, 3)
model=torch.load('')
model_ft.load_state_dict(model['state_dict'],strict = False)  # 替换为你的权重文件路径
model_ft.to(device)

with torch.no_grad():
    model_ft.eval()
    test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=1, shuffle=False)  # 使用测试集创建数据加载器
    for batch in test_loader:
        images, labels = batch  # 获取测试集的图像和标签
        images = images.to(device)  # 将图像移动到设备(例如GPU)上
        labels = labels.to(device)  # 将标签移动到设备(例如GPU)上

        preds = model_ft(images)  # 使用模型预测测试集
        _, predicted_labels = torch.max(preds, dim=1)  # 获取预测结果的类别标签
        predicted_labels = predicted_labels.cpu().numpy()  # 将预测结果转换为numpy数组
        labels = labels.cpu().numpy()  # 将真实标签转换为numpy数组
        confusion_matrix.update(predicted_labels, labels)  # 更新混淆矩阵
confusion_matrix.plot()
save_path = 'confusion_matrix.png'
plt.savefig(save_path)
confusion_matrix.summary()

 

测试数据:

the model accuracy is  0.9931506849315068
+--------------------------+----------+-----------+--------+-------------+----------+
|                          | accuracy | Precision | Recall | Specificity | F1 Score |
+--------------------------+----------+-----------+--------+-------------+----------+
|    Apple___Black_rot     |   1.0    |    1.0    |  1.0   |     1.0     |   1.0    |
| Apple___Cedar_apple_rust |  0.993   |   0.977   |  1.0   |     0.99    |  0.988   |
|     Apple___healthy      |  0.993   |    1.0    | 0.984  |     1.0     |  0.992   |
+--------------------------+----------+-----------+--------+-------------+----------+

以上即为全部内容!

  • 19
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值