mTLS: Netty单向/双向TLS Demo完整代码

NettyHelper.java: 主要用是创建EventLoopGroup和判断是否支持Epoll。

package org.example.netty;

import io.netty.channel.EventLoopGroup;
import io.netty.channel.epoll.Epoll;
import io.netty.channel.epoll.EpollEventLoopGroup;
import io.netty.channel.epoll.EpollSocketChannel;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.http2.Http2SecurityUtil;
import io.netty.handler.ssl.*;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.concurrent.DefaultThreadFactory;

import javax.net.ssl.SSLException;
import java.security.cert.CertificateException;
import java.util.concurrent.ThreadFactory;


public class NettyHelper {
    static final String NETTY_EPOLL_ENABLE_KEY = "netty.epoll.enable";

    static final String OS_NAME_KEY = "os.name";

    static final String OS_LINUX_PREFIX = "linux";

    public static EventLoopGroup eventLoopGroup(int threads, String threadFactoryName) {
        ThreadFactory threadFactory = new DefaultThreadFactory(threadFactoryName, true);
        return shouldEpoll() ? new EpollEventLoopGroup(threads, threadFactory) :
                new NioEventLoopGroup(threads, threadFactory);
    }


    public static boolean shouldEpoll() {
        if (Boolean.parseBoolean(System.getProperty(NETTY_EPOLL_ENABLE_KEY, "false"))) {
            String osName = System.getProperty(OS_NAME_KEY);
            return osName.toLowerCase().contains(OS_LINUX_PREFIX) && Epoll.isAvailable();
        }

        return false;
    }


    public static Class<? extends SocketChannel> socketChannelClass() {
        return shouldEpoll() ? EpollSocketChannel.class : NioSocketChannel.class;
    }
}

SslContexts: 创建SslContext对象的工具类

package org.example.netty.tls;

import io.netty.handler.codec.http2.Http2SecurityUtil;
import io.netty.handler.ssl.*;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import lombok.extern.slf4j.Slf4j;

import javax.net.ssl.SSLException;
import java.io.*;
import java.net.MalformedURLException;
import java.security.Provider;
import java.security.Security;
import java.security.cert.CertificateException;

@Slf4j
public class SslContexts {


    public static SslContext createTlsClientSslContext() throws SSLException {
        SslProvider provider = findSslProvider();
        return SslContextBuilder.forClient()
                .sslProvider(provider)
                .trustManager(InsecureTrustManagerFactory.INSTANCE)
                .protocols("TLSv1.3", "TLSv1.2")
                .build();
    }


    /**
     * 创建server SslContext
     * 会自动创建一个临时自签名的证书 -- Generates a temporary self-signed certificate
     *
     * @return
     * @throws CertificateException
     * @throws SSLException
     */
    public static SslContext createTlsServerSslContext() throws CertificateException, SSLException {
        SslProvider provider = findSslProvider();
        SelfSignedCertificate cert = new SelfSignedCertificate();
        return SslContextBuilder.forServer(cert.certificate(), cert.privateKey())
                .sslProvider(provider)
                .protocols("TLSv1.3", "TLSv1.2")
                .build();
    }


    public static SslContext createServerSslContext(File keyCertChainFile, File keyFile) {
        return createServerSslContext(keyCertChainFile, keyFile, null, null);
    }

    public static SslContext createServerSslContext(File keyCertChainFile, File keyFile, File trustCertCollection) {
        return createServerSslContext(keyCertChainFile, keyFile, null, trustCertCollection);
    }

    public static SslContext createServerSslContext(File keyCertChainFile, File keyFile, String keyPassword) {
        return createServerSslContext(keyCertChainFile, keyFile, keyPassword, null);
    }

    public static SslContext createServerSslContext(File keyCertChainFile, File keyFile, String keyPassword, File trustCertCollection) {
        return createServerSslContext(keyCertChainFile, keyFile, keyPassword, trustCertCollection, true);
    }

    public static SslContext createServerSslContext(File keyCertChainFile, File keyFile, String keyPassword, File trustCertCollection, boolean requireClientAuth) {
        try (InputStream keyCertChainInputStream = openInputStream(keyCertChainFile);
             InputStream keyInputStream = openInputStream(keyFile);
             InputStream trustCertCollectionInputStream = openInputStream(trustCertCollection);) {
            return createServerSslContext(keyCertChainInputStream, keyInputStream, keyPassword, trustCertCollectionInputStream, requireClientAuth);
        } catch (IOException e) {
            throw new IllegalArgumentException("Could not find certificate file or the certificate is invalid.", e);
        }
    }




    public static SslContext createServerSslContext(InputStream keyCertChainInputStream, InputStream keyInputStream) {
        return createServerSslContext(keyCertChainInputStream, keyInputStream, null, null);
    }

