2020-11-01

代码解读

#定义格式
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

#定义格式
vgg_format = transforms.Compose([
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])

data_dir = './dogscats'
#读取数据
dsets = {x: datasets.ImageFolder(os.path.join(data_dir, x), vgg_format)
         for x in ['train', 'valid']}
#获取大小
dset_sizes = {x: len(dsets[x]) for x in ['train', 'valid']}
#对象化
dset_classes = dsets['train'].classes

修改的部分代码如下
主要是对数据的提取做了改动

def test_model(model,dataloader,size):
    model.eval()
    #初始化变量
    predictions = np.zeros(size)
    all_classes = np.zeros(size)
    all_proba = np.zeros((size,2))
    i = 0
    running_loss = 0.0
    running_corrects = 0
    meici_shuzu=[0,1,2,3,4]
    total_cishu=0
    for inputs,classes in dataloader:
        print("num:",inputs)#获得inputs格式,进行分解
        #将数据放入GPU
        inputs = inputs.to(device)
        classes = classes.to(device)
        #放入模型,获取结果
        outputs = model(inputs)
        #计算损失值
        loss = criterion(outputs,classes)         
        _,preds = torch.max(outputs.data,1)
        #以下是修改的代码
        print("\npreds:",preds)#输出观察preds结构,发现是张量
        for shushu2 in range(0,5):#构造编号,后期发现其
        #实只需要调用test的图片名就好了
          meici_shuzu[shushu2]=shushu2+total_cishu
          #将preds用numpy转换成数组类型,便于储存
         dataframe=pd.DataFrame({'num':meici_shuzu,'result':preds.numpy()})
         #循环写入
         dataframe.to_csv("test.csv",index=False,mode='a',sep=",")
        total_cishu=total_cishu+5

在进行写入数据的时候,要注意图片在colab的排序方式,并按照对应的格式输出数据

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值