tensorflow自定义训练常用方法参数记录

最近需要复现某篇文章的神经网络结构,发现论文中没有给出代码或者是仅给出少量关键的代码,该如何下手准备复现呢
1.准备数据集+提取特征
数据集一般是论文提供或者自己搜集,提取特征的方法一般论文会给出,按给出的方法从数据集中提取即可。需要注意的是提取的特征最好是以numpy的形式存储,numpy的各种函数能比较方便的处理数据。此外最好在输入进神经网络前进行手工打乱,虽然model.fit()方法里也有shuffle选项,但该选项在validation_split参数生效之后,因此如果你的数据本身是有序的,需要先手工打乱再指定validation_split。

def shuffle_data(data_x, data_y, seed=207):
    # 对输入的numpy数组进行手动打乱
    np.random.seed(seed)
    np.random.shuffle(data_x)
    np.random.seed(seed)
    np.random.shuffle(data_y)
    return data_x, data_y

2.准备神经网络
TensorFlow搭建神经网络还是比较方便的,很多常用的结构tensorflow.keras.layers里都有集成的方法,需要对应的网络结构时直接调用就行。注意某些老版本的代码会直接调用keras包,程序会报错no moudle named keras,这时候把keras换成tensorflow.keras或者tensorflow.python.keras即可解决问题。
如我需要搭建如下图的神经网络,其中有些层的细节比如激活函数等会在论文里提到,最终生成的神经网络结果是这个样子
在这里插入图片描述

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Embedding, Conv1D, multiply, GlobalMaxPool1D, Input, Activation


def Malconv(max_len=200000, win_size=500, vocab_size=256):
    inp = Input((max_len,))
    emb = Embedding(vocab_size, 8)(inp)

    conv1 = Conv1D(kernel_size=(win_size), filters=128, strides=(win_size), padding='same')(emb)
    conv2 = Conv1D(kernel_size=(win_size), filters=128, strides=(win_size), padding='same')(emb)
    a = Activation('sigmoid', name='sigmoid')(conv2)

    mul = multiply([conv1, a])
    a = Activation('relu', name='relu')(mul)
    p = GlobalMaxPool1D()(a)
    d = Dense(64)(p)
    out = Dense(1, activation='sigmoid')(d)

    return Model(inp, out)

最近在使用的时候感觉TensorFlow在某些地方做了兼容,比如输入维度和实际设定的维度不一时也能运行,不过TensorFlow会发出报警

3.训练数据

从刚刚设定好的model中导入神经网络,设置loss,optimizer和metrics,一般metrics常用的就是acc(准确率)和ce(交叉熵),注意这两个可以放在一起使用
loss,optimizer,metrics介绍:
http://www.kaotop.com/it/18701.html

from Malconv import Malconv
model = Malconv()
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['acc','ce'])

然后就是把数据x和标签y输入到模型中,通过fit方法开始训练。这里介绍一些常用的参数
batch_size:控制一次输入进模型里的数据大小,默认为32,如果数据量过大的话可以把这个值调小
epochs:控制训练次数
shuffle:是否打乱,一般都是True
callbacks:设置回调函数,用于监控等目的
validation_split:自动分割输入的x,y作为验证集
workers:并行数量,一般设置为cpu总核数或者总核数的一半,和batch_size一样,设置小了训练慢,设置大了容易炸

  history = model.fit(x, y,batch_size=32,epochs=10000,shuffle=True,callbacks=callbacks,validation_split=0.2,workers=4)

如果不使用validation_split,那么就需要使用validation_data来指定测试数据

  history = model.fit(x, y,batch_size=32,epochs=10000,shuffle=True,callbacks=callbacks,validation_data=(x_val,y_val),workers=4)

其他的一些参数介绍:
https://blog.csdn.net/Aibiabcheng/article/details/117453782

然后是callback函数,一般常用的就是监控loss有没有持续下降,需不需要提前终止训练,比如这段代码就是监控训练数据上如果loss在early_stopping_rounds里没有持续下降那么就结束训练

callbacks = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    min_delta=1e-3,
    patience=early_stopping_rounds,
    verbose=0,
    mode='min',
    baseline=None,
    restore_best_weights=False
)

其他一些常用callback函数介绍:
https://blog.csdn.net/m0_47256162/article/details/122287628

最后是当训练完成后,需要预测数据和查看训练历史结果
预测数据:使用predict方法

result=model.predict(test_x)

训练的历史结果作为返回值保存在之前的history中,通过history.history可以获得一个字典,里面存储着loss等指标的变化

print(history.history)
  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
TensorFlow中,可以通过自定义损失函数来训练模型。自定义损失函数可以根据具体的问题和需求来设计,以更好地适应模型的训练目标。 下面是一个使用自定义损失函数训练模型的示例代码: ```python import tensorflow as tf def custom_loss(y_true, y_pred): # 自定义损失函数的计算逻辑 loss = tf.square(y_true - y_pred) # 这里以平方差作为损失函数 return loss if __name__ == "__main__": # 定义输入和输出张量 x = tf.constant(\[1., 2., 3.\]) y_true = tf.constant(\[4., 5., 6.\]) # 定义模型 y_pred = tf.Variable(\[0., 0., 0.\]) # 定义损失函数 loss = custom_loss(y_true, y_pred) # 创建一个优化器 optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) # 定义训练操作 train_op = optimizer.minimize(loss) # 创建一个会话并运行训练操作 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) for i in range(100): sess.run(train_op) # 打印训练结果 print("Final prediction:", y_pred.eval()) ``` 在上述代码中,我们定义了一个自定义损失函数`custom_loss`,并使用该损失函数来计算模型的损失。然后,我们使用梯度下降优化器来最小化损失,并进行模型的训练。最后,我们打印出训练结果。 请注意,这只是一个简单的示例,实际中的自定义损失函数可能会更加复杂,根据具体的问题和需求进行设计。 #### 引用[.reference_title] - *1* *2* *3* [TensorFlow自定义损失函数](https://blog.csdn.net/sinat_29957455/article/details/78369763)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值