流程
- 取vgg16模型fc2层向量保存到image.db文件中
- 使用canopy+欧氏距离粗略估计k值
- 使用k-means算法分类
获取图片向量(代码摘自 《自制AI图像搜索引擎》)
private INDArray getImgFeature(File imgFile) throws IOException {
NativeImageLoader loader = new NativeImageLoader(224, 224, 3);
INDArray imageArray = loader.asMatrix(imgFile);
DataNormalization scaler = new VGG16ImagePreProcessor();
scaler.transform(imageArray);
Map<String, INDArray> map = vgg16Model.feedForward(imageArray, false);
INDArray feature = map.get("fc2");
return feature;
}
private double[] INDArray2DoubleArray(INDArray indArr) {
String indArrStr = indArr.toString().replace("[", "").replace("]", "");
String[] strArr = indArrStr.split(",");
int len = strArr.length;
double[] doubleArr = new double[len];
for (int i = 0; i < len; i++) {
doubleArr[i] = Double.parseDouble(strArr[i]);
}
return doubleArr;
}
public class Classify {
public static List<Vector> getVectorListFromDB(String dbPath) {
DB db = DBMaker.fileDB(dbPath).make();
ConcurrentMap<String, double[]> map = db.hashMap("feat_map", Serializer.STRING, Serializer.DOUBLE_ARRAY).open();
List<Vector> vecs = new ArrayList<Vector>();
for (String key : map.keySet()) {
double[] val = map.get(key);
// norm2
val = Utils.normalizeL2(val);
Vector vec = new Vector(key, val);
vecs.add(vec);
}
db.close();
return vecs;
}
public static void kmeans(String dbPath,String fromPath,String distPath) throws IOException {
Long time=System.currentTimeMillis();
List<Vector> dataset = Classify.getVectorListFromDB(dbPath);
Canopy canopy = new Canopy(new ArrayList<>(dataset));
int m = canopy.cluster();
System.out.println("预计 "+m+" 个分类");
Kmeans cu = new Kmeans(new ArrayList<>(dataset));
cu.execute(m);
ArrayList<ArrayList<Vector>> clusters = cu.getCluster();
int i=0;
for (ArrayList<Vector> cluster : clusters) {
for (Vector vector : cluster) {
System.out.println("move "+vector.getKey()+" "+i);
moveFile(vector.getKey(),i,fromPath,distPath);
}
i++;
}
System.out.println("耗时 "+((System.currentTimeMillis()-time)/1000/60)+" min");
}
public static void moveFile(String from,int to,String fromPath,String distPath) throws IOException {
String baseDir=fromPath;
String distDir=distPath;
File toF = new File(distDir+File.separator+to);
if(!toF.isDirectory()){
toF.mkdir();
}
FileUtils.copyFile(new File(baseDir+ File.separator+from),new File(toF.getPath()+File.separator+from));
}
}
现成打包工具
使用说明
Option help = new Option("h",false,"显示帮助信息");
Option model=Option.builder("m").hasArg().argName("model").desc("模型路径名").build();
Option database = Option.builder("d").hasArg().argName("database").desc("图像特征库路径名").build();
Option img = Option.builder("i").hasArg().argName("imgdir").desc("用于构建特征库的图像文件夹路径全名").build();
Option dist = Option.builder("t").hasArg().argName("dist").desc("分类后的目标文件夹").build();
Option useDb = Option.builder("b").hasArg(false).desc("使用现成的db").build();
示例
对mp文件夹下的图片分类到test文件夹
java -jar .\GenerateImgsFeatDBTool-1.0-SNAPSHOT.jar -m D:\work\vgg16.zip -d D:\pic\image.db -i D:\pic\mp -t D:\pic\test
k-means算法生成的结果每次都不太一样,如果对分类结果不满意可以,加-b参数会使用上次生成的db文件,加快分类速度
java -jar .\GenerateImgsFeatDBTool-1.0-SNAPSHOT.jar -m D:\work\vgg16.zip -d D:\pic\image.db -i D:\pic\mp -t D:\pic\test -b
链接:https://pan.baidu.com/s/1hNf4YarZ5nnZkjmcMnPrng
提取码:yuc5