PyTorch项目应用实例(三)通用的图像分类模型实现图像分类(附代码与操作方法)

37 篇文章 8 订阅
27 篇文章 11 订阅

背景:需要从图像数据之中加载图像,然后根据标签训练。简单的直接将图片放入文件夹生成标签和训练数据,然后训练模型。实从图像到训练好的模型的转变。

代码地址(可直接运行):github地址:https://github.com/Xingxiangrui/image_classification_with_pytorch

也可直接copy 四中的代码。

目录

一、小样本量运行与调试

1.1 数据集

1.2 标签格式

1.3 最小数据集运行

1.4 可能的报错及解决方案(没有可不看)

二、大样本量数据集的生成

2.1 图片格式

2.2 标签格式

2.3 批处理生成数据与标签

图像名称读取

序列乱序

list拆为两个list

图像读出与写入

标签的生成

三、训练及验证

3.1 加载数据

3.2 数据加载函数

3.3 模型训练与验证

3.4 训练并保存模型

四、代码


一、小样本量运行与调试

先用小样本量运行程序,保证程序正确运行,再用大样本量实验。代码见最后。

1.1 数据集

代码中加载数据集的方式如下:

    print("Load dataset......")
    image_datasets = {x: customData(img_path='data/',
                                    txt_path=('data/TxtFile/' + x + '.txt'),
                                    data_transforms=data_transforms,
                                    dataset=x) for x in ['train', 'val']}

1.2 标签格式

位置 data/Txtfile/文件夹之下

两个文件  train.txt  与  val.txt

文件夹之中,每一行是文件的路径  比如  *****121.jpg   (中间是tab键)  0(标签)

即 图片名字 tab键 标签

注意标签从0开始,不要标签从1开始,不然会报错

1.3 最小数据集运行

直接data文件夹中放入两个文件,1.jpg ,2.jpg

子文件夹 data/Txtfile中两个文件 train.txt, val.txt

两个文件之中均为

1.jpg	0
2.jpg	1

直接运行,根目录下,python customData_train.py,即可正常运行。

1.4 可能的报错及解决方案(没有可不看)

python及torch版本问题导致的报错:

/home/xingxiangrui/env/lib/python3.6/site-packages/torchvision/transforms/transforms.py:563: UserWarning: The use of the transforms.RandomSizedCrop transform is deprecated, please use transforms.RandomResizedCrop instead.
"please use transforms.RandomResizedCrop instead.")
/home/xingxiangrui/env/lib/python3.6/site-packages/torchvision/transforms/transforms.py:188: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
"please use transforms.Resize instead.")

解决:直接将两个函数的名称替换为建议的名称。

RuntimeError: cuda runtime error (59) : device-side assert triggered at /home/lychee/mycode/pytorch/aten/src/THC/generic/THCTensorMath.cu:24

解决:标签需要从0开始,不要从1开始且最好不要跳。

return loss.data[0]
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python 

解决:将程序中语句更改

#原语句:
train_loss+=loss.data[0]
#修改后:
train_loss+=loss.item()

二、大样本量数据集的生成

2.1 图片格式

    print("Load dataset......")
    image_datasets = {x: customData(img_path='data/',
                                    txt_path=('data/TxtFile/' + x + '.txt'),
                                    data_transforms=data_transforms,
                                    dataset=x) for x in ['train', 'val']}

文件夹名称 data/ 存放图片

data/TxtFile存放train.txt与val.txt

2.2 标签格式

train.txt中存放训练集的图片名称及标签,val.txt中存放验证集的名称与标签。

文件中的格式 每行是一个数据,文件名.jpg   (tab键) 标签数字

class customData(Dataset):
    def __init__(self, img_path, txt_path, dataset = '', data_transforms=None, loader = default_loader):
        with open(txt_path) as input_file:
            lines = input_file.readlines()
            self.img_name = [os.path.join(img_path, line.strip().split('\t')[0]) for line in lines]
            self.img_label = [int(line.strip().split('\t')[-1]) for line in lines]
        self.data_transforms = data_transforms
        self.dataset = dataset
        self.loader = loader

