阳煦山立智慧医疗系统开发团队——项目记录(4)

主要记录一下最后一周的进度。

1.后端更新

Controller类:

BaseController:控制器通过泛型支持不同类型的实体对象,利用依赖注入和通用的服务层方法,提供了标准的保存、删除等操作,同时处理了请求、响应和会话管理等常见问题。

主要实现:

save 方法用于保存传入的对象 obj,并调用通用的服务层方法 service.save(obj) 完成保存操作。在保存成功或失败后,返回相应的响应结果;delete 方法用于根据传入的主键 ID 删除对象。在删除前会检查 ID 的有效性,如果删除成功则返回成功的响应,否则返回失败或数据不存在的响应;setReqAndRes 方法用于在每个子类方法调用之前设置请求、响应和会话对象,以及当前登录用户和疾病种类列表等常用信息,方便后续方法调用时使用。

UserController :是一个基于 Spring MVC 的 RESTful 风格控制器,负责处理与用户相关的操作,包括修改用户资料和修改用户密码。它继承了通用的基础控制器 BaseController<User>,从而获得了保存和删除等通用操作的实现,同时可以自定义处理用户特定的业务逻辑。

主要实现:修改资料 (saveProfile 方法)、修改密码 (savePassword 方法)

MedicineController:被设计为一个基于Spring MVC框架的RESTful控制器,使用@RestController注解,表明其直接返回数据而非视图。它继承自BaseController<Medicine>,通过依赖注入和泛型支持实现常见的CRUD(创建、读取、更新、删除)操作。

IllnessController:是一个基于Spring MVC框架的RESTful控制器,使用@RestController注解标识其直接返回数据而非视图。它继承自BaseController<Illness>,利用泛型支持和依赖注入的方式实现对疾病实体的常规操作。

Service服务类:

BaseService:在我们的项目中,有些服务类需要用到不同类型的数据和业务逻辑,为了提高代码的复用性和可维护性,我们设计了一个基础服务抽象类BaseService,,它包含了一些通用的功能和数据访问操作。

UserService: 在智慧医疗系统中,用户管理模块需要提供高效且安全的用户数据操作服务,包括用户注册、登录、身份验证及权限管理。同时,还需要管理疾病信息、药物信息及病历信息,并通过智能问答系统提供精准的医疗建议。

主要实现:根据条件查询用户、查询所有用户、保存或更新用户、获取用户、删除用户

代码如下:

package com.SmartMed_Connect.service;
 
import com.SmartMed_Connect.dao.UserDao;
import com.SmartMed_Connect.entity.User;
import com.SmartMed_Connect.utils.Assert;
import com.SmartMed_Connect.utils.BeanUtil;
import com.SmartMed_Connect.utils.VariableNameUtils;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
 
import java.io.Serializable;
import java.util.List;
import java.util.Map;
 
@Service
public class UserService extends BaseService<User> {
 
    @Autowired
    protected UserDao userDao;
 
    /**
     * 根据条件查询 User 对象列表
     *
     * @param o 查询条件封装的 User 对象
     * @return 符合条件的 User 对象列表
     */
    @Override
    public List<User> query(User o) {
        QueryWrapper<User> wrapper = new QueryWrapper();
        if (Assert.notEmpty(o)) {
            // 将对象转换为 Map 形式
            Map<String, Object> bean2Map = BeanUtil.bean2Map(o);
            for (String key : bean2Map.keySet()) {
                if (Assert.isEmpty(bean2Map.get(key))) {
                    continue;// 跳过空值字段
                }
                // 使用下划线形式的字段名查询
                //VariableNameUtils.humpToLine(key) 是一个工具方法,用于将驼峰命名法的属性名转换为下划线分隔的数据库字段名。
                // 例如,如果 key 是 "userName",则转换后变为 "user_name"。
                wrapper.eq(VariableNameUtils.humpToLine(key), bean2Map.get(key));
            }
        }
        return userDao.selectList(wrapper);
    }
 
    /**
     * 查询所有的 User 对象
     *
     * @return 所有的 User 对象列表
     */
    @Override
    public List<User> all() {
        return query(null);
    }
 
    /**
     * 保存或更新一个 User 对象
     *
     * @param o 要保存的 User 对象
     * @return 保存后的 User 对象
     */
    @Override
    public User save(User o) {
        // 如果 ID 为空,则插入新的对象
        if (Assert.isEmpty(o.getId())) {
            userDao.insert(o);
        } else {
            // 如果 ID 不为空,则更新已有对象
            userDao.updateById(o);
        }
        return userDao.selectById(o.getId());
    }
 
    /**
     * 根据 ID 获取一个 User 对象
     *
     * @param id 对象的主键 ID
     * @return 获取到的 User 对象
     */
    @Override
    public User get(Serializable id) {
        return userDao.selectById(id);
    }
 
