简单的transfer learning(图像分类)

1.准备数据


使用requests下载数据,并把数据以标准形式储存起来——例如以名字为苹果的目录下,就只储存苹果的图片。
然后通过这些数据目录来创建dataloader
例如:

train_dataloader, test_dataloader,class_names = data_setup.create_dataloaders(train_dir=train_dir,
                                                                               test_dir=test_dir,
                                                                               transform=auto_trainsform,
                                                                               batch_size=32)

2.准备模型

1.导入预训练好的模型

weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT
model = torchvision.models.efficientnet_b0(weights=weights).to(device)

2.冻结模型主干部分的参数(base layers)

for param in model.features.parameters():
    param.requires_grad = False

3.更改输出层(output_layer)

model.classifier = nn.Sequential(
    torch.nn.Dropout(p=0.2,inplace=True),
    torch.nn.Linear(in_features=1280,
                    out_features=output_shape,
                    bias=True).to(device)
)

3.开始训练

准备 loss_fn,optimizer

其它的略

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters(),
                             lr=0.001)

4.评估模型

把accuracy 和loss曲线画出来

5.应用模型

1.构建函数,接受输入的图片链接,输出对应的图片、预测的标签(label)、概率。

from modulefinder import IMPORT_NAME
from typing import List ,Tuple
from PIL import Image

def pred_and_plot_image(model:torch.nn.Module,
                        image_path:str,
                        class_names:List[str],
                        device:torch.device=device,
                        image_size:Tuple[int,int]= (224,224),
                        transform:torchvision.transforms = None):
  #open image
  img = Image.open(image_path)

  #transformation for image
  if transform:
    image_transform = transform
  else:
    image_transform=transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

  #go to device
  model.to(device)

  #evaluation mode
  with torch.inference_mode():
    transformed_image=image_transform(img).unsqueeze(dim=0)
    target_image_pred=model(transformed_image.to(device))

  p=torch.softmax(target_image_pred,dim=1)
  label=torch.argmax(p,dim=1)

  #Plot image with label and probability
  plt.figure()
  plt.imshow(img)
  plt.title(f"Pred: {class_names[label]} | Prob: {p.max():.3f}")
  plt.axis(False)

特别的,关于如何从提取链接,可以这么做:

from pathlib import Path
my_paths = list(Path().glob("微信图片_*.jpg"))

去年夏天,我做了一个饼(我知道很丑),但至少被成功识别出来啦!

  • 10
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值