squeeze
unsqueeze
先看torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的数去掉第一个维数为一的维度之后就变成(3)行。squeeze(a)就是将a中所有为1的维度删掉。不为1的维度没有影响。a.squeeze(N) 就是去掉a中指定的维数为一的维度。还有一种形式就是b=torch.squeeze(a,N) a中去掉指定的定的维数为一的维度。
view
相当于 reshape
accuracy
pred = torch.max([-0.1,0.5],1)
意味着 最大的设为1, 小的变0.
(y == pred).sum() 可以用来算最后精度, 加总后 处于总data数量
transform 的 未封包版本
from torchvision.transforms import functional
# 可用函数在下面
to_pil_image
normalize
resize
crop
hflip vflip
adjust_brightness
rotate to_grayscale
https://github.com/pytorch/vision/blob/master/torchvision/transforms/functional.py
transform 用在 Dataset 上
class ImageDataset(Dataset):
def __init__(self, root, transforms_=None, unaligned=False, mode='train'):
transforms_ = transforms.Compose([
transforms.ToTensor(),
transforms.ToPILImage(),
transforms.RandomCrop((300,300)),
]
self.transform = transforms.Compose(transforms_)
def __getitem__(self, index):
item_A = self.transform(Image.open(self.files_A[index % len(self.files_A)]))
)
#return {'A': item_A, 'B': item_B}
normalize
channel=(channel-mean)/std