1、网络的介绍参照下面
MobileNet系列(4):MobileNetv3网络详解_@BangBang的博客-CSDN博客_mobilenetv3
和下面的文字
https://www.toutiao.com/article/6972051587388785184/?app=news_article×tamp=1665491325&use_new_style=1&req_id=202210112028440101320380301D113152&group_id=6972051587388785184&tt_from=mobile_qq&utm_source=mobile_qq&utm_medium=toutiao_android&utm_campaign=client_share&share_token=b18046b3-f26e-473e-b539-1fc65dfd400d&source=m_redirect
TorchVision v0.9 中新增了一系列移动端友好的模型,可用于处理分类、目标检测、语义分割等任务。
2、研究mobilenetv3最后的输出层
3、修改输出层的方法
3.1 输出层修改实现方法一
chen_set_mode_out = nn.Sequential(
nn.Flatten(),
nn.Linear(576,1024),
nn.Hardswish(inplace=True),
nn.Dropout(p=0.2, inplace=True),
nn.Linear(1024, 1000),
nn.Linear(1000, 2)
).to(device)
# small = mobilenet_v3_small(pretrained=True)
trained_model = mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)
#注意默认是1000分类输出,[256,1000]
#(1)实现方法一:
model = nn.Sequential(*list(trained_model.children())[:-1], #[b, 512, 1, 1]
# Flatten(), # [b, 512, 1, 1] => [b, 512]
# nn.Flatten(),
# nn.Linear(256, 2)
chen_set_mode_out
).to(device)
3.2 输出层修改实现方法二
#(2)实现方法二
print("陈打印最后的分类模块---------------------")
print(trained_model.classifier._modules['0'])
print(trained_model.classifier._modules['1'])
print(trained_model.classifier._modules['2'])
print(trained_model.classifier._modules['3'])
trained_model.classifier._modules['3'] = nn.Linear(1024, 2)
model = trained_model
4、代码实现
4.1方法一的代码
import torch
from torch import optim, nn
# import visdom
# from tensorboardX import SummaryWriter #(1)引入tensorboardX
from torch.utils.tensorboard import SummaryWriter
# import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models
# from utils import Flatten
# from pokemon import Pokemon
# from resnet import ResNet18
# from torchvision.models import resnet18
from torchvision.models import mobilenet_v3_small
from PIL import Image
from tqdm import tqdm
from torchinfo import summary
import os
# batchsz = 32
batch_size = 256
lr = 1e-3
epochs = 10
img_resize = 224
# print("cuda:")
# print(torch.cuda.is_available())
# print(torch.cuda.device_count())
# print(torch.cuda.current_device())
# print(torch.cuda.get_device_name(0))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device('cpu')
torch.manual_seed(1234)
# tf = transforms.Compose([
# transforms.Resize((224,224)),
# transforms.ToTensor(),
# ])
#输入应该是PIL.Image类型
tf = transforms.Compose([
#匿名函数
# lambda x:Image.open(x).convert('RGB'), # string path= > image data
transforms.Resize((int(img_resize*1.25), int(img_resize*1.25))),
transforms.RandomRotation(15),
transforms.CenterCrop(img_resize),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# db = torchvision.datasets.ImageFolder(root=&#