    public static SslContext createServerSslContext(InputStream keyCertChainInputStream, InputStream keyInputStream, InputStream trustCertCollection) {
        return createServerSslContext(keyCertChainInputStream, keyInputStream, null, trustCertCollection);
    }

    public static SslContext createServerSslContext(InputStream keyCertChainInputStream, InputStream keyInputStream, String keyPassword) {
        return createServerSslContext(keyCertChainInputStream, keyInputStream, keyPassword, null);
    }

    public static SslContext createServerSslContext(InputStream keyCertChainInputStream, InputStream keyInputStream, String keyPassword, InputStream trustCertCollection) {
        return createServerSslContext(keyCertChainInputStream, keyInputStream, keyPassword, trustCertCollection, true);
    }

    public static SslContext createServerSslContext(InputStream keyCertChainInputStream, InputStream keyInputStream, String keyPassword, InputStream trustCertCollection, boolean requireClientAuth) {
        SslContextBuilder builder;
        if (keyPassword != null) {
            builder = SslContextBuilder.forServer(keyCertChainInputStream, keyInputStream, keyPassword);
        } else {
            builder = SslContextBuilder.forServer(keyCertChainInputStream, keyInputStream);
        }
        if (trustCertCollection != null) {
            builder.trustManager(trustCertCollection)
            ;
        }
        if (requireClientAuth) {
            builder.clientAuth(ClientAuth.REQUIRE);
        }

        try {
            SslProvider provider = findSslProvider();
            return builder
                    .sslProvider(provider)
                    .protocols("TLSv1.3", "TLSv1.2")
                    .build();
        } catch (SSLException e) {
            throw new IllegalStateException("Build SslSession failed.", e);
        }
    }


    public static SslContext createClientSslContext() {
        try {
            SslProvider provider = findSslProvider();
            return SslContextBuilder.forClient()
                    .sslProvider(provider)
                    .trustManager(InsecureTrustManagerFactory.INSTANCE)
                    .protocols("TLSv1.3", "TLSv1.2")
                    .build();
        } catch (SSLException e) {
            throw new IllegalStateException("Build SslSession failed.", e);
        }
    }

    public static SslContext createClientSslContext(File trustCertCollection) {
        return createClientSslContext(null, null, null, trustCertCollection);
    }

    public static SslContext createClientSslContext(File keyCertChainInputStream, File keyInputStream, File trustCertCollection) {
        return createClientSslContext(keyCertChainInputStream, keyInputStream, null, trustCertCollection);
    }

    public static SslContext createClientSslContext(File keyCertChainFile, File keyFile, String keyPassword, File trustCertCollectionFile) {
        try (InputStream keyCertChainInputStream = openInputStream(keyCertChainFile);
             InputStream keyInputStream = openInputStream(keyFile);
             InputStream trustCertCollectionInputStream = openInputStream(trustCertCollectionFile);) {
            return createClientSslContext(keyCertChainInputStream, keyInputStream, keyPassword, trustCertCollectionInputStream);
        } catch (IOException e) {
            throw new IllegalArgumentException("Could not find certificate file or the certificate is invalid.", e);
        }
    }

    public static SslContext createClientSslContext(InputStream trustCertCollection) {
        return createClientSslContext(null, null, null, trustCertCollection);
    }

    public static SslContext createClientSslContext(InputStream keyCertChainInputStream, InputStream keyInputStream, InputStream trustCertCollection) {
        return createClientSslContext(keyCertChainInputStream, keyInputStream, null, trustCertCollection);
    }

    public static SslContext createClientSslContext(InputStream keyCertChainInputStream, InputStream keyInputStream, String keyPassword, InputStream trustCertCollection) {
        SslContextBuilder builder = SslContextBuilder.forClient();
        if (trustCertCollection != null) {
            builder.trustManager(trustCertCollection)
                    .clientAuth(ClientAuth.REQUIRE);
        }
        if (keyCertChainInputStream != null || keyInputStream != null) {
            if (keyPassword != null) {
                builder.keyManager(keyCertChainInputStream, keyInputStream, keyPassword);
            } else {
                builder.keyManager(keyCertChainInputStream, keyInputStream);
            }
        }

        try {
            SslProvider provider = findSslProvider();
            return builder
                    .sslProvider(provider)
                    .protocols("TLSv1.3", "TLSv1.2")
                    .build();
        } catch (SSLException e) {
            throw new IllegalStateException("Build SslSession failed.", e);
        }
    }


