java 重写 tensorflow_Java 版本tensorflow模型推理实现(基于bert命名实体、基于transform文本分类)...

package com.techwolf.transformer;import com.alibaba.fastjson.*;importcom.alibaba.fastjson.parser.Feature;importorg.tensorflow.Graph;importorg.tensorflow.Session;importorg.tensorflow.Tensor;//importcom.alibaba.fastjson.JSONPObject;//importorg.json.JSONObject;import java.io.*;importjava.nio.file.Files;importjava.nio.file.Path;importjava.nio.file.Paths;import java.util.*;

publicclassJobPredict {

private static String jsonPath= "src/main/resources/resource.json";

private static String modelPath= "src/main/resources/model.pb";

private static Map positionToFeature = new HashMap();

private static Map jobMapping = new HashMap();

private static Map mergeMapping = new HashMap();

private static Map featureToId = new HashMap();

private static Map idToCode = new HashMap();

private static Map codeToLabel = new HashMap();

public static String readJsonFile(String fileName) throws FileNotFoundException {

String jsonStr= "";try{

File jsonFile=new File(fileName);

FileReader fileReader=new FileReader(jsonFile);

Reader reader= new InputStreamReader(new FileInputStream(jsonFile), "utf-8");

int ch=0;

StringBuffer sb=new StringBuffer();while ((ch = reader.read()) != -1) {

sb.append((char) ch);

}

fileReader.close();

reader.close();

jsonStr=sb.toString();returnjsonStr;

} catch (IOException e) {

e.printStackTrace();returnnull;

}

}

private static MapjsonTOMap(JSONObject jsobj) {

Map data = new HashMap();

Iterator it=jsobj.entrySet().iterator();while(it.hasNext()) {

Map.Entry entry = (Map.Entry) it.next();

data.put(entry.getKey(), entry.getValue());

}returndata;

}

private static void getConfig() throws FileNotFoundException {

String jsonStr=readJsonFile(jsonPath);

JSONObject obj=JSON.parseObject(jsonStr);

positionToFeature= jsonTOMap(obj.getJSONObject("position2feature"));

featureToId= jsonTOMap(obj.getJSONObject("feature2id"));

jobMapping= jsonTOMap(obj.getJSONObject("job_mapping"));

mergeMapping= jsonTOMap(obj.getJSONObject("merge_mapping"));

idToCode= jsonTOMap(obj.getJSONObject("id2position"));

codeToLabel= jsonTOMap(obj.getJSONObject("position_mapping"));

System.out.println("config data loaded!");

}

public static String convert(String utfString) {

StringBuilder sb=new StringBuilder();

int i= -1;

int pos=0;

int iint=0;while ((i = utfString.indexOf("\\u", pos)) != -1) {

String sd=utfString.substring(pos, i);

sb.append(sd);

iint= i + 5;if (iint

pos= i + 6;

sb.append((char) Integer.parseInt(utfString.substring(i+ 2, i + 6), 16));

}

}

String endStr= utfString.substring(iint + 1, utfString.length());return sb + "" +endStr;

}

private static MapgetCodeAndScore(JSONArray jsonArray) throws FileNotFoundException {

List codes = new ArrayList();

List scores = new ArrayList();

Integer codeFlag= -1;

float scoreFlag=(float) .0;for (int i = 0; i < jsonArray.size(); i++) {

JSONObject skillsItem=(JSONObject) jsonArray.get(i);

String code= (skillsItem.get("code")).toString();

Float score= Float.parseFloat((String) skillsItem.get("score"));

boolean isReplace=mergeMapping.containsKey(code);if(isReplace) {

code=(mergeMapping.get(code)).toString();

System.out.println("replace id:" +code);

}

String position=(jobMapping.get(code)).toString();

Integer featSeq=(Integer) positionToFeature.get(position);if (featSeq ==null) {

codes.add((Integer) featureToId.get(codeFlag.toString()));

scores.add(scoreFlag);

}else{

Integer x=(Integer) featureToId.get(featSeq.toString());

codes.add((Integer) featureToId.get(featSeq.toString()));

scores.add(score);

}

}if (jsonArray.size() < 3) {for(int i=0; i< (3-jsonArray.size()); i++) {

codes.add((Integer) featureToId.get(codeFlag.toString()));

scores.add(scoreFlag);

}

}

Map result = new HashMap();

result.put("codes", codes);

result.put("scores", scores);returnresult;

}

private static byte[] readAllByteOrExit(Path path){try{returnFiles.readAllBytes(path);

}catch (IOException e){

System.out.println("Failed to read[" + path + "]:" +e.getMessage());

System.exit(1);

}returnnull;

}

private static MapgetDataContent(String testFile) throws FileNotFoundException {

String jsonStr=readJsonFile(testFile);

JSONObject obj=JSON.parseObject(jsonStr, Feature.OrderedField);

JSONObject objNew=JSON.parseObject(obj.toJSONString(), Feature.OrderedField);

ArrayList sampleCode = new ArrayList();

ArrayList sampleScore = new ArrayList();

Map samples = new HashMap();for(String userId: objNew.keySet()) {

ArrayList codeList = new ArrayList();

ArrayList scoresList = new ArrayList();

JSONObject itemTags= (JSONObject) ((JSONObject)((JSONObject)objNew.get(userId)).get("_source")).get("tags");

JSONArray skills= (JSONArray) itemTags.get("skills");

JSONArray title= (JSONArray) itemTags.get("title");

JSONArray desc= (JSONArray) itemTags.get("desc");

Map skillsResult =getCodeAndScore(skills);

Map titleResult =getCodeAndScore(title);

Map descResult =getCodeAndScore(desc);

codeList.addAll(skillsResult.get("codes"));

codeList.addAll(titleResult.get("codes"));

codeList.addAll(descResult.get("codes"));

scoresList.addAll(skillsResult.get("scores"));

scoresList.addAll(titleResult.get("scores"));

scoresList.addAll(descResult.get("scores"));

sampleCode.add(codeList);

sampleScore.add(scoresList);

}

samples.put("sampleCode", sampleCode);

samples.put("sampleScore", sampleScore);

System.out.println("ok! sample feature created.");returnsamples;

}

public static int[] arraySort(float[] arr, boolean desc) {

float temp;

int index;

int k=arr.length;

int[] Index=new int[k];for (int i = 0; i < k; i++) {

Index[i]=i;

}for (int i = 0; i < arr.length; i++) {for (int j = 0; j < arr.length - i - 1; j++) {if(desc) {if (arr[j] < arr[j + 1]) {

temp=arr[j];

arr[j]= arr[j + 1];

arr[j+ 1] =temp;

index=Index[j];

Index[j]= Index[j + 1];

Index[j+ 1] =index;

}

}else{if (arr[j] > arr[j + 1]) {

temp=arr[j];

arr[j]= arr[j + 1];

arr[j+ 1] =temp;

index=Index[j];

Index[j]= Index[j + 1];

Index[j+ 1] =index;

}

}

}

}returnIndex;

}

private static void featToTensor(float[][][] indexes, int[][] codes, float[][] scores, Mapdata) {

List featCode = data.get("sampleCode");

List featScore = data.get("sampleScore");

int size= 9;for(int i=0; i < featCode.size(); i++) {

Object eachCode=featCode.get(i);

Object eachScore=featScore.get(i);

float [][] positionResult=new float[size][];for(int step=0; step < size; step++) {

float[] positionVector=new float[size];

positionVector[step]= 1;

positionResult[step]=positionVector;

}

indexes[i]=positionResult;

Integer[] targetInter= ((List)eachCode).toArray(new Integer[size]);

int[] codeResult=Arrays.stream(targetInter).mapToInt(Integer::valueOf).toArray();

Float[] targetFloat= ((List)eachScore).toArray(new Float[size]);

double[] scoreResult=Arrays.stream(targetFloat).mapToDouble(Double::valueOf).toArray();

float[] scoreFloat=new float[size];for(int j=0; j < scoreResult.length; j++) {

scoreFloat[j]=(float) scoreResult[j];

}

System.arraycopy(codeResult,0,codes[i], 0, codeResult.length);

System.arraycopy(scoreFloat,0,scores[i], 0, scoreResult.length);

}

}

private static List> modelInfer(Mapdata) {

int batchSize= data.get("sampleCode").size();

int padLength= 9;

int returnNum= 5;

int classNum= 868;

float[][][] indexes=new float[batchSize][padLength][padLength];

int[][] codes=new int[batchSize][padLength];

float[][] scores=new float[batchSize][padLength];

float transKeepProb= (float) 1.0;

float multiKeepProb= (float) 1.0;

byte[] graphDef=readAllByteOrExit(Paths.get(modelPath));

Graph g=new Graph();

g.importGraphDef(graphDef);

Session sess=new Session(g);

featToTensor(indexes, codes, scores, data);

Tensor tensorIndex=Tensor.create(indexes);

Tensor tensorCode=Tensor.create(codes);

Tensor tensorScore=Tensor.create(scores);

Tensor tensorTransProb=Tensor.create(transKeepProb);

Tensor tensorMultiProb=Tensor.create(multiKeepProb);

Tensor tensorClassResult=sess.runner().

feed("input_x:0", tensorCode).

feed("input_x_score:0", tensorScore).

feed("embed_position:0", tensorIndex).

feed("trans_keep_prob:0", tensorTransProb).

feed("multi_keep_prob:0", tensorMultiProb).

fetch("discriminator/softmax_score:0").run().get(0);

float[][] result=(float[][]) tensorClassResult.copyTo(new float[batchSize][classNum]);

List> predictResult =new ArrayList();for(int i=0; i

float[] resultVec=result[i];

int[] resultIndex=new int[classNum];

HashMap predictSample = new HashMap();

resultIndex=arraySort(resultVec, true);for(int s=0; s < returnNum; s++) {

String sampleCode=Integer.toString(resultIndex[s]);

String label=(String) codeToLabel.get(Integer.toString((Integer) idToCode.get(sampleCode)));

predictSample.put(label, resultVec[s]);

}

predictResult.add(predictSample);

}

tensorClassResult.close();

tensorMultiProb.close();

tensorTransProb.close();

tensorScore.close();

tensorCode.close();

tensorIndex.close();returnpredictResult;

}

public static void main (String[]args) throws IOException {

String testFile= "src/main/data/predict_data.json";

getConfig();

Map samples =getDataContent(testFile);

List> result =modelInfer(samples);

System.out.println(result);

}

}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值