计算机视觉篇---图像分类实战+理论讲解(7)ShuffleNetV2

ShuffleNetV2

创新点
分组卷积,减少参数量,对输入特征层进行分组,打乱,进行卷积
数据集

from PIL import Image
import torch
from torch.utils.data import Dataset

class MyDataSet(Dataset):
    def__init__(self,images_path:list,images_class:list,transform=None):
        self.images_path=images_path
        self.images_class=images_class
        self.transform=transform
    def__len__(self):
        return len(self.images_path)
    def__getitem__(self,item):
        img=Image.open(self.image_path[item])
        if img.mode!='RGB':
            raise ValueError("image : {} isn't RGB mode.".format(self.images_path[item]))
        label=self.images_class[item]
        if self.transform is not None:
            img=self.transform(img)
        return img,label
    @staticmethod
    def collate_fn(batch):
        images,labels=tuple(zip(*batch))
        images=torch.stack(images,dim=0)
        labels=torch.as_tensor(labels)
        return images,labels
        #torch.cat()对tensors沿指定维度拼接,但返回的Tensor的维数不会变
        #torch.stack()同样是对tensors沿指定维度拼接,但返回的Tensor会多一维
        

将训练与验证分组件化编写程序
utils

import os
import sys
import json
import pickle
import random 
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

def read_split_data(root:str,val_rate:float=0.2):
    random.seed(0)
    assert os.path.exists(root)."datasets root:{} dose not exists.".format(root)
    flower_class=[cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root,cla))]
    #遍历文件夹文件夹的名字就是相应类别的名字
    flower_class.sort()
    class_indices=dict((k,v) for v,k in enumerate(flower_class))
    json_str=json.dumps(dict((val,key) for key,val in class_indices.items()),indent=4)
    with open('class_indices.json','w') as json_file:
        json_file.write(json_str)
    train_images_path=[]
    train_images_label=[]
    val_images_path=[]
    val_images_label=[]
    every_class_num=[]
    supported=[".jpg",".JPG",".png",".PNG"]
    for cla in flower_class:
        cla_path=os.path.join(root,cla)
        images=[os.path.join(root,cla,i) for i in os.listdir(cla_path) if os.path.splitext(i)[-1] in supported]
        image_class=class_indicts[cla]
        every_class_num.append(len(images))
        val_path=random.sample(images,k=int(len(images)*val_rate))
        for img_path in images:
            if img_path in val_path:
                val_images_path.append(img_path)
                val_images_label.append(image_class)
                #同一个文件的类别都是一样的
            else 
                train_images_path.append(img_path)
                train_images_label.append(image_class)
        print("{} images were found in the dataset.".format(sum(every_class_num)))
        print("{} images for training.".format(len(trian_images_path)))
        print("{} images for validation.".format(len(val_images_path)))
        plot_image=False
        if plot_image:
            plt.bar(range(len(flower_class)),every_class_num,
            align='center')
            plt.xticks(range(len(flower_class)),flower_class)
            for i,v in enumerate(every_class_num):
                plt.text(x=i,y=v+5,s=str(v),ha='center')
            plt.xlabel('image class')
            plt.ylabel('number of images')
            plt.title('flower class distribution')
            plt.show()
        return train_images_path,train_images_label,val_image_paht,val_images_label
def plot_data_loader_image(data_loader):
    batch_size=data_loader.batch_size
    plot_num=min(batch_size,4)
    json_path='./class_indicts.json'
    assert os.path.exists(json_path),json_path+"does not exist。“
    json_file=open(json_path,'r')
    class_indices=json_load(json_file)
    for data int data_loader:
        images,labels=data
        for i in range(plot_num):
            img=image[i].numpy().transpose(1,2,0)
            img=(img*[0.229,0.334.0.225]+[0.485,0.465,0.406])*255
            label=labels[i].item()
            plt.subplot(1,plot_num,i+1)
            plt.xlabel(class_indices[str(label)])
            plt.xticks([])
            plt.yticks([])
            plt.imshow(img.astype('uint8'))
        plt.show()
