ECCV2022细粒度图像检索SEMICON代码学习记录

代码链接:GitHub - aassxun/SEMICON

环境配置

# 创建&激活虚拟环境
conda create -n semicon python==3.8.5
conda activate semicon

# 安装相关依赖包 (该 pytorch 为无 gpu 版本)
conda install pytorch==1.10.0 torchvision==0.11.1 torchaudio==0.10.0 cpuonly -c pytorch
pip install numpy==1.19.2
pip install loguru==0.5.3
pip install tqdm==4.54.1
pip install pandas
pip install scipy

需要将 SEMICON_train.pySEMICON.pyHash_mAP.pybaseline_train.pybaseline.py 中的import models.resnet as resnet 和 from models.resnet import *改为 import models.resnet_torch as resnet 和 from models.resnet_torch import *

下载CUB_200_2011数据集

参考博客:CUB-200-2011鸟类数据集的下载与使用pytorch加载_景唯acr的博客-CSDN博客_cub200-2011

代码运行

1)训练

python run.py --dataset cub-2011 --root /dataset/CUB2011/CUB_200_2011 --max-epoch 30 --gpu 0 --arch semicon --batch-size 16 --max-iter 40 --code-length 12,24,32,48 --lr 2.5e-4 --wd 1e-4 --optim SGD --lr-step 40 --num-samples 2000 --info 'CUB-SEMICON' --momen=0.91

2)测试

python run.py --dataset cub-2011 --root /dataset/CUB2011/CUB_200_2011 --gpu 0 --arch test --batch-size 16 --code-length 12,24,32,48 --wd 1e-4 --info 'CUB-SEMICON'

如果不想使用 gpu,将参数 --gpu 设为False 即可

代码学习

1)固定随机种子

与 YOLO-X 类似,将随机种子进行固定,后续实验将在此固定的随机种子下进行 (如消融实验等),增强了模型的可复现性 (但我觉得也只是仅限于特定的随机数,换另一个随机数可能结果又不一样了)。

torch.backends.cudnn.deterministic 和 torch.backends.cudnn.benchmark:前者可以保证每次运行网络的时候相同输入的输出是固定的,后者为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。适用场景是网络结构固定,网络的输入形状(包括 batch size,图片大小,输入的通道)是不变的,其实也就是一般情况下都比较适用。反之,如果卷积层的设置一直变化,将会导致程序不停地做优化,反而会耗费更多的时间。

参考博客:

【pytorch】torch.backends.cudnn.deterministic_Xhfei1224的博客-CSDN博客_torch.backends.cudnn.deter

torch.backends.cudnn.benchmark_Wanderer001的博客-CSDN博客_torch.backends.cudnn.benchmark

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(68)

2)数据加载

对应脚本:data/cub_2011.py

函数 load_data() 会返回 3 个 dataloader:query_dataloadertrain_dataloaderretrieval_dataloader

# 划分训练、测试集
Cub2011.init(root)
# 定义查询、训练及检索数据集,涉及的数据增强在 data/transform.py 中
query_dataset = Cub2011(root, 'query', query_transform())
train_dataset = Cub2011(root, 'train', train_transform())
retrieval_dataset = Cub2011(root, 'retrieval', query_transform())
class Cub2011(Dataset):

    def __init__(self, root, mode, transform=None, loader=default_loader):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.loader = default_loader

        if mode == 'train':
            self.data = Cub2011.TRAIN_DATA
            self.targets = Cub2011.TRAIN_TARGETS
        elif mode == 'query':
            self.data = Cub2011.QUERY_DATA
            self.targets = Cub2011.QUERY_TARGETS
        elif mode == 'retrieval':
            self.data = Cub2011.RETRIEVAL_DATA
            self.targets = Cub2011.RETRIEVAL_TARGETS
        else:
            raise ValueError(r'Invalid arguments: mode, can\'t load dataset!')

    @staticmethod
    def init(root):
        images = pd.read_csv(os.path.join(root, 'images.txt'), sep=' ',
                             names=['img_id', 'filepath'])
        image_class_labels = pd.read_csv(os.path.join(root, 'image_class_labels.txt'),
                                         sep=' ', names=['img_id', 'target'])
        train_test_split = pd.read_csv(os.path.join(root, 'train_test_split.txt'),
                                       sep=' ', names=['img_id', 'is_training_img'])

        data = images.merge(image_class_labels, on='img_id')
        all_data = data.merge(train_test_split, on='img_id')
        all_data['filepath'] = 'images/' + all_data['filepath']
        train_data = all_data[all_data['is_training_img'] == 1]
        test_data = all_data[all_data['is_training_img'] == 0]

        # Split dataset
        Cub2011.QUERY_DATA = test_data['filepath'].to_numpy()
        Cub2011.QUERY_TARGETS = encode_onehot((test_data['target'] - 1).tolist(), 200)

        Cub2011.TRAIN_DATA = train_data['filepath'].to_numpy()
        Cub2011.TRAIN_TARGETS = encode_onehot((train_data['target'] - 1).tolist(), 200)

        Cub2011.RETRIEVAL_DATA = train_data['filepath'].to_numpy()
        Cub2011.RETRIEVAL_TARGETS = encode_onehot((train_data['target'] - 1).tolist(), 200)

    def get_onehot_targets(self):
        return torch.from_numpy(self.targets).float()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = Image.open(os.path.join(self.root, self.data[idx])).convert('RGB')

        if self.transform is not None:
            img = self.transform(img)

        return img, self.targets[idx], idx

