def get_policy_probs(model, recons, mask): #mask(3,1,1,256,1) recons(3,1,256,256)
channel_size = mask.shape[1] #1
res = mask.size(-2) #256
recons = recons.view(mask.size(0) * channel_size, 1, res, res) #(3,1,256,256)
# Obtain policy model logits
output = model(recons) #跳到下一段代码policy_model_def.py得到output(3,256)
def forward(self, image): #image(3,1,256,256)
"""
Args:
image (torch.Tensor): Input tensor of shape [batch_size, self.in_chans, height, width]
mask (torch.Tensor): Input tensor of shape [resolution], containing 0s and 1s
Returns:
(torch.Tensor): Output tensor of shape [batch_size, self.out_chans, height, width]
"""
# Image embedding
# Initial block
image_emb = self.channel_layer(image)#(3,8,256,256)
# Apply down-sampling layers
for layer in self.down_sample_layers:#ConvBlock(in_chans=8, out_chans=16, drop_prob=0, max_pool_size=2)
image_emb = layer(image_emb)#(3,16,128,128)
for layer in self.down_sample_layers:#ConvBlock(in_chans=16, out_chans=32, drop_prob=0, max_pool_size=2)
image_emb = layer(image_emb)#(3,32,64,64)
for layer in self.down_sample_layers:#ConvBlock(in_chans=32, out_chans=64, drop_prob=0, max_pool_size=2)
image_emb = layer(image_emb)#(3,64,32,32)
for layer in self.down_sample_layers:#ConvBlock(in_chans=64, out_chans=128, drop_prob=0, max_pool_size=2)
image_emb = layer(image_emb)#(3,128,16,16)
for layer in self.down_sample_layers:#ConvBlock(in_chans=128, out_chans=256, drop_prob=0, max_pool_size=2)
image_emb = layer(image_emb)#(3,256,8,8)
image_emb = self.fc_out(image_emb.flatten(start_dim=1)) # flatten all but batch dimension # (3,256)
assert len(image_emb.shape) == 2
return image_emb
# Reshape trajectories back into their own dimension
output = output.view(mask.size(0), channel_size, res)#(3,1,256)
# Mask already acquired rows by setting logits to very negative numbers
loss_mask = (mask == 0).squeeze(-1).squeeze(-2).float()#(3,1,256)
logits = torch.where(loss_mask.byte(), output, -1e7 * torch.ones_like(output))#(3,1,256)
# Softmax over 'logits' representing row scores
probs = torch.nn.functional.softmax(logits - logits.max(dim=-1, keepdim=True)[0], dim=-1) #(3,1,256)
# Also need this for sampling the next row at the end of this loop
policy = torch.distributions.Categorical(probs)
return policy, probs