cifar10数据集下载、训练、模型导出和权重冻结以及预测

一.cifar10数据集介绍

cifar10数据集介绍[官网

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

CIFAR-10数据集包含了10个分类60000张32x32彩色图片,每个类型有6000张图片.50000用于训练10000用于测试.

数据集介绍到这里,可以直接将数据集下载然后用pikle读取,然后生成tfrecord格式的文件,tensorflow的slim框架已经将这些帮我们做了,所以本文只介绍tensorflow下slim是如何处理cifar-10数据的.


二.数据下载转换以及训练和验证

master/research/slim/scripts下执行脚本train_cifarnet_on_cifar10.sh

注意

<1>环境变量

TRAIN_DIR    存储数据

DATASET_DIR   存储模型

<2>python or python3以及是否使用gpu,如果不使用需要修改clone_on_cpu true

当train_cifarnet_on_cifar10.sh执行完毕

数据下载

TRAIN_DIR对应目录会有相应的下载好的数据生成:cifar10_test.tfrecord  cifar10_train.tfrecord  labels.txt

数据训练

DATASET_DIR对应目录会有训练生成的模型checkpoint  model.ckpt-100000.data-00000-of-00001  model.ckpt-100000.index  model.ckpt-100000.meta

数据验证

eval/Recall_5[0.993]
eval/Accuracy[0.8539]

三.模型导出和权重冻结

    模型导出

原代码

def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    image_size = FLAGS.image_size or network_fn.default_image_size
    placeholder = tf.placeholder(name='input', dtype=tf.float32,
                                 shape=[FLAGS.batch_size, image_size,
                                        image_size, 3])
    network_fn(placeholder)
    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())


修改如下

def main(_):
  if not FLAGS.output_file:
    raise ValueError('You must supply the path to save to with --output_file')
  tf.logging.set_verbosity(tf.logging.INFO)
  with tf.Graph().as_default() as graph:
    dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train',
                                          FLAGS.dataset_dir)
    preprocessing_name = FLAGS.model_name
    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
        preprocessing_name,
        is_training=False)

    network_fn = nets_factory.get_network_fn(
        FLAGS.model_name,
        num_classes=(dataset.num_classes - FLAGS.labels_offset),
        is_training=FLAGS.is_training)
    image_size = FLAGS.image_size or network_fn.default_image_size
    #与原来相比这里更改了placeholder的维度使其的input可以只接受一张图片
    # placeholder = tf.placeholder(name='input', dtype=tf.float32,
    #                              shape=[image_size,
    #                                     image_size, 3])
    placeholder = tf.placeholder(name='input',dtype=tf.string)
    #解码
    image = tf.image.decode_jpeg(placeholder,channels=3)
    #对数据进行预处理
    image = image_preprocessing_fn(image,image_size,image_size)
    #为了满足网络计算的要求,给x扩维,增加一个维度
    x = tf.expand_dims(image,axis=0)
    #x =tf.expand_dims(placeholder,axis=0)
    logits,end_points = network_fn(x)
    prediction = tf.nn.softmax(logits,name='output')
    graph_def = graph.as_graph_def()
    with gfile.GFile(FLAGS.output_file, 'wb') as f:
      f.write(graph_def.SerializeToString())

导出模型所用的命令如下:

#导出模型
python3 export_inference_graph.py \
--model_name=cifarnet \
--batch_size=1 \
--dataset_name=cifar10 \
--output_file=cifarnet_graph_def.pb \
--dataset_dir=./cifar10/
    

    权重冻结

关键代码如下:以后研究偷笑

freeze_graph.py

    if input_meta_graph_def:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_meta_graph_def.graph_def,
          output_node_names.split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)
    else:
      output_graph_def = graph_util.convert_variables_to_constants(
          sess,
          input_graph_def,
          output_node_names.split(","),
          variable_names_whitelist=variable_names_whitelist,
          variable_names_blacklist=variable_names_blacklist)