3)网络训练

主干网络

这里只使用了 resnet50 的前三个 layer,具体可查看 models/SEMICON.py 中的 ResNet_Backbone 类

model = ResNet_Backbone(Bottleneck, [3, 4, 6], **kwargs)

全局/局部转换网络

class ResNet_Refine(nn.Module):

    def __init__(self, block, layer, is_local=True, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, norm_layer=None):
        super(ResNet_Refine, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 1024
        self.dilation = 1

        self.is_local = is_local
        self.groups = groups
        self.base_width = width_per_group
        self.layer4 = self._make_layer(block, 512, layer, stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                norm_layer=norm_layer))
            layers.append(ChannelTransformer(planes * block.expansion, max(planes * block.expansion // 64, 16)))
        return nn.Sequential(*layers)

    def _forward_impl(self, x):
        x = self.layer4(x)

        pool_x = self.avgpool(x)
        pool_x = torch.flatten(pool_x, 1)
        if self.is_local:
            return x, pool_x
        else:
            return pool_x

    def forward(self, x):
        return self._forward_impl(x)

SEM

class SEM(nn.Module):

    def __init__(self, block, layer, att_size=4, num_classes=1000, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None,
                 norm_layer=None):
        super(SEM, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 1024
        self.dilation = 1
        self.att_size = att_size
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group

        self.layer4 = self._make_layer(block, 512, layer, stride=1)

        self.feature1 = nn.Sequential(
            conv1x1(self.inplanes, 1),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True),
        )
        self.feature2 = nn.Sequential(
            conv1x1(self.inplanes, 1),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True)
        )
        self.feature3 = nn.Sequential(
            conv1x1(self.inplanes, 1),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True)
        )
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        att_expansion = 0.25
        layers = []
        layers.append(block(self.inplanes, int(self.inplanes * att_expansion), stride,
                            downsample, self.groups, self.base_width, previous_dilation, norm_layer))
        for _ in range(1, blocks):
            layers.append(nn.Sequential(
                conv1x1(self.inplanes, int(self.inplanes * att_expansion)),
                nn.BatchNorm2d(int(self.inplanes * att_expansion))
            ))
            self.inplanes = int(self.inplanes * att_expansion)
            layers.append(block(self.inplanes, int(self.inplanes * att_expansion), groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer))
        return nn.Sequential(*layers)

    def _mask(self, feature, x):
        with torch.no_grad():
            cam1 = feature.mean(1)
            attn = torch.softmax(cam1.view(x.shape[0], x.shape[2] * x.shape[3]), dim=1)#B,H,W
            std, mean = torch.std_mean(attn)
            attn = (attn - mean) / (std ** 0.3) + 1 #0.15
            attn = (attn.view((x.shape[0], 1, x.shape[2], x.shape[3]))).clamp(0, 2)
        return attn

    def _forward_impl(self, x):
        x = self.layer4(x)#bs*64*14*14
        fea1 = self.feature1(x) #bs*1*14*14
        attn = 2-self._mask(fea1, x)

        x = x.mul(attn.repeat(1, self.inplanes, 1, 1))
        fea2 = self.feature2(x)
        attn = 2-self._mask(fea2, x)

        x = x.mul(attn.repeat(1, self.inplanes, 1, 1))
        fea3 = self.feature3(x)

        x = torch.cat([fea1,fea2,fea3], dim=1)
        return x

    def forward(self, x):
        return self._forward_impl(x)

ICON

class ChannelTransformer(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.head_dim = head_dim
        self.norm = nn.BatchNorm2d(dim)
        self.relu = nn.ReLU(inplace=True)

        self.qkv = nn.Conv2d(dim, dim * 3, 1, groups=num_heads)
        self.qkv2 = nn.Conv2d(dim, dim * 3, 1, groups=head_dim)

    def forward(self, x):
        B, C, H, W = x.shape
        qkv = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, H * W).transpose(0, 1)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = torch.sign(attn) * torch.sqrt(torch.abs(attn) + 1e-5)
        attn = attn.softmax(dim=-1)

        x = ((attn @ v).reshape(B, C, H, W) + x).reshape(B, self.num_heads, self.head_dim, H, W).transpose(1, 2).reshape(B, C, H, W)
        y = self.norm(x)
        x = self.relu(y)
        qkv2 = self.qkv2(x).reshape(B, 3, self.head_dim, self.num_heads, H * W).transpose(0, 1)
        q, k, v = qkv2[0], qkv2[1], qkv2[2]

        attn = (q @ k.transpose(-2, -1)) * (self.num_heads ** -0.5)
        attn = torch.sign(attn) * torch.sqrt(torch.abs(attn) + 1e-5)
        attn = attn.softmax(dim=-1)
        x = (attn @ v).reshape(B, self.head_dim, self.num_heads, H, W).transpose(1, 2).reshape(B, C, H, W) + y
        return x

损失函数

class ADSH_Loss(nn.Module):
    def __init__(self, code_length, gamma):
        super(ADSH_Loss, self).__init__()
        self.code_length = code_length
        self.gamma = gamma

    def forward(self, F, B, S, omega):
        hash_loss = ((self.code_length * S - F @ B.t()) ** 2).sum() / (F.shape[0] * B.shape[0]) / self.code_length * 12
        quantization_loss = ((F - B[omega, :]) ** 2).sum() / (F.shape[0] * B.shape[0]) * self.gamma / self.code_length * 12

        loss = hash_loss + quantization_loss
        return loss, hash_loss, quantization_loss

4)网络测试

