使用googlenet进行垃圾分类预测

27 篇文章 1 订阅
26 篇文章 1 订阅

数据集下载:https://aistudio.baidu.com/aistudio/datasetDetail/16630

主要训练模型:

coding=utf-8

“”"
author:lei
function:使用googlenet进行图片分类
“”"

import tensorflow as tf
from tensorflow.contrib import slim

import numpy as np

import pickle

import random

import os

class BuildModel(object):
def init(self, sort_path, data_path, model_path):
self.sort_path = sort_path
self.data_path = data_path
self.model_path = model_path
self.sort_size = 6 # 类别的数量
self.one_sort_size = 500 # 每个类别图片的数量
self.batch_size = 30 # 每次提取的数据集大小,所以每个类别每次提取时 应该提取 5张图片
self.one_sort_extract_num = 5
self.epoch_size = 100 # 所以epoch_size 100
self.frequence_size = 50 # 进行50轮迭代
self.image_size = 100
self.image_channels = 3
self.learning_rate = 0.00005

# 提取数据
def extract_data(self, file_path):
    with open(file_path, "rb") as f:
        data = pickle.load(f)
    return data

# 构建数据集
def constract_data(self, sort_data, epoch):
    data_list = []
    key_list = []
    # 循环构建feature和label,使其对应
    for key, value_list in sort_data.items():
        for data in value_list[epoch * self.one_sort_extract_num: (epoch + 1) * self.one_sort_extract_num]:
            data_list.append(data)
            key_list.append(key)
    # print(np.array(data_list).shape)  # (30, 100, 100, 3)
    # print(key_list)  # 每个类别有5个  共30个
    return data_list, key_list

# 构建一个googlenet模型
def inception(self, x, d0_area, d1_0area, d1_1area, d2_0area, d2_1area, d3_1area, scope, reuse=None):
    with tf.variable_scope(scope, reuse=reuse):
        with slim.arg_scope([slim.max_pool2d, slim.conv2d], stride=[1, 1]):
            with tf.variable_scope("net0"):
                net0 = slim.conv2d(x, num_outputs=d0_area, kernel_size=[1, 1], padding="SAME", scope="net0_1")

            with tf.variable_scope("net1"):
                net1 = slim.conv2d(x, num_outputs=d1_0area, kernel_size=[1, 1], padding="SAME", scope="net1_1")
                net1 = slim.conv2d(net1, num_outputs=d1_1area, kernel_size=[3, 3], padding="SAME", scope="net3_3")

            with tf.variable_scope("net2"):
                net2 = slim.conv2d(x, num_outputs=d2_0area, kernel_size=[1, 1], padding="SAME", scope="net2_1")
                net2 = slim.conv2d(net2, num_outputs=d2_1area, kernel_size=[5, 5], padding="SAME", scope="net2_5")

            with tf.variable_scope("net3"):
                net3_pool = slim.max_pool2d(x, kernel_size=[3, 3], padding="SAME", scope="net3_pool")
                net3 = slim.conv2d(net3_pool, num_outputs=d3_1area, kernel_size=[1, 1], padding="SAME", scope="net3_1")

            # 将goole_net组成
            google_net = tf.concat([net0, net1, net2, net3], axis=-1)

    return google_net

# 构建googlenet模型模块
def google_model(self, is_training):
    with tf.name_scope("data"):
        x = tf.placeholder(tf.float32, [None, self.image_size, self.image_size, self.image_channels], name="x")
        y_true = None
        y = None
        if is_training:
            y = tf.placeholder(tf.int32, [None])  # 不会出现类别
            y_true = tf.one_hot(y, depth=6, name="y_true")  # 将y_true转换成one-hot编码

    # 构建googlenet模型
    with slim.arg_scope([slim.max_pool2d, slim.conv2d, slim.avg_pool2d], padding="SAME", stride=1):
        with tf.variable_scope("block1", reuse=None):
            net = slim.conv2d(x, 64, [5, 5], stride=2, scope="conv5_5")
            # print(net)  # Tensor("block1/conv5_5/Relu:0", shape=(?, 50, 50, 64), dtype=float32)

        with tf.variable_scope("block2", reuse=None):
            net = slim.conv2d(net, 64, [1, 1], scope="conv1_1")
            net = slim.conv2d(net, 192, [3, 3], scope="conv3_3")
            net = slim.max_pool2d(net, [3, 3], stride=2, scope="max_pool")
            # print(net)  # Tensor("block2/max_pool/MaxPool:0", shape=(?, 25, 25, 192), dtype=float32)

        with tf.variable_scope("block3", reuse=None):
            net = self.inception(net, 64, 96, 128, 16, 32, 32, scope="google_net1")
            net = self.inception(net, 128, 128, 192, 32, 96, 64, scope="google_net2")
            net = slim.max_pool2d(net, [3, 3], stride=2, scope="max_pool")
            # print(net)  # Tensor("block3/max_pool/MaxPool:0", shape=(?, 13, 13, 480), dtype=float32)

        with tf.variable_scope("block4", reuse=None):
            net = self.inception(net, 192, 96, 208, 16, 48, 64, scope="google_net3")
            net = self.inception(net, 160, 112, 224, 24, 64, 64, scope="google_net4")
            net = self.inception(net, 128, 128, 256, 24, 64, 64, scope="google_net5")
            net = slim.max_pool2d(net, [3, 3], stride=2, scope="max_pool")
            print(net)  # Tensor("block4/max_pool/MaxPool:0", shape=(?, 7, 7, 512), dtype=float32)

    # 全连接层的构建
    with tf.variable_scope("correct", reuse=None):
        net = slim.flatten(net)  # 进行扁平化
        print(net)  # Tensor("correct/Flatten/flatten/Reshape:0", shape=(?, 25088), dtype=float32)
        if is_training:
            net = slim.dropout(net, keep_prob=0.5)  # 如果是训练集进行随机失活
        # 获取最后的预测值
        logit = slim.fully_connected(net, self.sort_size, activation_fn=None, scope="logit")  # 没有激活函数,使用softmax进行评测
        return y_true, logit, x, y

