关于LSTM-FCN程序的部署与运行

前面学习了LSTM-FCN的相关知识,现在针对该框架我们找到了一份代码资源,来通过对实现代码的解读进一步理解该模型。

实验环境

Windows10 python3.6 CUDA8.0 CuDNN5.1
GPU:GeForce GTX 960M

tensorflow-gpu>=1.2.0
keras>=2.0.4
scipy
numpy
pandas
scikit-learn>=0.18.2
h5py
matplotlib
joblib>=0.12

博主运行时使用了GPU加速,这可以大幅提高运行速度,如果没有GPU的话只需要python环境即可,只是CPU运行起来速度确实有些拉跨。当然我的显卡也很差劲,利用率直接拉满。
在这里插入图片描述

项目目录

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

代码实现

初始时我们指定选择使用GPU运行

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

构建LSTM-FCN模型

def generate_lstmfcn(MAX_SEQUENCE_LENGTH, NB_CLASS, NUM_CELLS=8):
    ip = Input(shape=(1, MAX_SEQUENCE_LENGTH))
    x = LSTM(NUM_CELLS)(ip)
#以一定概率丢弃一训练的参数,防止其过拟合
    x = Dropout(0.8)(x)#缀学层
#Permute可以同时多次交换tensor的维度
    y = Permute((2, 1))(ip)
    y = Conv1D(128, 8, padding='same', kernel_initializer='he_uniform')(y)
#批归一化 让我们的均值方差变化没有那么猛烈
    y = BatchNormalization()(y)
    y = Activation('relu')(y)
    y = Conv1D(256, 5, padding='same', kernel_initializer='he_uniform')(y)
    y = BatchNormalization()(y)
    y = Activation('relu')(y)
    y = Conv1D(128, 3, padding='same', kernel_initializer='he_uniform')(y)
    y = BatchNormalization()(y)
    y = Activation('relu')(y)
    y = GlobalAveragePooling1D()(y)
    x = concatenate([x, y])
    out = Dense(NB_CLASS, activation='softmax')(x)
    model = Model(ip, out)
    model.summary()
    # add load model code here to fine-tune
    return model

模型如下图所示:
在这里插入图片描述
区别在于这个模型中加入了注意力机制。

def generate_alstmfcn(MAX_SEQUENCE_LENGTH, NB_CLASS, NUM_CELLS=8):
    ip = Input(shape=(1, MAX_SEQUENCE_LENGTH))
    x = AttentionLSTM(NUM_CELLS)(ip)#注意力机制LSTM
    x = Dropout(0.8)(x)
    y = Permute((2, 1))(ip)
    y = Conv1D(128, 8, padding='same', kernel_initializer='he_uniform')(y)
    y = BatchNormalization()(y)
    y = Activation('relu')(y)
    y = Conv1D(256, 5, padding='same', kernel_initializer='he_uniform')(y)
    y = BatchNormalization()(y)
    y = Activation('relu')(y)
    y = Conv1D(128, 3, padding='same', kernel_initializer='he_uniform')(y)
    y = BatchNormalization()(y)
    y = Activation('relu')(y)
    y = GlobalAveragePooling1D()(y)
    x = concatenate([x, y])
    out = Dense(NB_CLASS, activation='softmax')(x)
    model = Model(ip, out)
    model.summary()
    # add load model code here to fine-tune
    return model

关于模型的具体构建都在layer-utils.py中,这里就不再赘述。
这里指定我们的数据集名称集合,在项目中为方便运行,它使用了循环来执行127个数据集,但我目前还没有成功,希望我能够后期完成吧,该文件主要用于加载数据集中使用,名称顺序与constants.py相同,内分训练集与测试集,数据读取时使用constants.py的数据集目录。