例如:

1.jpg	0
2.jpg	1

2.3 批处理生成数据与标签

大量的数据可以参考 python实操(二)制作数据集相关: label更改与批量化图片处理

图像名称读取

source_image_list = os.listdir(source_image_dir)
for idx in range(len(source_image_list)):
    if '.png' in source_image_list[idx-1]:
        continue
    elif '.jpg' in source_image_list[idx-1]:
        continue
    else:
        del source_image_list[idx]

注意,range是从1开始的,所以后面要-1。

序列乱序

参考:https://blog.csdn.net/amy_wumanfang/article/details/64483340

https://blog.csdn.net/matrix_google/article/details/72803741

直接 random.shuffle(list),直接讲list乱序后存入自己,很方便。

例如:

# -*- coding: utf-8 -*-
import random
# 对list洗牌,在原list上做改变
list = range(10)
print list
random.shuffle(list)
print "随机排序列表 : ",  list

list拆为两个list

拆为训练集和验证集,分别1/4和3/4

# train list and val list
source_train_list=[]
source_val_list=[]
for idx in range(len(source_image_list)):
    if idx<len(source_image_list)/4:
        source_val_list.append(source_image_list[idx-1])
    else:
        source_train_list.append(source_image_list[idx-1])

图像读出与写入

图像存于src_img之中,图像重命名用后用save写入。

    # read dource images and rename
    path_source_img = os.path.join(source_image_dir, source_image_name)
    src_img = Image.open(path_source_img)
    full_image_name=prefix+"_train_"+source_image_name
    print(full_image_name)
    # save renamed image to the target dir
    target_image_path=os.path.join(target_image_dir, full_image_name)
    src_img.save(target_image_path)

标签的生成

创建文件:a表示追加写入

# create label_file or write label file
txt_file_train_name="train.txt"
txt_file_val_name="val.txt"
txt_file_train_path=os.path.join(txt_file_dir, txt_file_train_name)
txt_file_val_path=os.path.join(txt_file_dir, txt_file_val_name)
train_txt_file= open(txt_file_train_path,"a")
val_txt_file= open(txt_file_val_path,"a")

有必要对每行加一个"\n"进行结尾

    # write image names and labels
    line_strings= full_image_name+"\t"+str(class_label)+"\n"
    train_txt_file.write(line_strings)

三、训练及验证

3.1 加载数据

