# lenet_train.py
import numpy as np
import oneflow as flow
import oneflow.typing as tp
#lenet网络结构
def lenet(data, train=False):
initializer = flow.truncated_normal(0.1)
conv1 = flow.layers.conv2d(
data,
32,
5,
padding="SAME",
activation=flow.nn.relu,
name="conv1",
kernel_initializer=initializer,
)
pool1 = flow.nn.max_pool2d(
conv1, ksize=2, strides=2, padding="SAME", name="pool1", data_format="NCHW"
)
conv2 = flow.layers.conv2d(
pool1,
64,
5,
padding="SAME",
activation=flow.nn.relu,
name="conv2",
kernel_initializer=initializer,
)
pool2 = flow.nn.max_pool2d(
conv2, ksize=2, strides=2, padding="SAME", name="pool2", data_format="NCHW"
)
reshape = flow.reshape(pool2, [pool2.shape[0], -1])
hidden = flow.layers.dense(
reshape,
512,
activation=flow.nn.relu,
kernel_initializer=initializer,
name="dense1",
)
if train:
hidden = flow.nn.dropout(hidden, rate=0.5, name="dropout")
#返回未激活的输出
return flow.layers.dense(hidden, 10, kernel_initializer=initializer, name="dense2")
#ofrecord文件解码成图像数据
def ofrecord_decode():
batch_size = 25 #批量大小
color_space = "GRAY" #图像颜色空间
ofrecord = flow.data.ofrecord_reader(
"./dataset/mnist-50/", #ofrecord数据集文件路径
batch_size=batch_size, #读取批量大小
data_part_num=5, #有几个ofrecord文件
part_name_suffix_length=-1, #ofrecord文件命名序号由几位填充,-1为不填充
random_shuffle=True, #是否打乱顺序读取样本
shuffle_after_epoch=True, #每轮结束后是否重新打乱样本顺序
)
#读取图像数据
image = flow.data.OFRecordImageDecoderRandomCrop(
ofrecord, "images", color_space=color_space, random_area=(0.95, 1.0), random_aspect_ratio=(0.99, 1.0)
)
#读取标签数据
labels = flow.data.OFRecordRawDecoder(
ofrecord, "labels", shape=(1,), dtype=flow.int32
)
#修改图像尺寸
rsz, scale, new_size = flow.image.Resize(
image, target_size=(28, 28), channels=1
)
#图像正则化
normal = flow.image.CropMirrorNormalize(
rsz,
color_space=color_space,
mean=[0.0],
std=[255.0],
output_dtype=flow.float,
)
return normal, labels
#训练作业函数
@flow.global_function(type="train")
def train_job() -> tp.Numpy:
#获取解码后的图像和标签
images, labels = ofrecord_decode()
# #获取网络输出
logits = lenet(images, train=True)
# #利用输出和标签计算softmax损失
loss = flow.nn.sparse_softmax_cross_entropy_with_logits(
labels, logits, name="softmax_loss"
)
# #损失降维求均值
loss = flow.math.reduce_mean(loss)
# #定义学习策略
lr_scheduler = flow.optimizer.PiecewiseConstantScheduler([], [0.001])
# #利用Adam方法优化损失
flow.optimizer.Adam(lr_scheduler).minimize(loss)
# #返回当次训练后的损失
return loss
if __name__ == "__main__":
flow.load_variables(flow.checkpoint.get("./model/"))
# 训练 300 * 25 个轮次
for epoch in range(25 * 600):
loss = train_job()
#每训练5步打印一次loss
if epoch % 5 == 0: print(loss)
# flow.checkpoint.save("./model/")
oneflow利用ofrecord数据集文件训练lenet模型
最新推荐文章于 2024-07-19 16:36:18 发布