本代码针对基于densenet 的 pytorch添加预训练模型的的一个分类方法,由官方教程为基础做的更改。
本实验主要目的是以Imagenet或其他大数据集已经训练好的权重文件,初始化到我们要用到的训练网络中。
本算法基于jupyter noetbook 下载anaconda,安装好需要的环境后 在代码目录打开命令行键入jupyter noetbook即可使用
代码参考:https://github.com/seasealfeng/densnet_transfer_learning
载入数据:
data_transforms = {
'train': transforms.Compose([
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.RandomCrop(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
data_dir = 'hymenoptera_data'#文件夹名称
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['