【LFW大规模测试的准备1】基于java接口实现LFW数据规范写入写出

java接口定义

package preprocess;

import java.util.List;

public interface PreProcess {

    List<NamePair> getFileNameList(char[] path, int index) ;  
    //输入txt文件路径,及组名,返回图片名称对

    boolean setResult( char[] path, List<ResultPair> resultPair );           
    // 输入ROC点列 ,保存为txt


}

定义了两个类, NamePair是图片对的名称,ResultPair是ROC数据结果的类。

package preprocess;

public class NamePair {
    char[] namePhoto1;
    char[] namePhoto2;
}
package preprocess;

public class ResultPair {
    float truePositiverate;
    float falsePositiverate;
    public ResultPair(){

    }; //构造函数
    public ResultPair(float f1, float f2){
        truePositiverate=f1;
        falsePositiverate=f2;
    };  //构造函数, 可用于直接赋值


}

接口的实现
定义一个Cire的类,来实现接口
1. 首先写了getLineName的一个内部函数, 功能是输入文本内容,输出各图片的全名。 主要用到string的split技术;
2. 具体实现getFileNameList,setResult这个接口的两个方法。 且接口中声明几个方法,需全部实现;
3. 在public static void main (String[] args) 主函数中,定义了编译的入口,并演示了该接口函数的功能效果。

package preprocess;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

public class Cire implements PreProcess {

    private NamePair getLineName( String LineContex){
        NamePair photoName = new NamePair();
        String[] aa=LineContex.split("\\t");
        if(aa.length==3){
            if(Integer.parseInt(aa[1])<10)
            {
                photoName.namePhoto1= (aa[0]+"_000"+aa[1]+".jpg").toCharArray();
            }else{
                photoName.namePhoto1= (aa[0]+"_00"+aa[1]+".jpg").toCharArray();
            }

            if(Integer.parseInt(aa[2])<10)
            {
                photoName.namePhoto2= (aa[0]+"_000"+aa[2]+".jpg").toCharArray();
            }else{
                photoName.namePhoto2= (aa[0]+"_00"+aa[2]+".jpg").toCharArray();
            }
        }else if(aa.length==4){
            if(Integer.parseInt(aa[1])<10)
            {
                photoName.namePhoto1= (aa[0]+"_000"+aa[1]+".jpg").toCharArray();
            }else{
                photoName.namePhoto1= (aa[0]+"_00"+aa[1]+".jpg").toCharArray();
            }

            if(Integer.parseInt(aa[3])<10)
            {
                photoName.namePhoto2= (aa[2]+"_000"+aa[3]+".jpg").toCharArray();
            }else{
                photoName.namePhoto2= (aa[2]+"_00"+aa[3]+".jpg").toCharArray();
            }
        }


        return photoName;
    } // 输入文本内容,输出规范化的图片名称

    public List<NamePair> getFileNameList(char[] path, int index) {
        //读取文件
        List<NamePair> photoNameList = new ArrayList<NamePair>();

        try{

            String encoding="GBK";
            File file= new File(String.valueOf(path));
            if(file.isFile()&& file.exists()){
                InputStreamReader read = new InputStreamReader(new FileInputStream(file),encoding);
                BufferedReader bufferedReader = new BufferedReader(read);
                String lineTxt =null;
                NamePair photoNamePair =new NamePair();  //对象必须new创建,否则容易报指针null的错误
                int i=1;
                while((lineTxt = bufferedReader.readLine())!= null ){

                    //每一行的名字+数字,作为name的输出
                    if(i<=(index*600)&& i>((index-1)*600))
                    {
                        photoNamePair= getLineName( lineTxt);
                        photoNameList.add(photoNamePair);

                        System.out.println(String.valueOf(photoNamePair.namePhoto1)+"    "+String.valueOf(photoNamePair.namePhoto2));

                    }

                    i++;
                }

                read.close();


            }else{
                System.out.println("找不到指定的文件");
            }


        }catch(Exception e){
            System.out.println("读取txt文件失败");
            e.printStackTrace();
        };

        return photoNameList;

    }  //输入txt文件路径,及组名,输出图片名称对 