直接根据txt文件之中的每一行加载数据,和标签然后可以进行训练。

    data_transforms = {
        'train': transforms.Compose([
            #transforms.RandomSizedCrop(224),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            #transforms.Scale(256),
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    use_gpu = torch.cuda.is_available()

    batch_size = 32
    num_class = 3
    print("batch size:",batch_size,"num_classes:",num_class)

    print("Load dataset......")
    # image_datasets = {x: customData(img_path='sin_poly_defect_data/',
    #                                 txt_path=('sin_poly_defect_data/TxtFile/general_train.txt'),
    #                                 data_transforms=data_transforms,
    #                                 dataset=x) for x in ['train', 'total_val']}
    image_datasets={}
    image_datasets['train'] = customData(img_path='sin_poly_defect_data/',
                                         txt_path=('sin_poly_defect_data/TxtFile/general_train.txt'),
                                         data_transforms=data_transforms,
                                         dataset='train')
    image_datasets['val'] = customData(img_path='sin_poly_defect_data/',
                                       txt_path=('sin_poly_defect_data/TxtFile/real_poly_defect.txt'),
                                       data_transforms=data_transforms,
                                       dataset='val')
    # train_data=image_datasets.pop('general_train')
    # image_datasets['train']=train_data
    # val_data=image_datasets.pop('total_val')
    # image_datasets['val']=val_data

    # wrap your data and label into Tensor
    print("wrap data into Tensor......")
    dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 batch_size=batch_size,
                                                 shuffle=True) for x in ['train', 'val']}

    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    print("total dataset size:",dataset_sizes)

3.2 数据加载函数

根据数据加载函数

def default_loader(path):
    try:
        img = Image.open(path)
        return img.convert('RGB')
    except:
        print("Cannot read image: {}".format(path))

# define your Dataset. Assume each line in your .txt file is [name/tab/label], for example:0001.jpg 1
class customData(Dataset):
    def __init__(self, img_path, txt_path, dataset = '', data_transforms=None, loader = default_loader):
        with open(txt_path) as input_file:
            lines = input_file.readlines()
            self.img_name = [os.path.join(img_path, line.strip().split('\t')[0]) for line in lines]
            self.img_label = [int(line.strip().split('\t')[-1]) for line in lines]
        self.data_transforms = data_transforms
        self.dataset = dataset
        self.loader = loader

    def __len__(self):
        return len(self.img_name)

    def __getitem__(self, item):
        img_name = self.img_name[item]
        label = self.img_label[item]
        img = self.loader(img_name)

        if self.data_transforms is not None:
            try:
                img = self.data_transforms[self.dataset](img)
            except:
                print("Cannot transform image: {}".format(img_name))
        return img, label

3.3 模型训练与验证

定义loss

    print("Define loss function and optimizer......")
    # define cost function
    criterion = nn.CrossEntropyLoss()

    # Observe that all parameters are being optimized
    optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.005, momentum=0.9)

    # Decay LR by a factor of 0.2 every 5 epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=5, gamma=0.2)

    # multi-GPU
    model_ft = torch.nn.DataParallel(model_ft, device_ids=[0])

3.4 训练并保存模型

    # train model
    print("start train_model......")
    model_ft = train_model(model=model_ft,
                           criterion=criterion,
                           optimizer=optimizer_ft,
                           scheduler=exp_lr_scheduler,
                           num_epochs=15,
                           use_gpu=use_gpu)

    # save best model
    print("save model......")
    torch.save(model_ft,"output/resnet_on_PV_best_total_val.pkl")

四、代码

github地址:https://github.com/Xingxiangrui/image_classification_with_pytorch

# -*- coding: utf-8 -*
"""
Created by Xingxiangrui on 2019.5.9
This code is to :
    1. copy image from source_image_dir to the target_image_dir
    2. And generate .txt file for further training
        in which each line is : image_name.jpg  (tab)  image_label (from 0)
        such as:
            image_01.jpg    0
            iamge_02.jpg    1
            ...
            image_02.jpg    0

"""

# import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
import random

# variables need to be change
source_image_dir="/Users/Desktop/used/SuZhouRuiTu_dataset/single-poly-defect/poly_OK"
target_image_dir="/Users/Desktop/used/SuZhouRuiTu_dataset/data_for_resnet_classification"
txt_file_dir="/Users/Desktop/used/SuZhouRuiTu_dataset/data_for_resnet_classification/TxtFile"
prefix="poly_OK"
class_label=1
# label 0: single_OK ; label_1: poly_OK ; label 2: poly_defect

print("Program Start......")
print("-"*20)
print("-"*20)
print("-"*20)

# load image list in the source dir
source_image_list = os.listdir(source_image_dir)
for idx in range(len(source_image_list)):
    if '.png' in source_image_list[idx-1]:
        continue
    elif '.jpg' in source_image_list[idx-1]:
        continue
    else:
        del source_image_list[idx-1]

# shuffle image list
print("initial list:")
print source_image_list
random.shuffle(source_image_list)
print("shuffled list:")
print source_image_list

# train list and val list
source_train_list=[]
source_val_list=[]
for idx in range(len(source_image_list)):
    if idx<len(source_image_list)/4:
        source_val_list.append(source_image_list[idx-1])
    else:
        source_train_list.append(source_image_list[idx-1])
print ("train_list")
print source_train_list
print("val_list")
print source_val_list