    /**
     * 创建 https server SslContext
     * 会自动创建一个临时自签名的证书 -- Generates a temporary self-signed certificate
     *
     * @return
     * @throws CertificateException
     * @throws SSLException
     */
    public static SslContext createHttpsServerSslContext() throws CertificateException, SSLException {
        SslProvider provider = findSslProvider();
        SelfSignedCertificate cert = new SelfSignedCertificate();
        return SslContextBuilder.forServer(cert.certificate(), cert.privateKey())
                .sslProvider(provider)
                .ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE)
                .protocols("TLSv1.3", "TLSv1.2")
                .applicationProtocolConfig(
                        new ApplicationProtocolConfig(ApplicationProtocolConfig.Protocol.ALPN, ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE,
                                ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT
                                , ApplicationProtocolNames.HTTP_2, ApplicationProtocolNames.HTTP_1_1
                        ))
                .build();
    }

    public static InputStream openInputStream(File file) throws IOException {
        return file == null ? null : file.toURI().toURL().openStream();
//        return file == null ? null : new FileInputStream(file);
    }

//
//    public static SslContext buildServerSslContext() {
//        SslContextBuilder sslClientContextBuilder;
//        InputStream serverKeyCertChainPathStream = null;
//        InputStream serverPrivateKeyPathStream = null;
//        InputStream serverTrustCertStream = null;
//        try {
//            serverKeyCertChainPathStream = sslConfig.getServerKeyCertChainPathStream();
//            serverPrivateKeyPathStream = sslConfig.getServerPrivateKeyPathStream();
//            serverTrustCertStream = sslConfig.getServerTrustCertCollectionPathStream();
//            String password = sslConfig.getServerKeyPassword();
//            if (password != null) {
//                sslClientContextBuilder = SslContextBuilder.forServer(serverKeyCertChainPathStream,
//                        serverPrivateKeyPathStream, password);
//            } else {
//                sslClientContextBuilder = SslContextBuilder.forServer(serverKeyCertChainPathStream,
//                        serverPrivateKeyPathStream);
//            }
//
//            if (serverTrustCertStream != null) {
//                sslClientContextBuilder.trustManager(serverTrustCertStream);
//                sslClientContextBuilder.clientAuth(ClientAuth.REQUIRE);
//            }
//        } catch (Exception e) {
//            throw new IllegalArgumentException("Could not find certificate file or the certificate is invalid.", e);
//        } finally {
//            safeCloseStream(serverTrustCertStream);
//            safeCloseStream(serverKeyCertChainPathStream);
//            safeCloseStream(serverPrivateKeyPathStream);
//        }
//        try {
//            return sslClientContextBuilder.sslProvider(findSslProvider()).build();
//        } catch (SSLException e) {
//            throw new IllegalStateException("Build SslSession failed.", e);
//        }
//    }
//
//    public static SslContext buildClientSslContext(URL url) {
//
//        SslContextBuilder builder = SslContextBuilder.forClient();
//        InputStream clientTrustCertCollectionPath = null;
//        InputStream clientCertChainFilePath = null;
//        InputStream clientPrivateKeyFilePath = null;
//        try {
//            clientTrustCertCollectionPath = sslConfig.getClientTrustCertCollectionPathStream();
//            if (clientTrustCertCollectionPath != null) {
//                builder.trustManager(clientTrustCertCollectionPath);
//            }
//
//            clientCertChainFilePath = sslConfig.getClientKeyCertChainPathStream();
//            clientPrivateKeyFilePath = sslConfig.getClientPrivateKeyPathStream();
//            if (clientCertChainFilePath != null && clientPrivateKeyFilePath != null) {
//                String password = sslConfig.getClientKeyPassword();
//                if (password != null) {
//                    builder.keyManager(clientCertChainFilePath, clientPrivateKeyFilePath, password);
//                } else {
//                    builder.keyManager(clientCertChainFilePath, clientPrivateKeyFilePath);
//                }
//            }
//        } catch (Exception e) {
//            throw new IllegalArgumentException("Could not find certificate file or find invalid certificate.", e);
//        } finally {
//            safeCloseStream(clientTrustCertCollectionPath);
//            safeCloseStream(clientCertChainFilePath);
//            safeCloseStream(clientPrivateKeyFilePath);
//        }
//        try {
//            return builder.sslProvider(findSslProvider()).build();
//        } catch (SSLException e) {
//            throw new IllegalStateException("Build SslSession failed.", e);
//        }
//    }


    /**
     * Returns OpenSSL if available, otherwise returns the JDK provider.
     */
    private static SslProvider findSslProvider() {
        return SslProvider.isAlpnSupported(SslProvider.OPENSSL) ? SslProvider.OPENSSL : SslProvider.JDK;
//        return SslProvider.OPENSSL;
//        if (OpenSsl.isAvailable()) {
//            log.debug("Using OPENSSL provider.");
//            return SslProvider.OPENSSL;
//        }
//        if (checkJdkProvider()) {
//            log.debug("Using JDK provider.");
//            return SslProvider.JDK;
//        }
//        throw new IllegalStateException(
//                "Could not find any valid TLS provider, please check your dependency or deployment environment, " +
//                        "usually netty-tcnative, Conscrypt, or Jetty NPN/ALPN is needed.");
    }

    private static boolean checkJdkProvider() {
        Provider[] jdkProviders = Security.getProviders("SSLContext.TLS");
        return (jdkProviders != null && jdkProviders.length > 0);
    }

    private static void safeCloseStream(InputStream stream) {
        if (stream == null) {
            return;
        }
        try {
            stream.close();
        } catch (IOException e) {
            log.warn("Failed to close a stream.", e);
        }
    }
}

