MQTT创建安全的连接

 mqtt是一个轻量的消息订阅/发布协议。公司项目中 MQTT 服务器使用了免费的 emq x broker, EMQ X 企业版: 云原生分布式物联网接入平台。不需要任何开发,安装即可。 android端使用mqtt的库 org.eclipse.paho.client.mqttv3-1.2.0.jar。

使用库创建mqtt连接的代码如下:

MyMqttClient.java


import java.io.InputStream;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;

import javax.net.ssl.SSLSocketFactory;

import org.eclipse.paho.client.mqttv3.MqttCallback;
import org.eclipse.paho.client.mqttv3.MqttClient;
import org.eclipse.paho.client.mqttv3.MqttConnectOptions;
import org.eclipse.paho.client.mqttv3.MqttException;
import org.eclipse.paho.client.mqttv3.MqttMessage;
import org.eclipse.paho.client.mqttv3.MqttSecurityException;
import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence;



class MyMqttClient{
    public static class LoginInfo{
        private String hostAddr;
        private String clientID;
        private String userName;
        private String password;
        private int connctTimeout = 30;
        private int keepAliveInterval = 60;
        private CertificatesInfo certificatesInfo;
        public LoginInfo(String hostAddr, String clientID, String userName,
                String password) {
            super();
            this.hostAddr = hostAddr;
            this.clientID = clientID;
            this.userName = userName;
            this.password = password;
        }
        public void setConnctTimeout(int connctTimeout) {
            this.connctTimeout = connctTimeout;
        }
        public void setKeepAliveInterval(int keepAliveInterval) {
            this.keepAliveInterval = keepAliveInterval;
        }
        public void setCertificatesInfo(CertificatesInfo certificatesInfo) {
            this.certificatesInfo = certificatesInfo;
        }
    }
    public static class CertificatesInfo{
        private X509Certificate caCert;
        private X509Certificate clientCert;
        private PrivateKey privateKey;
        private String privateKeyPassword;
        private CertificatesInfo(X509Certificate caCert, X509Certificate clientCert,
                PrivateKey privateKey, String privateKeyPassword) {
            super();
            this.caCert = caCert;
            this.clientCert = clientCert;
            this.privateKey = privateKey;
            this.privateKeyPassword = privateKeyPassword;
        }
        
        public static CertificatesInfo creatCertificates(InputStream isCaCert, InputStream isClientCert, InputStream isClientPrivateKey, String privateKeyPwd){
            X509Certificate caCert = CertificateUtils.getX509Certificate(isCaCert);
            X509Certificate clientCert = CertificateUtils.getX509Certificate(isClientCert);
            PrivateKey privateKey = CertificateUtils.getPkcs8PrivateKey(isClientPrivateKey);
            if(caCert != null && clientCert != null && privateKey != null){
                return new CertificatesInfo(caCert, clientCert, privateKey, privateKeyPwd);
            }else{
                return null;
            }
        }
    }

    private MqttClient mClient = null;
    public void connect(LoginInfo loginInfo, MqttCallback cb) throws CreateSSLSocketException, MqttSecurityException, MqttException{
        MqttConnectOptions options = new MqttConnectOptions();
        //是否保存离线消息
        //false保存离线消息,下次上线可以接收离线时接收到的消息;
        //true不保存离线消息,下次上线不接收离线时接收到的消息;
        options.setCleanSession(false);
        //设置用户名密码
        options.setUserName(loginInfo.userName);
        options.setPassword(loginInfo.password.toCharArray());
        // 设置连接超时时间30s
        options.setConnectionTimeout(loginInfo.connctTimeout);
        // 设置会话心跳时间60s
        options.setKeepAliveInterval(loginInfo.keepAliveInterval);
        
        CertificatesInfo certificatesInfo = loginInfo.certificatesInfo;
        //读取ssl加密连接需要的ca证书、客户端证书、客户端秘钥
        SSLSocketFactory sslSF = CertificateUtils.createSSLSocketFactory(certificatesInfo.caCert, 
                certificatesInfo.clientCert, 
                certificatesInfo.privateKey, 
                certificatesInfo.privateKeyPassword);
        options.setSocketFactory(sslSF);
 
        MemoryPersistence memoryPersistence = new MemoryPersistence();
        mClient = new MqttClient(loginInfo.hostAddr, loginInfo.clientID, memoryPersistence);
        mClient.setCallback(cb);
        mClient.connect(options);
    }

    public void disconnect() {
        if(mClient != null){
            try {
                mClient.disconnect();
            } catch (MqttException e) {
                e.printStackTrace();
            }
            mClient = null;
        }
    }
    public boolean isConnected(){
        return (null != mClient)&&mClient.isConnected();
    }
    
    public void subscribe(String topic, int qos) throws MqttClientCallException {
        if(null == mClient){
            throw new MqttClientCallException("mqttClientIsNull");
        }
        try {
            mClient.subscribe(topic, qos);
        } catch (MqttException e) {
            throw new MqttClientCallException(e);
        }
    }
    public void subscribe(String[] topics, int qos) throws MqttClientCallException {
        if(null == mClient){
            throw new MqttClientCallException("mqttClientIsNull");
        }
        try {
            int[] qoss = new int[topics.length];
            for (int i = 0; i < qoss.length; i++) {
                qoss[i] = qos;
            }
            mClient.subscribe(topics, qoss);
        } catch (MqttException e) {
            throw new MqttClientCallException(e);
        }
    }
    public void unsubscribe(String topic) throws MqttClientCallException {
        if(null == mClient){
            throw new MqttClientCallException("mqttClientIsNull");
        }
        try {
            mClient.unsubscribe(topic);
        } catch (MqttException e) {
            throw new MqttClientCallException(e);
        }
    }
    public void unsubscribe(String[] topics) throws MqttClientCallException {
        if(null == mClient){
            throw new MqttClientCallException("mqttClientIsNull");
        }
        try {
            mClient.unsubscribe(topics);
        } catch (MqttException e) {
            throw new MqttClientCallException(e);
        }
    }
    