# create label_file or write label file
txt_file_train_name="train.txt"
txt_file_val_name="val.txt"
txt_file_train_path=os.path.join(txt_file_dir, txt_file_train_name)
txt_file_val_path=os.path.join(txt_file_dir, txt_file_val_name)
train_txt_file= open(txt_file_train_path,"a")
val_txt_file= open(txt_file_val_path,"a")

# write train images and labels
print("write train images and labels......")
for source_image_name in source_train_list:
    print source_image_name

    # read dource images and rename
    path_source_img = os.path.join(source_image_dir, source_image_name)
    src_img = Image.open(path_source_img)
    full_image_name=prefix+"_train_"+source_image_name
    print(full_image_name)
    # save renamed image to the target dir
    target_image_path=os.path.join(target_image_dir, full_image_name)
    src_img.save(target_image_path)
    # write image names and labels
    line_strings= full_image_name+"\t"+str(class_label)+"\n"
    train_txt_file.write(line_strings)

# write val images and labels
print("write val images and labels......")
for source_image_name in source_val_list:
    print source_image_name

    # read dource images and rename
    path_source_img = os.path.join(source_image_dir, source_image_name)
    src_img = Image.open(path_source_img)
    full_image_name=prefix+"_val_"+source_image_name
    print(full_image_name)
    # save renamed image to the target dir
    target_image_path=os.path.join(target_image_dir, full_image_name)
    src_img.save(target_image_path)
    # write image names and labels
    line_strings= full_image_name+"\t"+str(class_label)+"\n"
    val_txt_file.write(line_strings)

print("source_image_dir:")
print source_image_dir
print("target_image_dir:")
print target_image_dir
print("prefix:")
print prefix
print("label:")
print class_label
print("image numbers:")
print len(source_image_list)


'''
import numpy as np
from PIL import Image
import os
import random

# variables need to be change
source_image_dir="/Users/Desktop/used/SuZhouRuiTu_dataset/single-poly-defect/poly_defect_gen"
target_image_dir="/Users/Desktop/used/SuZhouRuiTu_dataset/data_for_resnet_classification"
txt_file_dir="/Users/Desktop/used/SuZhouRuiTu_dataset/data_for_resnet_classification/TxtFile"
prefix="gen_poly_defect"
class_label=2
# label 0: single_OK ; label_1: poly_OK ; label 2: poly_defect

print("Program Start......")
print("-"*20)
print("-"*20)
print("-"*20)

# load image list in the source dir
source_image_list = os.listdir(source_image_dir)
for idx in range(len(source_image_list)):
    if '.png' in source_image_list[idx-1]:
        continue
    elif '.jpg' in source_image_list[idx-1]:
        continue
    else:
        del source_image_list[idx-1]



# create label_file or write label file
txt_file_train_name="train.txt"
# txt_file_val_name="val.txt"
txt_file_train_path=os.path.join(txt_file_dir, txt_file_train_name)
# txt_file_val_path=os.path.join(txt_file_dir, txt_file_val_name)
train_txt_file= open(txt_file_train_path,"a")
# val_txt_file= open(txt_file_val_path,"a")

# write train images and labels
print("write train images and labels......")
for source_image_name in source_image_list:
    print source_image_name

    # read dource images and rename
    path_source_img = os.path.join(source_image_dir, source_image_name)
    src_img = Image.open(path_source_img)
    full_image_name=prefix+"_train_"+source_image_name
    print(full_image_name)
    # save renamed image to the target dir
    target_image_path=os.path.join(target_image_dir, full_image_name)
    src_img.save(target_image_path)
    # write image names and labels
    line_strings= full_image_name+"\t"+str(class_label)+"\n"
    train_txt_file.write(line_strings)

print("source_image_dir:")
print source_image_dir
print("target_image_dir:")
print target_image_dir
print("prefix:")
print prefix
print("label:")
print class_label

print("image numbers:")
print len(source_image_list)

'''
# -*- coding: utf-8 -*
"""
created by xingxiangrui on 2019.5.9
This is the pytorch code for iamge classification
python 3.6 and torch 0.4.1 is ok

dataset mode:
    folder data in which is jpg images
    folder data/TxtFile/ in which is train.txt and val.txt
        in train.txt each line :
        image_name.jpg  (tab)  image_label (from 0)
        such as:
        image_01.jpg    0
        iamge_02.jpg    1
        ...
        image_02.jpg    0

"""
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.autograd import Variable
from torchvision import  models, transforms
import time
import os
from torch.utils.data import Dataset
from PIL import Image

