目录
1 导入库
2 加载数据集
我们这次换为mnist数据集,这个数据集是手写数字的数据集
3 归一化
我们此时看一下train_iamge的shape
4 创建数据集
4.1 创建训练图片数据集
我们看一下我们的图片数据集
4.2 创建训练标签数据集
我们同样看一下
4.3 把这两个数据集合并到一起
我们想把这两个数据合并成元组的形式,所以要额外加一个括号
看一下合并后的数据集
5 处理训练集
5.1 乱序
乱序个数我们现在设置为10000,这个数是随便选的,我们一共有60000张图片,全部乱序会影响速度,乱序太少会没有效果
5.2 重复
我们令其无限次重复
5.3 设置批次
我们将其设置为一批64个
6 创建测试集
测试集后面给validation_data用的,能更加准确的获取acc与loss
我们在创建的时候这样写就行
看一下这个测试集
我们没有必要对test数据集进行乱序,因为测试集仅在前向转播时使用,不会对其余参数造成影响,由于测试集不用随机,那么测试集也就不用重复,训练集重复的意义是每一次重复都是随机顺序的数据,我们唯一就是需要将测试集的batch调整至与训练集相同的batch值
- 如果内存够大batch也不用改,对于mnist这个数据量小的数据集来讲batch是不需要添加的
7 建立模型
8 编译模型
sparse_categorical_crossentropy我们从上面的shape可以看出来,train_label是由单一的数组成的,所以此处的loss使用sparse_categorical_crossentropy
9 训练模型
因为我们上面定义了batch,所以我们在这里要使用steps_per_epochs(每个epochs的步数),如果我们不加入这个就会导致我们只训练64个数据,如果加上了这个就会第一批训练64个,然后第二批再训练64个
由于ds_test默认为无限循环,所以我们要加入validation_steps这个参数,这个参数的值是测试集个数 // batch
- train_image.shape[0]是训练图片的总个数,在这里是60000,test_image.shape[0]是测试图片的个数,这里是10000
- 步数这里注意一定要是整数,不然会报错
后面我们可以用这个模型进行评估与预测,我们就不在这赘述了