NettySslCxt: 对SslContext进行了包装,同时加了一个标识是否是mtls的boolean 成员变量

package org.example.netty.tls;

import io.netty.handler.ssl.SslContext;
import lombok.Data;

@Data
public class NettySslCxt {
    private final boolean mtls;
    private final SslContext sslContext;
}

NettyTLSServer: Netty Server的代码

package org.example.netty.tls;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.*;
import io.netty.channel.epoll.EpollServerSocketChannel;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.LineBasedFrameDecoder;
import io.netty.handler.codec.MessageToByteEncoder;
import io.netty.handler.ssl.SslContext;
import lombok.extern.slf4j.Slf4j;
import org.example.netty.NettyHelper;

import javax.net.ssl.SSLException;
import java.net.InetSocketAddress;
import java.security.cert.CertificateException;

@Slf4j
public class NettyTLSServer {


    private InetSocketAddress bindAddress;
    private ServerBootstrap bootstrap;
    private EventLoopGroup bossGroup;
    private EventLoopGroup workerGroup;
    private ByteToMessageDecoder messageDecoder = new LineBasedFrameDecoder(Integer.MAX_VALUE);
    private MessageToByteEncoder encodeHandler;
    private ByteToMessageDecoder decodeHandler;
    private ChannelHandler handler;
    private NettySslCxt sslCxt;

    public NettyTLSServer(NettySslCxt sslCxt, MessageToByteEncoder encodeHandler, ByteToMessageDecoder decodeHandler, ChannelHandler handler) {
        this(8080, sslCxt, encodeHandler, decodeHandler, handler);
    }

    public NettyTLSServer(int bindPort, NettySslCxt sslCxt, MessageToByteEncoder encodeHandler, ByteToMessageDecoder decodeHandler, ChannelHandler handler) {
        this("localhost", bindPort, sslCxt, encodeHandler, decodeHandler, handler);
    }

    public NettyTLSServer(String bindIp, int bindPort, NettySslCxt sslCxt, MessageToByteEncoder encodeHandler, ByteToMessageDecoder decodeHandler, ChannelHandler handler) {
        bindAddress = new InetSocketAddress(bindIp, bindPort);
        this.handler = handler;
        this.encodeHandler = encodeHandler;
        this.decodeHandler = decodeHandler;
        this.sslCxt = sslCxt;
    }


    public void init() throws CertificateException, SSLException {
        initServerBootstrap();

        bootstrap.childHandler(new ChannelInitializer<SocketChannel>() {

            @Override
            protected void initChannel(SocketChannel ch) throws Exception {
                log.info("accept client: {} {}", ch.remoteAddress().getHostName(), ch.remoteAddress().getPort());
                ChannelPipeline pipeline = ch.pipeline();
                pipeline.addLast(sslCxt.getSslContext().newHandler(ch.alloc()));
                pipeline.addLast(new TlsHandler(true, isEnableMTls()));
                pipeline.addLast(messageDecoder)
                        .addLast(decodeHandler)
                        .addLast(encodeHandler)
                        .addLast(handler);
            }
        });
    }

    public void bind(boolean sync) throws CertificateException, SSLException {
        init();
        try {
            ChannelFuture channelFuture = bootstrap.bind(bindAddress).sync();
            if (channelFuture.isDone()) {
                log.info("netty server start at house and port: {} ", bindAddress.getPort());
            }
            io.netty.channel.Channel channel = channelFuture.channel();
            ChannelFuture closeFuture = channel.closeFuture();
            if (sync) {
                closeFuture.sync();
            }
        } catch (Exception e) {
            log.error("netty server start exception,", e);
        } finally {
            if (sync) {
                shutdown();
            }
        }
    }


    public void shutdown() {
        log.info("netty server shutdown");
        log.info("netty server shutdown bossEventLoopGroup&workerEventLoopGroup gracefully");
        bossGroup.shutdownGracefully();
        workerGroup.shutdownGracefully();
    }


