仅作为记录,大佬请跳过。
背景
tensor是[4,3,224,224]
要去除某一行,变成[3,3,224,224]
直接上代码
def del_tensor(arr, index):
if index < arr.shape[0] - 1:
arr1 = arr[0: index]
arr2 = arr[index + 1:]
return torch.cat((arr1, arr2), dim=0)
else:
return arr[0: index]
x_HE_CLS=del_tensor(x_HE_CLS, 2)
参考