    public void publish(String topic, String message) throws MqttClientCallException {
        if(null == mClient){
            throw new MqttClientCallException("mqttClientIsNull");
        }

        MqttMessage mqttMessage = new MqttMessage();
        //setRetained设置保留消息
        //false不保留消息,发布一个主题后,只有当前有订阅者存在的情况下才接收的到消息
        //true 保留消息,发布一个主题后,在发送给当前订阅者后,还会存到服务器,如果有
        //新的订阅者上线也会把该消息发给新的订阅者
        mqttMessage.setRetained(false);
        mqttMessage.setQos(2);
        mqttMessage.setPayload(message.getBytes());

        try {
            mClient.publish(topic, mqttMessage);
        } catch (Exception e) {
            throw new MqttClientCallException(e);
        }
    }
    
    @SuppressWarnings("serial")
    public class MqttClientCallException extends Exception{
        public MqttClientCallException(String msg){
            super(msg);
        }
        public MqttClientCallException(Throwable throwable){
            super(throwable);
        }
    }
}

CertificateUtils.java


import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.security.KeyFactory;
import java.security.KeyStore;
import java.security.NoSuchAlgorithmException;
import java.security.PrivateKey;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.security.interfaces.RSAPrivateKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509TrustManager;

import android.os.Build;
import android.text.TextUtils;
import android.util.Base64;

class CertificateUtils {
    public static X509Certificate getX509Certificate(InputStream inStream){
        try {
            CertificateFactory cf = CertificateFactory.getInstance("X.509");
            X509Certificate cert = (X509Certificate) cf.generateCertificate(inStream);
            return cert;
        } catch (CertificateException e) {
            e.printStackTrace();
        }
        return null;
    }

    public static PrivateKey getPkcs8PrivateKey(InputStream inStream){
        String privateKeyBase64 = "";
        try {
            InputStreamReader inr = new InputStreamReader(inStream);
            BufferedReader br = new BufferedReader(inr);
            do {
                String line = br.readLine();
                if(TextUtils.isEmpty(line)){
                    break;
                }else{
                    if(!line.startsWith("--")){
                        privateKeyBase64 += line;
                    }
                }
            } while (true);
            br.close();
            inr.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
        if(!TextUtils.isEmpty(privateKeyBase64)){
            try {
                byte[] privateKey = Base64.decode(privateKeyBase64, Base64.NO_WRAP);
                PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(privateKey);
                KeyFactory keyFactory = KeyFactory.getInstance("RSA");
                RSAPrivateKey  key = (RSAPrivateKey) keyFactory.generatePrivate(keySpec);
                return key;
            } catch (NoSuchAlgorithmException e) {
                e.printStackTrace();
            } catch (InvalidKeySpecException e) {
                e.printStackTrace();
            }
        }
        return null;
    }
    
    @SuppressWarnings("serial")
    public static class CreateSSLSocketException extends Exception{
        public CreateSSLSocketException(String detailMessage, Throwable throwable) {
            super(detailMessage, throwable);
        }
    }
    
    public static SSLSocketFactory createSSLSocketFactory(final X509Certificate caCert, final X509Certificate crtFile, PrivateKey privateKey, final String pwd ) throws CreateSSLSocketException  {
        SSLSocketFactory sllSocket = null;
        try{
            // client key and certificates are sent to server so it can authenticate us
            KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
            ks.load(null, null);
            ks.setCertificateEntry("certificate", crtFile);
            ks.setKeyEntry("private-key", privateKey, pwd.toCharArray(), new java.security.cert.Certificate[] { crtFile });
            KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
            kmf.init(ks, pwd.toCharArray());
            
            // CA certificate is used to authenticate server
            KeyStore caKs = KeyStore.getInstance(KeyStore.getDefaultType());
            caKs.load(null, null);
            caKs.setCertificateEntry("ca-certificate", caCert);
            TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
            tmf.init(caKs);
            
            TrustManager[] custCaTrust = new TrustManager[]{new X509TrustManager() {
                
                @Override
                public X509Certificate[] getAcceptedIssuers() {
                    return new X509Certificate[]{crtFile};
                }
                
                @Override
                public void checkServerTrusted(X509Certificate[] chain, String authType)
                        throws CertificateException {
                }
                
                @Override
                public void checkClientTrusted(X509Certificate[] chain, String authType)
                        throws CertificateException {
                }
            }};
            
            // finally, create SSL socket factory
            SSLContext context = SSLContext.getInstance("TLSv1.2");
            if(Build.VERSION.SDK_INT >= 24){
                context.init(kmf.getKeyManagers(), custCaTrust, null);
            }else{
                context.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
            }
            sllSocket = context.getSocketFactory();
        }catch(Exception e){
            throw new CreateSSLSocketException("CreateSSLSocketException", e);
        }
        return sllSocket;
    }
}

注意:

1. connect 是阻塞的,失败后不应该再调用connect,应该调用reconnect。否则线程会一直增长, 最后导致OOM

2. android7.0 以上对证书的验证做了限制,自己签名的证书在APP上验证不过,上述代码采用了自己定义了TrustManager类,相当于无条件信任了。另外可以参考如下官网的关于网络配置的方式,我没有验证过。网络安全配置  |  Android 开发者  |  Android Developers

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值