由于argmax操作不可导,可用gumbel-softmax操作来替代
用它获得的onehot向量作为mask可以实现argmax效果
使用随机采样的同时还要保证梯度可以回传
torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=- 1)
Gumbel-Softmax 分布中与
logits
形状相同的采样张量。如果是hard=True
,则返回的样本将为 one-hot,否则它们将是在dim
中总和为 1 的概率分布。Gumbel-Softmax Trick - 知乎 (zhihu.com)
#code is from GroupViT def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = False, dim: int = -1) -> torch.Tensor: # _gumbels = (-torch.empty_like( # logits, # memory_format=torch.legacy_contiguous_format).exponential_().log() # ) # ~Gumbel(0,1) # more stable https://github.com/pytorch/pytorch/issues/41663 gumbel_dist = torch.distributions.gumbel.Gumbel( torch.tensor(0., device=logits.device, dtype=logits.dtype), torch.tensor(1., device=logits.device, dtype=logits.dtype)) gumbels = gumbel_dist.sample(logits.shape) gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau) y_soft = gumbels.softmax(dim) if hard: # Straight through. index = y_soft.max(dim, keepdim=True)[1] y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) ret = y_hard - y_soft.detach() + y_soft else: # Reparametrization trick. ret = y_soft return ret
gumbel-softmax(替代argmax)
于 2022-09-14 16:11:47 首次发布