    /**
     * 根据 ID 删除一个 User 对象
     *
     * @param id 要删除的对象的主键 ID
     * @return 删除操作影响的行数(一般为1表示成功删除,0表示未找到)
     */
    @Override
    public int delete(Serializable id) {
        return userDao.deleteById(id);
    }
}

IllnessService :智慧医疗系统中,疾病管理模块需要提供高效且安全的疾病数据操作服务,包括疾病信息的查询、保存、更新及删除等操作。同时,还需要管理与疾病相关的药物信息及页面浏览量信息。

主要实现:根据条件查询疾病、查询所有疾病、保存或更新疾病、获取疾病、删除疾病、查找疾病列表的方法、查找单个疾病的方法、获取单个疾病的方法

代码如下:

package com.SmartMed_Connect.service;
 
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.map.MapUtil;
import cn.hutool.core.util.ObjectUtil;
import com.SmartMed_Connect.dao.IllnessDao;
import com.SmartMed_Connect.entity.*;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import com.SmartMed_Connect.entity.*;
import com.SmartMed_Connect.utils.Assert;
import com.SmartMed_Connect.utils.BeanUtil;
import com.SmartMed_Connect.utils.VariableNameUtils;
 
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
 
@Service
public class IllnessService extends BaseService<Illness> {
 
    @Autowired
    protected IllnessDao illnessDao;
 
    @Override
    public List<Illness> query(Illness o) {
        QueryWrapper<Illness> wrapper = new QueryWrapper();
        if (Assert.notEmpty(o)) {
            Map<String, Object> bean2Map = BeanUtil.bean2Map(o);
            for (String key : bean2Map.keySet()) {
                if (Assert.isEmpty(bean2Map.get(key))) {
                    continue;
                }
                wrapper.eq(VariableNameUtils.humpToLine(key), bean2Map.get(key));
            }
        }
        return illnessDao.selectList(wrapper);
    }
 
    @Override
    public List<Illness> all() {
        return query(null);
    }
 
    @Override
    public Illness save(Illness o) {
        if (Assert.isEmpty(o.getId())) {
            illnessDao.insert(o);
        } else {
            illnessDao.updateById(o);
        }
        return illnessDao.selectById(o.getId());
    }
 
    @Override
    public Illness get(Serializable id) {
        return illnessDao.selectById(id);
    }
 
    @Override
    public int delete(Serializable id) {
        return illnessDao.deleteById(id);
    }
 
    // 查找疾病列表的方法,根据疾病类别、名称和分页参数进行查询
    public Map<String, Object> findIllness(Integer kind, String illnessName, Integer page) {
        // 创建一个初始容量为 4 的 HashMap,用于存储最终结果
        Map<String, Object> map = new HashMap<>(4);
        // 创建一个 QueryWrapper 对象,用于构建查询条件
        QueryWrapper<Illness> illnessQueryWrapper = new QueryWrapper<>();
        // 如果 illnessName 非空,则添加模糊查询条件
        if (Assert.notEmpty(illnessName)) {
            illnessQueryWrapper
                    .like("illness_name", illnessName)// 疾病名称包含 illnessName
                    .or()
                    .like("include_reason", illnessName)// 包含原因包含 illnessName
                    .or()
                    .like("illness_symptom", illnessName)// 疾病症状包含 illnessName
                    .or()
                    .like("special_symptom", illnessName);// 特殊症状包含 illnessName
        }
 
        // 如果 kind 非空,则添加疾病种类的查询条件
        if (kind != null) {
            if (Assert.notEmpty(illnessName)) {
                // 如果 illnessName 非空,且 kind 非空,则添加种类 ID 条件,并按创建时间降序排序,同时进行分页
                illnessQueryWrapper.last("and (kind_id = " + kind + ") ORDER BY create_time DESC limit " + (page - 1) * 9 + "," + page * 9);
            } else {
                // 如果 illnessName 为空,但 kind 非空,则仅按种类 ID 查询,并按创建时间降序排序,同时进行分页
                illnessQueryWrapper.eq("kind_id", kind);
                illnessQueryWrapper.orderByDesc("create_time");
                illnessQueryWrapper.last("limit " + (page - 1) * 9 + "," + page * 9);
            }
        } else {
            illnessQueryWrapper.orderByDesc("create_time");
            illnessQueryWrapper.last("limit " + (page - 1) * 9 + "," + page * 9);
 
        }
        // 获取符合条件的疾病记录的总数
        int size = illnessDao.selectMaps(illnessQueryWrapper).size();
        // 查询符合条件的疾病记录
        List<Map<String, Object>> list = illnessDao.selectMaps(illnessQueryWrapper);
        list.forEach(l -> {
            Integer id = MapUtil.getInt(l, "id");// 获取疾病 ID
            Pageview pageInfo = pageviewDao.selectOne(new QueryWrapper<Pageview>().eq("illness_id", id));// 获取对应的页面浏览信息
            l.put("kindName", "暂无归属类");// 初始化 kindName 为 "暂无归属类"
            l.put("create_time", MapUtil.getDate(l, "create_time"));// 格式化创建时间
            l.put("pageview", pageInfo == null ? 0 : pageInfo.getPageviews());// 添加页面浏览量信息
            Integer kindId = MapUtil.getInt(l, "kind_id");// 获取疾病种类 ID
            if (Assert.notEmpty(kindId)) {
                IllnessKind illnessKind = illnessKindDao.selectById(kindId);// 根据种类 ID 查询对应的疾病种类
                if (Assert.notEmpty(illnessKind)) {
                    l.put("kindName", illnessKind.getName());// 如果疾病种类存在,则更新 kindName
                }
            }
        });
        map.put("illness", list);
        map.put("size", size < 9 ? 1 : size / 9 + 1);
        return map;
    }
 