def write_pickle(list_info:list,file_name:str):
    with open(file_name,'wb')as f:
        pickle.dump(list_info,f)
def read_pickle(file_name:str)->list:
    with open(file_name,'rb')as f:
        info_list=pickle.load(f)
        return info_list
def train_one_epoch(model,optimizer,data_loader,device,epoch):
    model.trian()
    loss_function=torch.nn.CrossEntropyLoss()
    optimizer.zero_grad()
    mean_loss=torch.zeros(1).to(device)
    data_loader=tqdm(data_loader)
    for step,data in enumerate(data_loader):
        images,labels=data
        pred=model(images.to(device))
        loss=loss_function(pred,labels.to(device))
        loss.backward()
        mean_loss=(mean_loss*step+loss.detach())/(step+1)
        data_loader.desc="[epoch{}] maen loss {}".format(epoch,
        round(mean_loss.item(),3))
        if not torch.isfinite(loss):
            print('VARNING:non-finite loss,ending training',loss)
            sys.exit(1)
        optimizer.step()
        optimizet.zero_grad()
    return mean_loss.item()

@torch.no_grad()
def evaluate(model,data_loader,device):
    model.eval()
    total_num=len(data_loader.dataset)
    sum_num=torch.zeros(1).to(device)
    data_loader=tqdm(data_loader)
    for step,data in enumerate(data_loader):
        images,labels=data
        pred=model(images.to(device))
        pred=torch.max(pred,dim=1)[1]
        sum_num+=torch.eq(pred,labels.to(device)).sum()
    return sum_num.item()/total_num
            
        


    


模型搭建

from typing import List,Callable
import torch 
import torch.nn as nn
from torch import Tensor
def channel_shuffle(x:Tensor,groups:int)->Tensor:
    batch_size,num_channels,height,width=x.size()
    channels_pre_group=num_channels//groups
    x=x.view(batch_size,groups,channels_pre_group,height,width)
    x=torch.transpose(x,1,2).contiguous()
    x=x.view(batch_size,-1,height,width)
    return x
class InvertedResidual(nn.Module):
    def__init__(self,input_c:int,output_c:int,stride:int):
    super(InvertedResidual,self).__init__()
    if stride not in [1,2]:
        raise ValueError("illegal stride value.")
    self.stride=stride
    assert output_c%2==0
    #其作用是如果它的条件返回错误,则终止程序执行.
    branch_features=output_c//2
    assert(self.stride!=1) or (input_c==branch_features<<1)
    #位运算 乘以2
    if self.stride==2:
        self.branch1==nn.Sequential(
            self.depthwise_conv(input_c,output_c,kernel_s=3,stride=self.stride,padding=1)
            nn.BatchNorm2d(input_c),
            nn.Conv2d(input_c,branch_features,kernel_size=1,stride=1,padding=0,bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True))

    else:
         self.brach1=nn.Sequential()
    self.branch2=nn.Sequential(
         nn.Conv2d(input_c if self.stride>1 else branch_features,branch_features,kernel_size=1,stride=1,padding=0,bias=False),
         nn.BatchNorm2d(brach_features),
         nn.ReLU(inplace=True),
         self.depthwise_conv(branch_features,branch_features,kernel_size=1,stride=1,padding=0,bias=False),
         nn.BatchNorm2d(branch_features),
         nn.ReLU(inplace=True))
     @staticmethod
     def depthwise_conv(input_c:int,
     output_c:int,
     kernel_s:int,
     stride:int=1,
     padding:int=0,
     bias:bool=False)->nn.Conv2d:
         return nn.Conv2d(in_channels=input_c,out_channels=output_c,kernel_size=kernel_s,stride=stride,padding=padding,bias=bias,groups=input_c)
     def forward(self,x:Tensor)->Tensor:
         if self.stride=1:
             x1,x2=x.chunk(2,dim=1)
             out=torch.cat((x1,self.branch2(x2)),dim=1)
         else:
             out=torch.cat((self.branch1(x),self.branch2(x)),dim=1)
         out=channel_shuffle(out,2)
         return out
