model_v2.py
t:扩展因子
c:输出特征矩阵的深度channel
n:bottleneck的重复次数
s:步距(针对第一层,其他为1)
import torch
from torch import nn
def _make_divisible(ch, divisor=8, min_ch=None):
if min_ch is None:
min_ch = divisor
new_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor)
if new_ch < 0.9 * ch:
new_ch += divisor
return new_ch
class ConvBNReLU(nn.Sequential):
def __init__(self,in_channel,out_channel,kerne_size=3,stride=1,groups=1):
padding = (kerne_size - 1)//2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_channel,out_channel,kerne_size,stride,padding,groups=groups,bias=False),
nn.BatchNorm2d(out_channel),
nn.ReLU6(inplace=True)
)
class InvertedResidual(nn.Module):
def __init__(self,in_channel,out_channel,stride,expand_ratio) -> None:
super().__init__()
hidden_channel = in_channel * expand_ratio
self.use_shortcut = stride==1 and in_channel==out_channel
layers = []
if expand_ratio != 1:
layers.append(ConvBNReLU(in_channel,hidden_channel,kerne_size=1))
layers.extend([ConvBNReLU(hidden_channel,hidden_channel,stride=stride,groups=hidden_channel),
nn.Conv2d(hidden_channel,out_channel,kernel_size=1,bias=False),
nn.BatchNorm2d(out_channel)])
self.conv = nn.Sequential(*layers)
def forward(self,x):
if self.use_shortcut:
return x + self.conv(x)
else:
return self.conv(x)
class MobileNetV2(nn.Module):
def __init__(self,num_classes=1000,alpha=1.0,round_nearest=8):
super().__init__()
block = InvertedResidual
input_channel = _make_divisible(32*alpha,round_nearest)
last_channel = _make_divisible(1280*alpha,round_nearest)
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
features = []
features.append(ConvBNReLU(3,input_channel,stride=2))
for t,c,n,s in inverted_residual_setting:
output_channel = _make_divisible(c*alpha,round_nearest)
for i in range(n):
stride = s if i==0 else 1
features.append(block(input_channel,output_channel,stride,expand_ratio=t))
input_channel = output_channel
features.append(ConvBNReLU(input_channel,last_channel,1))
self.features = nn.Sequential(*features)
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(last_channel,num_classes)
)
for m in self.modules():
if isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight,mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m,nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m,nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.zeros_(m.bias)
def forward(self,x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x,1)
x = self.classifier(x)
return x
train.py
其中迁移学习的参数下载地址
url: https://download.pytorch.org/models/mobilenet_v2-b0353104.pth
即mobilenet_v2.pth
import os
import sys
import json
import torch
import torch.nn as nn
from torch.utils import data
import torch.optim as optm
from torchvision import transforms,datasets
from tqdm import tqdm
from model_v2 import MobileNetV2
def main():
batch_size,epochs = 16,5
data_transform = {'train':transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
'val':transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])}
data_root = '../data'
image_path = os.path.join(data_root,'data_set','flower_data')
assert os.path.exists(image_path),'{} path does not exist.'.format(image_path)
train_dataset = datasets.ImageFolder(root=os.path.join(image_path,'train'),transform=data_transform['train'])
train_num = len(train_dataset)
flower_list = train_dataset.class_to_idx
cla_dict = dict((val,key) for key,val in flower_list.items())
json_str = json.dumps(cla_dict,indent=4)
with open('class_indices.json','w') as json_file:
json_file.write(json_str)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,
shuffle=True,num_workers=os.cpu_count())
validate_dataset = datasets.ImageFolder(root=os.path.join(image_path,'val'),transform=data_transform['val'])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size,
shuffle=False,num_workers=os.cpu_count())
print("using {} images for training, {} images for validation.".format(train_num,val_num))
net = MobileNetV2(num_classes=5)
model_weight_path = './mobilenet_v2.pth'
assert os.path.exists(model_weight_path), 'file {} dose not exist.'.format(model_weight_path)
pre_weights = torch.load(model_weight_path,map_location='cpu')
pre_dict = {k:v for k,v in pre_weights.items() if net.state_dict()[k].numel()==v.numel()}
missing_keys,unexpected_keys = net.load_state_dict(pre_dict,strict=False)
for param in net.features.parameters():
param.requires_grad = False
loss_function = nn.CrossEntropyLoss()
params = [p for p in net.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params,lr=0.001)
best_acc = 0.0
save_path = './MobileNetV2.pth'
train_steps = len(train_loader)
for epoch in range(epochs):
net.train()
running_loss = 0.0
train_bar = tqdm(train_loader,file=sys.stdout)
for step,data in enumerate(train_bar):
images,labels = data
optimizer.zero_grad()
logits = net(images)
loss = loss_function(logits,labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc = 'train epoch[{}/{}] loss{:.3f}'.format(epoch+1,epochs,loss)
net.eval()
acc = 0.0
with torch.no_grad():
val_bar = tqdm(validate_loader,file=sys.stdout)
for val_data in val_bar:
val_images,val_labels = val_data
outputs = net(val_images)
predict_y = torch.max(outputs,dim=1)[1]
acc += torch.eq(predict_y,val_labels).sum().item()
val_bar.desc = 'valid epoch[{}/{}]'.format(epoch+1,epochs)
val_accurate = acc / val_num
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / train_steps, val_accurate))
if val_accurate > best_acc:
best_acc = val_accurate
torch.save(net.state_dict(),save_path)
print('finished Training')
if __name__ =='__main__':
main()
训练好的参数放在MobileNetV2.pth
predict.py
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model_v2 import MobileNetV2
def main():
data_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
img_path = 'tulip.jpeg'
assert os.path.exists(img_path),'file:{} does not exist.'.format(img_path)
img = Image.open(img_path)
plt.imshow(img)
img = data_transform(img)
img = torch.unsqueeze(img,dim=0)
json_path = './class_indices.json'
assert os.path.exists(json_path),'file:{} dose not exist'.format(json_path)
with open(json_path,'r') as f:
class_indict = json.load(f)
print(class_indict)
model = MobileNetV2(num_classes=5)
model_weight_path = './MobileNetV2.pth'
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():
output = torch.squeeze(model(img))
print(output)
predict = torch.softmax(output,dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
predict[i].numpy()))
plt.show()
if __name__ == '__main__':
main()