    public boolean setResult( char[] path, List<ResultPair> resultPair  ) {

        File file= new File(String.valueOf(path));  //要保存的文件路径

        FileWriter out = null;
        try {
            out = new FileWriter( file );
        } catch (IOException e1) {
            // TODO Auto-generated catch block
            e1.printStackTrace();
        }

        for(int i=0;i<resultPair.size();i++){
            float data1= resultPair.get(i).truePositiverate;                    
            float data2= resultPair.get(i).falsePositiverate;
            String data= data1+"   "+data2;
            try {
                out.write(data);
            } catch (IOException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }

            try {
                out.write("\r\n");
            } catch (IOException e) {
                // TODO Auto-generated catch block
                e.printStackTrace();
            }
        }

        try {
            out.close();
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }

        return true ;  


    }  // 输入ROC点列, size 是数组的长度,保存为txt


    public static void main (String[] args){
        Cire C =new Cire();
        char[] path="D:\\pairs.txt".toCharArray();  //路径名
        List<NamePair> F= C.getFileNameList(path,10);   //得到对应组的图片名称


        List<ResultPair> resultPair= new ArrayList<ResultPair>();
        ResultPair obj1= new ResultPair(0.9f,0.1f);
        ResultPair obj2= new ResultPair(0.8f,0.2f);
        resultPair.add(obj1);
        resultPair.add(obj2);
        char[] path2="D:\\ROC.txt".toCharArray();  //输出ROC.txt的路径名
        C.setResult( path2, resultPair); //ROC保存到txt文件

    }

}
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
以下是基于 PyTorch 环境下使用 ArcFace 测试 LFW 数据集训练模型的完整代码: ```python import os import argparse import torch import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import ImageFolder from models import ArcFaceModel from losses import ArcFaceLoss # Define arguments parser = argparse.ArgumentParser(description='ArcFace LFW evaluation') parser.add_argument('--data_dir', type=str, default='./data/lfw', help='directory where the LFW dataset is located') parser.add_argument('--model_path', type=str, default='./models/model.pth', help='path to the trained model') parser.add_argument('--batch_size', type=int, default=32, help='batch size for testing') parser.add_argument('--image_size', type=int, default=112, help='image size for testing') parser.add_argument('--num_workers', type=int, default=4, help='number of workers for data loading') args = parser.parse_args() # Define device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define data transformations transform = transforms.Compose([ transforms.Resize((args.image_size, args.image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ]) # Load LFW dataset lfw_dataset = ImageFolder(args.data_dir, transform=transform) lfw_loader = DataLoader(lfw_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) # Load model model = ArcFaceModel(num_classes=len(lfw_dataset.classes)).to(device) model.load_state_dict(torch.load(args.model_path)) model.eval() # Define loss function criterion = ArcFaceLoss() # Define testing function def test(): correct = 0 total = 0 with torch.no_grad(): for images, labels in lfw_loader: images = images.to(device) labels = labels.to(device) embeddings = model(images) predictions = torch.argmax(embeddings, dim=1) correct += (predictions == labels).sum().item() total += len(labels) accuracy = correct / total return accuracy # Test model on LFW dataset accuracy = test() print('Accuracy on LFW dataset: {:.2%}'.format(accuracy)) ``` 需要注意的是,上述代码中用到了 `models` 和 `losses` 模块中的内容,因此需要提前创建这两个模块。`models` 模块是用来定义 ArcFace 模型的,这里可以使用开源的实现,如 `https://github.com/ronghuaiyang/arcface-pytorch`。`losses` 模块是用来定义 ArcFace 损失函数的,这里需要根据实际需求进行编写。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CVchina_BUAA

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值