    public Map<String, Object> findIllnessOne(Integer id) {
        // 查询 Illness 表,根据 ID 查找疾病
        Illness illness = illnessDao.selectOne(new QueryWrapper<Illness>().eq("id", id));
        // 查询 IllnessMedicine 表,找到所有与该疾病 ID 相关的药物
        List<IllnessMedicine> illnessMedicines = illnessMedicineDao.selectList(new QueryWrapper<IllnessMedicine>().eq("illness_id", id));
        // 初始化一个列表来存储药物
        List<Medicine> list = new ArrayList<>(4);
        // 初始化一个 Map 来存储返回结果
        Map<String, Object> map = new HashMap<>(4);
        // 查询 Pageview 表,根据疾病 ID 查找页面浏览记录
        Pageview illness_id = pageviewDao.selectOne(new QueryWrapper<Pageview>().eq("illness_id", id));
        // 如果没有找到对应的页面浏览记录,创建一个新的记录并插入数据库
        if (Assert.isEmpty(illness_id)) {
            illness_id = new Pageview();
            illness_id.setIllnessId(id);
            illness_id.setPageviews(1);
            pageviewDao.insert(illness_id);
        } else {
            // 如果找到了对应的页面浏览记录,更新浏览次数并保存
            illness_id.setPageviews(illness_id.getPageviews() + 1);
            pageviewDao.updateById(illness_id);
        }
        // 将疾病信息放入返回结果的 Map 中
        map.put("illness", illness);
 
        // 如果找到了相关的药物,将每个药物的信息放入列表中
        if (CollUtil.isNotEmpty(illnessMedicines)) {
            illnessMedicines.forEach(illnessMedicine -> {
                // 查询 Medicine 表,根据药物 ID 查找药物
                Medicine medicine = medicineDao.selectOne(new QueryWrapper<Medicine>().eq("id", illnessMedicine.getMedicineId()));
                // 如果药物不为空,将其添加到列表中
                if (ObjectUtil.isNotNull(medicine)) {
                    list.add(medicine);
                }
            });
            // 将药物列表放入返回结果的 Map 中
            map.put("medicine", list);
 
        }
        // 返回包含疾病和药物信息的 Map
        return map;
    }
 
    public Illness getOne(QueryWrapper<Illness> queryWrapper) {
        return illnessDao.selectOne(queryWrapper);
    }
}

MedicineService: 是一个用于管理药品信息的服务类。它提供了增、删、改、查等基本操作以及根据药品名称或关键字模糊查询和分页功能。

主要实现:根据条件查询药品、查询所有药品、保存或更新药品、获取药品、删除药品、查找药品列表并分页返回结果、查找单个药品的方法、获取单个药品的方法

代码如下:

package com.SmartMed_Connect.service;
 
import com.SmartMed_Connect.dao.MedicineDao;
import com.SmartMed_Connect.entity.Medicine;
import com.SmartMed_Connect.utils.Assert;
import com.SmartMed_Connect.utils.BeanUtil;
import com.SmartMed_Connect.utils.VariableNameUtils;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
 
import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
 
@Service
public class MedicineService extends BaseService<Medicine> {
 
    @Autowired
    protected MedicineDao medicineDao;
 
    /**
     * 根据条件查询 Medicine 对象列表
     *
     * @param o 查询条件封装的 Medicine 对象
     * @return 符合条件的 Medicine 对象列表
     */
    @Override
    public List<Medicine> query(Medicine o) {
        QueryWrapper<Medicine> wrapper = new QueryWrapper();
        if (Assert.notEmpty(o)) {
            // 将对象转换为 Map 形式
            Map<String, Object> bean2Map = BeanUtil.bean2Map(o);
            for (String key : bean2Map.keySet()) {
                // 跳过空值字段
                if (Assert.isEmpty(bean2Map.get(key))) {
                    continue;
                }
                // 使用下划线形式的字段名查询
                wrapper.eq(VariableNameUtils.humpToLine(key), bean2Map.get(key));
            }
        }
        return medicineDao.selectList(wrapper);
    }
 