def valid(query_dataloader, train_dataloader, retrieval_dataloader, code_length, args):
    num_classes, att_size, feat_size = args.num_classes, 1, 2048
    model = SEMICON.semicon(code_length=code_length, num_classes=num_classes, att_size=att_size, feat_size=feat_size,
                            device=args.device, pretrained=True)

    model.to(args.device)
    
    model.load_state_dict(torch.load('./checkpoints/' + args.info + '/' + str(code_length) + '/model.pkl'), strict=False)
    model.eval()
    query_code = torch.load('./checkpoints/' + args.info + '/' + str(code_length) + '/query_code.t')
    query_code = query_code.to(args.device)
    query_dataloader.dataset.get_onehot_targets = torch.load('./checkpoints/' + args.info + '/' + str(code_length) + '/query_targets.t')
    B = torch.load('./checkpoints/' + args.info + '/' + str(code_length) + '/database_code.t')
    B = B.to(args.device)
    retrieval_targets = torch.load('./checkpoints/' + args.info + '/' + str(code_length) + '/database_targets.t')
    retrieval_targets = retrieval_targets.to(args.device)

    
    mAP = evaluate.mean_average_precision(
        query_code.to(args.device),
        B,
        query_dataloader.dataset.get_onehot_targets().to(args.device),
        retrieval_targets,
        args.device,
        args.topk,
    )
    print("Code_Length: " + str(code_length), end="; ")
    print('[mAP:{:.5f}]'.format(mAP))

5)网络结构 (onnx model)

  • 4
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 6
    评论
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

chen_znn

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值