Knowledge distillation 代码实现(简易版)
-
定义ResNet(教师模型)
-
下载CIFAR10数据集并预处理
# 设置训练集图片预处理 transform_train = transforms.Compose([ # 随机裁剪成32x32并做padding=4的填充 transforms.RandomCrop(32, padding=4), # 以给定概率水平翻转图片,默认概率为0.5 transforms.RandomHorizontalFlip(), # 转换成tensor类型 transforms.ToTensor(), # 归一化 transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) # 设置测试集图片预处理 transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) ]) # 下载训练集并进行预处理 train_set = torchvision.datasets.CIFAR10(root='../data'