基于百度paddlehub多种海洋鱼类的智能分类识别


      利用AI技术对海洋鱼类进行识别,不仅能很好地开发利用和保护鱼类资源,也为发展海洋渔业生产发挥了积极的作用,具有重大的学术研究和经济价值。

Fish4Knowledge数据集

      台湾电力公司、台湾海洋研究所和垦丁国家公园在2010年10月1日至2013年9月30日期间,在台湾南湾海峡、兰屿岛和胡比湖的水下观景台收集的鱼类图像数据集该数据集包括23类鱼种,共27370张鱼的图像。

本实验的任务

      本实践选取23种鱼类数据随机抽取数据进行迁移学习训练。
      由于个别样品数量大,使微调时长变长,微调也不需要这么多样本,因此对超过200个的样品进行随机抽样,抽样200个,然后和样品数量少于200的样品合并,组成新的数据集,然后在新的数据集里随机抽样形成训练集、测试集、验证集,数据无重复使用,使用pandas 的进行数据处理,过程优雅大方。完成最后Finetune训练及预测结果的输出。

经典的ResNet-50作为预训练模型

      本实践选在白AI studio平台上使用paddlehub 模块进行,选取经典的ResNet-50作为预训练模型来Finetune,对23种海洋鱼类进行分类,由于样品数量较大,随机抽取样品进行迁移学习,测试结果准确率100%。

项目链接:海洋鱼类识别分类 可以直接运行体验。


#CPU环境启动请务必执行该指令
%set_env CPU_NUM=2 
env: CPU_NUM=2
#安装paddlehub
!pip install paddlehub==1.6.0 -i https://pypi.tuna.tsinghua.edu.cn/simple
!hub install ernie
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting paddlehub==1.6.0

      Successfully uninstalled paddlehub-1.5.0
Successfully installed paddlehub-1.6.0
Downloading ernie
[==================================================] 100.00%
Uncompress /home/aistudio/.paddlehub/tmp/tmpb4n3ruh7/ernie
[==================================================] 100.00%
Successfully installed ernie-1.2.0
import os
import pandas as pd #用列表生成 DataFrame,便于到出text文档。
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont #显示图片

!mkdir data/fish_pic

解压数据

!unzip -o data/data32688/fish_data.zip -d data/fish_pic
  inflating: data/fish_pic/fish_data/label_list.txt  

制作数据准备函数。

23种海洋鱼类对应的学名

生成类别用于结果判定。

fish_name=['Dascyllus reticulatus',
 'Plectroglyphidodon dickii',
 'Chromis chrysura',
 'Amphiprion clarkii',
 'Chaetodon lunulatus',
 'Chaetodon trifascialis',
 'Myripristis kuntee',
 'Acanthurus nigrofuscus',
 'Hemigymnus fasciatus',
 'Neoniphon sammara',
 'Abudefduf vaigiensis',
 'Canthigaster valentini',
 'Pomacentrus moluccensis',
 'Zebrasoma scopas',
 'Hemigymnus melapterus',
 'Lutjanus fulvus',
 'Scolopsis bilineata',
 'Scaridae',
 'Pempheris vanicolensis',
 'Zanclus cornutus',
 'Neoglyphidodon nigroris',
 'Balistapus undulatus',
 'Siganus fuscescens']
fish_name

生成鱼类名字字典、标签字典


name_list=[]   #生成标签文档
with open ('data/label_list.txt','w+') as f:
    for i in range(1,24):
        name_list.append('fish_'+str(i))
        f.write('fish_'+str(i)+'\n')

name_dict={b:a for a,b in enumerate(name_list)}
name_dict
real_name={a:b for a,b in zip(name_list,fish_name)}
real_name

生成图片路径和标签列表对应的字典 ,用于生成DataFrame

#此函数生成包括路径和标签的字典,可以直接生成DataFrame,后面是用路径列表生成DataFrame.

# def data_list():
#     path="data/fish_pic/fish_data/fish_image23" #图片所在文件夹
#     address_list = []   #图片地址列表
#     label_list = []     #标签列表
#     for root, dirs, files in os.walk(path, topdown=False):
#         for name in files:
#             address = os.path.join(root, name) #获取图片路径
#             address_list.append(address)
#             label = address.split('/')[4]      #路径分割后,截取目录名即为标记名,开始的时候大脑里转的是map,lambda,还是apply!一直出不来 为啥我不早点想出来呢,
#             label_list.append(name_dict.get(label)) #截取目录名对应的标注
#     return {'address':address_list,'label':label_list} #生成字典
# data_list()

