mxnet手写数字识别(2)

还可以写得简便一些的,是这个版本

import os,  sys
from utils import get_data

import mxnet as mx
import numpy as np
import logging
# 创建计算图
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(fc2, name='relu2', act_type="relu")
fc3 = mx.symbol.FullyConnected(act2, name='fc3', num_hidden=10)
# print(fc3)  这时候只是一个符号
softmax = mx.symbol.SoftmaxOutput(fc3, name = 'softmax')

n_epoch = 2
batch_size = 100
# 加载数据
basedir = os.path.dirname(__file__)
get_data.get_mnist(os.path.join(basedir, "data"))

train_dataiter = mx.io.MNISTIter(
        image=os.path.join(basedir, "data", "train-images-idx3-ubyte"),
        label=os.path.join(basedir, "data", "train-labels-idx1-ubyte"),
        data_shape=(784,),
        batch_size=batch_size, shuffle=True, flat=True, silent=False, seed=10)

val_dataiter = mx.io.MNISTIter(
        image=os.path.join(basedir, "data", "t10k-images-idx3-ubyte"),
        label=os.path.join(basedir, "data", "t10k-labels-idx1-ubyte"),
        data_shape=(784,),
        batch_size=batch_size, shuffle=True, flat=True, silent=False)
 
metric = mx.metric.create('acc')
 

mod = mx.mod.Module(softmax)
mod.fit(train_dataiter, eval_data=val_dataiter,
        optimizer_params={'learning_rate':0.01, 'momentum': 0.9}, num_epoch=n_epoch)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
MXNet可以通过使用RecordIO格式的数据来读取和处理im2rec文件。RecordIO是MXNet特有的一种数据格式,它将多个样本(例如图像及其对应的标签)存储在一个文件中,这样可以更高效地读取数据。使用RecordIO格式的数据可以提高数据读取的速度和效率,从而加快训练的速度。 要处理im2rec文件,可以使用MXNet提供的`mxnet.recordio`模块。该模块提供了一组函数,用于读取和写入RecordIO格式的数据。下面是一个示例代码,演示如何使用`mxnet.recordio`模块读取im2rec文件中的图像数据: ```python import mxnet as mx import numpy as np # 打开im2rec文件 record = mx.recordio.MXIndexedRecordIO('path/to/img.rec', 'path/to/img.idx', 'r') # 遍历文件中的所有图像 for i in range(len(record)): # 读取图像 item = record.read_idx(i) header, img = mx.recordio.unpack(item) # 将图像数据转换成numpy数组格式 nparr = np.frombuffer(img, dtype=np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # 对图像进行处理 # ... # 显示图像 cv2.imshow('image', img) cv2.waitKey(0) cv2.destroyAllWindows() ``` 在上面的代码中,我们首先打开了一个im2rec文件,然后使用`read_idx`函数逐个读取了文件中的所有图像数据。读取到的图像数据是一个二进制字符串,我们可以使用`unpack`函数将其解析成图像数据和标签数据。这里我们只对图像数据进行了处理,并使用OpenCV库将其显示出来。 需要注意的是,上面的示例代码只是一个简单的演示,实际使用中可能需要根据实际需求做一些修改和调整。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值