import torch.nn as nn import torch class net(nn.Module): def __init__(self): super().__init__() self.conv = nn.Conv2d(3, 6, 3, stride=2, padding=1) self.layer1 = nn.Conv2d(6,12,3,stride=2, padding=1) self.layer2 = nn.Conv2d(12, 24, 3, stride=2, padding=1) self.layer3 = nn.Conv2d(24, 48, 3, stride=2, padding=1) def forward(self, x): x = self.conv(x) x = self.layer1(x) x = self.layer2(x) x = x.detach() x = self.layer3(x) x = torch.mean(x) x.backward() print(self.layer2.weight.grad) # None print(self.layer3.weight.grad) # Tensor Size([48, 24, 3, 3) if __name__ == '__main__': data = torch.randn(1, 3, 224, 224) n = net() pred = n(data)
pytorch中使用detach()
最新推荐文章于 2024-02-13 00:33:55 发布