ND4j大矩阵做相识度计算实现1:N计算

现在很多做人脸识别的技术都是基于C++,python;但是java也是有自己的深度学习的库,Deeplearning4j就是java自己的深度学习库,ND4j就是dp4j里面专门做向量计算的库,现在就用ND4j实现1:N矩阵计算功能:

pom:

       <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-native</artifactId>
            <version>${dp4j-version}</version>
        </dependency>

1:N计算的核心就是,用矩阵和一个大矩阵做内积乘法

业务功能:就是用一张人脸的特征和数据库里面有的特征做内积乘法,快速在数据库里面找到这张人脸对应的人;

现实方法:

由于每次都在数据库里面去查询特征数据会消耗时间,所以我把特征数据初始化加载到内存里面来维护;做完内积计算以后需要找到相识度最大的对应的那条数据,opencv里面有一个sortIndex函数可以将数组的值排序以后返回对应的索引地址,java可以自己实现一个排序;

将查出来的特征值进行缓存:

@Component
public class NDCache {


    public static List<BasisFeature> basisFeatures = Lists.newArrayList();

    public static INDArray zeros = null;

    public static void load(List<BasisFeature> features){
        basisFeatures = features;
        List<INDArray> INDArrays = features.stream().map(basisFeature -> Nd4j.create(basisFeature.getBasis())).collect(Collectors.toList());
         zeros = Nd4j.vstack(INDArrays);
    }

}

实现矩阵计算,更新矩阵:

@Slf4j
@Service
public class NDService {

    @Autowired
    private MongoTemplate mongoTemplate;

    private static double forecastValues = 0.65D;

    private INDArray creArry(List<Double> features) {
        INDArray array = Nd4j.create(features.size(), 1);
        features.forEach(item -> array.putRow(features.indexOf(item), Nd4j.create(new double[]{item})));
        return array;
    }

    public BasisFeature comparRes(List<Double> features){
        INDArray resArr =  NDCache.zeros.mmul(creArry(features));
        Point[] points = Point.sortPoint(resArr.toDoubleVector());
        Point point = points[0];
        if (point.getValue() < forecastValues){
            return  null;
        }
        log.info("1比N第一位结果:"+point);
        return NDCache.basisFeatures.get(point.getIndex());
    }

    public void updateArr(BasisFeature basisFeature){
        for (int i=0;i<NDCache.basisFeatures.size();i++){
            BasisFeature feature = NDCache.basisFeatures.get(i);
            if (feature.getObjectId().equals(basisFeature.getObjectId())){
                NDCache.basisFeatures.set(i,basisFeature);
                NDCache.zeros.put(i,Nd4j.create(basisFeature.getBasis()));
                log.info("矩阵更新:"+NDCache.zeros.length());
            }
        }
    }

    public void loadNDarr(){
        Query query = new Query();
    
        List<BasisFeature> features = mongoTemplate.find(query,BasisFeature.class);
        NDCache.basisFeatures = features;
        List<INDArray> INDArrays = features.stream().map(basisFeature -> Nd4j.create(basisFeature.getBasis())).collect(Collectors.toList());
        NDCache.zeros = Nd4j.vstack(INDArrays);
        log.info("初始化矩阵大小:"+NDCache.zeros.length()+":数据源大小:"+features.size());
    }

    public void add(BasisFeature ... basisFeatures){
        if (ObjectUtils.isEmpty(NDCache.zeros)){
            loadNDarr();
        }else {
            List<INDArray> indArrays = Lists.newArrayList(NDCache.zeros);
            Arrays.asList(basisFeatures).stream().forEach(basisFeature -> indArrays.add(Nd4j.create(basisFeature.getBasis())));
            NDCache.zeros = Nd4j.vstack(indArrays);
            log.info("矩阵扩张:"+NDCache.zeros.length());
        }
    }

}

自定义sortIndex:

public class SortComprator implements Comparator {

    @Override
    public int compare(Object arg0, Object arg1) {
        Point t1=(Point)arg0;
        Point t2=(Point)arg1;
        return t2.getValue().compareTo(t1.getValue());
    }

}

自定义数据结构:

@Data
public class Point{

    private Double value;

    private int index;

   public  static Point[] sortPoint(double[] doubleVector){
        Point[] points = new Point[doubleVector.length];
        for (int i=0;i<doubleVector.length;i++){
            Point point = new Point();
            point.setValue(doubleVector[i]);
            point.setIndex(i);
            points[i] = point;
        }
        Arrays.sort(points,new SortComprator());
        return points;
    }

}

整个计算流程只是把其他语言的实现方法实现了一遍,但是java在计算方便的速度也是很快的,一百万的矩阵差不多也就141ms,

这个只是基于cpu的计算,当然nd4j也是可以基于cuda的GPU计算的。

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值