class ShuffleNetV2(nn.Module):
    def __init__(self,stages_repeates:List[int],
    stages_out_channels:List[int],
    num_classes:int=1000,
    inverted_residual:Callable[...,nn.Module]=InvertedResidual):
        super(ShuffleNetV2,self).__init__()
        if len(stages_repeats)!=3:
            raise ValueError("expected stages_repeats as list of 3 positive ints")

        if len(stages_out_channels)!=5:
            raise ValueError("expected stages_out_channels as list of 5 positive ints")
        self._stage_out_channels=stages_out_channels
        input_channels=3
        output_channels=slef._stage_out_channels[0]
        self.conv1=nn.Sequential(
        nn.Conv2d(input_channels,output_channels,kernel_size=3,stride=2,padding=1,bias=False),
        nn.BatchNorm2d(output_channels),
        nn.ReLU(inplace=True)
        )
        input_channels=output_channels
        self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
        self.stage2:nn.Sequential
        self.stage3:nn.Sequential
        self.stage4:nn.Sequential
        stage_names=["stage{}.format(i) for i in [2,3,4]]
        for name,repeats,output_channels in zip(stage_names,stage_repeats,self._stage_out_channels[1:]):
            seq=[inverted_residual(input_channels,output_channels,2)]
            for i in range(repeats-1):
                seq.append(inverted_residual(output_channels,output_channels,1))
            setattr(self,name,nn.Sequential(*seq))
            input_channels=output_channels
        output_channels=self._stage_out_channels[-1]
        self.conv5=nn.Sequential(
        nn.Conv2d(input_channels,output_channeles,kernel_size=1,stride=1,padding=0,bias=False),
        nn.BatchNorm2d(output_channels),
        nn.ReLU(inplace=True)
        )
        self.fc=nn.Linear(output_channels,num_classes)
    def _forward_impl(self,x:Tensor)->Tensor:
        x=self.conv1(X)
        x=self.maxpool(x)
        x=self.stage2(X)
        x=self.stage3(x)
        x=self.stage4(x)
        x=self.conv5(x)
        x=x.mean([2,3]) #卷积池化
        x=self.fc(x)
        return x
    def forward(self,x:Tensor)->Tensor:
        return self._forward_impl(x)
def shufflenet_v2_x1_0(num_classes=1000):
    model=ShuffleNetV2(stage_repeats=[4,8,4],
    stages_out_channels=[24,116,232,464,1024],
    num_classes=1000)
    return model
def shufflenet_v2_x0_5(num_classes=1000):
    model=ShuffleNetV2(stages_repeats=[4,8,4],
    stages_out_channels=[24,48,96,192,1024],
    num_classes=num_classes)
    return model
        


训练

import os 
import math
import argparse
import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import torch.optim.lr_scheduler as lr_scheduler

from model import shufflenet_v2_x1_0
from my_dataset import MyDataSet
from utils import read_split_data,train_one_epoch,evaluate


def main(args):
    device=torch.device(args.device if torch.cuda.is_availabel() else "cpu")
    print(args)
    tb_writer=SummaryWriter()
    if os.path.exists("./weights") is False:
        os.makedirs("./weights")
    train_images_path,train_images_label,val_images_path,val_images_label=read_split_data(args.data_path)
    data_transform={
"train":transforms.Compose([transforms.RandomResizeCorp(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]),
"val":transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])

}
    train_dataset=MyDataset(images_path=train_images_path,images_class=train_images_label,transform=transform["trian"])
    val_dataset=MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])
    batch_size=args.batch_size
    nw=min([os.cpu_count(),batch_size if batch_size>1 else 0,8])
    print('Using {} dataloader worker every process'.format(nw))
    train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,
    num_workers=nw,
    collate_fn=train_dataset.collate_fn)
    val_loader=torch.util.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,
    num_workers=nw,collate_fn=val_dataset.collate_fn)
    model=shufflenet_v2_x1_0(num_classes=args.num_classes).to(device)
    if args.weights!="":
        if os.path.exists(args.weights):
            weights_dict=torch.load(args.weights,map_location=device)
            load_weights_dict={k:v for k,v in weight_dict.items()
            if model.state_dict()[k].numel()==v.numel()}
            print(model.load_state_dict(load_weight_dict,strict=False))
        else:
            raise FileNotFoundError("NOT FOUNT weights file:{}".format(args.weights))
    if args.freeze_layers:
        for name,para in model.named_parameters():
            if "fc" not in name:
                para.requiers_grad_(False)
    pg=[p for p in model.parametes() if p.requires_grad]
    optimizer=optim.SGD(pg,lr=args.lr,momentum=0.9,weight_decay=4E-5)
    lf=lambda x:((1+math.cos(x*math.pi/args.epochs))/2)*(1-args.lrf)+args.lrf
    scheduler=lr_scheduler.LambdaLR(optimizer,lr_lambda=lf)
    for epoch in range(args.epochs):
        mean_loss=train_one_epoch(model=model,optimizer=optimizet,
        data_loader=train_loader,
        device=device,
        epoch=epoch)
        scheduler.step()
        acc=evaluate(model=model,data_loader=val_loader,device=evice)
        print("[epoch{}] accuracy:{}".format(epoch,round(acc,3)))
        tags=["loss","accuracy","learning_rate"]
        tb_writer.add_scalar(tags[0],mean_loss,epoch)
        tb_writer.add_scalar(tags[1],acc,epoch)
        tb_writer.add_scalar(tags[2],optimizer.param_groups[0]["lr"],epoch)
        torch.save(model.state_dict(),"./weights/model-{}.pth".format(epoch))