def train_loss(self, y_true, logit):
    # 求损失
    with tf.name_scope("loss"):
        # 训练则可以求解损失和正确率,test直接进行预测
        entroxy = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=logit)
        loss = tf.reduce_mean(entroxy)  # 求解损失的平均值
    return loss

def compute_y_label(self, logit):
    y_label = tf.argmax(tf.nn.softmax(logit), 1)  # 返回最大值的位置
    return y_label

# 进行训练
def train_step(self, y_true, logit, loss):
    # 进行训练
    with tf.name_scope("train"):
        train_op = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(loss)
        equal = tf.equal(tf.argmax(logit, 1), tf.argmax(y_true, 1))
        accuracy = tf.reduce_mean(tf.cast(equal, tf.float32))

    return train_op, accuracy

# 主逻辑类
def run(self):
    # 构建模型产生logit
    y_true, logits, x, y = self.google_model(True)
    # 求解损失
    loss = self.train_loss(y_true, logits)
    # 求解label
    y_label = self.compute_y_label(logits)
    # 进行训练
    train_op, accuracy = self.train_step(y_true, logits, loss)

    # 初始化变量
    init_op = tf.compat.v1.global_variables_initializer()
    saver = tf.train.Saver()

    with tf.compat.v1.Session() as sess:
        sess.run(init_op)

        # 1.定义好提取数据模块
        sort_data = self.extract_data(self.data_path)
        sort_num = self.extract_data(self.sort_path)

        file_list = os.listdir("/home/aistudio/garbage_sort/model/")

        if "checkpoint" in file_list:
            saver.restore(sess, save_path=self.model_path)

        value = 0  # 定义初始准确率大小
        i = 0
        for frequence in range(self.frequence_size):
            for epoch in range(self.epoch_size):
                i += 1
                features, labels = self.constract_data(sort_data, epoch)
                sess.run(train_op, feed_dict={x: features, y: labels})
                if epoch % 5 == 0:
                    acc = sess.run(accuracy,  feed_dict={x: features, y: labels})
                    print("frequence: {}, epoch: {}, acc: {}".format(frequence, epoch, acc))

                    if frequence < 30:
                        saver.save(sess, save_path=self.model_path, global_step=i)
                        print("保存成功!")
                    elif frequence >= 30 and acc > value:
                        value = acc
                        saver.save(sess, save_path=self.model_path, global_step=i)
                        print("保存成功!")

def main():
sort_path = “/home/aistudio/garbage_sort/data/sort_num.pkl”
data_path = “/home/aistudio/garbage_sort/data/sort_data.pkl”
model_path = “/home/aistudio/garbage_sort/model/model.ckpt”
# sort_path = “./data/sort_num.pkl”
# data_path = “./data/sort_data.pkl”
# model_path = “./model/model.ckpt”
model = BuildModel(sort_path, data_path, model_path)
model.run()

if name == ‘main’:
main()

数据提取:

coding=utf-8

“”"
author:lei
function:提取数据
“”"

import os
import tensorflow as tf
import numpy as np
import pickle

构造提取数据类

class ExtractData(object):
def init(self, file_path, save_path):
self.file_path = file_path
self.save_path = save_path
self.sort_num = dict()
self.sort_photo_data = dict()
self.sort_data = dict()

