存一些代码

   
    def generate_span_cutoff_embedding(self, embeds, input_lens, masks=None):
        seq, bs, _ = embeds.shape  
        cutoff_lengths = (input_lens * self.aug_cutoff_ratio).int()
        max_starts = torch.clamp(input_lens - cutoff_lengths, min=1)
        starts = (torch.rand(bs).to(self.device) * max_starts).int()
        random_masks = torch.ones((seq, bs), dtype=torch.float, )
        for i in range(bs):
            start = starts[i]
            cutoff_length = cutoff_lengths[i]
            random_masks[start:start + cutoff_length, i] = 0 

        random_masks = random_masks.unsqueeze(-1).to(self.device)
        input_embeds = embeds * random_masks 
        input_masks = masks * random_masks if masks is not None else None 

        return input_embeds, input_masks

    def generate_token_cutoff_embedding(self, embeds, input_lens, masks=None,):
        seq_length, bs, _ = embeds.shape
        cutoff_lengths = torch.clamp((input_lens * self.aug_cutoff_ratio).int(), min=0)
        zero_indices = [torch.randperm(input_lens[i])[:cutoff_lengths[i]].to(self.device) for i in range(bs)]
        random_masks = torch.ones((seq_length, bs), dtype=torch.float, device=self.device)
        for i in range(bs):
            random_masks[zero_indices[i], i] = 0
        random_masks = random_masks.unsqueeze(-1)  
            
        input_embeds = embeds * random_masks 
        input_masks = masks * random_masks  if masks is not None else None
        return input_embeds, input_masks

    def generate_dim_cutoff_embedding(self, embeds, masks=None,):
        _, bs, embedding_dim = embeds.shape
        cutoff_lengths = int(embedding_dim * self.aug_cutoff_ratio)
        zero_indices = [torch.randperm(embedding_dim)[:cutoff_lengths].to(self.device) for _ in range(bs)]
        random_masks = torch.ones((1, bs, embedding_dim), dtype=torch.float, device=self.device)
        for i in range(bs):
            random_masks[:, i, zero_indices[i]] = 0
        input_embeds = embeds * random_masks
        input_masks = masks * random_masks if masks is not None else None
        return input_embeds, input_masks

    # def contrastive_loss(self, feature1, feature2, temperature=0.07):  
    #     bs = feature1.size(0)  
    #     similarity_matrix = F.cosine_similarity(feature1.unsqueeze(1), feature2.unsqueeze(0), dim=2) / temperature  
    #     if bs == 1:
    #         return -similarity_matrix.mean()
    #     labels = torch.arange(bs, dtype=torch.long).to(self.device)  
    #     positive_mask = torch.eq(labels, labels.view(-1, 1))  
    #     pos = similarity_matrix[positive_mask].view(bs, -1)
    #     negative_mask = ~positive_mask
    #     neg = similarity_matrix[negative_mask].view(bs, -1)
    #     log_neg = torch.logsumexp(neg, dim=1) + 1e-28
    #     loss = -(pos - log_neg).mean()
    #     return loss

    # def contrastive_loss2(self, feature1, feature2, temperature=0.07):
    #     bs = feature1.size(0)
    #     similarity_matrix = F.cosine_similarity(feature1.unsqueeze(1), feature2.unsqueeze(0), dim=2) / temperature
    #     if bs == 1:
    #         return -similarity_matrix.mean()
    #     positive_mask = torch.eye(bs, dtype=torch.bool).to(self.device)
    #     pos = similarity_matrix[positive_mask]
    #     similarity_matrix2 = F.cosine_similarity(feature1.unsqueeze(1), feature1.unsqueeze(0), dim=2) / temperature
    #     neg = torch.cat([similarity_matrix[~positive_mask].view(bs, -1), 
    #                     similarity_matrix2[~positive_mask].view(bs, -1)], dim=1)
    #     log_neg = torch.logsumexp(neg, dim=1) + 1e-28
    #     loss = -(pos - log_neg).mean()
    #     return loss
    
    
    def contrastive_loss(self, z_i, z_j, temperature=0.07):
        """
        We do not sample negative examples explicitly.
        Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
        """
        batch_size = z_i.shape[0]
        N = 2 * batch_size
         
        z = torch.cat((z_i, z_j), dim=0)
        sim = self.contrastive_similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / temperature
       
        sim_i_j = torch.diag(sim, batch_size)
        sim_j_i = torch.diag(sim, -batch_size)  
        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
        
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        negative_samples = sim[mask].reshape(N, -1)
        
        labels = torch.zeros(N).to(positive_samples.device).long()
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.contrastive_criterion(logits, labels)
        loss /= N
        return loss


    

    

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值