    private void initServerBootstrap() {
        bootstrap = new ServerBootstrap();

        bossGroup = NettyHelper.eventLoopGroup(1, "NettyServerBoss");
        workerGroup = NettyHelper.eventLoopGroup(Math.min(Runtime.getRuntime().availableProcessors() + 1, 32), "NettyServerWorker");

        bootstrap.group(bossGroup, workerGroup)
                .channel(NettyHelper.shouldEpoll() ? EpollServerSocketChannel.class : NioServerSocketChannel.class)
                .option(ChannelOption.SO_REUSEADDR, Boolean.TRUE)
                .childOption(ChannelOption.TCP_NODELAY, Boolean.TRUE)
                .childOption(ChannelOption.SO_KEEPALIVE, Boolean.TRUE)
                .childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
                .childOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, 10000);
    }


    public ByteToMessageDecoder getMessageDecoder() {
        return messageDecoder;
    }

    public void setMessageDecoder(ByteToMessageDecoder messageDecoder) {
        this.messageDecoder = messageDecoder;
    }


    boolean isEnableMTls() {
        return sslCxt.isMtls();
    }


}

NettyTLSClient:Netty Client的代码

package org.example.netty.tls;

import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.channel.*;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.LineBasedFrameDecoder;
import io.netty.handler.codec.MessageToByteEncoder;
import io.netty.handler.ssl.SslContext;
import lombok.extern.slf4j.Slf4j;
import org.example.netty.NettyHelper;

import javax.net.ssl.SSLException;
import java.net.InetSocketAddress;
import java.util.concurrent.atomic.AtomicReference;

@Slf4j
public class NettyTLSClient {


    private InetSocketAddress serverAddress;

    private Bootstrap bootstrap;
    private EventLoopGroup workerGroup;

    private Channel channel;

    private ByteToMessageDecoder messageDecoder = new LineBasedFrameDecoder(Integer.MAX_VALUE);
    private MessageToByteEncoder encodeHandler;
    private ByteToMessageDecoder decodeHandler;

    private ChannelHandler handler;

    private NettySslCxt sslCxt;

    public NettyTLSClient(String severHost, int serverPort, NettySslCxt sslCxt, MessageToByteEncoder encodeHandler, ByteToMessageDecoder decodeHandler, ChannelHandler handler) {
        serverAddress = new InetSocketAddress(severHost, serverPort);
        this.encodeHandler = encodeHandler;
        this.decodeHandler = decodeHandler;
        this.handler = handler;
        this.sslCxt = sslCxt;
    }

    public ChannelFuture connect() throws SSLException {
        init();
        final ChannelFuture promise = bootstrap.connect(serverAddress.getHostName(), serverAddress.getPort());
//        final ChannelFuture promise = bootstrap.connect();
        promise.addListener(future -> {
            log.info("client connect to server: {}", future.isSuccess());
        });
        channel = promise.channel();
        return promise;
    }


    public void init() throws SSLException {
        initBootstrap();
        bootstrap.handler(new ChannelInitializer<SocketChannel>() {

            @Override
            protected void initChannel(SocketChannel ch) {
                final ChannelPipeline pipeline = ch.pipeline();
                pipeline
//                        .addLast(sslCtx.newHandler(ch.alloc(), serverAddress.getHostName(), serverAddress.getPort()))
                        .addLast(sslCxt.getSslContext().newHandler(ch.alloc()))
                        .addLast(new TlsHandler(false, isEnableMTls()))
                        .addLast(messageDecoder)
                        .addLast(decodeHandler)
                        .addLast(encodeHandler)
                        .addLast(handler);

            }
        });
    }

    private SslContext createSslContext() throws SSLException {
        if (isEnableMTls()) {
            return null;
        } else {
            return SslContexts.createTlsClientSslContext();
        }
    }

    private void initBootstrap() {
        bootstrap = new Bootstrap();
        workerGroup = NettyHelper.eventLoopGroup(1, "NettyClientWorker");
        bootstrap.group(workerGroup)
                .option(ChannelOption.SO_KEEPALIVE, true)
                .option(ChannelOption.TCP_NODELAY, true)
                .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
                .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 10000)
                .remoteAddress(serverAddress)
                .channel(NettyHelper.socketChannelClass());
    }

    public void shutdown() {
        log.info("netty client shutdown");
        channel.closeFuture()
                .addListener(future -> {
                    log.info("netty client shutdown workerEventLoopGroup gracefully");
                    workerGroup.shutdownGracefully();
                });
    }

    public Channel getChannel() {
        return channel;
    }

    public ByteToMessageDecoder getMessageDecoder() {
        return messageDecoder;
    }

    public void setMessageDecoder(ByteToMessageDecoder messageDecoder) {
        this.messageDecoder = messageDecoder;
    }

    boolean isEnableMTls() {
        return sslCxt.isMtls();
    }
}

