package com.techwolf.transformer;import com.alibaba.fastjson.*;importcom.alibaba.fastjson.parser.Feature;importorg.tensorflow.Graph;importorg.tensorflow.Session;importorg.tensorflow.Tensor;//importcom.alibaba.fastjson.JSONPObject;//importorg.json.JSONObject;import java.io.*;importjava.nio.file.Files;importjava.nio.file.Path;importjava.nio.file.Paths;import java.util.*;
publicclassJobPredict {
private static String jsonPath= "src/main/resources/resource.json";
private static String modelPath= "src/main/resources/model.pb";
private static Map positionToFeature = new HashMap();
private static Map jobMapping = new HashMap();
private static Map mergeMapping = new HashMap();
private static Map featureToId = new HashMap();
private static Map idToCode = new HashMap();
private static Map codeToLabel = new HashMap();
public static String readJsonFile(String fileName) throws FileNotFoundException {
String jsonStr= "";try{
File jsonFile=new File(fileName);
FileReader fileReader=new FileReader(jsonFile);
Reader reader= new InputStreamReader(new FileInputStream(jsonFile), "utf-8");
int ch=0;
StringBuffer sb=new StringBuffer();while ((ch = reader.read()) != -1) {
sb.append((char) ch);
}
fileReader.close();
reader.close();
jsonStr=sb.toString();returnjsonStr;
} catch (IOException e) {
e.printStackTrace();returnnull;
}
}
private static MapjsonTOMap(JSONObject jsobj) {
Map data = new HashMap();
Iterator it=jsobj.entrySet().iterator();while(it.hasNext()) {
Map.Entry entry = (Map.Entry) it.next();
data.put(entry.getKey(), entry.getValue());
}returndata;
}
private static void getConfig() throws FileNotFoundException {
String jsonStr=readJsonFile(jsonPath);
JSONObject obj=JSON.parseObject(jsonStr);
positionToFeature= jsonTOMap(obj.getJSONObject("position2feature"));
featureToId= jsonTOMap(obj.getJSONObject("feature2id"));
jobMapping= jsonTOMap(obj.getJSONObject("job_mapping"));
mergeMapping= jsonTOMap(obj.getJSONObject("merge_mapping"));
idToCode= jsonTOMap(obj.getJSONObject("id2position"));
codeToLabel= jsonTOMap(obj.getJSONObject("position_mapping"));
System.out.println("config data loaded!");
}
public static String convert(String utfString) {
StringBuilder sb=new StringBuilder();
int i= -1;
int pos=0;
int iint=0;while ((i = utfString.indexOf("\\u", pos)) != -1) {
String sd=utfString.substring(pos, i);
sb.append(sd);
iint= i + 5;if (iint
pos= i + 6;
sb.append((char) Integer.parseInt(utfString.substring(i+ 2, i + 6), 16));
}
}
String endStr= utfString.substring(iint + 1, utfString.length());return sb + "" +endStr;
}
private static MapgetCodeAndScore(JSONArray jsonArray) throws FileNotFoundException {
List codes = new ArrayList();
List scores = new ArrayList();
Integer codeFlag= -1;
float scoreFlag=(float) .0;for (int i = 0; i < jsonArray.size(); i++) {
JSONObject skillsItem=(JSONObject) jsonArray.get(i);
String code= (skillsItem.get("code")).toString();
Float score= Float.parseFloat((String) skillsItem.get("score"));
boolean isReplace=mergeMapping.containsKey(code);if(isReplace) {
code=(mergeMapping.get(code)).toString();
System.out.println("replace id:" +code);
}
String position=(jobMapping.get(code)).toString();
Integer featSeq=(Integer) positionToFeature.get(position);if (featSeq ==null) {
codes.add((Integer) featureToId.get(codeFlag.toString()));
scores.add(scoreFlag);
}else{
Integer x=(Integer) featureToId.get(featSeq.toString());
codes.add((Integer) featureToId.get(featSeq.toString()));
scores.add(score);
}
}if (jsonArray.size() < 3) {for(int i=0; i< (3-jsonArray.size()); i++) {
codes.add((Integer) featureToId.get(codeFlag.toString()));
scores.add(scoreFlag);
}
}
Map result = new HashMap();
result.put("codes", codes);
result.put("scores", scores);returnresult;
}
private static byte[] readAllByteOrExit(Path path){try{returnFiles.readAllBytes(path);
}catch (IOException e){
System.out.println("Failed to read[" + path + "]:" +e.getMessage());
System.exit(1);
}returnnull;
}
private static MapgetDataContent(String testFile) throws FileNotFoundException {
String jsonStr=readJsonFile(testFile);
JSONObject obj=JSON.parseObject(jsonStr, Feature.OrderedField);
JSONObject objNew=JSON.parseObject(obj.toJSONString(), Feature.OrderedField);
ArrayList sampleCode = new ArrayList();
ArrayList sampleScore = new ArrayList();
Map samples = new HashMap();for(String userId: objNew.keySet()) {
ArrayList codeList = new ArrayList();
ArrayList scoresList = new ArrayList();
JSONObject itemTags= (JSONObject) ((JSONObject)((JSONObject)objNew.get(userId)).get("_source")).get("tags");
JSONArray skills= (JSONArray) itemTags.get("skills");
JSONArray title= (JSONArray) itemTags.get("title");
JSONArray desc= (JSONArray) itemTags.get("desc");
Map skillsResult =getCodeAndScore(skills);
Map titleResult =getCodeAndScore(title);
Map descResult =getCodeAndScore(desc);
codeList.addAll(skillsResult.get("codes"));
codeList.addAll(titleResult.get("codes"));
codeList.addAll(descResult.get("codes"));
scoresList.addAll(skillsResult.get("scores"));
scoresList.addAll(titleResult.get("scores"));
scoresList.addAll(descResult.get("scores"));
sampleCode.add(codeList);
sampleScore.add(scoresList);
}
samples.put("sampleCode", sampleCode);
samples.put("sampleScore", sampleScore);
System.out.println("ok! sample feature created.");returnsamples;
}
public static int[] arraySort(float[] arr, boolean desc) {
float temp;
int index;
int k=arr.length;
int[] Index=new int[k];for (int i = 0; i < k; i++) {
Index[i]=i;
}for (int i = 0; i < arr.length; i++) {for (int j = 0; j < arr.length - i - 1; j++) {if(desc) {if (arr[j] < arr[j + 1]) {
temp=arr[j];
arr[j]= arr[j + 1];
arr[j+ 1] =temp;
index=Index[j];
Index[j]= Index[j + 1];
Index[j+ 1] =index;
}
}else{if (arr[j] > arr[j + 1]) {
temp=arr[j];
arr[j]= arr[j + 1];
arr[j+ 1] =temp;
index=Index[j];
Index[j]= Index[j + 1];
Index[j+ 1] =index;
}
}
}
}returnIndex;
}
private static void featToTensor(float[][][] indexes, int[][] codes, float[][] scores, Mapdata) {
List featCode = data.get("sampleCode");
List featScore = data.get("sampleScore");
int size= 9;for(int i=0; i < featCode.size(); i++) {
Object eachCode=featCode.get(i);
Object eachScore=featScore.get(i);
float [][] positionResult=new float[size][];for(int step=0; step < size; step++) {
float[] positionVector=new float[size];
positionVector[step]= 1;
positionResult[step]=positionVector;
}
indexes[i]=positionResult;
Integer[] targetInter= ((List)eachCode).toArray(new Integer[size]);
int[] codeResult=Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();
Float[] targetFloat= ((List)eachScore).toArray(new Float[size]);
double[] scoreResult=Arrays.stream(targetFloat).mapToDouble(Double::valueOf).toArray();
float[] scoreFloat=new float[size];for(int j=0; j < scoreResult.length; j++) {
scoreFloat[j]=(float) scoreResult[j];
}
System.arraycopy(codeResult,0,codes[i], 0, codeResult.length);
System.arraycopy(scoreFloat,0,scores[i], 0, scoreResult.length);
}
}
private static List> modelInfer(Mapdata) {
int batchSize= data.get("sampleCode").size();
int padLength= 9;
int returnNum= 5;
int classNum= 868;
float[][][] indexes=new float[batchSize][padLength][padLength];
int[][] codes=new int[batchSize][padLength];
float[][] scores=new float[batchSize][padLength];
float transKeepProb= (float) 1.0;
float multiKeepProb= (float) 1.0;
byte[] graphDef=readAllByteOrExit(Paths.get(modelPath));
Graph g=new Graph();
g.importGraphDef(graphDef);
Session sess=new Session(g);
featToTensor(indexes, codes, scores, data);
Tensor tensorIndex=Tensor.create(indexes);
Tensor tensorCode=Tensor.create(codes);
Tensor tensorScore=Tensor.create(scores);
Tensor tensorTransProb=Tensor.create(transKeepProb);
Tensor tensorMultiProb=Tensor.create(multiKeepProb);
Tensor tensorClassResult=sess.runner().
feed("input_x:0", tensorCode).
feed("input_x_score:0", tensorScore).
feed("embed_position:0", tensorIndex).
feed("trans_keep_prob:0", tensorTransProb).
feed("multi_keep_prob:0", tensorMultiProb).
fetch("discriminator/softmax_score:0").run().get(0);
float[][] result=(float[][]) tensorClassResult.copyTo(new float[batchSize][classNum]);
List> predictResult =new ArrayList();for(int i=0; i
float[] resultVec=result[i];
int[] resultIndex=new int[classNum];
HashMap predictSample = new HashMap();
resultIndex=arraySort(resultVec, true);for(int s=0; s < returnNum; s++) {
String sampleCode=Integer.toString(resultIndex[s]);
String label=(String) codeToLabel.get(Integer.toString((Integer) idToCode.get(sampleCode)));
predictSample.put(label, resultVec[s]);
}
predictResult.add(predictSample);
}
tensorClassResult.close();
tensorMultiProb.close();
tensorTransProb.close();
tensorScore.close();
tensorCode.close();
tensorIndex.close();returnpredictResult;
}
public static void main (String[]args) throws IOException {
String testFile= "src/main/data/predict_data.json";
getConfig();
Map samples =getDataContent(testFile);
List> result =modelInfer(samples);
System.out.println(result);
}
}