输出最大值MXNet实现

网络结构,输入为2个数,先经过10个节点的全连接层,再经过10个节点的ReLu,再经过10个节点的全连接层,再经过1个节点的全连接层,最后输出。

#-*-coding:utf-8-*- 

import logging
import math
import random
import mxnet as mx # 导入 MXNet 库
import numpy as np # 导入 NumPy 库,这是 Python 常用的科学计算库

logging.getLogger().setLevel(logging.DEBUG) # 打开调试信息的显示

'''设置超参数'''
n_sample = 10000 # 训练用的数据点个数
batch_size = 10 # 批大小
learning_rate = 0.1 # 学习速率
n_epoch = 10 # 训练 epoch 数

'''生成训练数据'''
# 每个数据点是在 (0,1) 之间的 2 个随机数
train_in = [[ random.uniform(0, 1) for c in range(2)] for n in range(n_sample)] 
train_out = [0 for n in range(n_sample)] # 期望输出,先初始化为 0
for i in range(n_sample):
    # 每个数据点的期望输出是 2 个输入数中的大者
    train_out[i] = max(train_in[i][0], train_in[i][1])

'''定义train_iter为训练数据的迭代器,data为输入数据,label为标签对应train_out,shuffle代表每个epoch会随机打乱数据'''
train_iter = mx.io.NDArrayIter(data = np.array(train_in), label = {'reg_label':np.array(train_out)}, batch_size = batch_size, shuffle = True)

'''定义网络结构,src为输入层,fc1,fc2,fc3是全连接层,act1,act2是ReLu层,num_hidden代表神经元个数,data是输入数据,name是输出'''
src = mx.sym.Variable('data') # 输入层
fc1  = mx.sym.FullyConnected(data = src, num_hidden = 10, name = 'fc1') # 全连接层
act1 = mx.sym.Activation(data = fc1, act_type = "relu", name = 'act1') # ReLU层
fc2  = mx.sym.FullyConnected(data = act1, num_hidden = 10, name = 'fc2') # 全连接层
act2 = mx.sym.Activation(data = fc2, act_type = "relu", name = 'act2') # ReLU层
fc3  = mx.sym.FullyConnected(data = act2, num_hidden = 1, name = 'fc3') # 全连接层
'''定义net为输出层,采用线性回归输出,MXNet会自动使用MSE作为损失函数,输入数据为fc3,输出层命名为reg'''
net = mx.sym.LinearRegressionOutput(data = fc3, name = 'reg') # 输出层

'''定义变量module需训练的网络模组,网络的输出symbol为net,期望标签名label_names为reg_label'''
module = mx.mod.Module(symbol = net, label_names = (['reg_label']))

'''定义module.fit进行训练'''
module.fit(
    train_iter, # 训练数据的迭代器
    eval_data = None, # 在此只训练,不使用测试数据
    eval_metric = mx.metric.create('mse'), # 输出 MSE 损失信息
    #将权重和偏置初始化为在[-0.5,0.5]间均匀的随机数
    initializer=mx.initializer.Uniform(0.5),
    optimizer = 'sgd', # 梯度下降算法为 SGD
    # 设置学习速率
    optimizer_params = {'learning_rate': learning_rate}, 
    num_epoch = n_epoch, # 训练 epoch 数
    # 每经过 100 个 batch 输出训练速度 
    batch_end_callback = None, 
    epoch_end_callback = None, 
)

#输出最终参数
for k in module.get_params():
    print(k)