TlsHandler:检查ssl handshake,在mtls场景下打印对端的证书信息,否则在client打印服务端的证书信息。

package org.example.netty.tls;

import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import lombok.extern.slf4j.Slf4j;

import javax.net.ssl.SSLSession;
import javax.security.cert.X509Certificate;
import java.text.SimpleDateFormat;
import java.util.Date;



@Slf4j
public class TlsHandler extends ChannelDuplexHandler {

    private boolean serverSide;
    private boolean mtls;

    public TlsHandler(boolean serverSide, boolean mtls) {
        this.serverSide = serverSide;
        this.mtls = mtls;
    }

    @Override
    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        ctx.pipeline().get(SslHandler.class).handshakeFuture().addListener(
                new GenericFutureListener<Future<Channel>>() {
                    @Override
                    public void operationComplete(Future<Channel> future) throws Exception {
                        if (future.isSuccess()) {
                            log.info("[{}] {} 握手成功", getSideType(), ctx.channel().remoteAddress());
                            SSLSession ss = ctx.pipeline().get(SslHandler.class).engine().getSession();
                            log.info("[{}] {} cipherSuite: {}", getSideType(), ctx.channel().remoteAddress(), ss.getCipherSuite());
                            if (mtls || !serverSide) {
                                X509Certificate cert = ss.getPeerCertificateChain()[0];
                                String info = null;
                                // 获得证书版本
                                info = String.valueOf(cert.getVersion());
                                System.out.println("证书版本:" + info);
                                // 获得证书序列号
                                info = cert.getSerialNumber().toString(16);
                                System.out.println("证书序列号:" + info);
                                // 获得证书有效期
                                Date beforedate = cert.getNotBefore();
                                info = new SimpleDateFormat("yyyy/MM/dd").format(beforedate);
                                System.out.println("证书生效日期:" + info);
                                Date afterdate = (Date) cert.getNotAfter();
                                info = new SimpleDateFormat("yyyy/MM/dd").format(afterdate);
                                System.out.println("证书失效日期:" + info);
                                // 获得证书主体信息
                                info = cert.getSubjectDN().getName();
                                System.out.println("证书拥有者:" + info);
                                // 获得证书颁发者信息
                                info = cert.getIssuerDN().getName();
                                System.out.println("证书颁发者:" + info);
                                // 获得证书签名算法名称
                                info = cert.getSigAlgName();
                                System.out.println("证书签名算法:" + info);
                            }
                        } else {
                            log.warn("[{}] {} 握手失败,关闭连接", getSideType(), ctx.channel().remoteAddress());
                            ctx.channel().closeFuture().addListener(closeFuture -> {
                                log.info("[{}] {} 关闭连接:{}", getSideType(), ctx.channel().remoteAddress(), closeFuture.isSuccess());
                            });
                        }
                    }
                });

        SocketChannel channel = (SocketChannel) ctx.channel();
        System.out.println(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss").format(new Date()) + " conn:");
        System.out.println("IP:" + channel.localAddress().getHostString());
        System.out.println("Port:" + channel.localAddress().getPort());
    }


    private String getSideType() {
        return serverSide ? "SERVER" : "CLIENT";
    }
}

NettyTLSMain.java: 测试Main Class

package org.example.netty.tls;

import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.codec.LineBasedFrameDecoder;
import io.netty.handler.codec.MessageToByteEncoder;
import lombok.extern.slf4j.Slf4j;

import javax.net.ssl.SSLException;
import java.io.File;
import java.nio.charset.StandardCharsets;
import java.security.cert.CertificateException;
import java.util.List;
import java.util.Scanner;

public class NettyTLSMain {

    private static String serverHost = "localhost";
    private static int serverPort = 10001;