# use PIL Image to read image
def default_loader(path):
    try:
        img = Image.open(path)
        return img.convert('RGB')
    except:
        print("Cannot read image: {}".format(path))

# define your Dataset. Assume each line in your .txt file is [name/tab/label], for example:0001.jpg 1
class customData(Dataset):
    def __init__(self, img_path, txt_path, dataset = '', data_transforms=None, loader = default_loader):
        with open(txt_path) as input_file:
            lines = input_file.readlines()
            self.img_name = [os.path.join(img_path, line.strip().split('\t')[0]) for line in lines]
            self.img_label = [int(line.strip().split('\t')[-1]) for line in lines]
        self.data_transforms = data_transforms
        self.dataset = dataset
        self.loader = loader

    def __len__(self):
        return len(self.img_name)

    def __getitem__(self, item):
        img_name = self.img_name[item]
        label = self.img_label[item]
        img = self.loader(img_name)

        if self.data_transforms is not None:
            try:
                img = self.data_transforms[self.dataset](img)
            except:
                print("Cannot transform image: {}".format(img_name))
        return img, label

def train_model(model, criterion, optimizer, scheduler, num_epochs, use_gpu):
    since = time.time()

    best_model_wts = model.state_dict()
    best_acc = 0.0

    for epoch in range(num_epochs):
        begin_time = time.time()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            count_batch = 0
            if phase == 'train':
                scheduler.step()
                model.train(True)  # Set model to training mode
            else:
                model.train(False)  # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0.0

            # Iterate over data.
            for data in dataloders[phase]:
                count_batch += 1
                # get the inputs
                inputs, labels = data

                # wrap them in Variable
                if use_gpu:
                    inputs = Variable(inputs.cuda())
                    labels = Variable(labels.cuda())
                else:
                    inputs, labels = Variable(inputs), Variable(labels)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                outputs = model(inputs)
                _, preds = torch.max(outputs.data, 1)
                loss = criterion(outputs, labels)

                # backward + optimize only if in training phase
                if phase == 'train':
                    loss.backward()
                    optimizer.step()
                # statistics
                running_loss += loss.item()
                running_corrects += torch.sum(preds == labels.data).to(torch.float32)

                # print result every 10 batch
                if count_batch%10 == 0:
                    batch_loss = running_loss / (batch_size*count_batch)
                    batch_acc = running_corrects / (batch_size*count_batch)
                    print('{} Epoch [{}] Batch [{}] Loss: {:.4f} Acc: {:.4f} Time: {:.4f}s'. \
                          format(phase, epoch, count_batch, batch_loss, batch_acc, time.time()-begin_time))
                    begin_time = time.time()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

            # save model
            if phase == 'train':
                if not os.path.exists('output'):
                    os.makedirs('output')
                torch.save(model, 'output/resnet_on_PV_epoch{}.pkl'.format(epoch))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = model.state_dict()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

