c1_greedy_policy_processing

def get_policy_probs(model, recons, mask): #mask(3,1,1,256,1) recons(3,1,256,256)
    channel_size = mask.shape[1] #1
    res = mask.size(-2) #256
    recons = recons.view(mask.size(0) * channel_size, 1, res, res) #(3,1,256,256)


    # Obtain policy model logits
    output = model(recons) #跳到下一段代码policy_model_def.py得到output(3,256)
    def forward(self, image): #image(3,1,256,256)
        """
        Args:
            image (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
            mask (torch.Tensor): Input tensor of shape [resolution], containing 0s and 1s

        Returns:
            (torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
        """

        # Image embedding
        # Initial block
        image_emb = self.channel_layer(image)#(3,8,256,256)
        # Apply down-sampling layers
        for layer in self.down_sample_layers:#ConvBlock(in_chans=8, out_chans=16, drop_prob=0, max_pool_size=2)
            image_emb = layer(image_emb)#(3,16,128,128)
        for layer in self.down_sample_layers:#ConvBlock(in_chans=16, out_chans=32, drop_prob=0, max_pool_size=2)
            image_emb = layer(image_emb)#(3,32,64,64)
        for layer in self.down_sample_layers:#ConvBlock(in_chans=32, out_chans=64, drop_prob=0, max_pool_size=2)
            image_emb = layer(image_emb)#(3,64,32,32)
        for layer in self.down_sample_layers:#ConvBlock(in_chans=64, out_chans=128, drop_prob=0, max_pool_size=2)
            image_emb = layer(image_emb)#(3,128,16,16)
        for layer in self.down_sample_layers:#ConvBlock(in_chans=128, out_chans=256, drop_prob=0, max_pool_size=2)
            image_emb = layer(image_emb)#(3,256,8,8)
        image_emb = self.fc_out(image_emb.flatten(start_dim=1))  # flatten all but batch dimension # (3,256)
        assert len(image_emb.shape) == 2
        return image_emb
    # Reshape trajectories back into their own dimension
    output = output.view(mask.size(0), channel_size, res)#(3,1,256)
    # Mask already acquired rows by setting logits to very negative numbers
    loss_mask = (mask == 0).squeeze(-1).squeeze(-2).float()#(3,1,256)
    logits = torch.where(loss_mask.byte(), output, -1e7 * torch.ones_like(output))#(3,1,256)
    # Softmax over 'logits' representing row scores
    probs = torch.nn.functional.softmax(logits - logits.max(dim=-1, keepdim=True)[0], dim=-1) #(3,1,256)
    # Also need this for sampling the next row at the end of this loop
    policy = torch.distributions.Categorical(probs)
    return policy, probs

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值