tensorflow 读取图片 Dataset用法

目录

bmp Dataset.from_tensor_slices:

Dataset简单用法

png这个测试ok:

读图片,resize,预测

构建dateset png格式可以训练:


bmp Dataset.from_tensor_slices:

    augfiles = ['test_images/532_img_.bmp']
    gtfiles = ['test_images/532_img_.bmp']

    augImages = tf.constant(augfiles)
    gtImages = tf.constant(gtfiles)

  dataset = tf.data.Dataset.from_tensor_slices((augImages, gtImages))
    # dataset = dataset.shuffle(len(augImages))
    # dataset = dataset.repeat()
    dataset = dataset.map(parse_function).batch(1)

Dataset简单用法

一、Dataset使用
# from_tensor_slices:表示从张量中获取数据。
# make_one_shot_iterator():表示只将数据读取一次,然后就抛弃这个数据了。
input_data = [1,2,3,5,8]
dataset = tf.data.Dataset.from_tensor_slices(input_data)
for e in dataset:
    print(e)

png这个测试ok:

    img_byte = tf.compat.v1.read_file(filename='test_images/532_img_.png')

    img_data_jpg = tf.image.decode_png(img_byte)  # 图像解码
    img_data_jpg = tf.image.convert_image_dtype(img_data_jpg, dtype=tf.uint8)  # 改变图像数据的类型

读图片,resize,预测

filename_image_string = tf.io.read_file(imgfile)
filename_image = tf.image.decode_png(filename_image_string, channels=3)
filename_image = tf.image.convert_image_dtype(filename_image, tf.float32)
filename_image = tf.image.resize(filename_image, (256, 256))
l, w, c = filename_image.shape
filename_image = tf.reshape(filename_image, [1, l, w, c])
output = model.predict(filename_image)
output = output.reshape((l, w, c)) * 255
cv2.imwrite(out_dir+ os.path.basename(imgfile), output)

构建dateset png格式可以训练:

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, SpatialDropout2D, ReLU, Input, Concatenate, Add
from tensorflow.keras.losses import MeanAbsoluteError, MeanSquaredError
from tensorflow.keras.optimizers import Adam
import os
import pandas as pd
import cv2


class UWCNN(tf.keras.Model):

    def __init__(self):
        super(UWCNN, self).__init__()
        self.conv1 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze1")
        self.relu1 = ReLU()
        self.conv2 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze2")
        self.relu2 = ReLU()
        self.conv3 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze3")
        self.relu3 = ReLU()
        self.concat1 = Concatenate(axis=3)

        self.conv4 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze4")
        self.relu4 = ReLU()
        self.conv5 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze5")
        self.relu5 = ReLU()
        self.conv6 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze6")
        self.relu6 = ReLU()
        self.concat2 = Concatenate(axis=3)

        self.conv7 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze7")
        self.relu7 = ReLU()
        self.conv8 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze8")
        self.relu8 = ReLU()
        self.conv9 = Conv2D(16, 3, (1, 1), 'same', name="conv2d_dehaze9")
        self.relu9 = ReLU()
        self.concat3 = Concatenate(axis=3)

        self.conv10 = Conv2D(3, 3, (1, 1), 'same', name="conv2d_dehaze10")
        self.add1 = Add()

    def call(self, inputs):
        image_conv1 = self.relu1(self.conv1(inputs))
        image_conv2 = self.relu2(self.conv2(image_conv1))
        image_conv3 = self.relu3(self.conv3(image_conv2))
        dehaze_concat1 = self.concat1([image_conv1, image_conv2, image_conv3, inputs])

        image_conv4 = self.relu4(self.conv4(dehaze_concat1))
        image_conv5 = self.relu5(self.conv5(image_conv4))
        image_conv6 = self.relu6(self.conv6(image_conv5))
        dehaze_concat2 = self.concat2([dehaze_concat1, image_conv4, image_conv5, image_conv6])

        image_conv7 = self.relu7(self.conv7(dehaze_concat2))
        image_conv8 = self.relu8(self.conv8(image_conv7))
        image_conv9 = self.relu9(self.conv9(image_conv8))
        dehaze_concat3 = self.concat3([dehaze_concat2, image_conv7, image_conv8, image_conv9])

        image_conv10 = self.conv10(dehaze_concat3)
        out = self.add1([inputs, image_conv10])
        return out


def parse_function(filename, label):
    filename_image_string = tf.io.read_file(filename)
    label_image_string = tf.io.read_file(label)
    # Decode the filename_image_string
    filename_image = tf.image.decode_png(filename_image_string, channels=3)
    filename_image = tf.image.convert_image_dtype(filename_image, tf.float32)
    # Decode the label_image_string
    label_image = tf.image.decode_png(label_image_string, channels=3)
    label_image = tf.image.convert_image_dtype(label_image, tf.float32)
    return filename_image, label_image


def combloss(y_actual, y_predicted):
    '''
    This is the custom loss function for keras model
    :param y_actual:
    :param y_predicted:
    :return:
    '''
    # this is just l2 + lssim
    lssim = tf.constant(1, dtype=tf.float32) - tf.reduce_mean(
        tf.image.ssim(y_actual, y_predicted, max_val=1, filter_size=13))  # remove max_val=1.0
    # lmse = MeanSquaredError(reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE)(y_actual, y_predicted)
    lmse = MeanSquaredError(reduction='sum_over_batch_size')(y_actual, y_predicted)
    lmse = tf.math.multiply(lmse, 4)
    return tf.math.add(lmse, lssim)