    /**
     * 查询所有的 Medicine 对象
     *
     * @return 所有的 Medicine 对象列表
     */
    @Override
    public List<Medicine> all() {
        return query(null);
    }
 
    /**
     * 保存或更新一个 Medicine 对象
     *
     * @param o 要保存的 Medicine 对象
     * @return 保存后的 Medicine 对象
     */
    @Override
    public Medicine save(Medicine o) {
        // 如果 ID 为空,则插入新的对象
        if (Assert.isEmpty(o.getId())) {
            medicineDao.insert(o);
        } else {
            // 如果 ID 不为空,则更新已有对象
            medicineDao.updateById(o);
        }
        return medicineDao.selectById(o.getId());
    }
 
    /**
     * 根据 ID 获取一个 Medicine 对象
     *
     * @param id 对象的主键 ID
     * @return 获取到的 Medicine 对象
     */
    @Override
    public Medicine get(Serializable id) {
        return medicineDao.selectById(id);
    }
 
    /**
     * 根据 ID 删除一个 Medicine 对象
     *
     * @param id 要删除的对象的主键 ID
     * @return 删除操作影响的行数(一般为1表示成功删除,0表示未找到)
     */
    @Override
    public int delete(Serializable id) {
        return medicineDao.deleteById(id);
    }
 
    /**
     * 根据药品名称或关键字模糊查询药品列表,并分页返回结果
     *
     * @param nameValue 药品名称或关键字
     * @param page      分页页码
     * @return 包含药品列表和分页信息的 Map
     */
    public Map<String, Object> getMedicineList(String nameValue, Integer page) {
 
        List<Medicine> medicineList = null;
        Map<String, Object> map = new HashMap<>(4);
        // 根据传入的名称或关键字进行模糊查询,支持分页
        if (Assert.notEmpty(nameValue)) {
            medicineList = medicineDao.selectList(new QueryWrapper<Medicine>().
                    like("medicine_name", nameValue)
                    .or().like("keyword", nameValue)
                    .or().like("medicine_effect", nameValue)
                    .last("limit " + (page - 1) * 9 + "," + page * 9));
        } else {
            medicineList = medicineDao.selectList(new QueryWrapper<Medicine>()
                    .last("limit " + (page - 1) * 9 + "," + page * 9));
        }
        // 将查询结果放入返回的 Map 中
        map.put("medicineList", medicineList);
        // 计算分页总数,并放入返回的 Map 中
        map.put("size", medicineList.size() < 9 ? 1 : medicineList.size() / 9 + 1);
        return map;
    }
}

其他 

  1. 新增IpUtil工具类,提供获取用户Ip地址的功能

  2. 新增MapUtil工具类,提供使用高德地图WebAPI服务进行的IP定位和POI搜索功能,可以根据用户所在城市,搜索当前城市内的所有综合医院

  3. 新增Controller GeographyController,该Controller包含获取地理信息的接口GeoInfo,通过访问该接口可以实现上述功能,获取市内所有的综合医院。

  4. 大改了ApiService.java服务类,触发智能问诊模式之后会自动保存相应的病史信息,同时给出问诊结果的时候会自动整合病史和当前的症状,从而给出更加合理的诊断结果

//查找病史
List<PatientHistory> PatientHistoryList = patientHistoryService.findByUserId(userController.loginUser.getId());
// 将病史列表中的元素转换成字符串并连接起来
String historyString = PatientHistoryList.stream()
                            .map(PatientHistory::toString)
                            .collect(Collectors.joining("。      "));
                    //打印查看病史列表
System.out.println(historyString);
patientInfo.toString();
//整合病史和当前病情
messages.add(createMessage(Role.USER,"我的的以往病史是:"+historyString+
                            "            " +
                            "我的当前的病症是:"+patientInfo.toString()+
                            "            " +
                            "请你针对我的的当前症状给出相应的建议,如果当前病症和我的病史相关那就需要考虑一下病史"));

 5.智慧轮询的问题更加人格化,亲切自然,并且为用户提供了更加明确的回答方向,这样收集的信息回更加准确,同时也能提高用户的使用体验。

    private String getNextQueryMessage() {
        queryStep++;
        switch (queryStep) {
            case 1:
                return "提供一下你最近的身高和体重信息";
            case 2:
                return "你目前有哪些症状?能详细和我说一下";
            case 3:
                return "病情发作的细节你有留意吗?,比如发作时间、持续时间和发作频率。";
            case 4:
                return "你的平时生活习惯怎样?比如饮食、运动、睡眠等方面。";
            case 5:
                return "请大致描述一下你的病史,包括曾经患过的疾病、手术史等。";
            case 6:
                return "你是否对某些药物过敏?如果有,请说一下你过敏的药物信息";
            case 7:
                return "你目前有那些正在使用的药物?请提供正在使用的药物清单。";
            default:
                break;
        }
        return null;
    }

 autoDL每个实例都留了端口6006暴露在外以提供服务,这就给了我们通过这个端口访问部署在AutoDL上的MING的可能性。

