Spring解决RocketMQ发消息与MySQL事务一致性

场景

  1. 用户订单并支付
  2. 发送消息开通查看文章权限
// 伪代码
@Transactional(rollbackFor=Exception.class)
public void pay(long uid, String orderNO) {
  Order order = orderService.selectOrder(uid, orderNO)
  if (order != null) {
    String status = "paid";
    orderDao.updateStatus(uid, orderNo, status);
    
    rocketMQTemplate.send("order:status", message(uid, orderNo, order.itemId, status));
  }
}

public class OrderStatusArticleListener implements RocketMQListener {
    public void onMessage(message) {
        Order order = orderService.selectOrder(message.uid, message.orderNo)
        if (order == null) {
           throw new RuntimeException("order not found. " + message.orderNo)
        }
        if (order.status != "paid") {
            throw new RuntimeException("order not paid. " + message.orderNo)
        }
        // 授权
        articleService.authorize(message.uid, message.itemId)
    }
}

上面的例子中会出现消费者查询订单的时候是未支付的状态。

为什么会这样呢?

这是因为我们在spring的事务中同步发送消息导致事务还没有提交。消息已经到了消费者端开始消费了。

解决:

  1. 增加消息表,与事务同步落库,标记为待处理
  2. MQ 发送成功
  3. MQ 的回调处理落库的数据,标记为处理完成

由于是在Spring的环境中,我们使用Spring的TransactionSynchronizationManager#registerSynchronization

if (TransactionSynchronizationManager.isSynchronizationActive()) {
    TransactionSynchronizationManager.registerSynchronization(new MQTransactionSynchronization(
            rocketMQTemplate, destination, message, timeout, delayLevel
    ));
}

我们自定义一个TransactionSynchronization名字叫MQTransactionSynchronization

@Slf4j
public class MQTransactionSynchronization implements TransactionSynchronization {
    private DataSource dataSource;
    private ConnectionHolder connectionHolder;
    private String id;
    private RocketMQTemplate rocketMQTemplate;
    private String destination;
    private Message message;
    private long timeout;
    private int delayLevel;
    
    public MQTransactionSynchronization(RocketMQTemplate  rocketMQTemplate, String destination, Message  message, long  timeout, int delayLevel) {
        this.rocketMQTemplate = rocketMQTemplate;
        this.destination = destination;
        this.message = message;
        this.timeout = timeout;
        this.delayLevel = delayLevel;
    }
    
    @Override
    public void beforeCompletion() {}
    
    @Override
    public void beforeCommit(boolean readOnly) {
        Map<Object, Object> resourceMap = TransactionSynchronizationManager.getResourceMap();
        for (Map.Entry<Object, Object> entry : resourceMap.entrySet()) {
            Object key = entry.getKey();
            Object value = entry.getValue();
            if (value instanceof ConnectionHolder) {
                this.dataSource = (DataSource) key;
                this.connectionHolder = (ConnectionHolder) value;
                break;
            }
        }
         if (connectionHolder == null) {
            log.warn("connectionHolder is null");
            return;
        }
		this.id = UUID.randomUUID().toString();
		final String mqTemplateName = ApplicationContextUtils.findBeanName(rocketMQTemplate.getClass(), rocketMQTemplate);
		MqMsgDao.insertMsg(connectionHolder, id, mqTemplateName, destination, message, timeout, delayLevel);
    }
    @Override
    public void afterCommit() {
        log.debug("afterCommit {}", TransactionSynchronizationManager.getCurrentTransactionName());
        try {
            rocketMQTemplate.syncSend(destination, message, timeout, delayLevel);
            MqMsgDao.deleteMsgById(dataSource, this.id);
        } catch (Exception e) {
            log.error("mq send message failed. topic:[{}], message:[{}]", destination, message, e);
        }
    }
    @Override
    public void afterCompletion(int status) {
        log.debug("afterCompletion {} : {}", TransactionSynchronizationManager.getCurrentTransactionName(), status);
        rocketMQTemplate = null;
        destination = null;
        message = null;
        connectionHolder = null;
        dataSource = null;
        id = null;
    }
}

@Slf4j
public class MqMsgDao {
    public static final String STATUS_NEW = "NEW";
    public static final Integer MAX_RETRY_TIMES = 5;
    private static final JsonMapper MAPPER = JsonMapper.builder()
            .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
            .enable(MapperFeature.PROPAGATE_TRANSIENT_MARKER)
            .build();
    
