主要引用自 https://www.heywhale.com/mw/project/5ea53f81105d91002d509ead
python行尾的'\'主要意义是该行输入未结束
导入图片的代码
'''tensorflow原生方法'''
import tensorflow as tf
from tensorflow.keras import datasets,layers,models
BATCH_SIZE = 100
def load_image(img_path,size = (32,32)):
label = tf.constant(1,tf.int8) if tf.strings.regex_full_match(img_path,".*automobile.*") \
else tf.constant(0,tf.int8)
img = tf.io.read_file(img_path)
img = tf.image.decode_jpeg(img) #注意此处为jpeg格式
img = tf.image.resize(img,size)/255.0
return(img,label)
#使用并行化预处理num_parallel_calls 和预存数据prefetch来提升性能
ds_train = tf.data.Dataset.list_files("/home/kesci/input/data3483/data/cifar2/train/*/*.jpg") \
.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
.shuffle(buffer_size = 1000).batch(BATCH_SIZE) \
.prefetch(tf.data.experimental.AUTOTUNE)
ds_test = tf.data.Dataset.list_files("/home/kesci/input/data3483/data/cifar2/test/*/*.jpg") \
.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
.batch(BATCH_SIZE) \
.prefetch(tf.data.experimental.AUTOTUNE)
在这里会遇到一个问题即python使用gbk编码无法解读文件地址的问题,也就是
"tf.data.Dataset.list_files"报如下或者是类似的错误
UnicodeDecodeError: 'utf-8' codec can't decode byte 0xd5
此时需要通过系统的设置-时间和语言-语言-管理语言设置(右侧)的“更改系统区域设置”使用UTF-8提供全球语言支持选项,弃用gkb编码
通过上图设置编码,设置完进行电脑重启
然后将“tf.data.Dataset.list_files”中的文件路径改为绝对路径而且使用 '\\' 替代 '/' 方可执行
然后定义模型
'''定义模型'''
tf.keras.backend.clear_session() #清空会话
inputs = layers.Input(shape=(32,32,3))
x = layers.Conv2D(32,kernel_size=(3,3))(inputs)
x = layers.MaxPool2D()(x)
x = layers.Conv2D(64,kernel_size=(5,5))(x)
x = layers.MaxPool2D()(x)
x = layers.Dropout(rate=0.1)(x)
x = layers.Flatten()(x)
x = layers.Dense(32,activation='relu')(x)
outputs = layers.Dense(1,activation = 'sigmoid')(x)
model = models.Model(inputs = inputs,outputs = outputs)
model.summary()
训练模型
import datetime
import os
stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = os.path.join('data', 'autograph', stamp)
## 在 Python3 下建议使用 pathlib 修正各操作系统的路径
# from pathlib import Path
# stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# logdir = str(Path('./data/autograph/' + stamp))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.binary_crossentropy,
metrics=["accuracy"]
)
history = model.fit(ds_train,epochs= 10,validation_data=ds_test,
callbacks = [tensorboard_callback],workers = 4)
查看模型
from tensorboard import notebook
notebook.list()
#在tensorboard中查看模型
notebook.start("--logdir /home/kesci/input/data3483/data/keras_model")
import pandas as pd
dfhistory = pd.DataFrame(history.history)
dfhistory.index = range(1,len(dfhistory) + 1)
dfhistory.index.name = 'epoch'
print(dfhistory)
import matplotlib.pyplot as plt
def plot_metric(history, metric):
train_metrics = history.history[metric]
val_metrics = history.history['val_'+metric]
epochs = range(1, len(train_metrics) + 1)
plt.plot(epochs, train_metrics, 'bo--')
plt.plot(epochs, val_metrics, 'ro-')
plt.title('Training and validation '+ metric)
plt.xlabel("Epochs")
plt.ylabel(metric)
plt.legend(["train_"+metric, 'val_'+metric])
plt.show()
plot_metric(history,"loss")
plot_metric(history,"accuracy")
#可以使用evaluate对数据进行评估
val_loss,val_accuracy = model.evaluate(ds_test,workers=4)
print(val_loss,val_accuracy)
使用模型
'''使用模型'''
model.predict(ds_test)
for x,y in ds_test.take(1):
print(model.predict_on_batch(x[0:20]))
保存模型
# 保存权重,该方式仅仅保存权重张量
model.save_weights('./data/tf_model_weights.ckpt',save_format = "tf")
# 保存模型结构与模型参数到文件,该方式保存的模型具有跨平台性便于部署
model.save('./data/tf_model_savedmodel', save_format="tf")
print('export saved model.')
model_loaded = tf.keras.models.load_model('./data/tf_model_savedmodel')
model_loaded.evaluate(ds_test)
假如需要导入更复杂的image-label数据的话,就不能使用上面的方法,具体的举例如下(bounding box)类型的label
def load_image_train(img_path, label, size=(108, 192)):
img = tf.io.read_file(img_path)
img = tf.image.decode_png(img)
img = tf.image.resize(img, size) / 255.0
return (img, label)
ds_train = tf.data.Dataset.from_tensor_slices((filepath, labeldata)) \
.map(load_image_train, num_parallel_calls=tf.data.experimental.AUTOTUNE) \
.batch(BATCH_SIZE) \
.prefetch(tf.data.experimental.AUTOTUNE)
'''filepath是文件的绝对地址字符串组成的python List,
labeldata也是标出的四个数字的数组组成的python List'''
'''上述方法可以避免麻烦而且很难做对的Tensorflow向量和字符串操作'''