不过想要访问该接口,需要以SSH隧道方式进行访问。

通过autodl一键安装Langchain-Chatchat_langchain-chatglm3一键部署安装包-CSDN博客提供的ssh端口转发的工具,可以将远程服务的端口转发到本地上,并访问本地端口来访问autodl上的接口,便可以与AutoDL上的ming_api_server进行通信。

基于MING模型的简单智能对话服务MingAPIService的实现

搜集病人的病历信息加入对话信息里,以增强对话质量。实现了对话功能、消息的存储与删除和信息增强。

代码如下:

package com.SmartMed_Connect.service;
 
import com.SmartMed_Connect.controller.UserController;
import com.SmartMed_Connect.dao.ChatMessageDao;
import com.SmartMed_Connect.dto.MessageWrapper;
import com.SmartMed_Connect.entity.ChatMessage;
import com.SmartMed_Connect.entity.PatientHistory;
import com.SmartMed_Connect.utils.ClientUtil;
import com.SmartMed_Connect.utils.IpUtil;
import com.SmartMed_Connect.utils.MapUtil;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
 
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
 
import javax.servlet.http.HttpServletRequest;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
 
@Service
public class MingAPIService {
 
 
    @Autowired
    protected PatientHistoryService patientHistoryService;
 
    @Autowired
    private UserController userController;
 
    @Autowired
    protected UserService userService;
 
    @Autowired
    private ChatMessageDao chatMessageMapper;
    @Autowired
    private HttpServletRequest request;
    @Value("${ming.url}")
    String mingURL;
    private final List<MessageWrapper> messageCacheList = new ArrayList<>();
 
    private String createQueryContent(Integer userId,String queryMessage){
 
        PatientHistory patientHistory = patientHistoryService.findByUserId(userId).get(0);
 
        String historyString = patientHistory.toString();
        String patientHistoryInfo = "病人的以往病史是:"+historyString+"请你结合这些信息给出相应的建议";
 
        return queryMessage+patientHistoryInfo;
 
 
    }
    public String query(String queryMessage){
        Map<String,Object> params = new HashMap<>();
 
        params.put("model","ming-moe-4B");
        params.put("n",1);
        params.put("temperature",0.5);
        params.put("max_tokens",3072);
 
        Integer userId = userController.loginUser.getId();
        String queryContent = queryMessage;
//        List<MessageWrapper> messages;
        if(messageCacheList.isEmpty()){
            fetchUserChatHistory(userId);
            if(messageCacheList.isEmpty()){
                queryContent = createQueryContent(userId,queryMessage);
            }
        }
 
 
        if(queryMessage.equals("/newchat")){
            messageCacheList.clear();
            clearUserChatHistory(userId);
            return "已清空聊天记录";
        }
 
        MessageWrapper userMessage = new MessageWrapper("user",queryContent);
        messageCacheList.add(userMessage);
 
        params.put("messages",messageCacheList);
        for(MessageWrapper messageWrapper:messageCacheList){
            System.out.println(messageWrapper);
        }
 
        JSONObject result = ClientUtil.sendPostRequest(mingURL,params);
 
        System.out.println(result);
        if(result==null){
            return "错误:连接失败";
        }
 
        JSONArray choices = result.getJSONArray("choices");
 
        if(choices==null){
 
            return "choice == null"+result.toJSONString();
        }
 
 
        MessageWrapper responseMessage = new MessageWrapper();
        for(int i =0;i<choices.size();++i){
            JSONObject messageJson = choices.getJSONObject(i).getJSONObject("message");
 
            String role  = messageJson.getString("role");
            String content = messageJson.getString("content");
 
            responseMessage.setRole(role);
            responseMessage.setContent(content);
        }
        messageCacheList.add(responseMessage);
        saveUserChatHistory(userId,messageCacheList);
 
        String resultContent = responseMessage.getContent();
        if(checkIfNeedHospitalLocationInfo(queryMessage)){
            String userIP = IpUtil.getIpAddress(request);
            String mapInfo = MapUtil.getAddressByIP(userIP);
            resultContent += mapInfo;
        }
        return resultContent;
    }
    private boolean checkIfNeedHospitalLocationInfo(String content){
        String keyword = "医院";
        // 编译正则表达式
        Pattern pattern = Pattern.compile(keyword);
        // 创建匹配器对象
        Matcher matcher = pattern.matcher(content);
        // 使用 find() 方法查找是否包含关键词
        boolean containsKeyword = matcher.find();
 
        if (containsKeyword) {
            return true;
        }
        else{
            return false;
        }
    }
 
 
    private void clearUserChatHistory(int userId){
        chatMessageMapper.clearChatMessageByUserId(userId);
    }
    private void saveUserChatHistory(Integer userId,List<MessageWrapper> messageList){
        for(MessageWrapper message: messageList){
            ChatMessage chatMessage = new ChatMessage(null, userId, message.getRole(), message.getContent());
            chatMessageMapper.insert(chatMessage);
        }
 
    }
    private void fetchUserChatHistory(int userId){
        List<ChatMessage> userChatHistory = chatMessageMapper.findChatMessageByUserId(userId);
//        List<MessageWrapper> messageCache = new ArrayList<>();
        if(userChatHistory==null){
            System.out.println("fetch chat history failed");
            return ;
        }
        for(ChatMessage chatMessage:userChatHistory){
            MessageWrapper messageWrapper = new MessageWrapper(chatMessage.getRole(),chatMessage.getContent());
            if(messageWrapper.getRole() !=null&& messageWrapper.getContent()!=null){
                messageCacheList.add(messageWrapper);
            }
 
        }
    }
}

