导入gpu工具包
import torch
先进行gpu是否存在的判断
train_on_gpu = torch.cuda.is_available() # 得到gpu运行的情况
if not train_on_gpu: # 根据gpu运行的情况,决定使用gpu还是cpu进行工作
print('CUDA cannot be used, training on CPU...')
else:
print('CUDA can be used, training on GPU...')
使用gpu训练
表示gpu不好使用的时候改用cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
观察模型的全连接层
通过观察全连接层,我们可以知道迁移的网络的输出是多少种类型,我们需要将全连接层改成我们所需要的输出的类型
model_ft = models.resnet152()
print(model_ft)