生成图片路径列表 ,用于生成DataFrame

#生成图片路径列表
def data_list():
    path="data/fish_pic/fish_data/fish_image23" #图片所在文件夹
    path_list = []   #图片地址列表    
    for root, dirs, files in os.walk(path, topdown=False):
        for name in files:
            path = os.path.join(root, name) #获取图片路径
            path_list.append(path)            
    return path_list #生成路径列表
#data_list()

用列表生成 DataFrame,便于到出text文档。

df = pd.DataFrame(data_list(),columns=['filepath'])     #生成数据框。
df['filepath'] = df.filepath.str[5:]                     #按要求产生相对路径。只要工作目录下的相对路径 。
df['label']=df.filepath.map(lambda x:x.split('/')[3]).map(name_dict) #用映射生成标签df.head()     

简单出个图确认前面的工作是否正常。

grouped=df['filepath'].groupby(df['label']).count() #查看样品分布情况
type(grouped)
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签
plt.rcParams['axes.unicode_minus']=False #用来正常显示负号
plt.figure(figsize=(16,6))

# 这里是调节横坐标的倾斜度,rotation是度数,以及设置刻度字体大小
plt.xticks(rotation=45,fontsize=20)
plt.yticks(fontsize=20)
plt.title('''Fish Category''',fontsize = 24)
plt.bar(grouped.index,grouped,color='r',tick_label=name_list,facecolor='#9999ff',edgecolor='white')
plt.savefig('/home/aistudio/work/bar_result.jpg')
# 可见我们数据很不整齐,全部投入使用吃力不讨好,因此,考虑随机抽样参与训练。

在这里插入图片描述

通过分析 样品个数大于两百的分类情况。

grouped[grouped.values>200].sum() #
grouped[grouped.values<200].sum()
959

随机采样,形成新的数据集

由于个别样品数量大,使微调时长边长,微调也不需要这么多样本,因此对超过200个的样品进行随机抽样,抽样200个,然后和样品数量少于200的样品合并,组成新的数据集,进行后续的Finetune训练。

'''
1、先找出样品个数大于200样品标签。
2、提取样品数量少于200个的样品,形成新的数据集。
3、从样品个数多于200个的品类种随机抽取200样品,加入到上述数据集。
4、最后共获得2959个样品的数据集。
'''

label_index = grouped[grouped.values>200].index # 样品数量大于200个标签值(label标签对应的 值)

extract=df.loc[~df['label'].isin(label_index)] #取非,样品数少于200的样品,全作为工作样品,
for i in label_index:
    #d=df.loc[df['label']==i].sample(200)
    d=df[df['label']==i].sample(200)           #样品数量大于200个的样品随机抽取200个作为工作样品。
    extract=extract.append(d)                  #分别补充至
len( extract)
2959
#新数据的数据分布。
grouped=extract['filepath'].groupby(df['label']).count() #样本数量统计
plt.bar(grouped.index,grouped)
<BarContainer object of 23 artists>

在这里插入图片描述

样品列表生成,关键操作

样品数据框随机打乱,按9:1比例随机生成 测试集、验证集、训练集列表文档。

df_new=extract.copy()                      #样品数据框随机打乱,按9:1比例生成 测试集,验证集,训练集列表文档。
df_new = df_new.sample(frac=1.0)           #打乱数据
df_validate = df_new.sample(frac=0.1)  #取同数量样本作为验证集
df_new.drop(index=df_validate.index,inplace=True) #去除测试集
df_test = df_new.sample(20)                          #随机抽取20个样本作为测试集
df_train = df_new.drop(index=df_test.index)         #剩下的为训练集
len(df_train.index)                                 #训练集样品总数为2643个。

2643
### 生成数据列表文件
df_test['filepath'] = 'data/'+df_test['filepath'] #测试集已跳出 深度学习微调环境需要补充完整路径,是个坑。
df_test.to_csv('data/test_list.txt', sep=' ', index=0,header=0) #导出 验证集列表
df_validate.to_csv('data/validate_list.txt', sep=' ', index=0,header=0) #导出 验证集列表
df_train.to_csv('data/train_list.txt', sep=' ', index=0,header=0) #导出 训练集列表

我们来看看 我们数据图片都长什么样子的。

像素渣渣的,能完成我们的分类任务吗?

import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
test_img = df_test.sample(1)['filepath'].tolist()[0]  #我们就从测试集里面先抽一张出来看看(提取路径)