2.前端

反馈在前端显示一直有乱码的问题,调查后发现是前端代码标签不匹配导致的,经调整后前端核心部分如下。

<tbody>
<tr th:each="feedback:${feedbackList}">
    <td th:text="${feedback.id}"></td>
    <td th:text="${feedback.name}"></td>
    <td th:text="${feedback.email}"></td>
    <td th:text="${feedback.title}"></td>
    <td th:text="${feedback.content}"></td>
    <td th:text="${#dates.format(feedback.createTime, 'yyyy-MM-dd HH:mm:ss')}"></td>
    <td><a th:onclick="deleteFeedback([[${feedback.id}]])" href="#">
    <span>class="label text-success">删除</span></a>
	</td>
</tr>
</tbody>

现在反馈功能可以正常显示。

修改导航栏,底栏部分的前端界面

实现医院输出换行。 

显示病人对应病历

制作了新的对话页面smart_doctor.html,用来存放微调大模型、或者结合对话和百科信息。修改了Base Controller,Message Controler和System Controller,用来存放新的页面。

更新:

index.html——首页,
medicine.html——药品,
profile.html——个人信息,
search-illness.html——搜索,
smart-doctor.html——智慧医生,
400.html——错误
401.html——错误
404.html——错误
500.html——错误

前端效果一览:

3.大模型尝试

千问与Ming的融合:

基座融合代码如下。

import torch
from transformers import AutoModel, AutoTokenizer

# 本地模型目录
qwen_model_path = "./qwen/Qwen1.5-1.8B-Chat"
ming_model_path = "./ming"

# 加载模型和分词器
try:
    tokenizer = AutoTokenizer.from_pretrained(qwen_model_path)
    qwen_model = AutoModel.from_pretrained(qwen_model_path)
    ming_model = AutoModel.from_pretrained(ming_model_path)
except Exception as e:
    print(f"Error loading models: {e}")
    exit(1)

# 获取模型的权重字典
qwen_state_dict = qwen_model.state_dict()
ming_state_dict = ming_model.state_dict()

# 将MING-MOE模型的权重复制到Qwen基座模型中,只复制匹配的权重
missing_keys = []
for key in ming_state_dict.keys():
    if key in qwen_state_dict:
        qwen_state_dict[key] = ming_state_dict[key]
    else:
        missing_keys.append(key)

if missing_keys:
    print(f"The following keys were not found in the Qwen base model state_dict and were skipped: {missing_keys}")

# 加载更新后的权重到Qwen基座模型
try:
    qwen_model.load_state_dict(qwen_state_dict)
    print("Weights loaded successfully.")
except Exception as e:
    print(f"Error loading state_dict: {e}")
    exit(1)

# 将更新后的模型保存到本地
new_model_path = "./Ming_YX/"
try:
    qwen_model.save_pretrained(new_model_path)
    tokenizer.save_pretrained(new_model_path)
    print(f"Model and tokenizer saved to {new_model_path}")
except Exception as e:
    print(f"Error saving model or tokenizer: {e}")
    exit(1)

这段代加载两个预训练模型(Qwen 模型和 Ming 模型),并检查是否加载成功。如果加载失败会打印错误信息并退出程序。

