// 定义该类所在的包路径,遵循 Java 的命名规范(通常是倒置的域名)
package com.shop.jieyou.service;
// 导入 Jackson 库的核心类,用于将 JSON 字符串解析为 Java 对象
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
// Spring 框架注解:标识此类为一个服务组件,由 Spring 容器管理生命周期
import com.shop.jieyou.entity.UserItem;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
// Java 标准库导入:用于处理输入输出流、读取外部进程输出
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
/**
* Python服务类,用于执行Python爬虫脚本并获取花卉数据
* 每次调用都会直接执行Python脚本,不进行任何结果缓存
*
* 功能说明:
* - 调用位于 resources/scripts/crawler.py 的 Python 爬虫脚本
* - 获取其标准输出(应为 JSON 格式)
* - 将 JSON 解析为 List<Map<String, Object>> 结构返回给控制器
* - 不使用缓存机制,每次请求均重新运行脚本
*/
@Service // 表示这是一个 Spring Service Bean,可被自动扫描并注入到其他组件中
public class PythonService {
// 使用 Jackson 提供的 ObjectMapper 实例来序列化/反序列化 JSON 数据
private final ObjectMapper objectMapper = new ObjectMapper();
// 定义 Python 脚本在项目中的相对路径
// 注意:此路径是相对于项目根目录的,适用于开发环境
// 生产环境中可能需要改为绝对路径或通过配置文件指定
private static final String SCRIPT_PATH = "src/main/resources/scripts/crawler.py";
/**
* 执行 Python 爬虫脚本以获取花卉信息列表
*
* 此方法会:
* 1. 启动一个新的操作系统进程来运行 Python 脚本
* 2. 捕获脚本的标准输出
* 3. 验证执行状态(退出码)
* 4. 解析输出为 Java 对象
* 5. 返回结构化数据
*
* @return 包含花卉信息的 Map 列表,每个 Map 表示一种花卉的字段(如 name, family 等)
* @throws IOException 当发生 I/O 错误(如无法启动进程、读取输出失败)时抛出
* @throws InterruptedException 当当前线程在等待进程结束时被中断,通常发生在 JVM 关闭期间
*/
public synchronized List<Map<String, Object>> getFlowers() throws IOException, InterruptedException {
// 创建一个 ProcessBuilder 实例,用于构建和启动外部进程
// 参数:"python" 是命令,SCRIPT_PATH 是要执行的脚本路径
ProcessBuilder pb = new ProcessBuilder("python", SCRIPT_PATH);
// 设置合并错误流到标准输出流
// 这样可以通过同一个 BufferedReader 同时读取正常输出和错误信息
// 便于调试问题(例如 Python 报错 ImportError)
pb.redirectErrorStream(true);
// 启动外部进程(即运行 python crawler.py)
// process 对象可用于控制该进程(等待、杀死等)
Process process = pb.start();
// 使用 try-with-resources 确保 BufferedReader 在使用后自动关闭
// InputStreamReader 将字节流转换为字符流,并指定 UTF-8 编码以正确处理中文
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(process.getInputStream(), StandardCharsets.UTF_8))) {
// 用于拼接从 Python 脚本输出的所有文本行
StringBuilder output = new StringBuilder();
String line;
// 循环读取每行输出,直到流结束(EOF)
while ((line = reader.readLine()) != null) {
// 去除每行首尾空白(包括换行符、空格),然后追加到总输出中
// 注意:这里没有添加换行符,意味着所有内容会被压缩成一行
output.append(line.trim());
}
// 等待 Python 进程执行完成,并获取其退出码
// 正常情况下应返回 0;非零值表示异常退出(如语法错误、模块未安装)
int exitCode = process.waitFor();
// 如果退出码不是 0,说明脚本执行失败
if (exitCode != 0) {
throw new RuntimeException("Python script exited with code: " + exitCode);
}
// 检查输出是否为空
// 即使脚本成功退出,也可能未打印任何有效数据
if (output.length() == 0) {
throw new RuntimeException("Python script returned empty output");
}
// 使用 Jackson 反序列化 JSON 字符串为 Java 对象
// TypeReference 是泛型辅助类,告诉 ObjectMapper 我们想要的是 List<Map<String, Object>>
// 每个 Map 对应一条花卉记录,key 是字段名(如 "name"),value 是对应值
List<Map<String, Object>> result = objectMapper.readValue(output.toString(),
new TypeReference<List<Map<String, Object>>>() {});
// 检查返回的数据中是否包含错误信息(假设 Python 脚本约定第一个元素带 error 字段表示失败)
if (!result.isEmpty() && result.get(0).containsKey("error")) {
throw new RuntimeException("爬虫错误: " + result.get(0).get("error"));
}
// 成功解析并验证后,返回花卉数据列表
return result;
} catch (Exception e) {
// 异常处理:确保即使出错也能清理系统资源
// 如果进程仍在运行,则强制终止它,防止僵尸进程或资源泄漏
if (process.isAlive()) {
process.destroyForcibly();
}
// 继续向上抛出异常,让调用者知道发生了什么
throw e;
}
}
@Autowired
JdbcTemplate jdbcTemplate;
public List<UserItem> getUserItemMatrix() {
String sql = "SELECT user_id, product_id, COUNT(*) as count " +
"FROM tb_order WHERE state = 1 " +
"GROUP BY user_id, product_id";
return jdbcTemplate.query(sql, (rs, rowNum) ->
new UserItem(
rs.getLong("user_id"),
rs.getLong("product_id"),
rs.getInt("count")
)
);
}
private static final String PYTHON_SCRIPT_PATH = "src/main/resources/scripts/python-model/infer.py";
private static final String PYTHON_ENV_PATH = "python"; // 虚拟环境
public Map<String, Object> classify(String file)
throws IOException, InterruptedException, Exception {
// 1. 校验文件类型(安全)
// if (!"image/jpeg".equals(file) && !"image/png".equals(file)) {
// throw new IllegalArgumentException("仅支持 JPG/PNG 图片");
// }
// 2. 保存临时文件
File tempFile = File.createTempFile("upload_", ".jpg");
try (FileOutputStream fos = new FileOutputStream(tempFile)) {
fos.write(file.getBytes());
}
// 3. 调用 Python 脚本
ProcessBuilder pb = new ProcessBuilder(
PYTHON_ENV_PATH,
PYTHON_SCRIPT_PATH,
tempFile.getAbsolutePath()
);
pb.redirectErrorStream(true);
Process process = pb.start();
// 4. 获取输出结果
StringBuilder output = new StringBuilder();
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(process.getInputStream()))) {
String line;
while ((line = reader.readLine()) != null) {
output.append(line);
}
}
int exitCode = process.waitFor();
if (exitCode != 0) {
throw new RuntimeException("Python 脚本执行失败,退出码: " + exitCode);
}
// 5. 解析 JSON 结果(假设 Python 返回的是 JSON 字符串)
ObjectMapper mapper = new ObjectMapper();
JsonNode jsonNode = mapper.readTree(output.toString().trim());
Map<String, Object> result = new HashMap<>();
result.put("predicted_class", jsonNode.get("predicted_class").asText());
result.put("confidence", jsonNode.get("confidence").asDouble());
return result;
}
}
# python-model/infer.py
import numpy as np
import sys
import os
import json
import logging
from PIL import Image
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
logging.getLogger("tensorflow").setLevel(logging.FATAL)
import tensorflow as tf
MODEL_PATH = "animal_classifier.h5"
IMG_SIZE = 224
CLASS_NAMES = ['cat', 'dog', 'elephant', 'fish'] # 必须与训练时一致!
def predict(image_path):
if not os.path.exists(image_path):
print(json.dumps({"error": f"图片不存在: {image_path}"}))
return
try:
# 加载模型
if not os.path.exists(MODEL_PATH):
print(json.dumps({"error": f"模型未找到: {MODEL_PATH}"}))
return
model = tf.keras.models.load_model(MODEL_PATH)
# 加载并预处理图像
img = Image.open(image_path).convert("RGB")
img = img.resize((IMG_SIZE, IMG_SIZE))
img_array = np.array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0) # 添加 batch 维度
# 预测
preds = model.predict(img_array, verbose=0)
confidence = float(np.max(preds))
label_idx = np.argmax(preds)
label = CLASS_NAMES[label_idx]
result = {
"success": True,
"predicted_class": label,
"confidence": round(confidence, 4),
"all_probabilities": {
CLASS_NAMES[i]: round(float(preds[0][i]), 4) for i in range(len(CLASS_NAMES))
}
}
print(json.dumps(result))
except Exception as e:
print(json.dumps({"error": f"预测失败: {str(e)}"}))
if __name__ == "__main__":
if len(sys.argv) != 2:
print(json.dumps({"error": "用法: python infer.py <image_path>"}))
else:package com.shop.jieyou.controller;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.shop.jieyou.common.Result;
import com.shop.jieyou.entity.UserItem;
import com.shop.jieyou.service.ItemService;
import com.shop.jieyou.service.PythonService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
/**
* 花卉相关接口控制器
* 提供三大功能:
* 1. 获取中国十大名花数据(来自爬虫或缓存)
* 2. 手动刷新花卉数据(强制重新爬取)
* 3. 基于用户行为的花卉推荐(调用Python协同过滤脚本)
*/
@RestController
@CrossOrigin(origins = "*") // 允许所有域访问,用于前端开发调试(生产环境建议限制域名)
@RequestMapping("/api")
public class FlowerController {
@Autowired
private PythonService pythonService; // 注入业务服务层,处理数据获取与推荐逻辑
@Autowired
private ItemService itemService;
/**
* 接口:GET /api/flowers
* 功能:获取“中国十大名花”数据列表
* 数据来源:可能来自数据库、Redis 缓存 或 调用 Python 爬虫脚本
*
* @return Result<List<Map<String, Object>>> 返回包含花卉信息的成功响应
*/
@GetMapping("/flowers")
public Result<List<Map<String, Object>>> getTopTenFlowers() {
try {
// 调用服务层获取花卉数据(内部可能带缓存机制)
List<Map<String, Object>> flowers = pythonService.getFlowers();
return Result.success(flowers); // 成功返回数据
} catch (Exception e) {
// 捕获异常并统一返回错误码和消息,避免暴露堆栈给前端
return Result.error("500", "获取花卉数据失败:" + e.getMessage());
}
}
/**
* 接口:POST /api/flowers/refresh
* 功能:强制刷新花卉数据缓存,触发重新爬取
* 使用场景:管理员手动更新数据时调用
*
* @return Result<Map<String, Object>> 返回刷新结果信息
*/
@PostMapping("/flowers/refresh")
public Result<Map<String, Object>> refreshData() {
try {
// TODO: 如果实现了 clearCache 方法,请取消注释并调用
// pythonService.clearCache(); // 清除旧缓存,下次 getFlowers 将重新爬取
// 重新获取最新数据(假设此时会触发爬虫)
List<Map<String, Object>> flowers = pythonService.getFlowers();
// 构造返回信息
Map<String, Object> data = new HashMap<>();
data.put("message", "数据已刷新");
data.put("count", flowers.size());
return Result.success(data);
} catch (Exception e) {
return Result.error("500", "刷新失败:" + e.getMessage());
}
}
// ========== 推荐系统相关常量定义 ==========
/**
* 输入文件路径:Java 将用户-商品行为数据写入此 JSON 文件供 Python 脚本读取
* 注意:src/main/resources 是编译后打包进 jar 的资源目录,不适合运行时写入!
* 建议改为外部路径如 "./data/input.json"
*/
private static final String INPUT_PATH = "src/main/resources/scripts/input.json";
/**
* 输出文件路径:Python 脚本将推荐结果写入此文件,Java 再读取返回给前端
*/
private static final String OUTPUT_PATH = "src/main/resources/scripts/output.json";
/**
* Python 协同过滤脚本路径
* 注意:resources 目录下的 .py 文件在打包后无法直接作为可执行脚本运行
* 更佳做法是将脚本放在项目外部或使用 ProcessBuilder 启动独立服务
*/
private static final String PYTHON_SCRIPT = "src/main/resources/scripts/collaborative.py";
/**
* 接口:GET /api/recommend?userId=123
* 功能:为指定用户生成个性化花卉推荐列表
* 实现方式:Java 查询数据库 → 写入 JSON 文件 → 调用 Python 脚本计算 → 读取结果返回
*
* @param userId 用户ID,必填参数
* @return Result<JsonNode> 推荐的商品ID数组(如 [101, 105, 108])
*/
@GetMapping("/recommend")
public Result recommendFlowers(@RequestParam("userId") Long userId) {
try {
// 1. 获取用户行为数据
List<UserItem> matrix = pythonService.getUserItemMatrix();
// 2. 调用 Python 脚本(通过 stdin/stdout 通信)
ProcessBuilder pb = new ProcessBuilder("python", PYTHON_SCRIPT, String.valueOf(userId));
pb.redirectErrorStream(true); // 合并错误流
Process process = pb.start();
// 3. 将数据写入脚本的标准输入
ObjectMapper mapper = new ObjectMapper();
mapper.writeValue(process.getOutputStream(), matrix);
process.getOutputStream().close(); // 关闭输入,通知Python结束读取
// 4. 读取Python脚本的输出(推荐结果)
JsonNode result = mapper.readTree(process.getInputStream());
// 5. 等待脚本执行完毕
int exitCode = process.waitFor();
if (exitCode != 0) {
return Result.error("500", "Python script failed with exit code: " + exitCode);
}
System.out.println(result);
return Result.success(result);
} catch (Exception e) {
e.printStackTrace();
return Result.error("500", "推荐生成失败:" + e.getMessage());
}
}
@PostMapping("/predict")
public ResponseEntity<String> predict(@RequestParam("file") MultipartFile file) {
try {
// 保存上传的文件到临时路径
String tempDir = System.getProperty("java.io.tmpdir");
File tempFile = new File(tempDir, file.getOriginalFilename());
file.transferTo(tempFile);
// 调用 Python 脚本执行预测
ProcessBuilder pb = new ProcessBuilder(
"D:\\Python\\python.exe",
"D:/DevCode/商城/Shop-master/shop-springboot/src/main/resources/scripts/image_classifier.py",
"predict",
tempFile.getAbsolutePath()
);
pb.redirectErrorStream(true); // 合并 stdout 和 stderr
Process process = pb.start();
BufferedReader reader = new BufferedReader(new InputStreamReader(process.getInputStream()));
StringBuilder output = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
output.append(line);
}
int exitCode = process.waitFor();
if (exitCode == 0) {
return ResponseEntity.ok(output.toString().trim());
} else {
return ResponseEntity.status(500).body("{\"error\": \"Prediction failed\"}");
}
} catch (Exception e) {
return ResponseEntity.status(500).body("{\"error\": \"" + e.getMessage() + "\"}");
}
}
private static final String PYTHON_EXECUTABLE = "D:\\Python\\python.exe"; // 或 "python3"
private static final String INFER_SCRIPT_PATH = "D:/DevCode/商城/Shop-master/shop-springboot/src/main/resources/scripts/python-model/infer.py";
@PostMapping("/uploadPython")
public Result<Map<String, Object>> classifyImage(@RequestParam String file) {
if (file.isEmpty()) {
return Result.error("400", "文件为空");
}
try {
// 只做参数传递,逻辑交给 Service
Map<String, Object> result = pythonService.classify(file);
return Result.success(result);
} catch (IOException e) {
return Result.error("500", "文件处理失败:" + e.getMessage());
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return Result.error("500", "处理被中断");
} catch (Exception e) {
return Result.error("500", "识别出错:" + e.getMessage());
}
}
}
predict(sys.argv[1])
最新发布