def generate_span_cutoff_embedding(self, embeds, input_lens, masks=None):
seq, bs, _ = embeds.shape
cutoff_lengths = (input_lens * self.aug_cutoff_ratio).int()
max_starts = torch.clamp(input_lens - cutoff_lengths, min=1)
starts = (torch.rand(bs).to(self.device) * max_starts).int()
random_masks = torch.ones((seq, bs), dtype=torch.float, )
for i in range(bs):
start = starts[i]
cutoff_length = cutoff_lengths[i]
random_masks[start:start + cutoff_length, i] = 0
random_masks = random_masks.unsqueeze(-1).to(self.device)
input_embeds = embeds * random_masks
input_masks = masks * random_masks if masks is not None else None
return input_embeds, input_masks
def generate_token_cutoff_embedding(self, embeds, input_lens, masks=None,):
seq_length, bs, _ = embeds.shape
cutoff_lengths = torch.clamp((input_lens * self.aug_cutoff_ratio).int(), min=0)
zero_indices = [torch.randperm(input_lens[i])[:cutoff_lengths[i]].to(self.device) for i in range(bs)]
random_masks = torch.ones((seq_length, bs), dtype=torch.float, device=self.device)
for i in range(bs):
random_masks[zero_indices[i], i] = 0
random_masks = random_masks.unsqueeze(-1)
input_embeds = embeds * random_masks
input_masks = masks * random_masks if masks is not None else None
return input_embeds, input_masks
def generate_dim_cutoff_embedding(self, embeds, masks=None,):
_, bs, embedding_dim = embeds.shape
cutoff_lengths = int(embedding_dim * self.aug_cutoff_ratio)
zero_indices = [torch.randperm(embedding_dim)[:cutoff_lengths].to(self.device) for _ in range(bs)]
random_masks = torch.ones((1, bs, embedding_dim), dtype=torch.float, device=self.device)
for i in range(bs):
random_masks[:, i, zero_indices[i]] = 0
input_embeds = embeds * random_masks
input_masks = masks * random_masks if masks is not None else None
return input_embeds, input_masks
# def contrastive_loss(self, feature1, feature2, temperature=0.07):
# bs = feature1.size(0)
# similarity_matrix = F.cosine_similarity(feature1.unsqueeze(1), feature2.unsqueeze(0), dim=2) / temperature
# if bs == 1:
# return -similarity_matrix.mean()
# labels = torch.arange(bs, dtype=torch.long).to(self.device)
# positive_mask = torch.eq(labels, labels.view(-1, 1))
# pos = similarity_matrix[positive_mask].view(bs, -1)
# negative_mask = ~positive_mask
# neg = similarity_matrix[negative_mask].view(bs, -1)
# log_neg = torch.logsumexp(neg, dim=1) + 1e-28
# loss = -(pos - log_neg).mean()
# return loss
# def contrastive_loss2(self, feature1, feature2, temperature=0.07):
# bs = feature1.size(0)
# similarity_matrix = F.cosine_similarity(feature1.unsqueeze(1), feature2.unsqueeze(0), dim=2) / temperature
# if bs == 1:
# return -similarity_matrix.mean()
# positive_mask = torch.eye(bs, dtype=torch.bool).to(self.device)
# pos = similarity_matrix[positive_mask]
# similarity_matrix2 = F.cosine_similarity(feature1.unsqueeze(1), feature1.unsqueeze(0), dim=2) / temperature
# neg = torch.cat([similarity_matrix[~positive_mask].view(bs, -1),
# similarity_matrix2[~positive_mask].view(bs, -1)], dim=1)
# log_neg = torch.logsumexp(neg, dim=1) + 1e-28
# loss = -(pos - log_neg).mean()
# return loss
def contrastive_loss(self, z_i, z_j, temperature=0.07):
"""
We do not sample negative examples explicitly.
Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
"""
batch_size = z_i.shape[0]
N = 2 * batch_size
z = torch.cat((z_i, z_j), dim=0)
sim = self.contrastive_similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / temperature
sim_i_j = torch.diag(sim, batch_size)
sim_j_i = torch.diag(sim, -batch_size)
positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
mask = torch.ones((N, N), dtype=bool)
mask = mask.fill_diagonal_(0)
for i in range(batch_size):
mask[i, batch_size + i] = 0
mask[batch_size + i, i] = 0
negative_samples = sim[mask].reshape(N, -1)
labels = torch.zeros(N).to(positive_samples.device).long()
logits = torch.cat((positive_samples, negative_samples), dim=1)
loss = self.contrastive_criterion(logits, labels)
loss /= N
return loss
存一些代码
最新推荐文章于 2024-09-14 19:55:48 发布