    public static void main(String[] args) throws CertificateException, SSLException {

        // 自签名的tls
//        nettyTlsLauncher(new NettySslCxt(false, SslContexts.createTlsServerSslContext()), new NettySslCxt(false, SslContexts.createTlsClientSslContext()));

        String certDir = "~/Desktop/test/cert/uncrypted4";
        // 指定证书的tls
//        nettyTlsLauncher(
//                new NettySslCxt(false, SslContexts.createServerSslContext(
//                        new File(certDir + "/server.crt"),
//                        new File(certDir + "/pkcs8_server.key"),
//                        null, null, false)),
//                new NettySslCxt(false, SslContexts.createClientSslContext()));

        //指定证书的tls -Djdk.security.allowNonCaAnchor=true
//        nettyTlsLauncher(
//                new NettySslCxt(false, SslContexts.createServerSslContext(
//                        new File(certDir + "/server.crt"),
//                        new File(certDir + "/pkcs8_server.key"),
//                        null,
//                        new File(certDir + "/ca.crt"),
//                        false)),
//                new NettySslCxt(false, SslContexts.createClientSslContext(
//                        new File(certDir + "/ca.crt"))));


        //
//        nettyTlsLauncher(
//                new NettySslCxt(true, SslContexts.createServerSslContext(
//                        new File(certDir + "/server.crt"),
//                        new File(certDir + "/server.key"),
//                        new File(certDir + "/ca.crt"))),
//                new NettySslCxt(true, SslContexts.createClientSslContext(
//                        new File(certDir + "/client.crt"),
//                        new File(certDir + "/client.key"),
//                        new File(certDir + "/ca.crt"))));

        nettyTlsLauncher(
                new NettySslCxt(true, SslContexts.createServerSslContext(
                        new File(certDir + "/server.crt"),
                        new File(certDir + "/server.key"),
                        new File(certDir + "/ca4.crt"))),
                new NettySslCxt(true, SslContexts.createClientSslContext(
                        new File(certDir + "/client.crt"),
                        new File(certDir + "/client.key"),
                        new File(certDir + "/ca.crt"))));


//        String serCertDir = "~/Desktop/test/cert/uncrypted_merge_2_3";
        String serCertDir = "~/Desktop/test/cert/uncrypted3";
//        String certDir2 = "~/Desktop/test/cert/uncrypted3";
//        String certDir3 = "~/Desktop/test/cert/uncrypted3";
//        nettyTlsLauncher(
//                new NettySslCxt(true, SslContexts.createServerSslContext(
//                        new File(serCertDir + "/server.crt"),
//                        new File(serCertDir + "/server.key"),
//                        new File(serCertDir + "/ca.crt"))),
//                new NettySslCxt(true, SslContexts.createClientSslContext(
//                        new File(certDir2 + "/client.crt"),
//                        new File(certDir2 + "/client.key"),
//                        new File(certDir2 + "/ca.crt")))
                , new NettySslCxt(false, SslContexts.createClientSslContext(
                        new File(certDir3 + "/client.crt"),
                        new File(certDir3 + "/client.key"),
                        new File(certDir3 + "/ca.crt")))
//        );

    }

    private static void nettyTlsLauncher(NettySslCxt serverSslCxt, NettySslCxt clientSslCxt) throws CertificateException, SSLException {
        nettyTlsLauncher(serverSslCxt, clientSslCxt, null);
    }

    private static void nettyTlsLauncher(NettySslCxt serverSslCxt, NettySslCxt clientSslCxt, NettySslCxt clientSslCxt2) throws CertificateException, SSLException {
        NettyTLSServer server = new NettyTLSServer(
                serverHost, serverPort, serverSslCxt, new Encoder(), new Decoder(), new StringServerChannelHandler());
        server.setMessageDecoder(new LineBasedFrameDecoder(Integer.MAX_VALUE));
        server.bind(false);


        NettyTLSClient client = new NettyTLSClient(
                serverHost, serverPort, clientSslCxt, new Encoder(), new Decoder(), new StringClientChannelHandler());
        client.setMessageDecoder(new LineBasedFrameDecoder(Integer.MAX_VALUE));
        client.connect().addListener(future -> {
            if (future.isSuccess()) {
                client.getChannel().writeAndFlush("--test--");
            }
        });

        NettyTLSClient client2;
        if (clientSslCxt2 != null) {
            client2 = new NettyTLSClient(
                    serverHost, serverPort, clientSslCxt2, new Encoder(), new Decoder(), new StringClientChannelHandler());
            client2.setMessageDecoder(new LineBasedFrameDecoder(Integer.MAX_VALUE));
            client2.connect().addListener(future -> {
                if (future.isSuccess()) {
                    client2.getChannel().writeAndFlush("--test2--");
                }
            });
        } else {
            client2 = null;
        }

        Scanner scanner = new Scanner(System.in);

        while (true) {
            System.out.println("waiting input");
            String line = scanner.nextLine();
            if ("exit".equals(line) || "eq".equals(line) || "quit".equals(line)) {
                client.shutdown();
                server.shutdown();
                return;
            }
            client.getChannel().writeAndFlush(line);
            if (client2 != null) {
                client2.getChannel().writeAndFlush("c2:" + line);
            }

        }
    }


    public static class Decoder extends ByteToMessageDecoder {