if __name__ == '__main__':
    print("Program start","-"*10)

    print("Init data transforms......")
    data_transforms = {
        'train': transforms.Compose([
            #transforms.RandomSizedCrop(224),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            #transforms.Scale(256),
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
    }

    use_gpu = torch.cuda.is_available()

    batch_size = 32
    num_class = 3
    print("batch size:",batch_size,"num_classes:",num_class)

    print("Load dataset......")
    # image_datasets = {x: customData(img_path='sin_poly_defect_data/',
    #                                 txt_path=('sin_poly_defect_data/TxtFile/general_train.txt'),
    #                                 data_transforms=data_transforms,
    #                                 dataset=x) for x in ['train', 'total_val']}
    image_datasets={}
    image_datasets['train'] = customData(img_path='sin_poly_defect_data/',
                                         txt_path=('sin_poly_defect_data/TxtFile/general_train.txt'),
                                         data_transforms=data_transforms,
                                         dataset='train')
    image_datasets['val'] = customData(img_path='sin_poly_defect_data/',
                                       txt_path=('sin_poly_defect_data/TxtFile/real_poly_defect.txt'),
                                       data_transforms=data_transforms,
                                       dataset='val')
    # train_data=image_datasets.pop('general_train')
    # image_datasets['train']=train_data
    # val_data=image_datasets.pop('total_val')
    # image_datasets['val']=val_data

    # wrap your data and label into Tensor
    print("wrap data into Tensor......")
    dataloders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                                 batch_size=batch_size,
                                                 shuffle=True) for x in ['train', 'val']}

    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
    print("total dataset size:",dataset_sizes)

    # get model and replace the original fc layer with your fc layer
    print("get resnet model and replace last fc layer...")
    model_ft = models.resnet50(pretrained=True)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, num_class)

    # if use gpu
    print("Use gpu:",use_gpu)
    if use_gpu:
        model_ft = model_ft.cuda()

    print("Define loss function and optimizer......")
    # define cost function
    criterion = nn.CrossEntropyLoss()

    # Observe that all parameters are being optimized
    optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.005, momentum=0.9)

    # Decay LR by a factor of 0.2 every 5 epochs
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=5, gamma=0.2)

    # multi-GPU
    model_ft = torch.nn.DataParallel(model_ft, device_ids=[0])

    # train model
    print("start train_model......")
    model_ft = train_model(model=model_ft,
                           criterion=criterion,
                           optimizer=optimizer_ft,
                           scheduler=exp_lr_scheduler,
                           num_epochs=15,
                           use_gpu=use_gpu)
    # save best model
    print("save model......")
    torch.save(model_ft,"output/resnet_on_PV_best_total_val.pkl")
  • 8
    点赞
  • 89
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 8
    评论
使用PyTorch提供的MobileNet模型实现图像分类的步骤如下: 1. 导入必要的库 ```python import torch import torchvision from torchvision import transforms ``` 2. 加载MobileNet模型 ```python model = torchvision.models.mobilenet_v2(pretrained=True) ``` 这里使用了PyTorch提供的预训练的MobileNet_v2模型,可以根据需要选择其他的预训练模型。 3. 对输入图像进行预处理 ```python preprocess = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) ``` 这里使用了一系列的transforms来对输入的图像进行预处理,包括将图像缩放到256x256大小、中心裁剪为224x224大小、将图像转换为Tensor格式、以及使用ImageNet数据集的均值和标准差对图像进行归一化。 4. 加载输入图像 ```python img = Image.open('test.jpg') ``` 这里使用了PIL库的Image模块来加载输入图像,可以根据实际情况选择其他的图像加载方式。 5. 对输入图像进行预处理 ```python img_tensor = preprocess(img) ``` 将输入图像转换为Tensor格式,并进行预处理。 6. 将输入图像送入模型中进行预测 ```python with torch.no_grad(): output = model(img_tensor.unsqueeze(0)) pred = output.argmax(dim=1) ``` 将Tensor格式的输入图像送入模型中进行预测,得到输出结果。这里使用了torch.no_grad()上下文管理器来关闭梯度计算,以减少内存占用。 7. 打印预测结果 ```python print('Predicted class:', pred.item()) ``` 打印出预测结果,即输入图像所属的类别。 以上是使用PyTorch提供的MobileNet模型实现图像分类的基本步骤。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

祥瑞Coding

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值