tensorflow-keras使用tips(自用)
0x01 模型的保存和导入
0 . 定义模型(以随便一个模型为例)(然而实际上是在被迫读CW攻击时 从论文作者那里白嫖来的NN结构代码)
为啥会想要记录,是因为这个比较奇葩,自定义了loss函数名字叫fn,然后load的时候找不到fn这个损失函数。困扰很久(整整一个下午!),最后晚课下课终于找到了原因(tmd!),连夜在教室吹着台风把它记下来,顺便记一下load和save的相关代码。
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Activation, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D
from tensorflow.keras.optimizers import SGD
import tensorflow as tf
tf = tf.compat.v1
def get_model(data, file_name, params, num_epochs=50, batch_size=128, train_temp=1, init=None):
"""
1. init是load的文件路径名,是文件夹的名字
2. file_name是save的文件夹名字
"""
model = Sequential()
print(data.train_data.shape)
model.add(Conv2D(params[0], (3, 3),
input_shape=data.train_data.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(params[1], (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(params[2], (3, 3)))
model.add(Activation('relu'))
model.add(Conv2D(params[3], (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(params[4]))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(params[5]))
model.add(Activation('relu'))
model.add(Dense(10))
def fn(correct, predicted):
return tf.nn.softmax_cross_entropy_with_logits(labels=correct,
logits=predicted / train_temp)
sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss=fn,
optimizer=sgd,
metrics=['accuracy'])
model.fit(data.train_data, data.train_labels,
batch_size=batch_size,
validation_data=(data.validation_data, data.validation_labels),
epochs=num_epochs,
# nb_epoch=num_epochs,
shuffle=True)
return model
2 . save和load:导入load需要的包
from tensorflow.keras.models
import load_model
3 . save
# 在model.fit(...)的后边
if file_name != None: # 如果传了保存的文件夹的名字就会保存
model.save(file_name)
# 这样会保存为一个文件夹,下面有asset, blahblah的文件
4 . load_model
if init != None: # 如果传了load文件夹的名字 就会load
print("load weights from {}".format(init))
model = load_model(init, custom_objects={'fn': fn})
# 重点在 custom_objects={'fn': fn}!
# 只有加了这么个字典,才不会报Unknow blahblah fn 的错!tmd!
print("loaded!")