import torch
def data_normal(orign_data):
d_min = orign_data.min()
if d_min < 0:
orign_data +=torch.abs(d_min)
d_min = orign_data.min()
d_max = orign_data.max()
dst = d_max - d_min
norm_data = (orign_data - d_min).true_divide(dst)
return norm_data
x = torch.randint(20,(2,5))
print(x.shape)
x = data_normal(x)
print(x)