Pytorch resnet花朵识别(5种花)附完整代码

使用PyTorch的ResNet模型进行花卉图像识别,通过重写dataset类并调整参数来应对有限的GPU资源。尽管测试集准确率有待提高,可通过扩大数据集和增加训练迭代次数来提升。
摘要由CSDN通过智能技术生成

notebook运行结果图:
在这里插入图片描述

随机从各种花的图片集中抽取一定数量的图片
因为设备限制,用所有的图片,图片太多
如果用cpu跑太慢了
用gpu太多图片会out of memory
所以用小部分图片,了解方法就好了
在这里插入图片描述
在这里插入图片描述
重写dataset类
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

显示图片
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

设置resnet的参数
因为用的损失函数是NLL,所以网络最后要接一个logsoftmax
在这里插入图片描述
损失函数、优化器,训练函数
在这里插入图片描述
计算正确率函数
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
测试集正确率不高,可以增大数据集
增加训练次数

在这里插入图片描述
在这里插入图片描述
源代码:

# 要添加一个新单元,输入 '# %%'
# 要添加一个新的标记单元,输入 '# %% [markdown]'
# %%
from IPython import get_ipython

# %%
# -*- coding: utf-8 -*-
get_ipython().run_line_magic('matplotlib', 'inline')
import os
data_path='flowers'
flower_type=os.listdir(data_path)
print(flower_type)


# %%
import matplotlib.pyplot as plt
from matplotlib import style
import seaborn as sns
style.use('fivethirtyeight')
sns.set(style='whitegrid',color_codes=True)

import numpy as np
import torch 
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms,models
from torch.utils.data import Dataset,DataLoader

from pprint import pprint

import random

from sklearn.preprocessing import OneHotEncoder,LabelEncoder
from sklearn.model_selection import train_test_split

from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from tqdm.notebook import tqdm
from tqdm.notebook import trange

import warnings
warnings.filterwarnings('ignore')


# %%
path_daisy=data_path+'/daisy'
path_dandelion=data_path+'/dandelion'
path_rose=data_path+'/rose'
path_sunflower=data_path+'/sunflower'
path_tulip=data_path+'/tulip'


# %%
def ImgPath(path):
    img_path=[]
    for img in os.listdir(path):
        img_path.append(os.path.join(path,img))
    
    return img_path


# %%
flower_daisy=ImgPath(path_daisy)
flower_dandelion=ImgPath(path_dandelion)
flower_rose=ImgPath(path_rose)
flower_sunflower=ImgPath(path_sunflower)
flower_tulip=ImgPath(path_tulip)

flower_daisy=random.sample(flower_daisy,200)
flower_dandelion=random.sample(flower_dandelion,200)
flower_rose=random.sample(flower_rose,200)
flower_sunflower=random.sample(flower_sunflower,200)
flower_tulip=random.sample(flower_tulip,200)


# %%
flower_type=np.array
好的,以下是一个基于 PyTorch ResNet50 模型的指纹识别登录代码示例。需要安装的 Python 库有:PyTorch, OpenCV, numpy, argparse。 ```python import torch import numpy as np import cv2 import argparse # 定义命令行参数 parser = argparse.ArgumentParser() parser.add_argument('--model_path', type=str, default='resnet50.pth', help='path to the trained model') parser.add_argument('--img_path', type=str, required=True, help='path to the input image') parser.add_argument('--threshold', type=float, default=0.5, help='threshold for fingerprint matching') args = parser.parse_args() # 加载模型 model = torch.load(args.model_path) # 打开图像文件 img = cv2.imread(args.img_path, cv2.IMREAD_GRAYSCALE) # 对图像进行预处理 img = cv2.resize(img, (224, 224)) img = img.astype(np.float32) / 255.0 img = np.expand_dims(img, axis=0) img = np.expand_dims(img, axis=0) # 使用模型进行指纹特征提取 with torch.no_grad(): features = model(torch.from_numpy(img)) # 加载已知指纹特征 known_features = np.load('known_features.npy') # 对比特征向量并返回结果 distances = np.linalg.norm(features.numpy() - known_features, axis=1) min_distance = np.min(distances) if min_distance < args.threshold: print('登录成功!') else: print('登录失败!') ``` 在运行该代码时,需要提供以下参数: - model_path:已经训练好的 PyTorch ResNet50 模型的路径。 - img_path:要识别的指纹图像的路径。 - threshold:判断指纹是否匹配的阈值,默认为 0.5。 另外,为了实现指纹识别登录,还需要保存已知用户的指纹特征向量。可以在系统中先让用户进行指纹录入,然后将其指纹特征向量保存在一个 numpy 数组中,以便在登录时进行比对。 希望这个示例能对你有所帮助!
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值