然后,这段代码试图获取 Qwen 模型和 Ming 模型的权重字典。将 Ming 模型的权重复制到 Qwen 模型的权重字典中,只复制那些在 Qwen 模型中存在的权重(匹配的权重)。在这之后,程序会打印出在 Qwen 模型中没有找到的权重键。将更新后的权重加载到 Qwen 模型中。如果加载失败会打印错误信息并退出程序。将更新后的模型和分词器保存到本地目录。如果保存失败会打印错误信息并退出程序。

 尝试将融合后的新模型导入到阿里魔搭中。

ming_api_server参照openAI的RESTful API和fastchat,实现了提供springboot后端访问的对话生成接口。其通过与model_worker通信来获取和解析模型生成的文本内容,并能够接收和返回json格式化信息。

实现了:

文本生成:仿照OpenAI的RESTful API风格,访问地址/v1/chat/completions时可以获得模型聊天的输出
多轮对话:把模型过往的对话记录缓存起来,生成的时候将这些聊天记录全部放在请求中

完整代码如下:

import asyncio
import argparse
import json
import os
from typing import Generator, Optional, Union, Dict, List, Any
 
import aiohttp
import fastapi
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
import httpx
 
from ming.conversations import conv_templates, get_default_conv_template, SeparatorStyle
from pydantic_settings import BaseSettings
import shortuuid
import tiktoken
import uvicorn
 
from fastchat.constants import (
    WORKER_API_TIMEOUT,
    ErrorCode,
)
 
from fastchat.protocol.openai_api_protocol import (
    ChatCompletionRequest,
    ChatCompletionResponse,
    ChatMessage,
    ChatCompletionResponseChoice,
    ErrorResponse,
    UsageInfo,
)
from fastchat.utils import build_logger
 
logger = build_logger("ming_api_server", "ming_api_server.log")
 
conv_template_map = {}
 
fetch_timeout = aiohttp.ClientTimeout(total=3 * 3600)
 
 
async def fetch_remote(url, pload=None, name=None):
    async with aiohttp.ClientSession(timeout=fetch_timeout) as session:
        async with session.post(url, json=pload) as response:
            chunks = []
            if response.status != 200:
                ret = {
                    "text": f"{response.reason}",
                    "error_code": ErrorCode.INTERNAL_ERROR,
                }
                return json.dumps(ret)
 
            async for chunk, _ in response.content.iter_chunks():
                chunks.append(chunk)
        output = b"".join(chunks).replace(b'\x00', b'')
    if name is not None:
        res = json.loads(output)
        if name != "":
            res = res[name]
        return res
    print(output)
    return output
 
 
class AppSettings(BaseSettings):
    controller_address: str = "http://localhost:21001"
    api_keys: Optional[List[str]] = None
 
 
app_settings = AppSettings()
app = fastapi.FastAPI()
headers = {"User-Agent": "FastChat API Server"}
 
 
def create_error_response(code: int, message: str) -> JSONResponse:
    return JSONResponse(
        ErrorResponse(message=message, code=code).model_dump(), status_code=400
    )
 
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
    return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
 
def check_requests(request) -> Optional[JSONResponse]:
    # Check all params
    if request.max_tokens is not None and request.max_tokens <= 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
        )
    if request.n is not None and request.n <= 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.n} is less than the minimum of 1 - 'n'",
        )
    if request.temperature is not None and request.temperature < 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.temperature} is less than the minimum of 0 - 'temperature'",
        )
    if request.temperature is not None and request.temperature > 2:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
        )
    if request.top_p is not None and request.top_p < 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.top_p} is less than the minimum of 0 - 'top_p'",
        )
    if request.top_p is not None and request.top_p > 1:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.top_p} is greater than the maximum of 1 - 'top_p'",
        )
    if request.top_k is not None and (request.top_k > -1 and request.top_k < 1):
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.",
        )
    if request.stop is not None and (
            not isinstance(request.stop, str) and not isinstance(request.stop, list)
    ):
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.stop} is not valid under any of the given schemas - 'stop'",
        )
    return None
def _add_to_set(s, new_stop):
    if not s:
        return
    if isinstance(s, str):
        new_stop.add(s)
    else:
        new_stop.update(s)
 
 