    public static List<MqMsg> listMsg(DataSource dataSource) {
        Connection conn = null;
        PreparedStatement ps = null;
        ResultSet rs = null;
        try {
            conn = dataSource.getConnection();
            ps = conn.prepareStatement("select * from tb_mq_msg where status = ? and retry_times < ? limit 100");
            int i = 0;
            ps.setObject(++i, STATUS_NEW);
            ps.setObject(++i, MAX_RETRY_TIMES);
            rs = ps.executeQuery();
            List<MqMsg> list = new ArrayList<>(100);
            while (rs.next()) {
                MqMsg mqMsg = new MqMsg();
                mqMsg.setId(rs.getString("id"));
                mqMsg.setStatus(rs.getString("status"));
                mqMsg.setMqTemplateName(rs.getString("mq_template_name"));
                mqMsg.setMqDestination(rs.getString("mq_destination"));
                mqMsg.setMqTimeout(rs.getLong("mq_timeout"));
                mqMsg.setMqDelayLevel(rs.getInt("mq_delay_level"));
                Map<String, Object> map = fromJson(rs.getString("payload"));
                GenericMessage<Object> message = new GenericMessage<>(map.get("payload"), (Map<String, Object>) map.get("headers"));
                mqMsg.setMessage(message);
                mqMsg.setRetryTimes(rs.getInt("retry_times"));
                mqMsg.setCreateTime(rs.getTimestamp("create_time"));
                mqMsg.setUpdateTime(rs.getTimestamp("update_time"));
                list.add(mqMsg);
            }
            return list;
        } catch (SQLException e) {
            throw new RuntimeException(e);
        } finally {
            close(rs, ps, conn);
        }
    }
    
    public static void insertMsg(ConnectionHolder connectionHolder,
                                 String id,
                                 String mqTemplateName,
                                 String mqDestination,
                                 Message message,
                                 long mqTimeout,
                                 int mqDelayLevel) {
         Connection connection = connectionHolder.getConnection();
         PreparedStatement ps = null;
         Map<String, Object> payload = new HashMap<>();
         payload.put("payload", message.getPayload());
         payload.put("headers", message.getHeaders());

        try {
            ps = connection.prepareStatement("insert into tb_mq_msg values(?,?,?,?,?,?,?,?,?,?)");
            Date now = new Date();
            int i = 0;
            ps.setObject(++i, id);
            ps.setObject(++i, STATUS_NEW);
            ps.setObject(++i, mqTemplateName);
            ps.setObject(++i, mqDestination);
            ps.setObject(++i, mqTimeout);
            ps.setObject(++i, mqDelayLevel);
            ps.setObject(++i, toJson(payload));
            ps.setObject(++i, 0);
            ps.setObject(++i, now);
            ps.setObject(++i, now);
            int affect = ps.executeUpdate();
            if (affect <= 0) {
                throw new RuntimeException("insert mq msg affect : " + affect);
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        } finally {
            close(ps);
        }
    }
    
    public static void updateMsgRetryTimes(DataSource dataSource, String id) {
        Connection conn = null;
        PreparedStatement ps = null;
        try {
            conn = dataSource.getConnection();
            ps = conn.prepareStatement("update tb_mq_msg set retry_times = retry_times + 1, update_time = ? where id = ?");
            int i = 0;
            ps.setObject(++i, new Date());
            ps.setObject(++i, id);
            int affect = ps.executeUpdate();
            if (affect <= 0) {
                log.error("update mq msg retry_times failed. id:[{}]", id);
                throw new RuntimeException("update mq msg retry_times failed. id:" + id);
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        } finally {
            close(ps, conn);
        }
    }
    
    public static void deleteMsgById(DataSource dataSource, String id) {
        Connection conn = null;
        PreparedStatement ps = null;
        try {
            conn = dataSource.getConnection();
            ps = conn.prepareStatement("delete from tb_mq_msg where id = ?");
            int i = 0;
            ps.setObject(++i, id);
            int affect = ps.executeUpdate();
            if (affect <= 0) {
                log.error("delete mq msg failed. id:[{}]", id);
                throw new RuntimeException("delete mq msg failed. id:" + id);
            }
        } catch (SQLException e) {
            throw new RuntimeException(e);
        } finally {
            close(ps, conn);
        }
    }
    
    private static void close(AutoCloseable... closeables) {
        if (closeables != null && closeables.length > 0) {
            for (AutoCloseable closeable : closeables) {
                if (closeable != null) {
                    try {
                        closeable.close();
                    } catch (Exception ignore) {
                    }
                }
            }
        }
    }
    
    private static String toJson(Object payload) {
        try {
            return MAPPER.writeValueAsString(payload);
        } catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }
    
    private static Map<String, Object> fromJson(String payload) {
        try {
            return MAPPER.readValue(payload, new TypeReference<Map<String, Object>>() {
            });
        } catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }
}
@Slf4j
@Component
public class ApplicationContextUtils implements ApplicationContextAware {

    private static ApplicationContext applicationContext;

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        ApplicationContextUtils.applicationContext = applicationContext;
        log.info("=== ApplicationContextUtils init ===");
    }

    public static ApplicationContext getApplicationContext() {
        return applicationContext;
    }
   

    public static Object getBean(String name) {
        return getApplicationContext().getBean(name);
    }

    public static <T> T getBean(Class<T> clazz) {
        return getApplicationContext().getBean(clazz);
    }

    public static <T> T getBean(String name, Class<T> clazz) {
        return getApplicationContext().getBean(name, clazz);
    }

    public static String findBeanName(Class clazz, Object obj) {
        Map<String, Object> beans = getApplicationContext().getBeansOfType(clazz);
        for (Map.Entry<String, Object> entry : beans.entrySet()) {
            Object value = entry.getValue();
            if (value == obj) {
                return entry.getKey();
            }
        }
        return null;
    }

}

解决消息发送失败,使用定时任务重试

@Slf4j
@Component
pubilc class MqMsgSchedule implements InitializingBean {
    private static final ScheduledThreadPoolExecutor EXECUTOR =
            new ScheduledThreadPoolExecutor(1, new ThreadFactory() {
                AtomicInteger threadCount = new AtomicInteger(0);
                @Override
                public Thread newThread(Runnable r) {
                    return new Thread(r, "mq-msg-" + threadCount.getAndIncrement() + "-" + r.hashCode());
                }
            }, new ThreadPoolExecutor.DiscardPolicy());

    @Override
    public void afterPropertiesSet() throws Exception {
        EXECUTOR.scheduleAtFixedRate(new Runnable() {
            @Override
            public void run() {
                retrySendTask();
            }
        }, 0, 5000, TimeUnit.MILLISECONDS);
    }

    public void retrySendTask() {     
        try {
            Map<String, DataSource> beans = ApplicationContextUtils.getApplicationContext().getBeansOfType(DataSource.class);
            for (Map.Entry<String, DataSource> entry : beans.entrySet()) {
                List<MqMsg> mqMsgList = MqMsgDao.listMsg(entry.getValue());
                for (MqMsg mqMsg : mqMsgList) {
                    if (mqMsg.getRetryTimes() >= MqMsgDao.MAX_RETRY_TIMES) { 
                        log.error("mqMsg retry times reach {}, id:[{}]", MqMsgDao.MAX_RETRY_TIMES, mqMsg.getId());
                    } else {
                        RocketMQTemplate rocketMQTemplate = (RocketMQTemplate) ApplicationContextUtils.getBean(mqMsg.getMqTemplateName());
                        try {
                            rocketMQTemplate.syncSend(mqMsg.getMqDestination(),
                                    mqMsg.getMessage(),
                                    mqMsg.getMqTimeout(),
                                    mqMsg.getMqDelayLevel());
                            MqMsgDao.deleteMsgById(entry.getValue(), mqMsg.getId());
                        } catch (Exception e) {
                            MqMsgDao.updateMsgRetryTimes(entry.getValue(), mqMsg.getId());
                            log.error("[task] mq send failed. mqMsg:[{}]", mqMsg, e);
                        }
                    }
                }
            }
        } catch (Exception e) {
            log.error("task error.", e);
        }
    }
}

提供调用类

@Slf4j
public final class MQTransactionHelper {