if__name__=='__main__':
    parser=argparse.ArgumentParser()
    parser.add_argument('--num_classes',type=int,default=5)
    parser.add_argument('--epochs',type=int,default=30)
    parser.add_argument('--batch-size',type=int,default=16)
    parser.add_argument('--lr',type=float,default=0.01)\
    parser.add_argument('--lrf',type=float,default=0.1)
    parser.add_argument('--data-path',type=str,default="/data/loower_photos")
    parser.add_argument('--weights',type=str,default='./shufflenetv2_x1,pth',
    help='initial weight path')
    parser.add_argument('--freeze-layers',type=bool,default=False)
    parser.add_argument('--device',default='cuda:0',help='device if)
    opt=parser.parse_args()
    main(opt)

测试

import os
import json
import torch 
from PIL import Image
form torchvision import transforms
import matplotlib.pyplot as plt
from model import shufflenet_v2_x1_0
def main():
    device=torch.device("cuda:0" if torch.cuda.is_availabel() else "cpu")9
    data_transform=transforms.Compose(
    [transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.4,0.4],[0.5,0.5,0.5])
    ])
    img_paht="./tuplig.jpg"
    assert os.path.exists(img_path)."file :'{}' dose not exist.".format(img_path))
    img=Image.open(img_path)
    plt.imshow(img)
    img=data_transform(img)
    img=torch.unsqueeze(img,dim=0)
    json_path='./class_indicts.json'
    assert os.path.exists(json_path),"file:'{}' dose not exist.".fromat(json_path)
    json_file=open(json_path,"r")
    class_indict=json.load(json_file)
    model=shufflenet_v2_x1_0(num_classes=5).to(device))
    model_weight_path="./weights/model-29.pt
    model.loade_state_dict(torch.load(model_weight_path,map_location=device))
    model.eval()
    with torch.no_grad():
        output=torch.squeeze(model(img.to(device)
        predict=torch.softmax(output,dim=0)
        predict_cla=torch.argmax(predict).numpy()
    print_res="class :{} prob:{:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())
    plt.title(print_res)
    print(print_res)
    plt.show()
if __name__=='__main__':
    main()
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

free_girl_fang

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值