命令如下

#冻结模型
python3 freeze_graph.py \
--input_graph=cifarnet_graph_def.pb \
--input_binary=true \#注意这里二进制的方式否则会报error
--input_checkpoint="./cifarnet-model/model.ckpt-100000" \
--output_graph=freezed_cifarnet.pb \
--output_node_names=output
#设置为out的原因prediction = tf.nn.softmax(logits,name='output')

三.加载模型和验证结果

全部代码如下

"""Simple image classification with Inception.

Run image classification with Inception trained on ImageNet 2012 Challenge data
set.

This program creates a graph from a saved GraphDef protocol buffer,
and runs inference on an input JPEG image. It outputs human readable
strings of the top 5 predictions along with their probabilities.

Change the --image_file argument to any jpg image to compute a
classification of that image.

Please see the tutorial and website for a detailed description of how
to use this script to perform image recognition.

https://tensorflow.org/tutorials/image_recognition/
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os.path
import re
import sys
import tarfile

import numpy as np
from six.moves import urllib
import tensorflow as tf

FLAGS = tf.app.flags.FLAGS

# pylint: disable=line-too-long
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
# pylint: enable=line-too-long


class NodeLookup(object):
  """Converts integer node ID's to human readable labels."""

  def __init__(self,
               label_path=None):
    if not label_path:
      tf.logging.fatal('please specify the label file.')
      return
    self.node_lookup = self.load(label_path)

  def load(self, label_path):
    """Loads a human readable English name for each softmax node.

    Args:
      label_lookup_path: string UID to integer node ID.
      uid_lookup_path: string UID to human-readable string.

    Returns:
      dict from integer node ID to human-readable string.
    """
    if not tf.gfile.Exists(label_path):
      tf.logging.fatal('File does not exist %s', label_lookup_path)

    # Loads mapping from string UID to human-readable string
    proto_as_ascii_lines = tf.gfile.GFile(label_path).readlines()
    id_to_human = {}
    for line in proto_as_ascii_lines:
      if line.find(':') < 0:
        continue
      _id, human = line.rstrip('\n').split(':')
      id_to_human[int(_id)] = human

    return id_to_human

  def id_to_string(self, node_id):
    if node_id not in self.node_lookup:
      return ''
    return self.node_lookup[node_id]