    public static void syncSend(final RocketMQTemplate rocketMQTemplate,
                                final String destination,
                                final Message message) {
        syncSend(rocketMQTemplate, destination, message,
                rocketMQTemplate.getProducer().getSendMsgTimeout(), 0);
    }

    public static void syncSend(final RocketMQTemplate rocketMQTemplate,
                                final String destination,
                                final Message message,
                                final long timeout,
                                final int delayLevel) {
        if (TransactionSynchronizationManager.isSynchronizationActive()) {
            TransactionSynchronizationManager.registerSynchronization(new MQTransactionSynchronization(
                    rocketMQTemplate, destination, message, timeout, delayLevel
            ));
        }
    }

}

数据库

CREATE TABLE `tb_mq_msg` (
  `id` VARCHAR(64) NOT NULL,
  `status` VARCHAR(20) NOT NULL COMMENT '事件状态(待发布NEW)',
  `mq_template_name` VARCHAR(1000) NOT NULL,
  `mq_destination` VARCHAR(1000) NOT NULL,
  `mq_timeout` BIGINT NOT NULL,
  `mq_delay_level` INT NOT NULL,
  `payload` TEXT NOT NULL,
  `retry_times` INT NOT NULL,
  `create_time` DATETIME NOT NULL,
  `update_time` DATETIME NOT NULL,
  PRIMARY KEY (`id`),
  KEY `idx_status` (`status`)
) ENGINE=INNODB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;

源码:https://github.com/jsbxyyx/rmq-transaction

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值