第3周 T2 用TensorFlow实现cifar10数据集图像分类
导入必要的库
这一步比较基础, 按需求导入即可
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
import numpy as np
设置CPU
我电脑没有GPU,只能设置成CPU跑, 实测下来训练一个epoch需要70多秒, 15年的老笔记本了。
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
导入数据cifar10数据
没啥可说的
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
归一化
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images.shape, train_labels.shape, test_images.shape, test_labels.shape
可视化
class_names = ['airplane', 'automobile