def train(ckptpath="./train_type1/cp.ckpt", type='type1'):
    # df = pd.read_csv(datafile)

    augfiles = ['test_images/532_img_.png']
    gtfiles = ['test_images/532_label_.png']

    augImages = tf.constant(augfiles)
    gtImages = tf.constant(gtfiles)

    dataset = tf.data.Dataset.from_tensor_slices((augImages, gtImages))
    # dataset = dataset.shuffle(len(augImages))
    # dataset = dataset.repeat()
    dataset = dataset.map(parse_function).batch(1)

    # Call backs
    # checkpoint_path = "./train_type1/cp.ckpt"
    checkpoint_path = ckptpath
    checkpoint_dir = os.path.dirname(checkpoint_path)

    # Create a callback that saves the model's weights
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path, save_weights_only=True, verbose=1)

    model = UWCNN()
    model.compile(optimizer=Adam(), loss=combloss)
    model.fit(dataset, epochs=1, callbacks=[cp_callback])

    # os.listdir(checkpoint_dir)
    # model.save('saved_model/my_model')
    model.save('save_model/' + type)
    # model.sample_weights('model_weight.h5')

def model_test(imgdir="./test_images/", imgfile="12433.png", ckdir="./train_type1/cp.ckpt", outdir="./results/",
               type='type1'):
    model = UWCNN()
    # model.load_weights('model_weight.h5')
    # model = tf.keras.models.load_model('save_model/' + type, custom_objects={'loss': combloss}, compile=False)


    augfiles = ['test_images/532_img_.bmp']
    gtfiles = ['test_images/532_img_.bmp']

    augImages = tf.constant(augfiles)
    gtImages = tf.constant(gtfiles)

    dataset = tf.data.Dataset.from_tensor_slices((augImages, gtImages))
    dataset = dataset.map(parse_function).batch(1)
    cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=ckdir, save_weights_only=True, verbose=1)

    model.compile(optimizer=Adam(), loss=combloss)
    model.fit(dataset, epochs=1, callbacks=[cp_callback])
    model.summary()


    model.load_weights(ckdir)
    filename_image_string = tf.io.read_file(imgdir + imgfile)
    filename_image = tf.image.decode_png(filename_image_string, channels=3)
    filename_image = tf.image.convert_image_dtype(filename_image, tf.float32)
    filename_image = tf.image.resize(filename_image, (460, 620))
    l, w, c = filename_image.shape
    filename_image = tf.reshape(filename_image, [1, l, w, c])
    output = model.predict(filename_image)
    output = output.reshape((l, w, c)) * 255
    cv2.imwrite(outdir + type + "_" + imgfile, output)


if __name__ == "__main__":
    train(ckptpath="./train_type1/cp.ckpt", type='type1')
    exit(0)
    type = "type1"
    ckdir = "./train_type1/cp.ckpt"
    model_test(imgdir="./test_images/", imgfile="532_img_.png", ckdir=ckdir, outdir="./results/", type=type)
    # model_test(imgdir="./test_images/", imgfile="602_img_.png", ckdir=ckdir, outdir="./results/", type=type)
    # model_test(imgdir="./test_images/", imgfile="617_img_.png", ckdir=ckdir, outdir="./results/", type=type)
    # model_test(imgdir="./test_images/", imgfile="12422.png", ckdir=ckdir, outdir="./results/", type=type)
    # model_test(imgdir="./test_images/", imgfile="12433.png", ckdir=ckdir, outdir="./results/", type=type)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
TensorFlow是一个开源的机器学习框架,通常用于创建神经网络模型。在训练模型之前,需要准备好数据集,本文将介绍如何使用TensorFlow读取数据。 TensorFlow提供了多种读取数据的方法,其中最常用的是使用tf.data模块。首先,我们需要定义一个数据集对象,并通过读取文件的方式将数据加载进来。TensorFlow支持多种文件格式,如csv、txt、json、tfrecord等,可以根据自己的需求选择合适的格式。 加载数据后,我们可以对数据进行一些预处理,比如做数据增强、进行归一化等操作。预处理完数据后,我们需要将数据转化为张量类型,并将其打包成batch。通过这种方式,我们可以在每次训练中同时处理多个数据。 随后,我们可以使用tf.data.Dataset中的shuffle()函数打乱数据集顺序,防止模型只学习到特定顺序下的模式,然后使用batch()函数将数据划分成批次。最后,我们可以使用repeat()函数让数据集每次可以被使用多次,达到更好的效果。 在TensorFlow中,我们可以通过输入函数将数据集传入模型中,使模型能够直接从数据集中读取数据。使用输入函数还有一个好处,即能够在模型训练时动态地修改数据的内容,特别是在使用esimator模块进行模型训练时,输入函数是必须要的。 总结一下,在TensorFlow读取数据的流程如下:定义数据集对象-读取文件-预处理数据-打包数据为batch-打乱数据集-划分批次数据-重复数据集-使用输入函数读取数据。 在实际应用过程中,我们还可以通过其他方式来读取数据,如使用numpy、pandas等工具库,也可以自定义数据集类来处理数据。无论使用何种方式,读取数据都是机器学习训练中重要的一步,需要仔细处理。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI算法网奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值