        @Override
        protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
            byte[] b = new byte[in.readableBytes()];
            in.readBytes(b);
            out.add(new String(b, StandardCharsets.UTF_8));
        }
    }

    @ChannelHandler.Sharable
    public static class Encoder extends MessageToByteEncoder<String> {

        @Override
        protected void encode(ChannelHandlerContext ctx, String msg, ByteBuf out) throws Exception {
            out.writeBytes((msg + "\n").getBytes(StandardCharsets.UTF_8));
        }
    }


    @Slf4j
    public static class StringClientChannelHandler extends ChannelDuplexHandler {

        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
            log.info("received message from server: {}", msg);
            super.channelRead(ctx, msg);
        }
    }

    @Slf4j
    public static class StringServerChannelHandler extends ChannelDuplexHandler {

        @Override
        public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
            log.info("received message from client: {}", msg);
            ctx.writeAndFlush("server response: " + msg);
        }

        @Override
        public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
            log.info("occur exception, close channel:{}.", ctx.channel().remoteAddress(), cause);
            ctx.channel().closeFuture()
                    .addListener(future -> {
                        log.info("close client channel {}: {}",
                                ctx.channel().remoteAddress(),
                                future.isSuccess());
                    });
        }
    }
}

参考

netty实现TLS/SSL双向加密认证
Netty+OpenSSL TCP双向认证证书配置
基于Netty的MQTT Server实现并支持SSL
Netty tls验证
netty使用ssl双向认证
netty中实现双向认证的SSL连接
记一次TrustAnchor with subject异常解决
SpringBoot (WebFlux Netty) 支持动态更换https证书
手动实现CA数字认证(java)
java编程方式生成CA证书
netty https有什么方式根据域名设置证书?

  • 16
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
以下是使用Java代码实现采用HTTPS和TLS 1.2版本建立连接,并完成双向TLS认证(mTLS)的示例: ```java import java.io.BufferedReader; import java.io.InputStream; import java.io.InputStreamReader; import java.net.URL; import java.security.KeyStore; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLException; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.X509TrustManager; public class SSLExample { public static void main(String[] args) throws Exception { // Load client certificate and private key for mTLS String keyStorePath = "client.jks"; String keyStorePassword = "password"; String keyPassword = "password"; KeyStore keyStore = KeyStore.getInstance("JKS"); keyStore.load(SSLExample.class.getClassLoader().getResourceAsStream(keyStorePath), keyStorePassword.toCharArray()); KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); keyManagerFactory.init(keyStore, keyPassword.toCharArray()); // Load server truststore for server certificate verification String trustStorePath = "server.jks"; String trustStorePassword = "password"; KeyStore trustStore = KeyStore.getInstance("JKS"); trustStore.load(SSLExample.class.getClassLoader().getResourceAsStream(trustStorePath), trustStorePassword.toCharArray()); TrustManagerFactory trustManagerFactory = TrustManagerFactory .getInstance(TrustManagerFactory.getDefaultAlgorithm()); trustManagerFactory.init(trustStore); // Create SSL context with TLS 1.2 protocol and mTLS configuration SSLContext sslContext = SSLContext.getInstance("TLSv1.2"); sslContext.init(keyManagerFactory.getKeyManagers(), trustManagerFactory.getTrustManagers(), null); // Set default SSL context for HTTPS connection HttpsURLConnection.setDefaultSSLSocketFactory(sslContext.getSocketFactory()); // Create URL object for HTTPS endpoint URL url = new URL("https://example.com/api"); // Open HTTPS connection HttpsURLConnection connection = (HttpsURLConnection) url.openConnection(); // Set request method and headers connection.setRequestMethod("GET"); connection.setRequestProperty("User-Agent", "Mozilla/5.0"); // Get response from HTTPS endpoint try (InputStream inputStream = connection.getInputStream(); BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream))) { String line; while ((line = reader.readLine()) != null) { System.out.println(line); } } catch (SSLException e) { e.printStackTrace(); } finally { connection.disconnect(); } } // Trust manager to accept all server certificates private static final TrustManager[] TRUST_ALL_CERTIFICATES = new TrustManager[] { new X509TrustManager() { public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { // Do nothing } public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { // Do nothing } public X509Certificate[] getAcceptedIssuers() { return new X509Certificate[0]; } } }; } ``` 在上面的示例中,我们使用了Java的`HttpsURLConnection`类来建立连接。我们首先加载客户端证书和私钥,并使用它们来创建`KeyManager`对象。然后,我们加载服务器信任库并使用它来创建`TrustManager`对象。接下来,我们使用这些`KeyManager`和`TrustManager`对象创建一个SSL上下文对象,该对象使用TLS 1.2协议并完成mTLS配置。最后,我们将默认的SSL上下文设置为HTTPS连接,并打开连接以向服务器发送请求。 在此示例中,我们还提供了一个`TrustManager`实现,它接受所有服务器证书。这不是一个好的做法,因为它会使您的应用程序容易受到中间人攻击。在实际生产环境中,您应该使用一个更安全的`TrustManager`实现,它会验证服务器证书并拒绝不受信任的证书。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值