def create_graph(model_file=None):
  """Creates a graph from saved GraphDef file and returns a saver."""
  # Creates graph from saved graph_def.pb.
  if not model_file:
    model_file = FLAGS.model_file
  with open(model_file, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')


def run_inference_on_image(image, model_file=None):
  """Runs inference on an image.

  Args:
    image: Image file name.

  Returns:
    Nothing
  """
  if not tf.gfile.Exists(image):
    tf.logging.fatal('File does not exist %s', image)
  image_data = open(image, 'rb').read()

  # Creates graph from saved GraphDef.
  create_graph(model_file)

  with tf.Session() as sess:
    # Some useful tensors:
    # 'softmax:0': A tensor containing the normalized prediction across
    #   1000 labels.
    # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
    #   float description of the image.
    # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
    #   encoding of the image.
    # Runs the softmax tensor by feeding the image_data as input to the graph.
    softmax_tensor = sess.graph.get_tensor_by_name('output:0')
    predictions = sess.run(softmax_tensor,
                           {'input:0': image_data})
    predictions = np.squeeze(predictions)

    # Creates node ID --> English string lookup.
    node_lookup = NodeLookup(FLAGS.label_file)

    top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
    top_names = []
    for node_id in top_k:
      human_string = node_lookup.id_to_string(node_id)
      top_names.append(human_string)
      score = predictions[node_id]
      print('id:[%d] name:[%s] (score = %.5f)' % (node_id, human_string, score))
  return predictions, top_k, top_names


def main(_):
  image = (FLAGS.image_file if FLAGS.image_file else
           os.path.join(FLAGS.model_dir, 'cropped_panda.jpg'))
  run_inference_on_image(image)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  # classify_image_graph_def.pb:
  #   Binary representation of the GraphDef protocol buffer.
  # imagenet_synset_to_human_label_map.txt:
  #   Map from synset ID to a human readable string.
  # imagenet_2012_challenge_label_map_proto.pbtxt:
  #   Text representation of a protocol buffer mapping a label to synset ID.
  parser.add_argument(
      '--model_file',
      type=str,
      default='/tmp/imagenet',
      help="""\
      Path to the .pb file that contains the frozen weights. \
      """
  )
  parser.add_argument(
      '--label_file',
      type=str,
      default='',
      help='Absolute path to label file.'
  )
  parser.add_argument(
      '--image_file',
      type=str,
      default='',
      help='Absolute path to image file.'
  )
  parser.add_argument(
      '--num_top_predictions',
      type=int,
      default=5,
      help='Display this many predictions.'
  )
  FLAGS, unparsed = parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

执行命令

python3 classify_image.py \
--model_file=./freezed_cifarnet.pb \
--image_file=./timg.jpeg \
--label_file=./cifar10/labels.txt

结果如下

id:[2] name:[bird] (score = 0.74933)
id:[3] name:[cat] (score = 0.09537)
id:[4] name:[deer] (score = 0.09519)
id:[0] name:[airplane] (score = 0.02756)
id:[1] name:[automobile] (score = 0.01199)

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
好的,以下是一个使用PyTorch对CIFAR-10数据集进行处理、绘图并导出模型的示例: ```python import torch import torchvision import torchvision.transforms as transforms import matplotlib.pyplot as plt # 定义一些数据处理的transform transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # 加载CIFAR-10数据集 trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) # 定义类别标签 classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 显示一些训练图片 def imshow(img): img = img / 2 + 0.5 # 反归一化 npimg = img.numpy() plt.imshow(np.transpose(npimg, (1, 2, 0))) plt.show() dataiter = iter(trainloader) images, labels = dataiter.next() imshow(torchvision.utils.make_grid(images)) # 定义一个CNN模型 class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) self.bn1 = torch.nn.BatchNorm2d(64) self.relu1 = torch.nn.ReLU(inplace=True) self.conv2 = torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) self.bn2 = torch.nn.BatchNorm2d(128) self.relu2 = torch.nn.ReLU(inplace=True) self.conv3 = torch.nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) self.bn3 = torch.nn.BatchNorm2d(256) self.relu3 = torch.nn.ReLU(inplace=True) self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2) self.fc1 = torch.nn.Linear(256 * 8 * 8, 512) self.relu4 = torch.nn.ReLU(inplace=True) self.fc2 = torch.nn.Linear(512, 10) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.conv2(x) x = self.bn2(x) x = self.relu2(x) x = self.conv3(x) x = self.bn3(x) x = self.relu3(x) x = self.pool(x) x = x.view(-1, 256 * 8 * 8) x = self.fc1(x) x = self.relu4(x) x = self.fc2(x) return x # 定义损失函数和优化器 criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) # 训练模型 net = Net() for epoch in range(100): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() print('[%d] loss: %.3f' % (epoch + 1, running_loss / len(trainloader))) # 保存模型 torch.save(net.state_dict(), 'cifar_net.pth') ``` 这个示例中,我们首先定义了一些数据处理的transform,然后使用`torchvision.datasets.CIFAR10`加载CIFAR-10数据集,并使用`torch.utils.data.DataLoader`生成数据迭代器。接着,我们定义了一个简单的CNN模型,并使用交叉熵损失函数和随机梯度下降优化器进行训练。最后,我们保存了训练好的模型。在代码中,我们还显示了一些训练图片,并使用了一个名为`imshow`的函数来实现图片的展示。这个示例中的一些处理步骤包括数据增强、归一化、批量处理等,这些都是为了提高模型训练效果。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值