GPU的设置
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
tf.config.set_visible_devices([gpu0],"GPU")
导入数据集并且显示
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
import cv2
import numpy as np
# 导入mnist数据,依次分别为训练集图片、训练集标签、测试集图片、测试集标签
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
print('训练集',train_images.shape)
##图片的显示
#数据可视化
num = 40
col = 8
row = int(num / 8)
index = np.random.randint(1, len(train_images), num)
for i in range(num):
for j in range(8):
# plt.figure()
plt.subplot(row, col, i + 1)
plt.xticks([]) # 去掉x轴的刻度
plt.yticks([]) #