oneflow利用ofrecord数据集文件训练lenet模型

# lenet_train.py
import numpy as np
import oneflow as flow
import oneflow.typing as tp

#lenet网络结构
def lenet(data, train=False):
    initializer = flow.truncated_normal(0.1)
    conv1 = flow.layers.conv2d(
        data,
        32,
        5,
        padding="SAME",
        activation=flow.nn.relu,
        name="conv1",
        kernel_initializer=initializer,
    )
    pool1 = flow.nn.max_pool2d(
        conv1, ksize=2, strides=2, padding="SAME", name="pool1", data_format="NCHW"
    )
    conv2 = flow.layers.conv2d(
        pool1,
        64,
        5,
        padding="SAME",
        activation=flow.nn.relu,
        name="conv2",
        kernel_initializer=initializer,
    )
    pool2 = flow.nn.max_pool2d(
        conv2, ksize=2, strides=2, padding="SAME", name="pool2", data_format="NCHW"
    )
    reshape = flow.reshape(pool2, [pool2.shape[0], -1])
    hidden = flow.layers.dense(
        reshape,
        512,
        activation=flow.nn.relu,
        kernel_initializer=initializer,
        name="dense1",
    )
    if train:
        hidden = flow.nn.dropout(hidden, rate=0.5, name="dropout")
    #返回未激活的输出
    return flow.layers.dense(hidden, 10, kernel_initializer=initializer, name="dense2")


#ofrecord文件解码成图像数据
def ofrecord_decode():
    batch_size = 25  #批量大小
    color_space = "GRAY"  #图像颜色空间
    ofrecord = flow.data.ofrecord_reader(
        "./dataset/mnist-50/", #ofrecord数据集文件路径
        batch_size=batch_size,  #读取批量大小
        data_part_num=5,  #有几个ofrecord文件
        part_name_suffix_length=-1,  #ofrecord文件命名序号由几位填充,-1为不填充
        random_shuffle=True, #是否打乱顺序读取样本
        shuffle_after_epoch=True,  #每轮结束后是否重新打乱样本顺序
    )
    #读取图像数据
    image = flow.data.OFRecordImageDecoderRandomCrop(
        ofrecord, "images", color_space=color_space, random_area=(0.95, 1.0), random_aspect_ratio=(0.99, 1.0)
    )
    #读取标签数据
    labels = flow.data.OFRecordRawDecoder(
        ofrecord, "labels", shape=(1,), dtype=flow.int32
    )
    #修改图像尺寸
    rsz, scale, new_size = flow.image.Resize(
        image, target_size=(28, 28), channels=1
    )
    #图像正则化
    normal = flow.image.CropMirrorNormalize(
        rsz,
        color_space=color_space,
        mean=[0.0],
        std=[255.0],
        output_dtype=flow.float,
    )
    return normal, labels

#训练作业函数
@flow.global_function(type="train")
def train_job() -> tp.Numpy:
    #获取解码后的图像和标签
    images, labels = ofrecord_decode()
    # #获取网络输出
    logits = lenet(images, train=True)
    # #利用输出和标签计算softmax损失
    loss = flow.nn.sparse_softmax_cross_entropy_with_logits(
        labels, logits, name="softmax_loss"
    )
    # #损失降维求均值
    loss = flow.math.reduce_mean(loss)
    # #定义学习策略
    lr_scheduler = flow.optimizer.PiecewiseConstantScheduler([], [0.001])
    # #利用Adam方法优化损失
    flow.optimizer.Adam(lr_scheduler).minimize(loss)
    # #返回当次训练后的损失
    return loss


if __name__ == "__main__":
    flow.load_variables(flow.checkpoint.get("./model/"))
    # 训练 300 * 25 个轮次
    for epoch in range(25 * 600):
        loss = train_job()
        #每训练5步打印一次loss
        if epoch % 5 == 0: print(loss)
    # flow.checkpoint.save("./model/")
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值