此处观看更加
本文的完整代码可以在此查看
这段时间不是一直想要使用Netty模仿一下dubbo,自己写一个RPC框架嘛,然后在学习的过程中还是学到了不少新东西。我们知道在网络传输中,数据都是以二进制流来传输,但是在Java中数据都是以对象的形式来存储,所以我们想要传输数据,这就涉及到对象的序列化以及反序列化了。而我们知道,不同的序列化协议适用不同的应用场景,jdk原生的序列化方式因为其性能原因绝大多数的人都不会考虑使用它,而我们想要写一个高性能的RPC框架,一个合适的序列化协议自然也是重中之重,因为目前所构思的RPC框架只是 Java to Java,所以我选择Kryo作为序列/反序列化的方式。
其他的序列化方式还有很多,他们都各自不同的优缺点,和不同的使用场景。想要深入了解的同学不妨参考一下下面的的文章:
序列化协议是应用层的协议
kryo的使用方法可以参靠下面这篇文章:
2.核心代码目录结构 C:. ├─client 客户端代码 ├─codec 编解码器 ├─entity 实体类 └─server 服务端代码2.1.导入依赖
首先,先导入Netty的依赖,我是直接导入的Netty的所有模块
io.netty >netty-all4.1.68.Final
然后导入Kryo的相关依赖,因为我使用的KryoUtil,还需要导入commons-codec的依赖,因为spring的相关依赖会和Kryo的依赖冲突,所以直接导入的Kryo-shaded,具体的原因读者可以自行百度,我不再赘述
com.esotericsoftware
kryo-shaded
4.0.0
commons-codec
commons-codec
1.10
再导入一些其他的依赖
org.junit.jupiter
junit-jupiter-api
5.8.2
test
org.apache.logging.log4j
log4j-slf4j-impl
2.6.2
org.apache.logging.log4j
log4j-core
2.6.2
org.projectlombok
lombok
1.16.16
2.2.KryoUtil
编写Kryo工具类,用于后面的序列反序列化对象
package cuit.epoch.pymjl.util;
import com.esotericsoftware.kryo.Kryo;
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;
import org.apache.commons.codec.binary.Base64;
import org.objenesis.strategy.StdInstantiatorStrategy;
import java.io.*;
public class KryoUtil {
private static final String DEFAULT_ENCODING = "UTF-8";
//每个线程的 Kryo 实例
private static final ThreadLocal kryoLocal = new ThreadLocal() {
@Override
protected Kryo initialValue() {
Kryo kryo = new Kryo();
//支持对象循环引用(否则会栈溢出)
kryo.setReferences(true); //默认值就是 true,添加此行的目的是为了提醒维护者,不要改变这个配置
//不强制要求注册类(注册行为无法保证多个 JVM 内同一个类的注册编号相同;而且业务系统中大量的 Class 也难以一一注册)
kryo.setRegistrationRequired(false); //默认值就是 false,添加此行的目的是为了提醒维护者,不要改变这个配置
//Fix the NPE bug when deserializing Collections.
((Kryo.DefaultInstantiatorStrategy) kryo.getInstantiatorStrategy())
.setFallbackInstantiatorStrategy(new StdInstantiatorStrategy());
return kryo;
}
};
public static Kryo getInstance() {
return kryoLocal.get();
}
//-----------------------------------------------
// 序列化/反序列化对象,及类型信息
// 序列化的结果里,包含类型的信息
// 反序列化时不再需要提供类型
//-----------------------------------------------
public static byte[] writeToByteArray(T obj) {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
Output output = new Output(byteArrayOutputStream);
Kryo kryo = getInstance();
kryo.writeClassAndObject(output, obj);
output.flush();
return byteArrayOutputStream.toByteArray();
}
public static String writeToString(T obj) {
try {
return new String(Base64.encodeBase64(writeToByteArray(obj)), DEFAULT_ENCODING);
} catch (UnsupportedEncodingException e) {
throw new IllegalStateException(e);
}
}
@SuppressWarnings("unchecked")
public static T readFromByteArray(byte[] byteArray) {
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(byteArray);
Input input = new Input(byteArrayInputStream);
Kryo kryo = getInstance();
return (T) kryo.readClassAndObject(input);
}
public static T readFromString(String str) {
try {
return readFromByteArray(Base64.decodeBase64(str.getBytes(DEFAULT_ENCODING)));
} catch (UnsupportedEncodingException e) {
throw new IllegalStateException(e);
}
}
//-----------------------------------------------
// 只序列化/反序列化对象
// 序列化的结果里,不包含类型的信息
//-----------------------------------------------
public static byte[] writeObjectToByteArray(T obj) {
ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
Output output = new Output(byteArrayOutputStream);
Kryo kryo = getInstance();
kryo.writeObject(output, obj);
output.flush();
return byteArrayOutputStream.toByteArray();
}
public static String writeObjectToString(T obj) {
try {
return new String(Base64.encodeBase64(writeObjectToByteArray(obj)), DEFAULT_ENCODING);
} catch (UnsupportedEncodingException e) {
throw new IllegalStateException(e);
}
}
@SuppressWarnings("unchecked")
public static T readObjectFromByteArray(byte[] byteArray, Class clazz) {
ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(byteArray);
Input input = new Input(byteArrayInputStream);
Kryo kryo = getInstance();
return kryo.readObject(input, clazz);
}
public static T readObjectFromString(String str, Class clazz) {
try {
return readObjectFromByteArray(Base64.decodeBase64(str.getBytes(DEFAULT_ENCODING)), clazz);
} catch (UnsupportedEncodingException e) {
throw new IllegalStateException(e);
}
}
}
2.3.编写实体类
- 编写RPC请求类
package cuit.epoch.pymjl.nettydemo.entity;
import lombok.*;
@AllArgsConstructor
@NoArgsConstructor
@Data
@Builder
@ToString
public class RpcRequest {
private String interfaceName;
private String methodName;
}
- 编写RPC响应类
package cuit.epoch.pymjl.nettydemo.entity;
import lombok.*;
@NoArgsConstructor
@AllArgsConstructor
@Data
@ToString
@Builder
public class RpcResponse {
private String message;
}
2.4.编写编解码器
注意,对数据进行编码时因为TCP粘包/拆包的原因,我们这里需要自定义传输协议,然后我这里是:把传输数据的长度写在字节数组的前面四个字节中,服务端在读取数据时会先从前四个字节获取到这次传输数据的长度,在对数据进行都写操作
另外,我们需要在编解码器中将对象序列化成字节数组或者将字节数组反序列化成原对象
- 编码器
package cuit.epoch.pymjl.nettydemo.codec; import cuit.epoch.pymjl.util.KryoUtil; import io.netty.buffer.ByteBuf; import io.netty.channel.ChannelHandlerContext; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.MessageToByteEncoder; import lombok.AllArgsConstructor; @AllArgsConstructor public class NettyKryoEncoder extends MessageToByteEncoder
- 解码器
对数据进行解码的时候需要注意此次接收到的数据是否齐全
package cuit.epoch.pymjl.nettydemo.codec;
import cuit.epoch.pymjl.util.KryoUtil;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import java.util.List;
@AllArgsConstructor
@Log4j2
public class NettyKryoDecoder extends ByteToMessageDecoder {
private final Class> clazz;
private static final int BODY_LENGTH = 4;
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List
2.5.服务端
- 先使用Netty初始化服务,让服务端循环监听客户端的请求
package cuit.epoch.pymjl.nettydemo.server;
import cuit.epoch.pymjl.nettydemo.codec.NettyKryoDecoder;
import cuit.epoch.pymjl.nettydemo.codec.NettyKryoEncoder;
import cuit.epoch.pymjl.nettydemo.entity.RpcRequest;
import cuit.epoch.pymjl.nettydemo.entity.RpcResponse;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import lombok.extern.log4j.Log4j2;
@Log4j2
public class NettyServer {
private final int port;
public NettyServer(int port) {
this.port = port;
}
public void run() {
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
// TCP默认开启了 Nagle 算法,该算法的作用是尽可能的发送大数据快,减少网络传输。TCP_NODELAY 参数的作用就是控制是否启用 Nagle 算法。
.childOption(ChannelOption.TCP_NODELAY, true)
// 是否开启 TCP 底层心跳机制
.childOption(ChannelOption.SO_KEEPALIVE, true)
//表示系统用于临时存放已完成三次握手的请求的队列的最大长度,如果连接建立频繁,服务器处理创建新连接较慢,可以适当调大这个参数
.option(ChannelOption.SO_BACKLOG, 128)
.handler(new LoggingHandler(LogLevel.INFO))
.childHandler(new ChannelInitializer() {
@Override
protected void initChannel(SocketChannel ch) {
ch.pipeline().addLast(new NettyKryoDecoder(RpcRequest.class));
ch.pipeline().addLast(new NettyKryoEncoder(RpcResponse.class));
ch.pipeline().addLast(new NettyServerHandler());
}
});
// 绑定端口,同步等待绑定成功
ChannelFuture f = b.bind(port).sync();
log.info("Netty server start success, port: {}", port);
// 等待服务端监听端口关闭
f.channel().closeFuture().sync();
} catch (InterruptedException e) {
log.error("occur exception when start server:", e);
} finally {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}
}
- 自定义服务端的Handler,处理业务
package cuit.epoch.pymjl.nettydemo.server;
import cuit.epoch.pymjl.nettydemo.entity.RpcRequest;
import cuit.epoch.pymjl.nettydemo.entity.RpcResponse;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.ReferenceCountUtil;
import lombok.extern.log4j.Log4j2;
import java.util.concurrent.atomic.AtomicInteger;
@Log4j2
public class NettyServerHandler extends ChannelInboundHandlerAdapter {
private static final AtomicInteger atomicInteger = new AtomicInteger(1);
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
try {
RpcRequest rpcRequest = (RpcRequest) msg;
log.info("server receive msg: [{}] ,times:[{}]", rpcRequest, atomicInteger.getAndIncrement());
RpcResponse messageFromServer = RpcResponse.builder().message("message from server").build();
ChannelFuture f = ctx.writeAndFlush(messageFromServer);
f.addListener(ChannelFutureListener.CLOSE);
} finally {
ReferenceCountUtil.release(msg);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
log.error("server catch exception", cause);
ctx.close();
}
}
2.6.客户端
- 初始化客户端
package cuit.epoch.pymjl.nettydemo.client;
import cuit.epoch.pymjl.nettydemo.codec.NettyKryoDecoder;
import cuit.epoch.pymjl.nettydemo.codec.NettyKryoEncoder;
import cuit.epoch.pymjl.nettydemo.entity.RpcRequest;
import cuit.epoch.pymjl.nettydemo.entity.RpcResponse;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.AttributeKey;
import lombok.extern.log4j.Log4j2;
@Log4j2
public class NettyClient {
private final String host;
private final int port;
private static final Bootstrap b;
public NettyClient(String host, int port) {
this.host = host;
this.port = port;
}
static {
EventLoopGroup group = new NioEventLoopGroup();
b = new Bootstrap();
b.group(group)
.channel(NioSocketChannel.class)
.handler(new LoggingHandler(LogLevel.INFO))
//设置连接的超时时间,超过这个时间则代表连接失败
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 5000)
.handler(new ChannelInitializer() {
@Override
protected void initChannel(SocketChannel socketChannel) throws Exception {
socketChannel.pipeline().addLast(new NettyKryoDecoder(RpcResponse.class));
socketChannel.pipeline().addLast(new NettyKryoEncoder(RpcRequest.class));
socketChannel.pipeline().addLast(new NettyClientHandler());
}
});
}
public RpcResponse sendMessage(RpcRequest rpcRequest) {
try {
ChannelFuture f = b.connect(host, port).sync();
log.info("client connect server success ==> {}:{}", host, port);
Channel futureChannel = f.channel();
log.info("client start send message");
if (futureChannel != null) {
futureChannel.writeAndFlush(rpcRequest).addListener(channelFuture -> {
if (channelFuture.isSuccess()) {
log.info("client send message success ==> [{}]", rpcRequest);
} else {
log.error("send failed cause: ", channelFuture.cause());
}
});
}
//阻塞等待服务器返回结果
f.channel().closeFuture().sync();
//获取返回结果
AttributeKey key = AttributeKey.valueOf("response");
RpcResponse rpcResponse = futureChannel.attr(key).get();
if (rpcResponse != null) {
log.info("RpcResponse is [{}]", rpcResponse);
return rpcResponse;
} else {
log.error("RpcResponse is Null");
}
} catch (InterruptedException e) {
e.printStackTrace();
}
return null;
}
}
- 自定义客户端Handler
package cuit.epoch.pymjl.nettydemo.client;
import cuit.epoch.pymjl.nettydemo.entity.RpcResponse;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.AttributeKey;
import io.netty.util.ReferenceCountUtil;
import lombok.extern.log4j.Log4j2;
@Log4j2
public class NettyClientHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
try {
RpcResponse response = (RpcResponse) msg;
log.info("handler client receive response from server, response={}", response.toString());
//声明一个AttributeKey对象
AttributeKey key = AttributeKey.valueOf("response");
ctx.channel().attr(key).set(response);
ctx.channel().close();
} finally {
ReferenceCountUtil.release(msg);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
log.error("client caught exception", cause);
ctx.close();
}
}
3.测试
3.1.服务端启动类
package cuit.epoch.pymjl.nettydemo.server;
public class ServerMain {
public static void main(String[] args) {
new NettyServer(8080).run();
}
}
3.2.客户端启动类
package cuit.epoch.pymjl.nettydemo.client;
import cuit.epoch.pymjl.nettydemo.entity.RpcRequest;
import cuit.epoch.pymjl.nettydemo.entity.RpcResponse;
public class ClientMain {
public static void main(String[] args) {
RpcRequest rpcRequest = RpcRequest.builder()
.interfaceName("interface")
.methodName("hello").build();
NettyClient nettyClient = new NettyClient("127.0.0.1", 8080);
for (int i = 0; i < 3; i++) {
nettyClient.sendMessage(rpcRequest);
}
RpcResponse rpcResponse = nettyClient.sendMessage(rpcRequest);
System.out.println(rpcResponse.toString());
}
}
3.3.运行项目
- 启动服务端
- 启动客户端
至此,整个项目就成功运行了
4.小结至此,Netty使用Kryo序列化对象传输数据的Demo到此为止了,这中间其实还有很多细节的地方我没有多说。比如Kryo线程不安全,需要使用ThreadLocal来保证线程安全,这进而又引出一个问题,ThreadLocal是什么?它为什么能保证线程安全?这些问题我会在后面更新一篇文章详细解释。
除此之外,Netty传输的异步机制Listenner的相关知识点我也没有做详细讲解,以及AttrbuteKey,AttributeKeyMap等,这些知识点目前笔者也是一知半解,在没有熟练掌握这些知识点前我也不敢细说,误人子弟。等我后面详细研究之后,理解通透后再更新相关的讲解文章.
最后,附上我的Log4j2的配置文件
%d %highlight{%-5level}{ERROR=Bright RED, WARN=Bright Yellow, INFO=Bright Green, DEBUG=Bright Cyan, TRACE=Bright White} %style{[%t]}{bright,magenta} %style{%c{1.}.%M(%L)}{cyan}: %msg%n /logs
append="true">
```



