microsoft/simMIM-visualize 視覺化(原創)

microsoft/simMIM原碼:
https://github.com/microsoft/SimMIM/tree/main​​​​​​

microsoft/Swim-Transformer:

https://github.com/microsoft/Swin-Transformer/tree/main

上面兩個repo的模型應該是一樣的,模型使用models/simmim.py,因為都沒有提供視覺化的程式,我使用MAE的代碼改寫,另外Xiang Li等人的UM-MAE也有對應的代碼。

( config請根據模型自己建立一個dict )

( build_simmim(config) 從 models/simmim.py 拿)

simMIM(改寫輸出)

class SimMIM(nn.Module):
    def __init__(self, config, encoder, encoder_stride, in_chans, patch_size):
        super().__init__()
        self.config = config
        self.encoder = encoder
        self.encoder_stride = encoder_stride

        self.decoder = nn.Sequential(
            nn.Conv2d(
                in_channels=self.encoder.num_features,
                out_channels=self.encoder_stride ** 2 * 3, kernel_size=1),
            nn.PixelShuffle(self.encoder_stride),
        )

        self.in_chans = in_chans
        self.patch_size = patch_size

    def forward(self, x, mask):
        z = self.encoder(x, mask)
        x_rec = self.decoder(z)

        mask = mask.repeat_interleave(self.patch_size, 1).repeat_interleave(self.patch_size, 2).unsqueeze(1).contiguous()
        
        # norm target as prompted
        if self.config['MODEL']['norm_target']:
            x = norm_targets(x, self.config['MODEL']['norm_patch_size'])
        
        loss_recon = F.l1_loss(x, x_rec, reduction='none')
        loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans
        """
        注意,這裡要改寫成回傳x_rec作為輸出
        """
        return x_rec,loss

Utilities

from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as T

#自行替換成符合你的資料集的mean與std
image_mean = np.array([0.485, 0.456, 0.406])
image_std = np.array([0.229, 0.224, 0.225])

class MyTransform:
    def __init__(self, config, mask_ratio):
        self.transform_img = T.ToTensor()
        model_patch_size=config['MODEL']['patch_size']
        self.mask_generator = MaskGenerator(
            input_size=config['DATA']['input_size'],
            mask_patch_size=config['DATA']['mask_patch_size'],
            model_patch_size=model_patch_size,
            mask_ratio=mask_ratio,
        )
    
    def __call__(self, img):
        img = self.transform_img(img)
        mask = self.transform_img(self.mask_generator())
        
        return img, mask

def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image * image_std + image_mean) * 255, 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def prepare_model(chkpt_dir):
    # build model
    model = build_simmim(config)
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    rpe_mlp_keys = [k for k in checkpoint['model'].keys() if "rpe_mlp" in k]
    for k in rpe_mlp_keys:
        checkpoint['model'][k.replace('rpe_mlp', 'cpb_mlp')] = checkpoint['model'].pop(k)
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    del checkpoint
    model.eval()
    return model

def run_one_image(img, model,mask_ratio=0.65):
    #transform and mask
    transform = MyTransform(config, mask_ratio)
    x,mask = transform(img)

    # run simMIM
    y,_ = model(x.unsqueeze(dim=0).float(), mask)
    y = y.detach().squeeze(0)
    print(y.shape)

    # visualize the mask
    mask = mask.repeat_interleave(model.patch_size, 1).repeat_interleave(model.patch_size, 2).contiguous()
    im_masked = x * (1 - mask)

    # Reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 4, 1)
    show_image(torch.einsum('chw->hwc', x), "original")

    plt.subplot(1, 4, 2)
    show_image(torch.einsum('chw->hwc', im_masked), "masked")

    plt.subplot(1, 4, 3)
    show_image(torch.einsum('chw->hwc', y), "reconstruction")

    plt.subplot(1, 4, 4)
    show_image(torch.einsum('chw->hwc', im_paste), "reconstruction + visible")

    plt.show()
#如果有checkpoint自行替換路徑
model = prepare_model('../out_dir/pretrain/simMIM_pt_base_192_w6-45.pth')
#Prepare Image
transform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            #T.RandomCrop((192,192)),
            T.RandomResizedCrop((192,192)),
        ])

#準備你要測試的圖片
img = Image.open('../Fabrics/Quixel/001/oi2uhyp_2K_Roughness.jpg')
img = transform(img)
img = np.array(img) / 255.

assert img.shape == (192, 192, 3)

# normalize by ImageNet mean and std
img = img - image_mean
img = img / image_std

plt.rcParams['figure.figsize'] = [3,3]
show_image(torch.tensor(img))
#Visualize
torch.manual_seed(123456)
print('simMIM with pixel reconstruction:')
run_one_image(img,model,mask_ratio=0.65)

轉載請標記出處。

另外我有將程式碼改寫成單卡可在jupyter上跑的simMIM,如果有需要可以詢問。

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值