async def get_gen_params(
        model_name: str,
        messages: Union[str, List[Dict[str, str]]],
        *,
        temperature: float,
        top_p: float,
        top_k: Optional[int],
        presence_penalty: Optional[float],
        frequency_penalty: Optional[float],
        max_tokens: Optional[int],
        echo: Optional[bool],
        logprobs: Optional[int] = None,
        stop: Optional[Union[str, List[str]]],
) -> Dict[str, Any]:
    conv = conv_templates["qwen"].copy()
 
    if isinstance(messages, str):
        prompt = messages
    else:
        for message in messages:
            msg_role = message["role"]
            if msg_role == "system":
                conv.system = message['content']
            elif msg_role == "user":
                if type(message["content"]) == list:
                    text_list = [
                        item["text"]
                        for item in message["content"]
                        if item["type"] == "text"
                    ]
                    text = "\n".join(text_list)
                    conv.append_message(conv.roles[0], text)
                else:
                    conv.append_message(conv.roles[0], message["content"])
            elif msg_role == "assistant":
                conv.append_message(conv.roles[1], message["content"])
            else:
                raise ValueError(f"Unknown role: {msg_role}")
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()
 
    gen_params = {
        "model": model_name,
        "prompt": prompt,
        "temperature": temperature,
        "logprobs": logprobs,
        "top_p": top_p,
        "top_k": top_k,
        "presence_penalty": presence_penalty,
        "frequency_penalty": frequency_penalty,
        "max_new_tokens": max_tokens,
        "echo": echo,
        "stop_token_ids": stop
    }
 
 
    new_stop = set()
    _add_to_set(stop, new_stop)
    # _add_to_set(conv.stop_str, new_stop)
    gen_params["stop"] = list(new_stop)
 
    logger.debug(f"==== request ====\n{gen_params}")
    return gen_params
 
async def get_worker_address(model_name: str) -> str:
    controller_address = app_settings.controller_address
    worker_addr = await fetch_remote(
        controller_address + "/get_worker_address", {"model": model_name}, "address"
    )
    if worker_addr == "":
        raise ValueError(f"No available worker for {model_name}")
    logger.debug(f"model_name: {model_name}, worker_addr: {worker_addr}")
    return worker_addr
 
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
    error_check_ret = check_requests(request)
    if error_check_ret is not None:
        return error_check_ret
    worker_addr = await get_worker_address(request.model)
    # 参数解析
    gen_params = await get_gen_params(
        request.model,
        request.messages,
        temperature=request.temperature,
        top_p=request.top_p,
        top_k=request.top_k,
        presence_penalty=request.presence_penalty,
        frequency_penalty=request.frequency_penalty,
        max_tokens=request.max_tokens,
        echo=False,
        stop=request.stop,
    )
    choices = []
    chat_completions = []
    # 发送消息
    for i in range(request.n):
        content = asyncio.create_task(generate_completion(gen_params, worker_addr))
 
        chat_completions.append(content)
    try:
        all_tasks = await asyncio.gather(*chat_completions)
    except Exception as e:
        return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
    usage = UsageInfo()
    for i, content in enumerate(all_tasks):
        if isinstance(content, str):
            content = json.loads(content)
        if content["error_code"] != 0:
            return create_error_response(content["error_code"], content["text"])
        choices.append(
            ChatCompletionResponseChoice(
                index=i,
                message=ChatMessage(role="assistant", content=content["text"]),
                finish_reason=content.get("finish_reason", "stop"),
            )
        )
    return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
 
 
async def generate_completion(payload: Dict[str, Any], worker_addr: str):
    async with httpx.AsyncClient() as client:
        async with client.stream("POST", worker_addr + "/worker_generate_stream",
                                 headers=headers, json=payload, timeout=WORKER_API_TIMEOUT) as response:
            content = await response.aread()
            return content.replace(b'\0', b'').decode()
 
 
def create_api_server():
    parser = argparse.ArgumentParser(
        description="Simple RESTful API server."
    )
    parser.add_argument("--host", type=str, default="localhost", help="host name")
    parser.add_argument("--port", type=int, default=8000, help="port number")
    parser.add_argument(
        "--controller-address", type=str, default="http://localhost:21001"
    )
    parser.add_argument(
        "--allow-credentials", action="store_true", help="allow credentials"
    )
    parser.add_argument(
        "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
    )
    parser.add_argument(
        "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
    )
    parser.add_argument(
        "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
    )
    parser.add_argument(
        "--api-keys",
        type=lambda s: s.split(","),
        help="Optional list of comma separated API keys",
    )
    parser.add_argument(
        "--ssl",
        action="store_true",
        required=False,
        default=False,
        help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
    )
    args = parser.parse_args()
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )
    app_settings.controller_address = args.controller_address
    app_settings.api_keys = args.api_keys
 
    logger.info(f"args: {args}")
    return args
if __name__ == "__main__":
    args = create_api_server()
    if args.ssl:
        uvicorn.run(
            app,
            host=args.host,
            port=args.port,
            log_level="info",
            ssl_keyfile=os.environ["SSL_KEYFILE"],
            ssl_certfile=os.environ["SSL_CERTFILE"],
        )
    else:
        uvicorn.run(app, host=args.host, port=args.port, log_level="info")

  • 32
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值