2021-09-13

tensorflow-keras使用tips(自用)

0x01 模型的保存和导入

0 . 定义模型(以随便一个模型为例)(然而实际上是在被迫读CW攻击时 从论文作者那里白嫖来的NN结构代码)

为啥会想要记录,是因为这个比较奇葩,自定义了loss函数名字叫fn,然后load的时候找不到fn这个损失函数。困扰很久(整整一个下午!),最后晚课下课终于找到了原因(tmd!),连夜在教室吹着台风把它记下来,顺便记一下load和save的相关代码。

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.optimizers import SGD

import tensorflow as tf

tf = tf.compat.v1

def get_model(data, file_name, params, num_epochs=50, batch_size=128, train_temp=1, init=None):
	"""
	1. init是load的文件路径名,是文件夹的名字
	2. file_name是save的文件夹名字
	"""

    model = Sequential()

    print(data.train_data.shape)

    model.add(Conv2D(params[0], (3, 3),
                     input_shape=data.train_data.shape[1:]))
    model.add(Activation('relu'))
    model.add(Conv2D(params[1], (3, 3)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Conv2D(params[2], (3, 3)))
    model.add(Activation('relu'))
    model.add(Conv2D(params[3], (3, 3)))
    model.add(Activation('relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Flatten())
    model.add(Dense(params[4]))
    model.add(Activation('relu'))
    model.add(Dropout(0.5))
    model.add(Dense(params[5]))
    model.add(Activation('relu'))
    model.add(Dense(10))

    def fn(correct, predicted):
        return tf.nn.softmax_cross_entropy_with_logits(labels=correct,
                                                       logits=predicted / train_temp)

    sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)

    model.compile(loss=fn,
                  optimizer=sgd,
                  metrics=['accuracy'])

    model.fit(data.train_data, data.train_labels,
              batch_size=batch_size,
              validation_data=(data.validation_data, data.validation_labels),
              epochs=num_epochs,
              # nb_epoch=num_epochs,
              shuffle=True)
              
    return model

2 . save和load:导入load需要的包

from tensorflow.keras.models
 import load_model

3 . save

	# 在model.fit(...)的后边
    if file_name != None: # 如果传了保存的文件夹的名字就会保存
        model.save(file_name)
    # 这样会保存为一个文件夹,下面有asset, blahblah的文件

4 . load_model

    if init != None: # 如果传了load文件夹的名字 就会load
       print("load weights from {}".format(init))
       model = load_model(init, custom_objects={'fn': fn})
       # 重点在 custom_objects={'fn': fn}! 
       # 只有加了这么个字典,才不会报Unknow blahblah fn 的错!tmd!
       print("loaded!")
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值