场景
- 用户订单并支付
- 发送消息开通查看文章权限
// 伪代码
@Transactional(rollbackFor=Exception.class)
public void pay(long uid, String orderNO) {
Order order = orderService.selectOrder(uid, orderNO)
if (order != null) {
String status = "paid";
orderDao.updateStatus(uid, orderNo, status);
rocketMQTemplate.send("order:status", message(uid, orderNo, order.itemId, status));
}
}
public class OrderStatusArticleListener implements RocketMQListener {
public void onMessage(message) {
Order order = orderService.selectOrder(message.uid, message.orderNo)
if (order == null) {
throw new RuntimeException("order not found. " + message.orderNo)
}
if (order.status != "paid") {
throw new RuntimeException("order not paid. " + message.orderNo)
}
// 授权
articleService.authorize(message.uid, message.itemId)
}
}
上面的例子中会出现消费者查询订单的时候是未支付的状态。
为什么会这样呢?
这是因为我们在spring的事务中同步发送消息导致事务还没有提交。消息已经到了消费者端开始消费了。
解决:
- 增加消息表,与事务同步落库,标记为待处理
- MQ 发送成功
- MQ 的回调处理落库的数据,标记为处理完成
由于是在Spring的环境中,我们使用Spring的TransactionSynchronizationManager#registerSynchronization
if (TransactionSynchronizationManager.isSynchronizationActive()) {
TransactionSynchronizationManager.registerSynchronization(new MQTransactionSynchronization(
rocketMQTemplate, destination, message, timeout, delayLevel
));
}
我们自定义一个TransactionSynchronization名字叫MQTransactionSynchronization
@Slf4j
public class MQTransactionSynchronization implements TransactionSynchronization {
private DataSource dataSource;
private ConnectionHolder connectionHolder;
private String id;
private RocketMQTemplate rocketMQTemplate;
private String destination;
private Message message;
private long timeout;
private int delayLevel;
public MQTransactionSynchronization(RocketMQTemplate rocketMQTemplate, String destination, Message message, long timeout, int delayLevel) {
this.rocketMQTemplate = rocketMQTemplate;
this.destination = destination;
this.message = message;
this.timeout = timeout;
this.delayLevel = delayLevel;
}
@Override
public void beforeCompletion() {}
@Override
public void beforeCommit(boolean readOnly) {
Map<Object, Object> resourceMap = TransactionSynchronizationManager.getResourceMap();
for (Map.Entry<Object, Object> entry : resourceMap.entrySet()) {
Object key = entry.getKey();
Object value = entry.getValue();
if (value instanceof ConnectionHolder) {
this.dataSource = (DataSource) key;
this.connectionHolder = (ConnectionHolder) value;
break;
}
}
if (connectionHolder == null) {
log.warn("connectionHolder is null");
return;
}
this.id = UUID.randomUUID().toString();
final String mqTemplateName = ApplicationContextUtils.findBeanName(rocketMQTemplate.getClass(), rocketMQTemplate);
MqMsgDao.insertMsg(connectionHolder, id, mqTemplateName, destination, message, timeout, delayLevel);
}
@Override
public void afterCommit() {
log.debug("afterCommit {}", TransactionSynchronizationManager.getCurrentTransactionName());
try {
rocketMQTemplate.syncSend(destination, message, timeout, delayLevel);
MqMsgDao.deleteMsgById(dataSource, this.id);
} catch (Exception e) {
log.error("mq send message failed. topic:[{}], message:[{}]", destination, message, e);
}
}
@Override
public void afterCompletion(int status) {
log.debug("afterCompletion {} : {}", TransactionSynchronizationManager.getCurrentTransactionName(), status);
rocketMQTemplate = null;
destination = null;
message = null;
connectionHolder = null;
dataSource = null;
id = null;
}
}
@Slf4j
public class MqMsgDao {
public static final String STATUS_NEW = "NEW";
public static final Integer MAX_RETRY_TIMES = 5;
private static final JsonMapper MAPPER = JsonMapper.builder()
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
.enable(MapperFeature.PROPAGATE_TRANSIENT_MARKER)
.build();
public static List<MqMsg> listMsg(DataSource dataSource) {
Connection conn = null;
PreparedStatement ps = null;
ResultSet rs = null;
try {
conn = dataSource.getConnection();
ps = conn.prepareStatement("select * from tb_mq_msg where status = ? and retry_times < ? limit 100");
int i = 0;
ps.setObject(++i, STATUS_NEW);
ps.setObject(++i, MAX_RETRY_TIMES);
rs = ps.executeQuery();
List<MqMsg> list = new ArrayList<>(100);
while (rs.next()) {
MqMsg mqMsg = new MqMsg();
mqMsg.setId(rs.getString("id"));
mqMsg.setStatus(rs.getString("status"));
mqMsg.setMqTemplateName(rs.getString("mq_template_name"));
mqMsg.setMqDestination(rs.getString("mq_destination"));
mqMsg.setMqTimeout(rs.getLong("mq_timeout"));
mqMsg.setMqDelayLevel(rs.getInt("mq_delay_level"));
Map<String, Object> map = fromJson(rs.getString("payload"));
GenericMessage<Object> message = new GenericMessage<>(map.get("payload"), (Map<String, Object>) map.get("headers"));
mqMsg.setMessage(message);
mqMsg.setRetryTimes(rs.getInt("retry_times"));
mqMsg.setCreateTime(rs.getTimestamp("create_time"));
mqMsg.setUpdateTime(rs.getTimestamp("update_time"));
list.add(mqMsg);
}
return list;
} catch (SQLException e) {
throw new RuntimeException(e);
} finally {
close(rs, ps, conn);
}
}
public static void insertMsg(ConnectionHolder connectionHolder,
String id,
String mqTemplateName,
String mqDestination,
Message message,
long mqTimeout,
int mqDelayLevel) {
Connection connection = connectionHolder.getConnection();
PreparedStatement ps = null;
Map<String, Object> payload = new HashMap<>();
payload.put("payload", message.getPayload());
payload.put("headers", message.getHeaders());
try {
ps = connection.prepareStatement("insert into tb_mq_msg values(?,?,?,?,?,?,?,?,?,?)");
Date now = new Date();
int i = 0;
ps.setObject(++i, id);
ps.setObject(++i, STATUS_NEW);
ps.setObject(++i, mqTemplateName);
ps.setObject(++i, mqDestination);
ps.setObject(++i, mqTimeout);
ps.setObject(++i, mqDelayLevel);
ps.setObject(++i, toJson(payload));
ps.setObject(++i, 0);
ps.setObject(++i, now);
ps.setObject(++i, now);
int affect = ps.executeUpdate();
if (affect <= 0) {
throw new RuntimeException("insert mq msg affect : " + affect);
}
} catch (SQLException e) {
throw new RuntimeException(e);
} finally {
close(ps);
}
}
public static void updateMsgRetryTimes(DataSource dataSource, String id) {
Connection conn = null;
PreparedStatement ps = null;
try {
conn = dataSource.getConnection();
ps = conn.prepareStatement("update tb_mq_msg set retry_times = retry_times + 1, update_time = ? where id = ?");
int i = 0;
ps.setObject(++i, new Date());
ps.setObject(++i, id);
int affect = ps.executeUpdate();
if (affect <= 0) {
log.error("update mq msg retry_times failed. id:[{}]", id);
throw new RuntimeException("update mq msg retry_times failed. id:" + id);
}
} catch (SQLException e) {
throw new RuntimeException(e);
} finally {
close(ps, conn);
}
}
public static void deleteMsgById(DataSource dataSource, String id) {
Connection conn = null;
PreparedStatement ps = null;
try {
conn = dataSource.getConnection();
ps = conn.prepareStatement("delete from tb_mq_msg where id = ?");
int i = 0;
ps.setObject(++i, id);
int affect = ps.executeUpdate();
if (affect <= 0) {
log.error("delete mq msg failed. id:[{}]", id);
throw new RuntimeException("delete mq msg failed. id:" + id);
}
} catch (SQLException e) {
throw new RuntimeException(e);
} finally {
close(ps, conn);
}
}
private static void close(AutoCloseable... closeables) {
if (closeables != null && closeables.length > 0) {
for (AutoCloseable closeable : closeables) {
if (closeable != null) {
try {
closeable.close();
} catch (Exception ignore) {
}
}
}
}
}
private static String toJson(Object payload) {
try {
return MAPPER.writeValueAsString(payload);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
private static Map<String, Object> fromJson(String payload) {
try {
return MAPPER.readValue(payload, new TypeReference<Map<String, Object>>() {
});
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
}
@Slf4j
@Component
public class ApplicationContextUtils implements ApplicationContextAware {
private static ApplicationContext applicationContext;
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
ApplicationContextUtils.applicationContext = applicationContext;
log.info("=== ApplicationContextUtils init ===");
}
public static ApplicationContext getApplicationContext() {
return applicationContext;
}
public static Object getBean(String name) {
return getApplicationContext().getBean(name);
}
public static <T> T getBean(Class<T> clazz) {
return getApplicationContext().getBean(clazz);
}
public static <T> T getBean(String name, Class<T> clazz) {
return getApplicationContext().getBean(name, clazz);
}
public static String findBeanName(Class clazz, Object obj) {
Map<String, Object> beans = getApplicationContext().getBeansOfType(clazz);
for (Map.Entry<String, Object> entry : beans.entrySet()) {
Object value = entry.getValue();
if (value == obj) {
return entry.getKey();
}
}
return null;
}
}
解决消息发送失败,使用定时任务重试
@Slf4j
@Component
pubilc class MqMsgSchedule implements InitializingBean {
private static final ScheduledThreadPoolExecutor EXECUTOR =
new ScheduledThreadPoolExecutor(1, new ThreadFactory() {
AtomicInteger threadCount = new AtomicInteger(0);
@Override
public Thread newThread(Runnable r) {
return new Thread(r, "mq-msg-" + threadCount.getAndIncrement() + "-" + r.hashCode());
}
}, new ThreadPoolExecutor.DiscardPolicy());
@Override
public void afterPropertiesSet() throws Exception {
EXECUTOR.scheduleAtFixedRate(new Runnable() {
@Override
public void run() {
retrySendTask();
}
}, 0, 5000, TimeUnit.MILLISECONDS);
}
public void retrySendTask() {
try {
Map<String, DataSource> beans = ApplicationContextUtils.getApplicationContext().getBeansOfType(DataSource.class);
for (Map.Entry<String, DataSource> entry : beans.entrySet()) {
List<MqMsg> mqMsgList = MqMsgDao.listMsg(entry.getValue());
for (MqMsg mqMsg : mqMsgList) {
if (mqMsg.getRetryTimes() >= MqMsgDao.MAX_RETRY_TIMES) {
log.error("mqMsg retry times reach {}, id:[{}]", MqMsgDao.MAX_RETRY_TIMES, mqMsg.getId());
} else {
RocketMQTemplate rocketMQTemplate = (RocketMQTemplate) ApplicationContextUtils.getBean(mqMsg.getMqTemplateName());
try {
rocketMQTemplate.syncSend(mqMsg.getMqDestination(),
mqMsg.getMessage(),
mqMsg.getMqTimeout(),
mqMsg.getMqDelayLevel());
MqMsgDao.deleteMsgById(entry.getValue(), mqMsg.getId());
} catch (Exception e) {
MqMsgDao.updateMsgRetryTimes(entry.getValue(), mqMsg.getId());
log.error("[task] mq send failed. mqMsg:[{}]", mqMsg, e);
}
}
}
}
} catch (Exception e) {
log.error("task error.", e);
}
}
}
提供调用类
@Slf4j
public final class MQTransactionHelper {
public static void syncSend(final RocketMQTemplate rocketMQTemplate,
final String destination,
final Message message) {
syncSend(rocketMQTemplate, destination, message,
rocketMQTemplate.getProducer().getSendMsgTimeout(), 0);
}
public static void syncSend(final RocketMQTemplate rocketMQTemplate,
final String destination,
final Message message,
final long timeout,
final int delayLevel) {
if (TransactionSynchronizationManager.isSynchronizationActive()) {
TransactionSynchronizationManager.registerSynchronization(new MQTransactionSynchronization(
rocketMQTemplate, destination, message, timeout, delayLevel
));
}
}
}
数据库
CREATE TABLE `tb_mq_msg` (
`id` VARCHAR(64) NOT NULL,
`status` VARCHAR(20) NOT NULL COMMENT '事件状态(待发布NEW)',
`mq_template_name` VARCHAR(1000) NOT NULL,
`mq_destination` VARCHAR(1000) NOT NULL,
`mq_timeout` BIGINT NOT NULL,
`mq_delay_level` INT NOT NULL,
`payload` TEXT NOT NULL,
`retry_times` INT NOT NULL,
`create_time` DATETIME NOT NULL,
`update_time` DATETIME NOT NULL,
PRIMARY KEY (`id`),
KEY `idx_status` (`status`)
) ENGINE=INNODB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci;