loss
dual loss
代码:
Gated-SCNN中的gumbel_softmax: 如公式(9)所示,用gumbel softmax实现argmax()。
g = _gumbel_softmax_sample(input_logits.view(N, C, -1), tau=0.5)
gumbel_noise = _sample_gumbel(logits.size(), eps=eps)
#_sample_gumbel构造gumbel分布g=-log(-log(u+eps))
U = torch.rand(shape).cuda()
return - torch.log(eps - torch.log(U + eps))
#log(变量π)+gumbel分布
y = logits + gumbel_noise
#softmax,用softmax后的结果(记作s_y)也是一个概率分布,并且接近变量的概率分布,但是要注意:s_y是一个连续变量的概率分布,而变量π是离散变量。比如π={(0:0.1),(1:0.5),(2:0.4)},则s_y是一个接近π的概率分布,但是它的值位于0-2。
return F.softmax(y / tau, 1)
g = compute_grad_mag(g, cuda=self._cuda)
E_ = convTri(E, 4, cuda)
Ox, Oy = numerical_gradients_2d(E_, cuda)
mag = torch.sqrt(torch.mul(Ox,Ox) + torch.mul(Oy,Oy) + 1e-6)
mag = mag / mag.max();
loss_ewise = F.l1_loss(g, g_hat, reduction=‘none’, reduce=False)
pytorch中的gumbel_softmax
位于torch.nn.functional
def gumbel_softmax(logits: Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1) -> Tensor:
if has_torch_function_unary(logits):
return handle_torch_function(gumbel_softmax, (logits,), logits, tau=tau, hard=hard, eps=eps, dim=dim)
if eps != 1e-10:
warnings.warn("`eps` parameter is deprecated and has no effect.")
gumbels = (
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()
) # ~Gumbel(0,1)
gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim)
if hard:
# Straight through.
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
# Reparametrization trick.
ret = y_soft
return ret
Gated-SCNN中的gumbel_softmax等价于torch.nn.functional.gumbel_softmax(logits, tau=0.5, hard= False, eps = 1e-10, dim=1)
自己使用时的代码:
pred是network的输出,没有做softmax和log。
netout_back_roof = torch.log(pred1.softmax(dim=1))#B,C,H,W
netout_back_roof_gumbel = F.gumbel_softmax(netout_back_roof, tau=0.1, hard= True, eps = 1e-10, dim=1)#B,C,H,W