# create model
model = create_model(num_classes=num_classes)
model.build((32, img_size, img_size, 3))
# 下载我提前转好的预训练权重
# 链接: https://pan.baidu.com/s/1cHVwia2i3wD7-0Ueh2WmrQ 密码: sq8c
# load weights
pre_weights_path = './swin_small_patch4_window7_224.h5'
assert os.path.exists(pre_weights_path), "cannot find {}".format(pre_weights_path)
model.load_weights(pre_weights_path, by_name=True, skip_mismatch=True)
# freeze bottom layers
if freeze_layers:
for layer in model.layers:
if "head" not in layer.name:
layer.trainable = False
else:
print("training {}".format(layer.name))
model.summary()
input = np.random.rand(32, 244, 244, 3)
print(input.shape)
out = model(input)
print(out.shape)
net_size_image = cv2.resize(rgb_image, (img_size, img_size))
net_image = tf.expand_dims(net_size_image, axis=0)
net_image = tf.cast(net_image,dtype=tf.float32)
tf_net_image = tf.convert_to_tensor(net_image, dtype=tf.float32)
output = model(tf_net_image, training=False)
output = np.array(output)
output = output.flatten()
print(output)