dataset_map = [

                   ('ChlorineConcentration', 2),
                   ('InsectWingbeatSound', 3),
                   ('Lighting7', 4),
                   ('Wine', 5),
                   ('WordsSynonyms', 6),
                   ('50words', 7),
                   ('Beef', 8),
                   ('DistalPhalanxOutlineAgeGroup', 9),
                   ('DistalPhalanxOutlineCorrect', 10),
                   ('DistalPhalanxTW', 11),
                   ('ECG200', 12),
                   ('ECGFiveDays', 13),
                   ('BeetleFly', 14),
                   ('BirdChicken', 15),
                   ('ItalyPowerDemand', 16),
                   ('SonyAIBORobotSurface', 17),
                   ('SonyAIBORobotSurfaceII', 18),
                   ('MiddlePhalanxOutlineAgeGroup', 19),
                   ('MiddlePhalanxOutlineCorrect', 20),
                   ('MiddlePhalanxTW', 21),
                   ('ProximalPhalanxOutlineAgeGroup', 22),
                   ('ProximalPhalanxOutlineCorrect', 23),
                   ('ProximalPhalanxTW', 24),
                   ('MoteStrain', 25),
                   ('MedicalImages', 26),
                   ('Strawberry', 27),
                   ('ToeSegmentation1', 28),
                   ('Coffee', 29),
                   ('Cricket_X', 30),
                   ('Cricket_Y', 31),
                   ('Cricket_Z', 32),
                   ('uWaveGestureLibrary_X', 33),
                   ('uWaveGestureLibrary_Y', 34),
                   ('uWaveGestureLibrary_Z', 35),
                   ('ToeSegmentation2', 36),
                   ('DiatomSizeReduction', 37),
                   ('car', 38),
                   ('CBF', 39),
                   ('CinC_ECG_torso', 40),
                   ('Computers', 41),
                   ('Earthquakes', 42),
                   ('ECG5000', 43),
                   ('ElectricDevices', 44),
                   ('FaceAll', 45),
                   ('FaceFour', 46),
                   ('FacesUCR', 47),
                   ('Fish', 48),
                   ('FordA', 49),
                   ('FordB', 50),
                   ('Gun_Point', 51),
                   ('Ham', 52),
                   ('HandOutlines', 53),
                   ('Haptics', 54),
                   ('Herring', 55),
                   ('InlineSkate', 56),
                   ('LargeKitchenAppliances', 57),
                   ('Lighting2', 58),
                   ('MALLAT', 59),
                   ('Meat', 60),
                   ('NonInvasiveFatalECG_Thorax1', 61),
                   ('NonInvasiveFatalECG_Thorax2', 62),
                   ('OliveOil', 63),
                   ('OSULeaf', 64),
                   ('PhalangesOutlinesCorrect', 65),
                   ('Phoneme', 66),
                   ('plane', 67),
                   ('RefrigerationDevices', 68),
                   ('ScreenType', 69),
                   ('ShapeletSim', 70),
                   ('ShapesAll', 71),
                   ('SmallKitchenAppliances', 72),
                   ('StarlightCurves', 73),
                   ('SwedishLeaf', 74),
                   ('Symbols', 75),
                   ('synthetic_control', 76),
                   ('Trace', 77),
                   ('Patterns', 78),
                   ('TwoLeadECG', 79),
                   ('UWaveGestureLibraryAll', 80),
                   ('wafer', 81),
                   ('Worms', 82),
                   ('WormsTwoClass', 83),
                   ('yoga', 84),
                   ('ACSF1', 85),
                   ('AllGestureWiimoteX', 86),
                   ('AllGestureWiimoteY', 87),
                   ('AllGestureWiimoteZ', 88),
                   ('BME', 89),
                   ('Chinatown', 90),
                   ('Crop', 91),
                   ('DodgerLoopDay', 92),
                   ('DodgerLoopGame', 93),
                   ('DodgerLoopWeekend', 94),
                   ('EOGHorizontalSignal', 95),
                   ('EOGVerticalSignal', 96),
                   ('EthanolLevel', 97),
                   ('FreezerRegularTrain', 98),
                   ('FreezerSmallTrain', 99),
                   ('Fungi', 100),
                   ('GestureMidAirD1', 101),
                   ('GestureMidAirD2', 102),
                   ('GestureMidAirD3', 103),
                   ('GesturePebbleZ1', 104),
                   ('GesturePebbleZ2', 105),
                   ('GunPointAgeSpan', 106),
                   ('GunPointMaleVersusFemale', 107),
                   ('GunPointOldVersusYoung', 108),
                   ('HouseTwenty', 109),
                   ('InsectEPGRegularTrain', 110),
                   ('InsectEPGSmallTrain', 111),
                   ('MelbournePedestrian', 112),
                   ('MixedShapesRegularTrain', 113),
                   ('MixedShapesSmallTrain', 114),
                   ('PickupGestureWiimoteZ', 115),
                   ('PigAirwayPressure', 116),
                   ('PigArtPressure', 117),
                   ('PigCVP', 118),
                   ('PLAID', 119),
                   ('PowerCons', 120),
                   ('Rock', 121),
                   ('SemgHandGenderCh2', 122),
                   ('SemgHandMovementCh2', 123),
                   ('SemgHandSubjectCh2', 124),
                   ('ShakeGestureWiimoteZ', 125),
                   ('SmoothSubspace', 126),
                   ('UMD', 127)
                   ]

下面是具体实现的伪代码:

MODELS = [
        ('lstmfcn', generate_lstmfcn),
        ('alstmfcn', generate_alstmfcn),
    ]#指定两个模型
for model_id, (MODEL_NAME, model_fn) in enumerate(MODELS):#两个模型循环调用
	if not os.path.exists()#判断记录文件是否存在并打开准备写入;
	for dname, did in dataset_map:#循环数据集目录开始读取数据集
		load_data()#加载数据并完成预处理
		train_model()#训练模型
		evaluate_model()#评估模型
		写入实验结果;
关闭文件
画图	

实验结果

dataset_id,dataset_name,dataset_name_,test_accuracy
0,Adiac,lstmfcn_8_cells_weights/Adiac,0.849105
1,ArrowHead,lstmfcn_8_cells_weights/ArrowHead,0.822857
2,data/ChlorineConcentration,lstmfcn_8_cells_weights/ChlorineConcentration,0.821354

存在问题

  • 循环使用数据集仍未解决
  • InsectWingbeatSound数据集存在过拟合问题,目前在考虑是模型问题还是数据集划分问题,正在调参中。过拟合如图所示:
    出现过拟合问题,训练集准确性达到100%,但测试集最高到达46%后又从0开始上升到10%,经典的过拟合问题,开始啥都不会,后来随着学习逐渐
    掌握并朝着好的方向发展,但却开始一些细枝末节的学习,导致准确性下降

在这里插入图片描述

  • 3
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彭祥.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值