Test,Evaluate_gpu 修改,自动跑完你要的epoch

Test

当训练完的时候,会保存很多代的训练参数,后边做测试会把参数加载进行提取特征操作。加了一个循环,这样不用一个一个自己动手操作了,只需要把你想要的代数填进去就OK 了,

if __name__ == '__main__':
    ######################################################################
    # Load Collected data Trained model
    print('-------test-----------')
    #model_structure = My_model(751)
    model_structure = resnet50_rga()
    #model_structure = nn.DataParallel(model_structure)#######
    list = [194,192,190,189,179,169,159,149,139,129]# 根据列表的代数参数,生成对应参数的特征。方便做测试。
    for epoch_i in list:
        
        model = load_network(model_structure,epoch_i)
    #model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(save_path).items()})
    # Change to test mode
        model = model.eval()
        if use_gpu:
            model = model.cuda()
    # Extract feature
        cudnn.benchmark=True
        gallery_feature = extract_feature(model, dataloaders['gallery'])
        query_feature = extract_feature(model, dataloaders['query'])
    # Save to Matlab for check
        result = {'gallery_f': gallery_feature.numpy(), 'gallery_label': gallery_label, 'gallery_cam': gallery_cam,
              'query_f': query_feature.numpy(), 'query_label': query_label, 'query_cam': query_cam}
    #scipy.io.savemat('pytorch_result_my_net7_duke2.mat', result)#scipy.io.savemat('pytorch_result_my_net7_duke.mat', result)
        scipy.io.savemat('pytorch_result_%s_%s.mat'%(name,epoch_i), result)#opt.which_epoch
        print('第%s代的参数对应的特征已经生成,注意查收^-^^_^^-^^_^'%epoch_i)
    print('Finished test for all epoch in list.......')

evaluate——GPU

测试完了之后就生成了 对应epoch参数的gallery和query特征。需要分别做一个评估,看看效果。
改进:也是加了一个循环,只需要把test里list列表复制过来即可,既可以连续做出评估并且保存。
操作说明:1.list列表复制粘贴过来。 2.name记得和test一样。

if __name__ == '__main__':
    list = [198,196,194,192,190,189,179,169,159,149,139,129]
    for epoch_i in list:
        
        mat_name = 'pytorch_result_My_mode_%s.mat'%epoch_i


        result = scipy.io.loadmat(mat_name)##################################
        query_feature = torch.FloatTensor(result['query_f'])
        query_cam = result['query_cam'][0]
        query_label = result['query_label'][0]
        gallery_feature = torch.FloatTensor(result['gallery_f'])
        gallery_cam = result['gallery_cam'][0]
        gallery_label = result['gallery_label'][0]

        multi = os.path.isfile('multi_query.mat')

        if multi:
            m_result = scipy.io.loadmat('multi_query.mat')
            mquery_feature = torch.FloatTensor(m_result['mquery_f'])
            mquery_cam = m_result['mquery_cam'][0]
            mquery_label = m_result['mquery_label'][0]
            mquery_feature = mquery_feature.cuda()

        query_feature = query_feature.cuda()
        gallery_feature = gallery_feature.cuda()

        print(query_feature.shape)
        CMC = torch.IntTensor(len(gallery_label)).zero_()
        ap = 0.0
        #print(query_label)
        for i in range(len(query_label)):
            ap_tmp, CMC_tmp = evaluate(query_feature[i],query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam)
            if CMC_tmp[0]==-1:
                continue
            CMC = CMC + CMC_tmp
            ap += ap_tmp
            #print(i, CMC_tmp[0])

        CMC = CMC.float()
        CMC = CMC/len(query_label) #average CMC
        print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label)))

        name = 'My_mode'

        # 在日志文件中记录精度
        #with open('./model/%s/%s.txt' %(name,name+'_`223'),'a') as acc_file:
        with open('./model/%s/%s.txt' %(name,name+'_resnet+cross'),'a') as acc_file:
            acc_file.write('%s, Rank@1: %f, Rank@5: %f, Rank@10: %f,\nmAP:%f\n' % (mat_name, CMC[0], CMC[4], CMC[9], ap/len(query_label)))
        print('Finished evaluate for %s epoch '%epoch_i)
    print('Finished evaluate for all epoch in the list....')
    '''   
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值