img = Image.open(test_img).resize((256,256)) # 放大来看看!
print(img.format,img.size)		 # 输出图片基本信息

# draw = ImageDraw.Draw(img)
# font=ImageFont.truetype('simhei.ttf',30)
# draw.text((20,20),'正确',(100,000,100), font=font)
plt.figure(figsize=(10,10))
plt.imshow(img)

在这里插入图片描述

Step2、加载预训练模型

接下来我们要在PaddleHub中选择合适的预训练模型来Finetune,由于是图像分类任务,因此我们使用经典的ResNet-50作为预训练模型。PaddleHub提供了丰富的图像分类预训练模型,包括了最新的神经网络架构搜索类的PNASNet,我们推荐您尝试不同的预训练模型来获得更好的性能。

import paddlehub as hub

module = hub.Module(name="resnet_v2_50_imagenet")
[32m[2020-05-04 00:04:35,275] [    INFO] - Installing resnet_v2_50_imagenet module[0m
[33m[2020-05-04 00:04:35,294] [ WARNING] - /home/aistudio/.paddlehub/modules/ernie/module_desc.pb does not exist, the module will be reinstalled[0m
[32m[2020-05-04 00:04:35,297] [    INFO] - Module resnet_v2_50_imagenet already installed in /home/aistudio/.paddlehub/modules/resnet_v2_50_imagenet[0m

Step3、数据准备

接着需要加载图片数据集。我们使用自定义的数据进行体验,请查看适配自定义数据

from paddlehub.dataset.base_cv_dataset import BaseCVDataset
   
class DemoDataset(BaseCVDataset):	
   def __init__(self):	
       # 数据集存放位置
       
       self.dataset_dir = "data"
       super(DemoDataset, self).__init__(
           base_path=self.dataset_dir,
           train_list_file="train_list.txt",
           validate_list_file="validate_list.txt",
           #test_list_file="test_list.txt",
           label_list_file="label_list.txt",
           )
dataset = DemoDataset()

Step4、生成数据读取器

接着生成一个图像分类的reader,reader负责将dataset的数据进行预处理,接着以特定格式组织并输入给模型进行训练。

当我们生成一个图像分类的reader时,需要指定输入图片的大小

data_reader = hub.reader.ImageClassificationReader(
    image_width=module.get_expected_image_width(),
    image_height=module.get_expected_image_height(),
    images_mean=module.get_pretrained_images_mean(),
    images_std=module.get_pretrained_images_std(),
    dataset=dataset)
[32m[2020-05-03 23:50:36,711] [    INFO] - Dataset label map = {'fish_1': 0, 'fish_2': 1, 'fish_3': 2, 'fish_4': 3, 'fish_5': 4, 'fish_6': 5, 'fish_7': 6, 'fish_8': 7, 'fish_9': 8, 'fish_10': 9, 'fish_11': 10, 'fish_12': 11, 'fish_13': 12, 'fish_14': 13, 'fish_15': 14, 'fish_16': 15, 'fish_17': 16, 'fish_18': 17, 'fish_19': 18, 'fish_20': 19, 'fish_21': 20, 'fish_22': 21, 'fish_23': 22}[0m

Step5、配置策略

在进行Finetune前,我们可以设置一些运行时的配置,例如如下代码中的配置,表示:

use_cuda:设置为False表示使用CPU进行训练。如果您本机支持GPU,且安装的是GPU版本的PaddlePaddle,我们建议您将这个选项设置为True;

epoch:迭代轮数;

batch_size:每次训练的时候,给模型输入的每批数据大小为32,模型训练时能够并行处理批数据,因此batch_size越大,训练的效率越高,但是同时带来了内存的负荷,过大的batch_size可能导致内存不足而无法训练,因此选择一个合适的batch_size是很重要的一步;

log_interval:每隔10 step打印一次训练日志;

eval_interval:每隔50 step在验证集上进行一次性能评估;

checkpoint_dir:将训练的参数和数据保存到cv_finetune_turtorial_demo目录中;

strategy:使用DefaultFinetuneStrategy策略进行finetune;

更多运行配置,请查看RunConfig

同时PaddleHub提供了许多优化策略,如AdamWeightDecayStrategy、ULMFiTStrategy、DefaultFinetuneStrategy等,详细信息参见策略

config = hub.RunConfig(
    use_cuda= True,                              #是否使用GPU训练,默认为False,高级算力环境用True,cpu环境写True会报错;
    num_epoch=5,                                #Fine-tune的轮数,cpu 环境用3玩玩就好,高级算力可以试试10;
    checkpoint_dir="cv_finetune_turtorial_demo",#模型checkpoint保存路径, 若用户没有指定,程序会自动生成;
    batch_size=3,                              #训练的批大小,如果使用GPU,请根据实际情况调整batch_size;
    #eval_interval=10,  
    log_interval=20,                         #模型评估的间隔,默认每100个step评估一次验证集;
    strategy=hub.finetune.strategy.DefaultFinetuneStrategy())  #Fine-tune优化策略;
[32m[2020-05-03 23:51:13,956] [    INFO] - Checkpoint dir: cv_finetune_turtorial_demo[0m

Step6、组建Finetune Task

有了合适的预训练模型和准备要迁移的数据集后,我们开始组建一个Task。

由于该数据设置是一个二分类的任务,而我们下载的分类module是在ImageNet数据集上训练的千分类模型,所以我们需要对模型进行简单的微调,把模型改造为一个二分类模型:

获取module的上下文环境,包括输入和输出的变量,以及Paddle Program;
从输出变量中找到特征图提取层feature_map;
在feature_map后面接入一个全连接层,生成Task;

input_dict, output_dict, program = module.context(trainable=True)
img = input_dict["image"]
feature_map = output_dict["feature_map"]
feed_list = [img.name]

task = hub.ImageClassifierTask(
    data_reader=data_reader,
    feed_list=feed_list,
    feature=feature_map,
    num_classes=dataset.num_labels,
    config=config)
[32m[2020-05-03 23:51:22,805] [    INFO] - 267 pretrained paramaters loaded by PaddleHub[0m

Step5、开始Finetune

我们选择finetune_and_eval接口来进行模型训练,这个接口在finetune的过程中,会周期性的进行模型效果的评估,以便我们了解整个训练过程的性能变化

run_states = task.finetune_and_eval()  
2020-05-03 23:51:25,264-WARNING: paddle.fluid.layers.py_reader() may be deprecated in the near future. Please use paddle.fluid.io.DataLoader.from_generator() instead.
[32m[2020-05-03 23:51:25,417] [    INFO] - Strategy with scheduler: {'warmup': 0.0, 'linear_decay': {'start_point': 1.0, 'end_learning_rate': 0.0}, 'noam_decay': False, 'discriminative': {'blocks': 0, 'factor': 2.6}, 'gradual_unfreeze': 0, 'slanted_triangle': {'cut_fraction': 0.0, 'ratio': 32}}, regularization: {'L2': 0.001, 'L2SP': 0.0, 'weight_decay': 0.0} and clip: {'GlobalNorm': 0.0, 'Norm': 0.0}[0m
[32m[2020-05-03 23:51:32,027] [    INFO] - Try loading checkpoint from cv_finetune_turtorial_demo/ckpt.meta[0m
[32m[2020-05-03 23:51:33,291] [    INFO] - PaddleHub model checkpoint loaded. current_epoch=6, global_step=4430, best_score=0.96970[0m
[32m[2020-05-03 23:51:33,292] [    INFO] - PaddleHub finetune start[0m
[32m[2020-05-03 23:51:33,293] [    INFO] - PaddleHub finetune finished.[0m

Step6、预测

当Finetune完成后,我们使用模型来进行预测,先通过以下命令来获取测试的图片

import numpy as np
import matplotlib.pyplot as plt 
import matplotlib.image as mpimg
import pandas as pd
from PIL import ImageEnhance
with open("data/test_list.txt","r") as f:
    filepath = f.readlines()

data =[filepath[i].split(" ")[0] for i in range(20)]


label_map = dataset.label_dict()
index = 0
run_states = task.predict(data=data)
results = [run_state.run_results for run_state in run_states]
address_dict={}
# address=[]
# label=[]
for batch_result in results:
    print(batch_result)
    batch_result = np.argmax(batch_result, axis=2)[0]
    print(batch_result)
    for result in batch_result:
        index += 1
        result = label_map[result]
        address_dict[data[index - 1]] = result
        # address.append(data[index - 1])
        # label.append(result)
        print("input %i is %s, and the predict result is %s" %
              (index, data[index - 1], result))
        if data[index - 1].split('/')[4]== result:
            print('识别正确')
            plt.figure(dpi=150)
            test_img = data[index - 1]  #提取路径
            img = Image.open(test_img).resize((256,256))
            img=ImageEnhance.Contrast(img).enhance(5)
            txt=real_name.get(result)
            plt.title('Name: '+txt) 
            # draw = ImageDraw.Draw(img)
            # font=ImageFont.truetype('simhei.ttf',10)
            # draw.text((20,20),txt,(0,0,0), font=font)
            plt.imshow(img)
    
[32m[2020-05-03 23:51:36,977] [    INFO] - Load the best model from cv_finetune_turtorial_demo/best_model[0m
2020-05-03 23:51:37,333-WARNING: paddle.fluid.layers.py_reader() may be deprecated in the near future. Please use paddle.fluid.io.DataLoader.from_generator() instead.
[32m[2020-05-03 23:51:37,900] [    INFO] - PaddleHub predict start[0m
[32m[2020-05-03 23:51:38,018] [    INFO] - PaddleHub predict finished.[0m


[array([[9.9838245e-01, 2.1424730e-05, 1.2656345e-03, 6.8090185e-07,
        5.4016009e-06, 2.1940001e-05, 7.9796619e-06, 6.7578861e-05,
        1.7682560e-05, 1.8652343e-06, 3.2142034e-05, 2.6007438e-06,
        3.7851554e-05, 5.2116568e-05, 1.8562126e-06, 1.6681779e-05,
        6.2078257e-06, 1.3050021e-05, 1.8019156e-06, 2.4178057e-06,
        2.1239804e-05, 1.7351467e-05, 2.0125642e-06],
       [2.0257776e-11, 4.5166450e-17, 9.2968464e-15, 1.0233489e-12,
        2.0787955e-15, 6.5610948e-14, 6.5626677e-15, 4.3768522e-15,
        2.3028883e-14, 3.7043131e-15, 3.5293830e-16, 6.6887305e-13,
        4.8553213e-13, 2.2345953e-11, 1.6440207e-11, 3.1280771e-16,
        1.7824198e-16, 1.5413876e-13, 1.0000000e+00, 3.6992192e-13,
        8.6250163e-15, 2.4732909e-12, 4.5483398e-16],
       [9.9933797e-01, 6.5640852e-06, 3.6345117e-04, 8.0837737e-07,
        2.2100830e-06, 1.6546834e-05, 8.9046262e-05, 2.1726248e-05,
        8.4476760e-06, 3.8937792e-06, 3.0621279e-05, 2.9351759e-06,
        6.0607208e-06, 4.2965123e-05, 2.1455371e-06, 1.5686514e-05,
        1.5031532e-05, 8.5749816e-06, 1.3915815e-06, 1.5548852e-06,
        1.6901115e-05, 4.1376074e-06, 1.0685205e-06]], dtype=float32)]
[ 0 18  0]
input 1 is data/fish_pic/fish_data/fish_image23/fish_1/fish_004247456949_24114.png, and the predict result is fish_1
识别正确
input 2 is data/fish_pic/fish_data/fish_image23/fish_19/fish_003470725337_22773.png, and the predict result is fish_19
识别正确
input 3 is data/fish_pic/fish_data/fish_image23/fish_1/fish_003827616157_13854.png, and the predict result is fish_1
识别正确
[array([[9.9999964e-01, 3.7161729e-11, 5.1297686e-12, 8.5146126e-12,
        2.9661232e-10, 1.1049771e-11, 3.1799271e-07, 9.8698849e-10,
        1.8987831e-09, 1.7776849e-10, 8.1057827e-09, 3.6866326e-12,
        8.9025315e-12, 1.0751304e-09, 4.6999678e-11, 2.6438218e-10,
        3.7267650e-10, 5.3801332e-13, 2.2687355e-12, 6.1629035e-11,
        4.7699635e-09, 1.3631640e-11, 1.3252716e-12],
       [1.9883251e-05, 1.3590470e-06, 2.2823211e-05, 2.7045410e-06,
        6.8420059e-07, 9.9988389e-01, 7.3688666e-06, 7.3757433e-06,
        5.3981339e-06, 1.2130232e-07, 2.6865629e-07, 3.0773895e-06,
        1.9892354e-07, 2.0082223e-05, 1.1809469e-05, 1.8486882e-06,
        3.9174307e-07, 4.4827340e-07, 2.6069340e-06, 5.8405692e-07,
        1.8210008e-06, 8.9928579e-07, 4.3121772e-06],
       [3.3536893e-05, 2.1879962e-06, 3.0115032e-06, 1.8361514e-05,
        9.1015572e-05, 2.1203257e-05, 1.2928950e-06, 3.1041309e-06,
        1.0498506e-05, 1.3147397e-05, 2.0758594e-05, 9.9889499e-01,
        6.8580404e-05, 2.1757260e-06, 4.4445287e-06, 4.7462072e-06,
        7.4281666e-04, 2.0619484e-06, 1.0554678e-06, 3.1086591e-05,
        2.4402172e-05, 4.2208244e-06, 1.4659241e-06]], dtype=float32)]
[ 0  5 11]
input 4 is data/fish_pic/fish_data/fish_image23/fish_1/fish_003891906306_20550.png, and the predict result is fish_1
识别正确
input 5 is data/fish_pic/fish_data/fish_image23/fish_6/fish_004531207347_18616.png, and the predict result is fish_6
识别正确
input 6 is data/fish_pic/fish_data/fish_image23/fish_12/fish_003424815280_26242.png, and the predict result is fish_12
识别正确
[array([[1.2961534e-03, 5.4979660e-06, 1.4224035e-03, 5.2722562e-06,
        9.8097971e-06, 2.1705056e-05, 6.0661719e-06, 9.6609758e-05,
        2.7608732e-04, 1.6215048e-06, 7.2433212e-04, 5.1265280e-05,
        7.8307952e-05, 4.7154540e-06, 8.0193913e-06, 3.2547416e-06,
        1.8457058e-04, 9.9558216e-01, 2.2624405e-05, 9.3184480e-07,
        1.0048455e-05, 1.6916872e-04, 1.9334246e-05],
       [9.8935507e-06, 1.2858195e-07, 5.1656318e-10, 2.8420001e-07,
        2.5945056e-07, 7.8136111e-09, 1.1749773e-08, 4.6601513e-06,
        2.8647599e-07, 9.9997127e-01, 6.1078431e-08, 1.5403988e-07,
        6.6497078e-06, 8.1421724e-08, 7.7015452e-09, 5.0121631e-08,
        7.7885772e-07, 3.0636620e-09, 6.6370617e-08, 4.1109041e-07,
        4.9194255e-06, 7.8788753e-09, 2.9121899e-10],
       [2.1295104e-05, 5.2133686e-07, 3.7809356e-09, 2.3203299e-07,
        4.1057524e-07, 2.9227353e-08, 7.4902204e-08, 6.3220824e-07,
        1.4168151e-06, 9.9995327e-01, 3.0123874e-07, 3.5497001e-07,
        1.4144449e-05, 3.6450098e-07, 3.1915722e-08, 1.2103615e-07,
        3.2462997e-06, 2.9788980e-08, 3.4511973e-07, 1.2656816e-06,
        1.7559584e-06, 4.3941871e-08, 1.4163211e-09]], dtype=float32)]
[17  9  9]
input 7 is data/fish_pic/fish_data/fish_image23/fish_18/fish_002730000118_02864.png, and the predict result is fish_18
识别正确
input 8 is data/fish_pic/fish_data/fish_image23/fish_10/fish_004402377135_11974.png, and the predict result is fish_10
识别正确
input 9 is data/fish_pic/fish_data/fish_image23/fish_10/fish_004402307135_13420.png, and the predict result is fish_10
识别正确
[array([[1.30250375e-03, 1.47289783e-03, 9.83981669e-01, 5.06734708e-04,
        2.83381407e-04, 1.70471694e-03, 1.01082120e-03, 5.90463518e-04,
        4.55910980e-04, 5.05301461e-04, 1.36937166e-03, 9.30775830e-04,
        4.08732129e-04, 5.33024897e-04, 8.67635536e-04, 8.68820818e-04,
        8.31899291e-04, 4.49043990e-04, 1.10402463e-04, 1.61008851e-04,
        3.47322406e-04, 1.79731098e-04, 1.12771604e-03],
       [1.01465457e-06, 1.49591484e-09, 1.55146893e-06, 2.15804175e-08,
        7.93890820e-07, 4.06607938e-08, 7.89408858e-08, 1.22617530e-05,
        3.54814023e-09, 1.09079714e-07, 5.37434232e-07, 4.00010805e-07,
        2.91290485e-06, 4.76586663e-08, 6.12473698e-08, 9.99979138e-01,
        1.49377719e-08, 2.19103100e-08, 2.15508722e-09, 4.95080457e-08,
        4.93648145e-08, 1.12482992e-08, 1.07428787e-06],
       [9.02851480e-06, 4.10470768e-09, 1.03806874e-08, 7.85372389e-09,
        2.31480035e-09, 9.51035588e-08, 9.99989748e-01, 9.05345914e-08,
        1.36873922e-07, 2.42446188e-08, 6.20813125e-08, 1.43789421e-07,
        3.76341003e-09, 1.50050866e-07, 1.86954168e-08, 1.52759068e-08,
        8.57645119e-08, 2.91081803e-09, 7.39863459e-09, 3.99737630e-08,
        1.80906156e-07, 3.28879479e-10, 1.74213667e-07]], dtype=float32)]
[ 2 15  6]
input 10 is data/fish_pic/fish_data/fish_image23/fish_3/fish_000045949596_04518.png, and the predict result is fish_3
识别正确
input 11 is data/fish_pic/fish_data/fish_image23/fish_16/fish_000010519594_03508.png, and the predict result is fish_16
识别正确
input 12 is data/fish_pic/fish_data/fish_image23/fish_7/fish_000006519596_03836.png, and the predict result is fish_7
识别正确
[array([[2.75442348e-04, 3.60757895e-06, 1.13664573e-05, 4.09912673e-06,
        5.49445504e-05, 9.86143860e-08, 8.97527843e-06, 7.97908706e-06,
        1.43690959e-05, 2.43335251e-07, 5.82377033e-05, 1.80042355e-06,
        9.99388814e-01, 2.28913686e-05, 9.58619694e-06, 8.81803430e-07,
        6.53982488e-06, 1.31239051e-06, 3.60530908e-06, 4.89416107e-06,
        5.69612939e-06, 1.14595336e-04, 9.81699912e-08],
       [6.92641474e-12, 1.11314624e-20, 1.77881701e-19, 2.12828648e-16,
        7.37920651e-20, 3.80046097e-20, 1.27085040e-20, 1.95941220e-17,
        4.07124861e-15, 2.03029430e-20, 1.00000000e+00, 2.80248727e-20,
        1.97141459e-18, 2.55192748e-20, 1.49389045e-20, 3.38860483e-20,
        2.66607008e-16, 4.44293076e-22, 7.22047226e-24, 3.50366788e-16,
        7.86875278e-17, 1.23529974e-18, 5.33098292e-20],
       [4.27446883e-10, 2.13035212e-09, 1.97862549e-09, 1.40187051e-09,
        6.03317785e-10, 1.17328174e-08, 5.48356205e-08, 9.99841332e-01,
        4.03730427e-09, 4.68428407e-09, 1.85640907e-08, 1.03480335e-08,
        1.79719439e-08, 6.18994918e-08, 1.62255476e-09, 6.59740929e-10,
        3.25996591e-11, 7.49306006e-10, 9.71088099e-10, 2.11265316e-09,
        1.58386625e-04, 1.93180090e-11, 3.41308981e-09]], dtype=float32)]
[12 10  7]
input 13 is data/fish_pic/fish_data/fish_image23/fish_13/fish_003681525783_11351.png, and the predict result is fish_13
识别正确
input 14 is data/fish_pic/fish_data/fish_image23/fish_11/fish_003456265312_13825.png, and the predict result is fish_11
识别正确
input 15 is data/fish_pic/fish_data/fish_image23/fish_8/fish_000026380001_02581.png, and the predict result is fish_8
识别正确
[array([[1.8622677e-08, 3.1088336e-09, 1.5033690e-08, 4.4234970e-11,
        1.0318230e-09, 1.5726628e-05, 3.7059522e-07, 2.3255339e-08,
        4.3291322e-09, 1.1215312e-10, 1.9198172e-09, 7.9109030e-09,
        3.9343062e-09, 5.0108373e-10, 5.3063523e-08, 6.7261881e-09,
        1.2498110e-06, 5.7462821e-09, 2.7107130e-10, 1.6177125e-09,
        1.0032165e-08, 8.8566310e-09, 9.9998248e-01],
       [7.6505904e-07, 5.7382927e-08, 5.1700813e-09, 9.9999630e-01,
        1.9003376e-06, 1.2523445e-08, 2.0266111e-07, 2.3070646e-07,
        5.7400985e-08, 3.0061098e-08, 6.7495840e-08, 1.3841272e-07,
        8.6737941e-09, 1.8047774e-07, 2.0499339e-08, 1.2384769e-08,
        2.9939660e-09, 2.2982520e-09, 2.8242265e-08, 3.0983173e-08,
        4.1395985e-08, 2.0453479e-09, 4.9067378e-10],
       [1.7996728e-05, 7.5644362e-05, 2.8876038e-05, 4.1958254e-05,
        9.0223562e-05, 1.9945921e-05, 2.3786501e-04, 5.8109668e-05,
        9.9842680e-01, 5.1697909e-05, 2.0967286e-04, 6.9470472e-05,
        3.5093661e-04, 1.5315780e-05, 3.2388525e-05, 6.9625945e-05,
        2.4851759e-05, 1.8152052e-05, 1.7606932e-05, 2.5869163e-05,
        8.9030436e-05, 8.4063722e-06, 1.9367717e-05]], dtype=float32)]
[22  3  8]
input 16 is data/fish_pic/fish_data/fish_image23/fish_23/fish_000704000123_00782.png, and the predict result is fish_23
识别正确
input 17 is data/fish_pic/fish_data/fish_image23/fish_4/fish_000015359599_08529.png, and the predict result is fish_4
识别正确
input 18 is data/fish_pic/fish_data/fish_image23/fish_9/fish_000027490001_02920.png, and the predict result is fish_9
识别正确
[array([[7.29732541e-09, 1.85295501e-10, 2.45942045e-10, 1.49187962e-09,
        9.99999881e-01, 2.33453590e-10, 4.92274310e-08, 1.29306184e-08,
        8.07934053e-10, 4.15306151e-10, 1.84998261e-09, 4.69380812e-10,
        1.06633976e-07, 4.79274342e-10, 4.78696971e-10, 4.94759362e-08,
        2.58550598e-10, 9.87123647e-11, 1.05043904e-10, 1.19140109e-09,
        2.20563345e-09, 2.12079604e-10, 1.19280397e-09],
       [1.83691354e-05, 5.73476837e-06, 4.10300618e-06, 2.98744538e-07,
        1.25064773e-07, 4.09575796e-06, 1.51911545e-05, 9.86044233e-06,
        2.79531878e-06, 3.03529248e-07, 7.41589020e-07, 1.42476290e-06,
        4.78221750e-07, 9.99907374e-01, 7.13641839e-07, 7.01478598e-07,
        3.11047010e-08, 7.37128573e-07, 7.57552016e-06, 4.14618597e-07,
        2.54211358e-07, 2.82010149e-07, 1.85032168e-05]], dtype=float32)]
[ 4 13]
input 19 is data/fish_pic/fish_data/fish_image23/fish_5/fish_003752925961_19406.png, and the predict result is fish_5
识别正确
input 20 is data/fish_pic/fish_data/fish_image23/fish_14/fish_000003219596_03754.png, and the predict result is fish_14
识别正确


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/font_manager.py:1331: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans
  (prop.get_family(), self.defaultFamily[fontext]))

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

最后还可导出结果列表。

import pandas as pd
df=pd.DataFrame.from_dict(address_dict,orient='index') 
df =  df.reset_index() 
df.columns=['addr','lab']
#df['addr']=df.addr.str[5:]
df.to_csv('data/result.txt',sep=' ',index=0,header=0)  #导出结果,跟test_list比较下,是否一致。
  • 2
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
PyTorch是一个开源的深度学习框架,在深度学习任务中具有很高的灵活性和可扩展性。而Jupyter是一个交互式的编程环境,能够方便地组织和展示代码、图像以及文字等内容。 基于海洋鱼类分类这个任务,我们可以使用PyTorch和Jupyter来构建一个分类模型。首先,我们可以使用PyTorch库来加载、预处理和训练我们的数据集。我们可以使用PyTorch提供的数据加载器和转换工具来加载数据,并使用数据增强技术来扩充数据集,提高模型泛化能力。然后,使用PyTorch来构建一个深度学习模型,如卷积神经网络(CNN),用于鱼类图像分类任务。我们可以使用PyTorch提供的各种层和激活函数来定义网络架构,并使用梯度下降算法和优化器来训练模型。 在Jupyter中,我们可以使用Markdown单元格来记录和展示我们的代码、实验结果和解释。我们可以使用代码单元格来编写、运行和调试我们的PyTorch代码。此外,我们还可以使用Jupyter的可视化功能来展示模型训练的过程中的损失曲线、准确度等指标,并可视化模型在测试集上的预测结果。 使用PyTorch和Jupyter,我们可以方便地进行迭代实验,优化模型的性能,调整超参数并进行模型的可视化分析。通过PyTorch和Jupyter的强大功能和易用性,我们可以更有效地进行海洋鱼类分类任务的研究和实践。同时,这也为其他相关的深度学习任务提供了一个良好的平台。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值