if __name__ == '__main__':
t = torch.ones(32, 3, 64, 64)
model = sa_layer(64)
y = model(t)
print("print(y.shape)", y.shape) # shape
如何测试网络能不能跑通
最新推荐文章于 2024-08-10 15:48:22 发布
if __name__ == '__main__':
t = torch.ones(32, 3, 64, 64)
model = sa_layer(64)
y = model(t)
print("print(y.shape)", y.shape) # shape