import com.rabbitmq.client.AMQP;
import com.rabbitmq.client.Channel;
import com.rabbitmq.client.DefaultConsumer;
import com.rabbitmq.client.Envelope;
import com.rabbitmq.client.GetResponse;
import lombok.extern.slf4j.Slf4j;
import org.springframework.amqp.rabbit.connection.Connection;
import org.springframework.amqp.rabbit.connection.ConnectionFactory;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.InitializingBean;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@Slf4j
public abstract class RabbitBatchReceiver implements InitializingBean, DisposableBean {
private ConnectionFactory connectionFactory;
private String queueName;
private int batch;
private Connection connection;
private Channel channel;
private AtomicBoolean status = new AtomicBoolean();
private Lock lock = new ReentrantLock();
private Condition condition = lock.newCondition();
private Thread thread;
public RabbitBatchReceiver(ConnectionFactory connectionFactory, String queueName, int batch) {
this.connectionFactory = connectionFactory;
this.queueName = queueName;
this.batch = batch;
}
@Override
public void afterPropertiesSet() throws Exception {
connection = connectionFactory.createConnection();
channel = connection.createChannel(true);
channel.basicQos(1);
status.set(true);
thread = new Thread(() -> {
while (status.get()) {
try {
take();
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
});
thread.start();
}
@Override
public void destroy() throws Exception {
status.set(false);
try {
channel.close();
} catch (Exception e) {
log.error(e.getMessage(), e);
}
try {
connection.close();
} catch (Exception e) {
log.error(e.getMessage(), e);
}
thread.interrupt();
}
private void take() throws Exception {
try {
//声明事务
channel.txSelect();
List<Message> list = new ArrayList<>();
GetResponse response = channel.basicGet(queueName, false);
if (null == response) {
//监听新数据
listenNewData();
} else {
list.add(new Message(response.getEnvelope(), response.getProps(), response.getBody()));
}
while (list.size() < batch && null != (response = channel.basicGet(queueName, false))) {
list.add(new Message(response.getEnvelope(), response.getProps(), response.getBody()));
}
if (!list.isEmpty()) {
try {
//process
callBatch(list);
//ack
for (Message message : list) {
channel.basicAck(message.getEnvelope().getDeliveryTag(), false);
}
} catch (Exception e) {
//unack
for (Message message : list) {
channel.basicNack(message.getEnvelope().getDeliveryTag(), false, true);
}
throw e;
}
}
//提交事务
channel.txCommit();
} catch (Exception e) {
log.error(e.getMessage(), e);
//回滚事务
channel.txRollback();
}
}
private void listenNewData() throws Exception {
String consumerTag = null;
try {
consumerTag = channel.basicConsume(queueName, false, new DefaultConsumer(channel) {
@Override
public void handleDelivery(String consumerTag, Envelope envelope, AMQP.BasicProperties properties, byte[] body) throws IOException {
//unack 重新走basicGet读取消息
channel.basicNack(envelope.getDeliveryTag(), false, true);
signalAll();
}
});
await();
} finally {
if (null != consumerTag) {
channel.basicCancel(consumerTag);
}
}
}
protected abstract void callBatch(List<Message> list) throws Exception;
private void await() throws Exception {
lock.lock();
try {
condition.await();
} finally {
lock.unlock();
}
}
private void signalAll() {
lock.lock();
try {
condition.signalAll();
} finally {
lock.unlock();
}
}
public static class Message {
private Envelope envelope;
private AMQP.BasicProperties properties;
private byte[] data;
public Message(Envelope envelope, AMQP.BasicProperties properties, byte[] data) {
this.envelope = envelope;
this.properties = properties;
this.data = data;
}
public Envelope getEnvelope() {
return envelope;
}
public AMQP.BasicProperties getProperties() {
return properties;
}
public byte[] getData() {
return data;
}
}
}
使用例子:
public void test(ConnectionFactory factory) {
new RabbitBatchReceiver(factory, "test", 1000) {
@Override
protected void callBatch(List<Message> list) throws Exception {
}
};
}