def extract_photo(self, photo_list):
    with tf.compat.v1.name_scope("extract_file"):
        # 构造文件队列
        file_queue = tf.compat.v1.train.string_input_producer(photo_list)
        # 构造阅读器
        reader = tf.compat.v1.WholeFileReader()
        # 读取每张图片的数据
        key, value = reader.read(file_queue)
        # 对读取的图片数据进行解码
        image = tf.compat.v1.image.decode_jpeg(value)
        # 处理图片大小
        image_resize = tf.compat.v1.image.resize_images(image, [100, 100])
        # 将图片数据进行固定 确定为3通道数据
        image_resize.set_shape([100, 100, 3])
        # 进行批处理
        image_batch = tf.compat.v1.train.batch([image_resize], batch_size=500, num_threads=1, capacity=500)

    with tf.compat.v1.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess, coord=coord)
        sort_data = sess.run([image_batch])
        coord.request_stop()

    # print(sort_data)
    return sort_data

def save_data(self, file_path, data):
    with open(file_path, "wb") as f:
        pickle.dump(data, f)
    print("保存成功!")

def get_sort(self):
    sort_list = os.listdir(self.file_path)
    # print(sort_list)  [sort for sort in sort_list]
    for key, value in enumerate(sort_list):
        self.sort_num[key] = value
    # print(self.sort_num)  [sort for sort in sort_list]

    for key, sort in enumerate(sort_list):
        photo_dirs = self.file_path + "/" + sort + "/"
        photo_list = os.listdir(photo_dirs)
        # 将图片和图片序列传入
        file_list = [photo_dirs + photo_dir for photo_dir in photo_list]
        # print(file_list)
        sort_data = self.extract_photo(file_list)  # 将每个分类的数据传入
        sort_data = np.array(sort_data).reshape([500, 100, 100, 3])
        self.sort_data[key] = sort_data
    # print(self.sort_data)
    # print(self.sort_num)  # {0: 'cardboard', 1: 'glass', 2: 'metal', 3: 'paper', 4: 'plastic', 5: 'trash'}
    self.save_data(self.save_path + "sort_num.pkl", self.sort_num)  # 保存数据
    self.save_data(self.save_path + "sort_data.pkl", self.sort_data)  # 保存数据

def run(self):
    # 获取物品分类
    self.get_sort()

def main():
file_path = “./data/image_classify_resnet50/training_dataset”
save_path = “./data/”
ext = ExtractData(file_path, save_path)
ext.run()

if name == ‘main’:
main()

预测模块:

coding=utf-8

“”"
author:lei
function:测试集,使用测试集对模型进行测试
“”"

import tensorflow as tf
import pickle
import random
import os
from 深度学习_slim.DeepLearning_slim.fish_sort.google_net import BuildModel
import numpy as np

随机产生数据集

def product_data(sort_path, file_path):
with open(sort_path, “rb”) as f:
sort = pickle.load(f)

with open(file_path, "rb") as f:
    data = pickle.load(f)

sort_key = random.randint(0, 5)
sort_image = data[sort_key][random.randint(0, 499)]
# print(sort_key)
# print(sort_image)
return sort, sort_key, sort_image

制作模型进行预测

def model_predict(model_path, sort_key, sort_image, sort):
with tf.Graph().as_default():
x = tf.placeholder(tf.float32, [1, 100, 100, 3])
model = BuildModel(
sort_path="/home/aistudio/garbage_sort/data/sort_num.pkl",
data_path="/home/aistudio/garbage_sort/data/sort_data.pkl",
model_path="/home/aistudio/garbage_sort/model/model.ckpt"
)

    # 构建模型产生logit
    y_true, logits, x, y = model.google_model(False)
    # 求解label
    y_label = model.compute_y_label(logits)

    sort_image = np.reshape(sort_image, [1, 100, 100, 3])

    saver = tf.train.Saver()

    with tf.Session() as sess:
        print("restore model...")
        file_list = os.listdir("./model/")

        if "checkpoint" in file_list:
            saver.restore(sess, model_path)
        else:
            print("Not find the model file!")
            exit()

        prediction = sess.run(y_label, feed_dict={x: sort_image})
        sort_prediction = sort[prediction[0]]
        print("预测结果:{}, 类别编号结果为:{}".format(sort_prediction, prediction))
        print("真正结果为:{},类别编号结果为:{}".format(sort[sort_key], sort_key))

def main():
# 提取数据
# file_path = “/home/aistudio/garbage_sort/data/sort_data.pkl”
# model_path = “/home/aistudio/garbage_sort/model/”
sort_path = “./data/sort_num.pkl”
file_path = “./data/sort_data.pkl”
model_path = “./model/model.ckpt-3006”
sort, sort_key, sort_image = product_data(sort_path, file_path)
model_predict(model_path, sort_key, sort_image, sort)

if name == “main”:
main()

预测结果:

restore model…
预测结果:metal, 类别编号结果为:[2]
真正结果为:metal,类别编号结果为:2

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值