input为ReLu()中的一个参数,默认为Faluse,保留输入数据
import torch
from torch.nn import ReLU
from torch import nn
input =torch.tensor([
[1, -0.5], [-1, 3] # 1为batchsize
])
output = torch.reshape(input, (-1, 1, 2, 2))
print(output.shape)
class LR(nn.Module):
def __init__(self):
super(LR, self).__init__()
self.relu1 = ReLU()
def forward(self, input):
output = self.relu1(input)
return output
lrp = LR()
output = lrp(input)
print(output)
# tensor([[1., 0.],
# [0., 3.]]) 当x<0,x=0