EfficientNetV2
创新点:裁剪网络 MB 和MBFUSE 模块
模块
from collections import OrderedDict
from functools import partial
from typing import Callable,Optional
import torch.nn as nn
import torch
from torch import Tensor
def drop_path(x,drop_path:float=0,training:bool=False):
if drop_path==0. or training:
return x
keep_prob=1-drop_prob
shape=(x.shape[0],)+(1,)*(x.ndim-1)
random_tensor=keep_prob+torch.rand(shape,dtype=x.dtype,device=x.device)
random_tensor.floor_()
output=x.div(keep_prob)*random_tensor
return output
class DropPath(nn.Module):
def __init__(self,drop_prob=None):
super(DropPath,self).__init__()
self.drop_prob=drop_prob
def forward(self,x):
return drop_path(x,self.drop_path,self.training)
class ConvBNAct(nn.Module):
def __init__(self,int_planes:int,
kernel_size:int=3,
stride:int=1,
group:int=1,
norm_layer:Optional[Callabel[...,nn.Module]]=None,
activation_layer:Optional[Callabel[...,nn.Module]]=None):
super(ConvBNAct,self).__init__()
padding=(kernel_size-1)//2
if norm_layer is None:
norm_layer=nn.BatchNorm2d
if activation_layer is None:
activaion_layer=nn.SiLU
self.conv=nn.Conv2d(in_channels=in_planes,
out_channels=out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False)
self.bn=norm_layer(out_planes)
self.act=activation_layer()
def forward(self,x):
result=self.conv(x)
result=self.bn(result)
result=self.act(result)
return result
class SequeezeExcite(nn.Module):
def__init__(self,input_c:int,
expand_c:int,
se_ratio:float=0.25):
super(SequeezeExcite,self).__init__()
squeeze_c=int(input_c*se_ratio)
self.conv_reduce=nn.Conv2d(expand_c,squeeze_c,1)
self.act1=nn.SiLU()
self.conv_expand=nn.Conv2d(squeeze_c,expand_c,1)
self.act2=nn.Sigmoid()
def forward(self,x:Tensor)->Tensor:
scale=x.mean((2,3),keepdim=True)
scale=self.conv_reduce(scale)
scale=self.act1(scale)
scale=self.conv_expand(scale)
scale=self.act2(scale)
return scale*x
class MBConv(nn.Module):
def__init__(self,kernel_size:int,
input_c:int,
out_c:int,
expand_ratio:int,
stride:int,
se_ratio:float,
drop_rate:float,
norm_layer:Callable[...,nn.Module]):
super(MBConv,self).__init__()
if stride not in [1,2]:
raise ValueError("illegal stride value.")
self.has_shortcut=(stride==1 and input_c==out_c)
activation_layer=nn.SiLU
expand_c=input_c* expand_ratio
assert expand_ratio!=1
self.expand_conv=ConvBNAct(input_c,expand_c,kernel_size=1,
norm_layer=norm_layer,
activation_layer=activation_layer)
self.dwconv=ConvBNAct(expand_c,expand_c,kernel_size=kernel_size,
stride=stride,
groups=expanded_c,
norm_layer=norm_layer,
activation_layer=activation_layer)
self.se=SqueezeExcite(input_c,expanded_c,se_ratio) if se_ratio>0 else nn.Identity()
self.project_conv=ConvBNAct(expanded_c,
out_planes=out_c,
norm_layer=norm_layer,
activation_layer=nn.Identity)
self.out_channels=out_c
self.drop_rate=drop_rate
if self.has_shortcut and drop_rate>0:
self.dropout=DropPatt(drop_rate)
def forward(self,x:Tensor)->Tensor:
result=self.expand_conv(x)
result=self.dwconv(result)
result=self.se(result)
result=self.project_conv(result)
if self.has_shortcut:
if self.drop_rate>0:
result=self.dropout(result)
result +=x
return result
class FusedMBConv(nn.Module):
def__init__(self,kernel_size:int,
input_c:int,
out_c:int,
expand_ratio:int,
stride:int,
se_ratio:float,
norm_layer:Callable[...,nn.Module]):
super(FusedMBConv,self).__init__()
assert stride in [1,2]
assert se_ratio==0
self.has_shortcut=stride==1 and input_c==out_c
self.drop_rate=drop_rate
self.has_expansion=expand_ratio!=1
activation_layer=nn.SiLU
expanded_c=input_c*expand_ratio
if self.has_expasion:
self.expand_conv=ConvBNAct(input_c,expanded_c,
kernel_size=kernel_size,
stride=stride,
norm_layer=norm_layer,
activation_layer=activation_layer)
self.project_conv=ConvBNAct(expand_c,out_c,
kernel_size=1,
norm_layer=norm_layer,
activation_layer=nn.Identity)
else:
self.project_conv=ConvBNAct(input_c,out_c,
kernel_size=kernel_size,
stride=stride,
norm_layer=norm_layer,
activation_layer=activation_layer)
self.out_channels=out_c
self.drop_rate=drop_rate
if self.has_shortcut and drop_rate>0:
self.dropout=DropPath(drop_rate)
def forward(self,x:Tensor)->Tensor:
if self.has_expansion:
result=self.expand_conv(x)
result=self.project_conv(result)
else:
result=self.project_conv(x)
if self.has_shortcut:
if self.drop_rate>0:
result=self.dropout(result)
result+=x
return result
class EfficientNetV2(nn.Module):
def __init__(self,model_cnf:lise,
num_classes:int=1000,
num_features:int=1280,
dropout_rate:float=0.2,
drop_connect_rate:float=0.2):
super(EfficientNetV2,self).__init__()
for cnf in model_cnf:
assert len(cnf)==8
norm_layer=partial(nn.BatchNorm2d,eps=1e-3,momentum=0.1)
stem_filter_num=model_cnf[0][4]
self.stem=ConvBNAct(3,stem_filer_num,kernel_size=3,
stride=2,norm_layer=norm_layer)
total_blocks=sum([i[0] for i in moeld_cnf])
block_id=0
block=[]
for cnf in model_cnf:
repeats=cnf[0]
op=FusedMBConv if cnf[-2]==0 else MbConv
for i in range(repeats):
blocks.append(op(kernel_size=cnf[1],
input_c=cnf[4] if i==0 else cnf[5],
out_c=cnf[5],
expand_ratio=cnf[3],
stride=cnf[2] if i==0 else 1,
se_ratio=cnf[-1],
drop_rate=drop_connect_rate* block_id/total_blocks,
norm_layer=norm_layer)
block_id+=1
self.blocks=nn.Sequential(*blocks)
head_input_c=model_cnf[-1][-3]
head_OrderedDict()
head.update({"project_conv":ConvBNAct(head_input_c,
num_features,
kernel_size=1,
norm_layer=norm_layer)})
head.update({"avgpool":nn.AdaptiveAvgPool2d(1)})
head.update({"flatten":nn.Flatten()})
self.head=nn.Sequential(head)
for m in self.modules():
if isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight,mode="fan_out")
if m.bias is not None:
nn.nint.zeros_(m.bias)
elif isinstance(m,nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m,nn.Linear):
nn.init.normal_(m.weight,0,0.01)
nn.init.zeros_(m.bias)
def forward(self,x:Tensor)->Tensor:
x=self.stem(x)
x=self.blocks(x)
x=self.head(x)
return x
def efficientnetv2_s(num_classes:int=1000):
model_config = [[2, 3, 1, 1, 24, 24, 0, 0],
[4, 3, 2, 4, 24, 48, 0, 0],
[4, 3, 2, 4, 48, 64, 0, 0],
[6, 3, 2, 4, 64, 128, 1, 0.25],
[9, 3, 1, 6, 128, 160, 1, 0.25],
[15, 3, 2, 6, 160, 256, 1, 0.25]]
model = EfficientNetV2(model_cnf=model_config,
num_classes=num_classes,
dropout_rate=0.2)
return model
def efficientnetv2_m(num_classes: int = 1000):
# repeat, kernel, stride, expansion, in_c, out_c, operator, se_ratio
model_config = [[3, 3, 1, 1, 24, 24, 0, 0],
[5, 3, 2, 4, 24, 48, 0, 0],
[5, 3, 2, 4, 48, 80, 0, 0],
[7, 3, 2, 4, 80, 160, 1, 0.25],
[14, 3, 1, 6, 160, 176, 1, 0.25],
[18, 3, 2, 6, 176, 304, 1, 0.25],
[5, 3, 1, 6, 304, 512, 1, 0.25]]
model = EfficientNetV2(model_cnf=model_config,
num_classes=num_classes,
dropout_rate=0.3)
return model
def efficientnetv2_l(num_classes: int = 1000):
# repeat, kernel, stride, expansion, in_c, out_c, operator, se_ratio
model_config = [[4, 3, 1, 1, 32, 32, 0, 0],
[7, 3, 2, 4, 32, 64, 0, 0],
[7, 3, 2, 4, 64, 96, 0, 0],
[10, 3, 2, 4, 96, 192, 1, 0.25],
[19, 3, 1, 6, 192, 224, 1, 0.25],
[25, 3, 2, 6, 224, 384, 1, 0.25],
[7, 3, 1, 6, 384, 640, 1, 0.25]]
model = EfficientNetV2(model_cnf=model_config,
num_classes=num_classes,
dropout_rate=0.4)
return model
训练
import os
import math
import argparse
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler
from model import efficientnetv2_s as create_model
from my_dataset import MyDataSet
from utils import read_split_data,train_one_epoch,evaluate
def main():
device=torch.device(args.device if torch.cuda.is_available() else "cpu")
print(args)
tb_writer=SummaryWriter()
if os.path.exists("./weights") is False:
os.makdir("./weight")
train_images_path,train_images_label,val_images_path,val_images_label=
read_split_data(args.data_path)
img_size={"s":[300,384],"m":[384,480],"l":[384,480]}
num_model="s"
data_transform={
"train":transforms.Compose([transforms.RandomResizedCrop(img_size[num_modell][0]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
"val":transforms.Compose([transforms.Resize(img_size[num_model][1]),
transforms.CenterCrop(img_size[num_model][1]),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])])
train_dataset=MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])
val_dataset=MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])
batch_size=args.batch_size
nw=min([os.cpu_count(),batch_size if batch_size>1 else 0,8])
print('using {} dataloader workers every process'.format(nw))
train_loader=torch.utils.data.DataLoader("train_dataset,batch_size
=bathc_size,
shuffle=True,
pin_memory=True,
num_workers=nw,
collate_fn=train_datase.collate_fn)
val_loader=torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw,
collate_fn=val_dataset.collate_fn)
model=create_model(num_classes=args.num_classes).to(device)
if args.weights!='':
if os.path.exists(args.weights):
weight_dict=torch.load(args.weights,map_location=device)
load_weight_dict={k:v for k,v in weight_dict.items()
if model.state_dict()[k].numel()==v.numel()}
print(model.load_state_dict(load_weights_dict,strict=False))
else:
raise FileNotFoundError("not found weight file :{}".format(args.weights))
if args.freeze_layers:
for name,para in model.named_parameters():
if "head" not in name:
para.requiers_grad(False)
else:
print("trianing{}".format(name))
pg=[p for p in model.parameters() if p.requires_grad]
optimizer=optim.SGD(pg,lr=args.lr,momentum=0.9,weight_decay=1e-4)
lf=lambda x:((1+math.cos(x*math.pi/args.epochs))/2)*(1-args.lrf)+args.lrf
scheduler=lr_scheduler.LambdaLR(optimizer,lr_lambda=lf)
for epoch in range(args.epochs):
train_loss,train_acc =train_one_epoch(model=model,optimizer=optimizer,
data_loader=train_loader,
device=device,epoch=epoch)
scheduler.step()
val_loss,val_acc=evaluate(model=model,data_loader=val_loader,
device=device,epoch=epoch)
tags=["train_loss","train_acc","val_loss","val_acc",learning_rate"]
tb_writer.add_scalar(tags[0],train_loss,epoch)
tb_writer.add_scalar(tags[1],train_acc,epoch)
tb_writer.add_scalar(tags[2],val_loss,epoch)
tb_writer.add_scalar(tags[3],val_acc,epoch)
tb_writer.add_scalar(tags[4],optimizer.param_groups[0]["lr"],epoch)
torch.save(model.state_dict(),"./weights/model-{}.pth".format(epoch))
if__name__=='__main__':
parser=argparse.ArgumentParser()
parser.add_argument('--num_classes',type=int,default=5)
parser.add_argument('--epochs',type=int,default=30)
parser.add_argument('--batch-size',type=int,default=8)
parser.add_argument('--lr',type=float,default=0.01)
parser.add_argument('--lrf',type=float,default=0.01)
parset.add_argument('--data-path',type=str,default="/data/flower_photos")
parser.add_argument('--weights',type=str,default='./efficientnet.pth',help='initial weight path')
parser.add_argument('--freeze-layers',type=bool,default=True)
parser.add_argument('--device',default='cuda:0',help='device id')
opt=parser.parse_args()
main(opt)
from PIL import Image
import torch
from torch.utils.data import Dataset
class MyDataSet(Dataset):
def__init__(self,images_path:list,images_class:list,transform=None):
self.images_path=iamges_path
self.images_class=images_class
self.transform=transform
def __len__(self):
return len(self.image_path)
def__getitem__(self,item):
img=Image.open(self.images_path[item])
if img.mode!=RGB
raise ValueError("image:{} isn't RCB mode.".format(self.images_paht[item]))
label=self.images_class[item]
if self.transform is not None:
img=self.transform(img)
return img,lable
@staticmethod
def collate_fn(batch):
images,labels=tuple(zip(*batch))
images=torch.stack(images,dim=0)
lables=torch.as_tensor(lables)
return images,labels
import os
import sys
import json
import pickle
import random
import tqdm import tqdm
import matplotlib.pyplot as plt
def read_split_data(root:str,val_rate:float=0.2):
random.seed(0)
assert os.path.exists(root),"dataset root:{} dose not exist.".format(root)
flower_class =[cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root,cla))]
flower_class.sort()
class_indices=dict((k,v) for v,k in enumerate(flower_class))
json_str=json.dump(dict((val,key) for key,val in class_indices.items()),indent=4)
with open('class_indices.json','w') as json_file:
json_file.write(json_str)
train_images_path=[]
train_images_label=[]
val_images_label=[]
val_images_path=[]
every_class_num=[]
supported=[".jpg",".JPG",".png",".PNG"]
for cla in flower_class:
cla_path=os.path.join(root,cla)
images=[os.path.join(root,cla,i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
image_class=class_indices[cla]
every_classe_num.append(len(images))
val_path=random.sample(images,k=int(len(images)*val_rate))
for img_path in images:
if img_path in val_path:
val_images_path.append(img_path)
val_images_label.append(images_class)
else:
train_imgeas_path.append(img_path)
train_images_lable.append(image_class)
print("{} images were found in the dataset.".format(sum(every_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation .".format(len(val_images_path)))
plot_image=False
if plot_image:
plt.bar(range(len(flower_class)).every_class_num,align='center')
plt.xticks(range(len(flower_class)),flower_class)
for i,v in enumerate(every_class_num):
plt.text(x=i,y=v+5,s=dtr(v),ha='center')
plt.xlable('image class')
plt.ylabel('number of image')
plt.title('flower class distribution')
plt.show()
return train_images_path,train_images_label,val_images_path,val_images_label
def train_one_epoch(model,optimizer,data
-loader,device,epoch):
model.train()
loss_function=torch.nn.CrossEntropyLoss()
accu_loss=torch.zeros(1).to(device)
accu_num=torch.zeros(1).to(device)
optimier.zero_grad()
sample_num=0
data_loader=tqdm(data_loader)
for step,data in enumerate(data_loader):
images,labels=data
sample_num+=images.shape[0]
pred=model(images.to(device))
pred_classes=torch.max(pred,dim=1)[1]
accu_num+=torch.eq(pred_classes,lables.to(device)).sum()
loss=loss_function(pred,lables.to(device))
loss.backward()
accu_loss+=loss.detach()
data_loader.desc="[train epoch{}]loss:{},acc:{}".format(epoch,
accu_loss.item()/(step+1),accu_num.item()/sample_num)
if not torch.isfimite(loss):
print('warning non ending training,',loss)
sys.exit(1)
optimizer.step()
optimizer.zero_grad()
return accu_loss.item()/(step+1),accu_num.item()/sample_num
@torch.no_grad()
def evaluate(model,data_loader,device,epoch):
loss_function=torch.nn.CrossEntropyLoss()
model.eval()
accu_num=torch.zeros(1).to(device)
accu_loss=torch.zeros(1).to(device)
sample_num=0
data_loader=tqdm(data_loader)
for step,data in enumerate(data_loader):
images,lables=data
sample_num+=images.shape[0]
pred=model(images.to(device))
pred_classes=torch.max(pred,dim=1)[1]
accu_num+=torch.eq(pred_classes,label.to(device)).sum()
loss=loss_function(pred,labels.to(device))
accu_loss+=loss
data_loader.desc="[valid epoch{}]loss:{},acc:{}".f
ormat(epoch,accu_loss.item()/(step+1).
accu_num.item()/sample_num)
return accu_loss.item()/(step+1),accu_num.item()/sample_num
import os
import json
import torch
import PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import efficientnetv2_s as create_model
def main():
device=torch.device("cuda:0 if torch.cuda.is_available() else "cpu")
img_size={"s":[300,384],
"m",[384,480],"l":[384,480]}
num_model="s"
data_transform=transforms.Compose([transforms.Resize(img_size[num_model][1]),
transforms.CenterCrop(img_size[num_model][1]),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
img_path="./typle.jpg"
assert os.path.exists(img_path),"file:{} dose
not exist.".format(img_path)
img=Image.open(img_path)
plt.imshow(img)
img=data_transform(img)
img=torch.unsqueeze(img,dim=0)
json_path='./class_indices.json'
assert os.path.exists(json_path),"file:{}
dose not exist.".format(json_path)
json_file=open(json_path,"r")
class_indict=json.load(json_file)
model=create_model(num_classes=5).to(device)
model_weight_path="./weight.pth"
model.load_state_dict(torch.load(model_weight_path,
map_location=device))
model.eval()
with torch_no_grad():
output=torch.squeeze(model(img.to(device))).cpu()
predict=torch.softmax(output,dim=0)
predict_cla=torch.argmax(predict).numpy()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
print(print_res)
plt.show()
if __name__ == '__main__':
main()