package com.hito.indai.standalone.service.algorithm;
import cn.hutool.core.util.RuntimeUtil;
import com.hito.indai.standalone.dto.ProcessDTO;
import com.hito.indai.standalone.entity.train.TrainModel;
import com.hito.indai.standalone.enums.model.TrainStatusEnum;
import com.hito.indai.standalone.exception.IndaiException;
import com.hito.indai.standalone.repository.train.TrainModelRepository;
import com.hito.indai.standalone.service.trian.ModelCacheSevice;
import com.hito.indai.standalone.util.ProccessUtil;
import com.sun.jna.Platform;
import io.micronaut.context.annotation.Value;
import lombok.extern.slf4j.Slf4j;
import javax.inject.Inject;
import javax.inject.Singleton;
import javax.transaction.Transactional;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Date;
import java.util.HashMap;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static java.util.regex.Pattern.*;
/**
* @author fuchanghai
*/
@Slf4j
@Singleton
public class ShellService {
@Value("${micronaut.server.port}")
private Integer myServerPort;
@Inject
private TrainModelRepository trainModelRepository;
private static HashMap<String, Process> processHashMap = new HashMap<>();
public String getGPU() throws IOException {
Process process = null;
try {
if (Platform.isWindows()) {
process = Runtime.getRuntime().exec("nvidia-smi.exe");
} else if (Platform.isLinux()) {
String[] shell = {"/bin/bash", "-c", "nvidia-smi"};
process = Runtime.getRuntime().exec(shell);
}
process.getOutputStream().close();
} catch (IOException e) {
e.printStackTrace();
throw new IndaiException("显卡不存在或获取显卡信息失败");
}
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
StringBuffer stringBuffer = new StringBuffer();
String line = "";
while (null != (line = reader.readLine())) {
stringBuffer.append(line + "\n");
}
return stringBuffer.toString();
}
public int startTrain(TrainModel trainModel, String configFilePath, String pictureJsonFilePath, int time) {
String doc = "\"" + getAlgorithmAbsolutePath() + File.separator + "classify_train.exe\" " + " --config_file=\"" + configFilePath + "\" --callback_ip=localhost " + " --json_file=\"" + pictureJsonFilePath + "\" --callback_port=" + myServerPort;
log.info(doc);
Process process = null;
int pid = 0;
try {
process = Runtime.getRuntime().exec(new String[]{"cmd", "/c", doc});
pid = ProccessUtil.getPid(process);
//此处先执行业务 再打印流,因为流是阻塞的,否则只有流打印完才能执行业务
beforeCallAlg(trainModel,pid,time,process);
String str = null;
BufferedReader buffer = new BufferedReader(new InputStreamReader(process.getErrorStream(), "gbk"));
while ((str = (buffer.readLine())) != null) {
log.info("输出流:" + unicodeToString(str));
}
} catch (Exception e) {
e.printStackTrace();
}
return pid;
}
public int startDetectionOrSegmentTrain(TrainModel trainModel, String configFilePath, String trainJsonFile,String validJsonFile, int time,String taskType) {
String doc = "\"" + getAlgorithmAbsolutePath() + File.separator + "instance_train.exe\" " +
" --config_file=\"" + configFilePath +
"\" --callback_ip=localhost " +
" --train_json_file=\"" + trainJsonFile +
"\" --valid_json_file=\"" + validJsonFile +
"\" --task_type=\"" + taskType +
"\" --callback_port=" + myServerPort;
log.info(doc);
Process process = null;
int pid = 0;
try {
process = Runtime.getRuntime().exec(new String[]{"cmd", "/c", doc});
pid = ProccessUtil.getPid(process);
beforeCallAlg(trainModel,pid,time,process);
String str = null;
BufferedReader buffer = new BufferedReader(new InputStreamReader(process.getErrorStream(), "gbk"));
while ((str = (buffer.readLine())) != null) {
log.info("输出流:" + unicodeToString(str));
}
} catch (Exception e) {
e.printStackTrace();
}
return pid;
}
//事务要细粒到这不然会被流阻塞
@Transactional(rollbackOn = Exception.class)
public void beforeCallAlg(TrainModel trainModel,Integer pid,Integer time ,Process process){
trainModel.setProcessId(pid);
trainModel.setTrainStatus(TrainStatusEnum.RUNNING.code());
trainModel.setBeginTrainTime(new Date());
trainModelRepository.update(trainModel);
ProcessDTO processDTO = new ProcessDTO();
processDTO.setBeginTime(trainModel.getBeginTrainTime());
processDTO.setTime(time + 1);
processDTO.setProcess(process);
processDTO.setStart(false);
ModelCacheSevice.processMap.put(trainModel.getId(), processDTO);
}
// 根据pid 杀死算法进程
public void stopTrain(int pid) {
//使用WMIC获取CPU序列号
Process process = null;
try {
process = Runtime.getRuntime().exec("taskkill /pid " + pid + " -t -f");
process.getOutputStream().close();
} catch (IOException e) {
e.printStackTrace();
}
}
//根据pid 查看算法进程是否存活
public boolean processIsAlive(Integer processId) {
String docCommand = "tasklist | findstr " + processId;
String s = RuntimeUtil.execForStr(new String[]{"cmd", "/c", docCommand});
if (s.length() == 0) {
return false;
}
return true;
}
public String getAlgorithmAbsolutePath() {
// jar包运行
File jarFile = new File(System.getProperty("java.class.path"));
String algorithmPath = jarFile.getParent() + File.separator + "algorithm";
// 本地调试
/*File file = new File("src");
String algorithmPath = new File(file.getAbsolutePath()).getParentFile().getAbsolutePath() + File.separator + "algorithm";*/
/* String algorithmPath = "C:\\Program Files\\ZhiTu\\resources\\libs\\algorithm";*/
log.info("算法代码绝对路径:" + algorithmPath);
return algorithmPath;
}
public static String unicodeToString(String str) {
Pattern pattern = compile("(\\\\u(\\p{XDigit}{4}))");
Matcher matcher = pattern.matcher(str);
char ch;
while (matcher.find()) {
ch = (char) Integer.parseInt(matcher.group(2), 16);
str = str.replace(matcher.group(1), ch+"" );
}
return str;
}
}
java 根据pid 查看进程是否存活,杀掉进程,调用python
于 2020-08-19 17:21:11 首次发布