转载于:https://www.cnblogs.com/cold-city/p/10460392.html

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
求高人解答有关BP神经网络输入训练时出现最大值和最小值-neiqian lun.xls   正在做毕设,训练样本为表格形式,在不同车速和方向盘转角输入给定下的车轮角速度。有400多个训练样本,训练时出现了输入最大最小之相等的情况; p=[0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190     200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 0     10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200     210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 390 400    0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190    200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 390    400 410 420 430 440 0 10 20 30 40 50 60 70 80 90 100 110 120 130 140    150 160 170 180 190 200 210 220 230 240 250 260 270 280 290 300 310 320 330 340    350 360 370 380 390 400 410 420 430 440 450 0 10 20 30 40 50 60 70 80    90 100 110 120 130 140 150 160 170 180 190 200 210 220 230 240 250 260 270 280    290 300 310 320 330 340 350 360 370 380 390 400 410 420 430 440 450 0 10 20    30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 210 220    230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 390 400 410 420    430 440 450 0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160    170 180 190 200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360    370 380 390 400 410 420 430 440 450 0 10 20 30 40 50 60 70 80 90 100    110 120 130 140 150 160 170 180 190 200 210 220 230 240 250 260 270 280 290 300    310 320 330 340 350 360 370 380 390 400 410 420 430 440 450 0 10 20 30 40    50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 210 220 230 240    250 260 270 280 290 300 310 320 330 340 350 360 370 380 390 400 410 420 430 440    450 0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180    190 200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380;   15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15     15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 15 20     20 20 20 20 20 20 20 20 20 20 20 20 20 20 20 20 20 20 20 20     20 20 20 20 20 20 20 20 20 20 20 20 20 20 20 20 20 20 20 20     25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25     25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25 25     25 25 25 25 25 30 30 30 30 30 30 30 30 30 30 30 30 30 30 30     30 30 30 30 30 30 30 30 30 30 30 30 30 30 30 30 30 30 30 30     30 30 30 30 30 30 30 30 30 30 30 35 35 35 35 35 35 35 35 35     35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35     35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 35 40 40 40     40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40     40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40 40     40 40 40 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45     45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45 45     45 45 45 45 45 45 45 45 45 50 50 50 50 50 50 50 50 50 50 50     50 50 50 50 50 50 50 50 50 50 50 50 50 50 50 50 50 50 50 50     50 50 50 50 50 50 50 50 50 50 50 50 50 50 50 55 55 55 55 55     55 55 55 55 55 55 55 55 55 55 55 55 55 55 55 55 55 55 55 55     55 55 55 55 55 55 55 55 55 55 55 55 55 55 55 55 55 55 55 55     55 60 60 60 60 60 60 60 60 60 60 60 60 60 60 60 60 60 60 60     60 60 60 60 60 60 60 60 60 60 60 60 60 60 60 60 60 60 60 60]; t=[740.93 738.49 736.11 733.76 731.41 729.01 726.49 723.8 720.86 717.59 713.94 709.81 705.15 699.89 693.98 687.37 679.99 671.79 662.67 652.55    641.37 628.83 614.93 599.54 582.5 563.66 542.85 519.9 494.63 466.87 436.42 403.08 366.64 326.86 283.46 236.15 184.59 128.36 66.968 987.9    984.7 981.43 978.23 975.07 971.91 968.69 965.34 961.78 957.94 953.74 949.11 943.99 938.34 932.12 925.3 917.85 909.7 900.81 891.1 880.48    868.86 856.12 842.15 826.82 810.01 791.61 771.51 749.64 725.89 700.19 672.47 642.61 610.53 576.1 539.16 499.5 456.91 411.17 362.05 309.29      1234.9 1230.9 1226.6 1222.5 1218.4 1214.4 1210.4 1206.3 1202.1 1197.5 1192.7 1187.4 1181.6 1175.3 1168.4 1160.9 1152.8 1144 1134.3 1123.8    1112.3 1099.8 1086 1071.1 1054.7 1036.9 1017.5 996.5 973.81 949.4 923.22 895.22 865.37 833.63 799.93 764.19 726.31 686.17 643.66 598.63    550.96 500.52 447.13 390.51 330.46 1481.9 1477.1 1471.7 1466.4 1461.3 1456.4 1451.5 1446.5 1441.4 1436    1430.2 1423.9 1417 1409.3 1400.9    1391.7 1381.5 1370.3 1358.2 1344.9 1330.5 1314.8 1297.8 1279.3 1259.4 1237.8 1214.7 1189.8 1163.4 1135.3 1105.6 1074.4 1041.6 1007.1 970.93    932.99 893.17 851.37 807.5 761.47 713.18 662.54 609.47 553.88 495.72 434.9 1728.8 1723.2 1716.5 1710 1703.7 1697.7 1691.8 1685.8 1679.6     1672.9 1665.7 1657.7 1648.7 1638.5 1627.1 1614.4 1600.3 1584.9 1567.9 1549.5 1529.5 1508 1484.9 1460 1433.5 1405.3 1375.4 1343.9 1310.8     1276.2 1240.4 1203.2 1164.8 1125.2 1084.3 1042 998.24 952.8 905.55 856.36 805.15 751.88 696.46 638.86 579.01 516.88 1975.8 1969.4 1961.2     1953.1 1945.6 1938.3 1931.2 1923.8 1916.1 1907.6 1898.1 1887.2 1874.7 1860.5 1844.2 1826 1805.5 1783 1758.4 1731.7 1703.1 1672.6 1640.3     1606.1 1570.3 1532.9 1494 1453.7 1412 1369.2 1325.5 1281 1235.8 1190 1143.4 1095.9 1047.4 997.47 946.03 892.87 837.87 780.98 722.17     661.38 598.59 533.76 2222.8 2215.5 2205.5 2195.8 2186.8 2178.2 2169.5 2160.4 2150.5 2139.3 2126.4 2111.3 2093.8 2073.4 2050.1 2023.6 1994     1961.4 1925.9 1888 1847.7 1805.6 1761.7 1716.4 1669.7 1621.9 1573 1523.2 1472.7 1421.8 1370.8 1319.7 1268.6 1217.2 1165.5 1113.2 1060     1005.9 950.4 893.37 834.66 774.2 711.91 647.75 581.79 514.01 2469.8 2461.7 2449.7 2438 2427.4 2417.2  2406.7 2395.3 2382.4 2367.5 2349.8     2329 2304.5 2275.8 2242.8 2205.5 2164.1 2119 2070.8 2019.8 1966.8 1912.1 1856.1 1799.4 1742 1684.2 1626.1 1567.9 1510.2 1453 1396.5     1340.4 1284.7 1229.1 1173.6 1117.6 1061.1 1003.7 945.04 884.96 823.3 759.99 694.98 628.21 559.64 489.36 2716.7 2707.8 2693.5 2679.6 2667.4     2655.4 2642.7 2628.4 2611.6 2591.7 2568 2539.6 2506 2467.1 2422.9 2373.6 2319.8 2262.3 2201.6 2138.6 2073.7 2007.5 1940.4 1872.9 1805.5     1738.3 1671.6 1606 1541.7 1478.9 1417.5 1357.1 1297.5 1238.6 1179.8 1120.9 1061.6 1001.5 940.25 877.67 813.6 748     680.81 611.99 541.5     469.35 2963.7 2953.9 2937.1 2920.7 2906.9 2892.9 2877.5 2859.5 2837.9 2811.9 2780.4 2743.1 2699.8 2650.2 2595 2534.9 2470.9 2403.6 2333.5     2261.2 2187 2111.4 2034.8 1957.7 1880.6 1804.2 1728.5 1654.5 1582.9 1513.9 1447.2 1382.6 1319.4 1257.2 1195.5 1134.1 1072.4 1010.1 946.91]; [pn,minp,maxp,tn,mint,maxt] = premnmx; net=newff,[14,14,1],{'tansig','tansig','purelin'},'trainlm'); net.trainParam.show=5; net.trainParam.epochs=1000; net.trainParam.goal=1e-5; net=init; [net,tr]=train; 源程序如上,求高人告诉该怎么改动。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值