netty手写rpc
详见:https://github.com/yury757/nettystudy
一、服务端模块
1、server服务器
package net.yury.netty.Test10Rpc.server;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import lombok.extern.slf4j.Slf4j;
import net.yury.netty.Test10Rpc.service.RpcRegisterProcessor;
@Slf4j
public class RpcServer {
public static void main(String[] args) throws InterruptedException {
EventLoopGroup bossGroup = new NioEventLoopGroup();
EventLoopGroup workerGroup = new NioEventLoopGroup();
// 初始化业务处理接口
RpcRegisterProcessor.init();
try {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class)
.option(ChannelOption.RCVBUF_ALLOCATOR, new AdaptiveRecvByteBufAllocator(128, 1024, 4096))
.option(ChannelOption.SO_REUSEADDR, true)
.childOption(ChannelOption.SO_KEEPALIVE, true)
.childHandler(new RpcServerHandlerInitializer());
ChannelFuture f = b.bind(8080).sync();
f.channel().closeFuture().sync();
} finally {
workerGroup.shutdownGracefully();
bossGroup.shutdownGracefully();
}
}
}
2、RpcServerHandlerInitializer
package net.yury.netty.Test10Rpc.server;
import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import net.yury.netty.Test10Rpc.RpcCodec;
public class RpcServerHandlerInitializer extends ChannelInitializer<Channel> {
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline()
// 日志
.addLast(new LoggingHandler(LogLevel.INFO))
// 处理粘包、半包
.addLast(new LengthFieldBasedFrameDecoder(2048, 0, 4))
// 编解码
.addLast(new RpcCodec())
// 业务逻辑处理
.addLast(new RpcServerHandler());
}
}
3、RpcServerHandler
package net.yury.netty.Test10Rpc.server;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import lombok.extern.slf4j.Slf4j;
import net.yury.netty.Test10Rpc.RpcRequestMessage;
import net.yury.netty.Test10Rpc.RpcResponseMessage;
import net.yury.netty.Test10Rpc.service.RpcRegisterProcessor;
import java.lang.reflect.Method;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
public class RpcServerHandler extends SimpleChannelInboundHandler<RpcRequestMessage> {
private static final AtomicInteger COUNT = new AtomicInteger(0);
@Override
protected void channelRead0(ChannelHandlerContext ctx, RpcRequestMessage msg) {
log.debug("receive msg is: {}", msg);
String interfaceName = msg.getInterfaceName();
Object result = null;
RpcResponseMessage resp = new RpcResponseMessage();
try {
Class<?> clazz = Class.forName(interfaceName);
Method method = clazz.getMethod(msg.getMethodName(), msg.getParamsTypes());
result = method.invoke(RpcRegisterProcessor.RPC_SERVICE.get(clazz).get(0), msg.getParamsValues());
resp.setSequenceID(msg.getSequenceID());
resp.setReturnValue(result);
resp.setCause(null);
}catch (Throwable cause) {
// Throwable[] suppressed = cause.getSuppressed();
// for (Throwable throwable : suppressed) {
//
// }
// while ()
Throwable realCause = cause.getCause();
realCause.printStackTrace();
// 返回一个新的exception,避免原始的exception太长导致网络IO异常
Exception exception = new Exception("rpc call error: " + realCause.getMessage());
// 只保留第一个trace
exception.setStackTrace(new StackTraceElement[] {realCause.getStackTrace()[0]});
resp.setCause(exception);
}
// 处理业务逻辑
ctx.channel().writeAndFlush(resp);
}
}
二、客户端模块
1、客户端及测试代码
package net.yury.netty.Test10Rpc.client;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.concurrent.DefaultPromise;
import io.netty.util.concurrent.Promise;
import lombok.extern.slf4j.Slf4j;
import net.yury.netty.Test10Rpc.ComplicateClass;
import net.yury.netty.Test10Rpc.RpcRequestMessage;
import net.yury.netty.Test10Rpc.service.HelloService;
import java.lang.reflect.Proxy;
import java.net.InetSocketAddress;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;
@Slf4j
public class RpcClient {
private static Channel channel = null;
private static final AtomicLong REQUEST_ID = new AtomicLong(0);
private static final ObjectMapper MAPPER = new ObjectMapper();
public RpcClient() throws InterruptedException {
Bootstrap bs = new Bootstrap();
NioEventLoopGroup workerGroup = new NioEventLoopGroup();
channel = bs.group(workerGroup)
.channel(NioSocketChannel.class)
.handler(new RpcClientHandlerInitializer())
.connect(new InetSocketAddress("localhost", 8080))
.sync()
.channel();
channel.closeFuture().addListener(future -> workerGroup.shutdownGracefully());
}
public static void main(String[] args) throws InterruptedException {
RpcClient client = new RpcClient();
HelloService proxyService = client.getProxyService(HelloService.class);
try {
String res1 = proxyService.sayHello("yury757-1");
String res2 = proxyService.sayHello("yury757-2");
String res3 = proxyService.sayHello("yury757-3");
String res4 = proxyService.sayHello("yury757-4");
String res5 = proxyService.sayHello("yury757-5");
ComplicateClass testComplicateReturnType = proxyService.testComplicateReturnType();
System.out.println(res1);
System.out.println(res2);
System.out.println(res3);
System.out.println(res4);
System.out.println(res5);
System.out.println(testComplicateReturnType);
System.out.println(proxyService.addNewMethod());
proxyService.testError();
}finally {
client.close();
}
}
public <T> T getProxyService(Class<T> clazz) {
long id = REQUEST_ID.getAndIncrement();
@SuppressWarnings("unchecked")
T t = (T) Proxy.newProxyInstance(clazz.getClassLoader(), new Class[]{clazz}, (proxy, method, args) -> {
RpcRequestMessage request = new RpcRequestMessage(
id,
clazz.getName(),
method.getName(),
method.getReturnType(),
method.getParameterTypes(),
args);
Promise<Object> promise = send(request);
promise.await();
Object res;
if (promise.isSuccess()) {
Object o = promise.getNow();
Class<?> returnType = method.getReturnType();
if (o instanceof String) {
if (String.class.equals(returnType)) {
return o;
}else {
res = MAPPER.readValue((String)o, returnType);
}
}else if (o instanceof JsonNode) {
JsonNode node = (JsonNode) o;
res = MAPPER.treeToValue(node, returnType);
}else if (o instanceof Map) {
@SuppressWarnings("unchecked")
Map<String, Object> map = (Map<String, Object>)o;
ObjectNode node = MAPPER.valueToTree(map);
res = MAPPER.treeToValue(node, returnType);
}else {
throw new RuntimeException("unrecognized type " + o.getClass());
}
return res;
}else {
throw new RuntimeException(promise.cause());
}
});
return t;
}
public Promise<Object> send(RpcRequestMessage request) {
try {
channel.writeAndFlush(request).addListener(future -> {
if (!future.isSuccess()) {
throw new RuntimeException(future.cause());
}
}).sync();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
Promise<Object> promise = new DefaultPromise<>(channel.eventLoop());
RpcClientHandlerInitializer.PROMISE_MAP.put(request.getSequenceID(), promise);
return promise;
}
public void close() {
try {
channel.close().sync();
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}
2、RpcClientHandlerInitializer
package net.yury.netty.Test10Rpc.client;
import io.netty.channel.*;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.concurrent.Promise;
import lombok.extern.slf4j.Slf4j;
import net.yury.netty.Test10Rpc.RpcCodec;
import net.yury.netty.Test10Rpc.RpcResponseMessage;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
@Slf4j
public class RpcClientHandlerInitializer extends ChannelInitializer<Channel> {
public static final Map<Long, Promise<Object>> PROMISE_MAP = new ConcurrentHashMap<>();
@Override
protected void initChannel(Channel ch) throws Exception {
ch.pipeline()
// 日志
// .addLast(new LoggingHandler(LogLevel.INFO))
// 处理粘包、半包
.addLast(new LengthFieldBasedFrameDecoder(2048, 0, 4))
// 编解码
.addLast(new RpcCodec(false))
// 业务逻辑处理
.addLast(new SimpleChannelInboundHandler<RpcResponseMessage>() {
@Override
protected void channelRead0(ChannelHandlerContext ctx, RpcResponseMessage msg) throws Exception {
long id = msg.getSequenceID();
Promise<Object> promise = PROMISE_MAP.getOrDefault(id, null);
if (promise == null) {
throw new RuntimeException("promise is null, please check your program");
}
// 一定别忘了remove
PROMISE_MAP.remove(id);
if (msg.getCause() != null) {
promise.setFailure(msg.getCause());
}else {
promise.setSuccess(msg.getReturnValue());
}
}
});
}
}
三、通用类
1、编解码
package net.yury.netty.Test10Rpc;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageCodec;
import java.nio.charset.StandardCharsets;
import java.util.List;
/**
* @author yury757
* 以Message为消息体的编解码
*/
public class RpcCodec extends ByteToMessageCodec<Message> {
private boolean isServer;
private ObjectMapper mapper = new ObjectMapper();
public RpcCodec() {
this(true);
}
public RpcCodec(boolean isServer) {
super();
this.isServer = isServer;
}
@Override
protected void encode(ChannelHandlerContext ctx, Message msg, ByteBuf out) throws Exception {
byte[] bytes = mapper.writeValueAsString(msg).getBytes(StandardCharsets.UTF_8);
out.writeInt(bytes.length);
out.writeBytes(bytes);
}
@Override
protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
int length = in.readInt();
CharSequence cs = in.readCharSequence(length, StandardCharsets.UTF_8);
// 如果是服务端则解析成请求消息,如果是客户端则解析成响应消息
Class<?> clazz = isServer? RpcRequestMessage.class: RpcResponseMessage.class;
Object message = mapper.readValue(cs.toString(), clazz);
out.add(message);
}
}
2、消息接口 Message
package net.yury.netty.Test10Rpc;
public interface Message {
}
3、请求消息类 RpcRequestMessage
package net.yury.netty.Test10Rpc;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@NoArgsConstructor
@AllArgsConstructor
@Data
public class RpcRequestMessage implements Message {
private long sequenceID;
private String interfaceName;
private String methodName;
private Class<?> returnType;
private Class[] paramsTypes;
private Object[] paramsValues;
}
4、响应消息类 RpcResponseMessage
package net.yury.netty.Test10Rpc;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@NoArgsConstructor
@AllArgsConstructor
@Data
public class RpcResponseMessage implements Message {
private long sequenceID;
private Object returnValue;
private Throwable cause;
}
5、自定义实体类
package net.yury.netty.Test10Rpc;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable;
import java.util.Map;
@AllArgsConstructor
@NoArgsConstructor
@Data
public class ComplicateClass implements Serializable {
@AllArgsConstructor
@NoArgsConstructor
@Data
public static class TestInnerClass {
private Class<?> clazz = TestInnerClass.class;
}
private int id;
private String[] array;
private TestInnerClass obj;
private Map<Long, Long> map;
}
6、类处理工具类
package net.yury;
import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.URL;
import java.net.URLDecoder;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
public class ClassUtils {
/**
* 取得某个包下实现了某个接口所有类
* @param packageName
* @param c
* @param <T>
* @return
*/
public static <T> List<Class<? extends T>> getAllClassByInterface(String packageName, Class<T> c) {
if (!c.isInterface()) {
throw new RuntimeException("clazz c is not a interface");
}
List<Class<? extends T>> res = new ArrayList<>();
// 获取当前包下以及子包下所以的类
getClasses(packageName, ((pkgName, classSimpleName, clazz) -> {
if (c.isAssignableFrom(clazz) && (!c.equals(clazz))) {
@SuppressWarnings("unchecked")
Class<T> var1 = (Class<T>) clazz;
res.add(var1);
}
}));
return res;
}
public static interface ClassVisitor {
/**
* @param packageName 包名
* @param classSimpleName 类名
* @param clazz 类
*/
public void visit(String packageName, String classSimpleName, Class<?> clazz);
}
public static void getClasses(String packageName, ClassVisitor visitor) {
// 获取包的名字 并进行替换
String packageDirName = packageName.replace('.', '/');
// 定义一个枚举的集合 并进行循环来处理这个目录下的things
Enumeration<URL> dirs = null;
try {
dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
} catch (IOException e) {
throw new RuntimeException(e);
}
// 循环迭代下去
while (dirs.hasMoreElements()) {
// 获取下一个元素
URL url = dirs.nextElement();
// 得到协议的名称
String protocol = url.getProtocol();
// 如果是以文件的形式保存在服务器上
if ("file".equals(protocol)) {
// 获取包的物理路径
String filePath = null;
try {
filePath = URLDecoder.decode(url.getFile(), "UTF-8");
} catch (UnsupportedEncodingException e) {
throw new RuntimeException(e);
}
// 以文件的方式扫描整个包下的文件 并添加到集合中
findAndAddClassesInPackageByFile(packageName, filePath, true, visitor);
}
}
}
public static void findAndAddClassesInPackageByFile(String packageName, String packagePath, final boolean recursive, ClassVisitor visitor){
// 获取此包的目录 建立一个File
File dir = new File(packagePath);
// 如果不存在或者 也不是目录就直接返回
if (!dir.exists() || !dir.isDirectory()) {
return;
}
// 如果存在 就获取包下的所有文件 包括目录
File[] dirFiles = dir.listFiles(new FileFilter() {
// 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
public boolean accept(File file) {
return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
}
});
// 循环所有文件
for (File file : dirFiles) {
// 如果是目录 则继续扫描
if (file.isDirectory()) {
findAndAddClassesInPackageByFile(packageName + "." + file.getName(),
file.getAbsolutePath(),
recursive,
visitor);
} else {
// 如果是java类文件 去掉后面的.class 只留下类名
String className = file.getName().substring(0, file.getName().length() - 6);
try {
//添加到集合中去
visitor.visit(packageName, className, Class.forName(packageName + '.' + className));
} catch (ClassNotFoundException e) {
e.printStackTrace();
}
}
}
}
}
四、业务服务类及业务注册架构
1、业务类注册接口
package net.yury.netty.Test10Rpc.service;
public interface RpcRegister {
public Object getInstance();
}
2、服务注册初始化类
package net.yury.netty.Test10Rpc.service;
import lombok.extern.slf4j.Slf4j;
import net.yury.ClassUtils;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
@Slf4j
public class RpcRegisterProcessor {
/**
* 接口 => 对象
*/
public static final ConcurrentHashMap<Class<?>, CopyOnWriteArrayList<Object>> RPC_SERVICE = new ConcurrentHashMap<>();
/**
* 往 {@link RpcRegisterProcessor#RPC_SERVICE } 注册一个类及其对象,会将这个类的所有实现接口抽取出来,并进行注册
* @param clazz
* @param obj
*/
public static void register(Class<?> clazz, Object obj) {
Class<?> objClass = obj.getClass();
// 验证obj是否是clazz的实例,或者子类实例
objClass.asSubclass(clazz);
Class<?>[] interfaces = clazz.getInterfaces();
for (Class<?> var1 : interfaces) {
if (RpcRegister.class.equals(var1)) {
continue;
}
RPC_SERVICE.putIfAbsent(var1, new CopyOnWriteArrayList<>());
RPC_SERVICE.get(var1).add(obj);
log.debug("register success. interface: " + var1 + ", impl: " + clazz);
}
}
/**
* 初始化,将指定包下实现了 {@link RpcRegister} 接口的所有类初始化,并注册到 {@link RpcRegisterProcessor#RPC_SERVICE } 中
*/
public static void init() {
List<Class<? extends RpcRegister>> clazzList = ClassUtils.getAllClassByInterface("net.yury.netty.Test10Rpc.service", RpcRegister.class);
for (Class<? extends RpcRegister> clazz : clazzList) {
RpcRegister obj;
try {
obj = clazz.getDeclaredConstructor().newInstance();
} catch (Exception e) {
throw new RuntimeException(e);
}
register(clazz, obj);
}
}
}
3、业务接口
package net.yury.netty.Test10Rpc.service;
import net.yury.netty.Test10Rpc.ComplicateClass;
public interface HelloService {
/**
* 普通方法测试
* @param msg
* @return
*/
public String sayHello(String msg);
/**
* 返回值为复杂类型方法测试
* @return
*/
public ComplicateClass testComplicateReturnType();
/**
* 异常测试
*/
public void testError();
/**
* 新增一个方法后不需要对架构进行调整,只需要:
* 1、新增接口方法
* 2、实现类新增方法实现
* 3、重启客户端
* 4、客户端调用新方法,重启客户端
*/
public String addNewMethod();
}
4、业务类实现
package net.yury.netty.Test10Rpc.service;
import net.yury.netty.Test10Rpc.ComplicateClass;
import java.util.HashMap;
import java.util.concurrent.atomic.AtomicInteger;
public class HelloServiceImpl implements HelloService, RpcRegister {
public static final AtomicInteger COUNT = new AtomicInteger(0);
@Override
public String sayHello(String msg) {
return "hello world, " + msg + ", your count is " + COUNT.getAndIncrement();
}
@Override
public ComplicateClass testComplicateReturnType() {
return new ComplicateClass(1, new String[] {"123"}, new ComplicateClass.TestInnerClass(), new HashMap<>());
}
@Override
public void testError() {
int i = 1 / 0;
}
@Override
public String addNewMethod() {
return "123";
}
@Override
public Object getInstance() {
return new HelloServiceImpl();
}
}