应该说最先读的一段应该是forward这一段,用来理解整个网络框架
def forward_features(self, x):
pdb.set_trace()# x.shape=[256,3,384,384]
x = self.patch_embed(x)#x.shape=[256,576,384]
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
if self.dist_token is None:
x = torch.cat((cls_token, x), dim=1)
else:
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)#x.shape=[256,577,384]
x = self.blocks(x)#self_att & ffn,x.shape=[384,577,384]
x = self.norm(x)
if self.dist_token is None:
return self.pre_logits(x[:, 0])
else:
return x[:, 0], x[:, 1]
def forward(self, x):
#pdb.set_trace()
x = self.forward_features(x)
if self.head_dist is not None:
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
if self.training and not torch.jit.is_scripting():
# during inference, return the average of both classifier predictions
return x, x_dist
else:
return (x + x_dist) / 2
else:
x = self.head(x)
return x