ShuffleNetV2
创新点
分组卷积,减少参数量,对输入特征层进行分组,打乱,进行卷积
数据集
from PIL import Image
import torch
from torch.utils.data import Dataset
class MyDataSet(Dataset):
def__init__(self,images_path:list,images_class:list,transform=None):
self.images_path=images_path
self.images_class=images_class
self.transform=transform
def__len__(self):
return len(self.images_path)
def__getitem__(self,item):
img=Image.open(self.image_path[item])
if img.mode!='RGB':
raise ValueError("image : {} isn't RGB mode.".format(self.images_path[item]))
label=self.images_class[item]
if self.transform is not None:
img=self.transform(img)
return img,label
@staticmethod
def collate_fn(batch):
images,labels=tuple(zip(*batch))
images=torch.stack(images,dim=0)
labels=torch.as_tensor(labels)
return images,labels
#torch.cat()对tensors沿指定维度拼接,但返回的Tensor的维数不会变
#torch.stack()同样是对tensors沿指定维度拼接,但返回的Tensor会多一维
将训练与验证分组件化编写程序
utils
import os
import sys
import json
import pickle
import random
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
def read_split_data(root:str,val_rate:float=0.2):
random.seed(0)
assert os.path.exists(root)."datasets root:{} dose not exists.".format(root)
flower_class=[cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root,cla))]
#遍历文件夹文件夹的名字就是相应类别的名字
flower_class.sort()
class_indices=dict((k,v) for v,k in enumerate(flower_class))
json_str=json.dumps(dict((val,key) for key,val in class_indices.items()),indent=4)
with open('class_indices.json','w') as json_file:
json_file.write(json_str)
train_images_path=[]
train_images_label=[]
val_images_path=[]
val_images_label=[]
every_class_num=[]
supported=[".jpg",".JPG",".png",".PNG"]
for cla in flower_class:
cla_path=os.path.join(root,cla)
images=[os.path.join(root,cla,i) for i in os.listdir(cla_path) if os.path.splitext(i)[-1] in supported]
image_class=class_indicts[cla]
every_class_num.append(len(images))
val_path=random.sample(images,k=int(len(images)*val_rate))
for img_path in images:
if img_path in val_path:
val_images_path.append(img_path)
val_images_label.append(image_class)
#同一个文件的类别都是一样的
else
train_images_path.append(img_path)
train_images_label.append(image_class)
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(trian_images_path)))
print("{} images for validation.".format(len(val_images_path)))
plot_image=False
if plot_image:
plt.bar(range(len(flower_class)),every_class_num,
align='center')
plt.xticks(range(len(flower_class)),flower_class)
for i,v in enumerate(every_class_num):
plt.text(x=i,y=v+5,s=str(v),ha='center')
plt.xlabel('image class')
plt.ylabel('number of images')
plt.title('flower class distribution')
plt.show()
return train_images_path,train_images_label,val_image_paht,val_images_label
def plot_data_loader_image(data_loader):
batch_size=data_loader.batch_size
plot_num=min(batch_size,4)
json_path='./class_indicts.json'
assert os.path.exists(json_path),json_path+"does not exist。“
json_file=open(json_path,'r')
class_indices=json_load(json_file)
for data int data_loader:
images,labels=data
for i in range(plot_num):
img=image[i].numpy().transpose(1,2,0)
img=(img*[0.229,0.334.0.225]+[0.485,0.465,0.406])*255
label=labels[i].item()
plt.subplot(1,plot_num,i+1)
plt.xlabel(class_indices[str(label)])
plt.xticks([])
plt.yticks([])
plt.imshow(img.astype('uint8'))
plt.show()
def write_pickle(list_info:list,file_name:str):
with open(file_name,'wb')as f:
pickle.dump(list_info,f)
def read_pickle(file_name:str)->list:
with open(file_name,'rb')as f:
info_list=pickle.load(f)
return info_list
def train_one_epoch(model,optimizer,data_loader,device,epoch):
model.trian()
loss_function=torch.nn.CrossEntropyLoss()
optimizer.zero_grad()
mean_loss=torch.zeros(1).to(device)
data_loader=tqdm(data_loader)
for step,data in enumerate(data_loader):
images,labels=data
pred=model(images.to(device))
loss=loss_function(pred,labels.to(device))
loss.backward()
mean_loss=(mean_loss*step+loss.detach())/(step+1)
data_loader.desc="[epoch{}] maen loss {}".format(epoch,
round(mean_loss.item(),3))
if not torch.isfinite(loss):
print('VARNING:non-finite loss,ending training',loss)
sys.exit(1)
optimizer.step()
optimizet.zero_grad()
return mean_loss.item()
@torch.no_grad()
def evaluate(model,data_loader,device):
model.eval()
total_num=len(data_loader.dataset)
sum_num=torch.zeros(1).to(device)
data_loader=tqdm(data_loader)
for step,data in enumerate(data_loader):
images,labels=data
pred=model(images.to(device))
pred=torch.max(pred,dim=1)[1]
sum_num+=torch.eq(pred,labels.to(device)).sum()
return sum_num.item()/total_num
模型搭建
from typing import List,Callable
import torch
import torch.nn as nn
from torch import Tensor
def channel_shuffle(x:Tensor,groups:int)->Tensor:
batch_size,num_channels,height,width=x.size()
channels_pre_group=num_channels//groups
x=x.view(batch_size,groups,channels_pre_group,height,width)
x=torch.transpose(x,1,2).contiguous()
x=x.view(batch_size,-1,height,width)
return x
class InvertedResidual(nn.Module):
def__init__(self,input_c:int,output_c:int,stride:int):
super(InvertedResidual,self).__init__()
if stride not in [1,2]:
raise ValueError("illegal stride value.")
self.stride=stride
assert output_c%2==0
#其作用是如果它的条件返回错误,则终止程序执行.
branch_features=output_c//2
assert(self.stride!=1) or (input_c==branch_features<<1)
#位运算 乘以2
if self.stride==2:
self.branch1==nn.Sequential(
self.depthwise_conv(input_c,output_c,kernel_s=3,stride=self.stride,padding=1)
nn.BatchNorm2d(input_c),
nn.Conv2d(input_c,branch_features,kernel_size=1,stride=1,padding=0,bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True))
else:
self.brach1=nn.Sequential()
self.branch2=nn.Sequential(
nn.Conv2d(input_c if self.stride>1 else branch_features,branch_features,kernel_size=1,stride=1,padding=0,bias=False),
nn.BatchNorm2d(brach_features),
nn.ReLU(inplace=True),
self.depthwise_conv(branch_features,branch_features,kernel_size=1,stride=1,padding=0,bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True))
@staticmethod
def depthwise_conv(input_c:int,
output_c:int,
kernel_s:int,
stride:int=1,
padding:int=0,
bias:bool=False)->nn.Conv2d:
return nn.Conv2d(in_channels=input_c,out_channels=output_c,kernel_size=kernel_s,stride=stride,padding=padding,bias=bias,groups=input_c)
def forward(self,x:Tensor)->Tensor:
if self.stride=1:
x1,x2=x.chunk(2,dim=1)
out=torch.cat((x1,self.branch2(x2)),dim=1)
else:
out=torch.cat((self.branch1(x),self.branch2(x)),dim=1)
out=channel_shuffle(out,2)
return out
class ShuffleNetV2(nn.Module):
def __init__(self,stages_repeates:List[int],
stages_out_channels:List[int],
num_classes:int=1000,
inverted_residual:Callable[...,nn.Module]=InvertedResidual):
super(ShuffleNetV2,self).__init__()
if len(stages_repeats)!=3:
raise ValueError("expected stages_repeats as list of 3 positive ints")
if len(stages_out_channels)!=5:
raise ValueError("expected stages_out_channels as list of 5 positive ints")
self._stage_out_channels=stages_out_channels
input_channels=3
output_channels=slef._stage_out_channels[0]
self.conv1=nn.Sequential(
nn.Conv2d(input_channels,output_channels,kernel_size=3,stride=2,padding=1,bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
input_channels=output_channels
self.maxpool=nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.stage2:nn.Sequential
self.stage3:nn.Sequential
self.stage4:nn.Sequential
stage_names=["stage{}.format(i) for i in [2,3,4]]
for name,repeats,output_channels in zip(stage_names,stage_repeats,self._stage_out_channels[1:]):
seq=[inverted_residual(input_channels,output_channels,2)]
for i in range(repeats-1):
seq.append(inverted_residual(output_channels,output_channels,1))
setattr(self,name,nn.Sequential(*seq))
input_channels=output_channels
output_channels=self._stage_out_channels[-1]
self.conv5=nn.Sequential(
nn.Conv2d(input_channels,output_channeles,kernel_size=1,stride=1,padding=0,bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
self.fc=nn.Linear(output_channels,num_classes)
def _forward_impl(self,x:Tensor)->Tensor:
x=self.conv1(X)
x=self.maxpool(x)
x=self.stage2(X)
x=self.stage3(x)
x=self.stage4(x)
x=self.conv5(x)
x=x.mean([2,3]) #卷积池化
x=self.fc(x)
return x
def forward(self,x:Tensor)->Tensor:
return self._forward_impl(x)
def shufflenet_v2_x1_0(num_classes=1000):
model=ShuffleNetV2(stage_repeats=[4,8,4],
stages_out_channels=[24,116,232,464,1024],
num_classes=1000)
return model
def shufflenet_v2_x0_5(num_classes=1000):
model=ShuffleNetV2(stages_repeats=[4,8,4],
stages_out_channels=[24,48,96,192,1024],
num_classes=num_classes)
return model
训练
import os
import math
import argparse
import torch
import torch.optim as optim
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriter
import torch.optim.lr_scheduler as lr_scheduler
from model import shufflenet_v2_x1_0
from my_dataset import MyDataSet
from utils import read_split_data,train_one_epoch,evaluate
def main(args):
device=torch.device(args.device if torch.cuda.is_availabel() else "cpu")
print(args)
tb_writer=SummaryWriter()
if os.path.exists("./weights") is False:
os.makedirs("./weights")
train_images_path,train_images_label,val_images_path,val_images_label=read_split_data(args.data_path)
data_transform={
"train":transforms.Compose([transforms.RandomResizeCorp(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]),
"val":transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
}
train_dataset=MyDataset(images_path=train_images_path,images_class=train_images_label,transform=transform["trian"])
val_dataset=MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])
batch_size=args.batch_size
nw=min([os.cpu_count(),batch_size if batch_size>1 else 0,8])
print('Using {} dataloader worker every process'.format(nw))
train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,
num_workers=nw,
collate_fn=train_dataset.collate_fn)
val_loader=torch.util.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,
num_workers=nw,collate_fn=val_dataset.collate_fn)
model=shufflenet_v2_x1_0(num_classes=args.num_classes).to(device)
if args.weights!="":
if os.path.exists(args.weights):
weights_dict=torch.load(args.weights,map_location=device)
load_weights_dict={k:v for k,v in weight_dict.items()
if model.state_dict()[k].numel()==v.numel()}
print(model.load_state_dict(load_weight_dict,strict=False))
else:
raise FileNotFoundError("NOT FOUNT weights file:{}".format(args.weights))
if args.freeze_layers:
for name,para in model.named_parameters():
if "fc" not in name:
para.requiers_grad_(False)
pg=[p for p in model.parametes() if p.requires_grad]
optimizer=optim.SGD(pg,lr=args.lr,momentum=0.9,weight_decay=4E-5)
lf=lambda x:((1+math.cos(x*math.pi/args.epochs))/2)*(1-args.lrf)+args.lrf
scheduler=lr_scheduler.LambdaLR(optimizer,lr_lambda=lf)
for epoch in range(args.epochs):
mean_loss=train_one_epoch(model=model,optimizer=optimizet,
data_loader=train_loader,
device=device,
epoch=epoch)
scheduler.step()
acc=evaluate(model=model,data_loader=val_loader,device=evice)
print("[epoch{}] accuracy:{}".format(epoch,round(acc,3)))
tags=["loss","accuracy","learning_rate"]
tb_writer.add_scalar(tags[0],mean_loss,epoch)
tb_writer.add_scalar(tags[1],acc,epoch)
tb_writer.add_scalar(tags[2],optimizer.param_groups[0]["lr"],epoch)
torch.save(model.state_dict(),"./weights/model-{}.pth".format(epoch))
if__name__=='__main__':
parser=argparse.ArgumentParser()
parser.add_argument('--num_classes',type=int,default=5)
parser.add_argument('--epochs',type=int,default=30)
parser.add_argument('--batch-size',type=int,default=16)
parser.add_argument('--lr',type=float,default=0.01)\
parser.add_argument('--lrf',type=float,default=0.1)
parser.add_argument('--data-path',type=str,default="/data/loower_photos")
parser.add_argument('--weights',type=str,default='./shufflenetv2_x1,pth',
help='initial weight path')
parser.add_argument('--freeze-layers',type=bool,default=False)
parser.add_argument('--device',default='cuda:0',help='device if)
opt=parser.parse_args()
main(opt)
测试
import os
import json
import torch
from PIL import Image
form torchvision import transforms
import matplotlib.pyplot as plt
from model import shufflenet_v2_x1_0
def main():
device=torch.device("cuda:0" if torch.cuda.is_availabel() else "cpu")9
data_transform=transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5,0.4,0.4],[0.5,0.5,0.5])
])
img_paht="./tuplig.jpg"
assert os.path.exists(img_path)."file :'{}' dose not exist.".format(img_path))
img=Image.open(img_path)
plt.imshow(img)
img=data_transform(img)
img=torch.unsqueeze(img,dim=0)
json_path='./class_indicts.json'
assert os.path.exists(json_path),"file:'{}' dose not exist.".fromat(json_path)
json_file=open(json_path,"r")
class_indict=json.load(json_file)
model=shufflenet_v2_x1_0(num_classes=5).to(device))
model_weight_path="./weights/model-29.pt
model.loade_state_dict(torch.load(model_weight_path,map_location=device))
model.eval()
with torch.no_grad():
output=torch.squeeze(model(img.to(device)
predict=torch.softmax(output,dim=0)
predict_cla=torch.argmax(predict).numpy()
print_res="class :{} prob:{:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())
plt.title(print_res)
print(print_res)
plt.show()
if __name__=='__main__':
main()