import paddle import paddle.nn.functional as F from paddle.nn import Conv2D, Linear, MaxPool2D import numpy as np import os import json import random print(paddle.__version__) def data_generator(imgs, labels, batch_size, mode='train'): imgs_length = len(imgs) index_list = list(range(imgs_length)) if mode == 'train': random.shuffle(index_list) imgs_list, labels_list = [], [] for i in index_list: img = np.array(imgs[i].astype('float32')) label = np.reshape(labels[i], [1]).astype('int64') imgs_list.append(img) labels_list.append(label) if len(imgs_list) == batch_size: yield np.array(imgs_list), np.array(labels_list) imgs_list, labels_list = [], [] if len(imgs_list) > 0: yield np.array(imgs_list), np.array(labels_list) def load_data(mode='train'): with open('data/mnist.json','r') as f: data = json.load(f) train_set, val_set, eval_set = data if mode == 'train': imgs, labels = train_set[0], train_set[1] elif mode == 'valid': imgs, labels = val_set[0], val_set[1] elif mode == 'eval': imgs, labels = eval_set[0], eval_set[1] else: raise Exception(
在学习人工智能深度学习综合实践一书中进行项目二时导入了mnist数据集却无法识别文件
最新推荐文章于 2024-10-17 08:00:29 发布
在学习人工智能深度学习的实践中,遇到项目二的一个挑战:成功导入MNIST数据集后,程序却无法正确识别文件。这可能是由于数据加载、预处理或numpy操作方面的问题。进一步排查和理解数据格式及Python代码的执行流程是解决此问题的关键。
摘要由CSDN通过智能技术生成