最近在尝试复现SimMIM,但是总也不成功,论文里重建和预测的图像看着都尚可
论文中重建和预测的图是:
我是用github里的参数文件simmim_pretrain__swin_large__img192_window12__800ep.pth
复现后是:
#代码如下:(模型部分并没有改动)
def run_one_image(img, mask, model):
# x = torch.tensor(img)
#
# # make it a batch-like
# x = x.unsqueeze(dim=0)
# x = torch.einsum('nhwc->nchw', x)
# run MAE
y,loss,mask = model(img, mask)
# y = model.unpatchify(y)
y = torch.einsum('nchw->nhwc', y).detach().cpu()
# visualize the mask
# mask = mask.detach()
# mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0] ** 2 * 3) # (N, H*W, p*p*3)
# mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
x = torch.einsum('nchw->nhwc', img)
# masked image
im_masked = x * (1 - mask)
# MAE reconstruction pasted with visible patches
im_paste = x * (1 - mask) + y * mask
# make the plt figure larger
plt.rcParams['figure.figsize'] = [24, 24]
plt.subplot(1, 4, 1)
show_image(x[0], "original")
plt.subplot(1, 4, 2)
show_image(im_masked[0], "masked")
plt.subplot(1, 4, 3)
show_image(y[0], "reconstruction")
plt.subplot(1, 4, 4)
show_image(im_paste[0], "reconstruction + visible")
plt.show()