netty手写rpc

yury757 / 2023-05-14 / 原文

详见:https://github.com/yury757/nettystudy

一、服务端模块

1、server服务器

package net.yury.netty.Test10Rpc.server;

import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import lombok.extern.slf4j.Slf4j;
import net.yury.netty.Test10Rpc.service.RpcRegisterProcessor;

@Slf4j
public class RpcServer {
    public static void main(String[] args) throws InterruptedException {
        EventLoopGroup bossGroup = new NioEventLoopGroup();
        EventLoopGroup workerGroup = new NioEventLoopGroup();
        // 初始化业务处理接口
        RpcRegisterProcessor.init();

        try {
            ServerBootstrap b = new ServerBootstrap();
            b.group(bossGroup, workerGroup)
                    .channel(NioServerSocketChannel.class)
                    .option(ChannelOption.RCVBUF_ALLOCATOR, new AdaptiveRecvByteBufAllocator(128, 1024, 4096))
                    .option(ChannelOption.SO_REUSEADDR, true)
                    .childOption(ChannelOption.SO_KEEPALIVE, true)
                    .childHandler(new RpcServerHandlerInitializer());
            ChannelFuture f = b.bind(8080).sync();
            f.channel().closeFuture().sync();
        } finally {
            workerGroup.shutdownGracefully();
            bossGroup.shutdownGracefully();
        }
    }
}

2、RpcServerHandlerInitializer

package net.yury.netty.Test10Rpc.server;

import io.netty.channel.Channel;
import io.netty.channel.ChannelInitializer;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import net.yury.netty.Test10Rpc.RpcCodec;

public class RpcServerHandlerInitializer extends ChannelInitializer<Channel> {
    @Override
    protected void initChannel(Channel ch) throws Exception {
        ch.pipeline()
                // 日志
                .addLast(new LoggingHandler(LogLevel.INFO))
                // 处理粘包、半包
                .addLast(new LengthFieldBasedFrameDecoder(2048, 0, 4))
                // 编解码
                .addLast(new RpcCodec())
                // 业务逻辑处理
                .addLast(new RpcServerHandler());
    }
}

3、RpcServerHandler

package net.yury.netty.Test10Rpc.server;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import lombok.extern.slf4j.Slf4j;
import net.yury.netty.Test10Rpc.RpcRequestMessage;
import net.yury.netty.Test10Rpc.RpcResponseMessage;
import net.yury.netty.Test10Rpc.service.RpcRegisterProcessor;

import java.lang.reflect.Method;
import java.util.concurrent.atomic.AtomicInteger;

@Slf4j
public class RpcServerHandler extends SimpleChannelInboundHandler<RpcRequestMessage> {
    private static final AtomicInteger COUNT = new AtomicInteger(0);

    @Override
    protected void channelRead0(ChannelHandlerContext ctx, RpcRequestMessage msg) {
        log.debug("receive msg is: {}", msg);
        String interfaceName = msg.getInterfaceName();
        Object result = null;
        RpcResponseMessage resp = new RpcResponseMessage();
        try {
            Class<?> clazz = Class.forName(interfaceName);
            Method method = clazz.getMethod(msg.getMethodName(), msg.getParamsTypes());
            result = method.invoke(RpcRegisterProcessor.RPC_SERVICE.get(clazz).get(0), msg.getParamsValues());
            resp.setSequenceID(msg.getSequenceID());
            resp.setReturnValue(result);
            resp.setCause(null);
        }catch (Throwable cause) {
//            Throwable[] suppressed = cause.getSuppressed();
//            for (Throwable throwable : suppressed) {
//
//            }
//            while ()
            Throwable realCause = cause.getCause();
            realCause.printStackTrace();

            // 返回一个新的exception,避免原始的exception太长导致网络IO异常
            Exception exception = new Exception("rpc call error: " + realCause.getMessage());
            // 只保留第一个trace
            exception.setStackTrace(new StackTraceElement[] {realCause.getStackTrace()[0]});
            resp.setCause(exception);
        }
        // 处理业务逻辑
        ctx.channel().writeAndFlush(resp);
    }
}

二、客户端模块

1、客户端及测试代码

package net.yury.netty.Test10Rpc.client;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.*;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.util.concurrent.DefaultPromise;
import io.netty.util.concurrent.Promise;
import lombok.extern.slf4j.Slf4j;
import net.yury.netty.Test10Rpc.ComplicateClass;
import net.yury.netty.Test10Rpc.RpcRequestMessage;
import net.yury.netty.Test10Rpc.service.HelloService;

import java.lang.reflect.Proxy;
import java.net.InetSocketAddress;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;

@Slf4j
public class RpcClient {
    private static Channel channel = null;
    private static final AtomicLong REQUEST_ID = new AtomicLong(0);
    private static final ObjectMapper MAPPER = new ObjectMapper();

    public RpcClient() throws InterruptedException {
        Bootstrap bs = new Bootstrap();
        NioEventLoopGroup workerGroup = new NioEventLoopGroup();
        channel = bs.group(workerGroup)
                .channel(NioSocketChannel.class)
                .handler(new RpcClientHandlerInitializer())
                .connect(new InetSocketAddress("localhost", 8080))
                .sync()
                .channel();
        channel.closeFuture().addListener(future -> workerGroup.shutdownGracefully());
    }

    public static void main(String[] args) throws InterruptedException {
        RpcClient client = new RpcClient();
        HelloService proxyService = client.getProxyService(HelloService.class);
        try {
            String res1 = proxyService.sayHello("yury757-1");
            String res2 = proxyService.sayHello("yury757-2");
            String res3 = proxyService.sayHello("yury757-3");
            String res4 = proxyService.sayHello("yury757-4");
            String res5 = proxyService.sayHello("yury757-5");
            ComplicateClass testComplicateReturnType = proxyService.testComplicateReturnType();
            System.out.println(res1);
            System.out.println(res2);
            System.out.println(res3);
            System.out.println(res4);
            System.out.println(res5);
            System.out.println(testComplicateReturnType);
            System.out.println(proxyService.addNewMethod());
            proxyService.testError();
        }finally {
            client.close();
        }
    }

    public <T> T getProxyService(Class<T> clazz) {
        long id = REQUEST_ID.getAndIncrement();
        @SuppressWarnings("unchecked")
        T t = (T) Proxy.newProxyInstance(clazz.getClassLoader(), new Class[]{clazz}, (proxy, method, args) -> {
            RpcRequestMessage request = new RpcRequestMessage(
                    id,
                    clazz.getName(),
                    method.getName(),
                    method.getReturnType(),
                    method.getParameterTypes(),
                    args);
            Promise<Object> promise = send(request);
            promise.await();
            Object res;
            if (promise.isSuccess()) {
                Object o = promise.getNow();
                Class<?> returnType = method.getReturnType();
                if (o instanceof String) {
                    if (String.class.equals(returnType)) {
                        return o;
                    }else {
                        res = MAPPER.readValue((String)o, returnType);
                    }
                }else if (o instanceof JsonNode) {
                    JsonNode node = (JsonNode) o;
                    res = MAPPER.treeToValue(node, returnType);
                }else if (o instanceof Map) {
                    @SuppressWarnings("unchecked")
                    Map<String, Object> map = (Map<String, Object>)o;
                    ObjectNode node = MAPPER.valueToTree(map);
                    res = MAPPER.treeToValue(node, returnType);
                }else {
                    throw new RuntimeException("unrecognized type " + o.getClass());
                }
                return res;
            }else {
                throw new RuntimeException(promise.cause());
            }
        });
        return t;
    }

    public Promise<Object> send(RpcRequestMessage request) {
        try {
            channel.writeAndFlush(request).addListener(future -> {
                if (!future.isSuccess()) {
                    throw new RuntimeException(future.cause());
                }
            }).sync();
        } catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        Promise<Object> promise = new DefaultPromise<>(channel.eventLoop());
        RpcClientHandlerInitializer.PROMISE_MAP.put(request.getSequenceID(), promise);
        return promise;
    }

    public void close() {
        try {
            channel.close().sync();
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }
}

2、RpcClientHandlerInitializer

package net.yury.netty.Test10Rpc.client;

import io.netty.channel.*;
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.util.concurrent.Promise;
import lombok.extern.slf4j.Slf4j;
import net.yury.netty.Test10Rpc.RpcCodec;
import net.yury.netty.Test10Rpc.RpcResponseMessage;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

@Slf4j
public class RpcClientHandlerInitializer extends ChannelInitializer<Channel> {
    public static final Map<Long, Promise<Object>> PROMISE_MAP = new ConcurrentHashMap<>();

    @Override
    protected void initChannel(Channel ch) throws Exception {
        ch.pipeline()
                // 日志
//                .addLast(new LoggingHandler(LogLevel.INFO))
                // 处理粘包、半包
                .addLast(new LengthFieldBasedFrameDecoder(2048, 0, 4))
                // 编解码
                .addLast(new RpcCodec(false))
                // 业务逻辑处理
                .addLast(new SimpleChannelInboundHandler<RpcResponseMessage>() {
                    @Override
                    protected void channelRead0(ChannelHandlerContext ctx, RpcResponseMessage msg) throws Exception {
                        long id = msg.getSequenceID();
                        Promise<Object> promise = PROMISE_MAP.getOrDefault(id, null);
                        if (promise == null) {
                            throw new RuntimeException("promise is null, please check your program");
                        }
                        // 一定别忘了remove
                        PROMISE_MAP.remove(id);
                        if (msg.getCause() != null) {
                            promise.setFailure(msg.getCause());
                        }else {
                            promise.setSuccess(msg.getReturnValue());
                        }
                    }
                });
    }
}

三、通用类

1、编解码

package net.yury.netty.Test10Rpc;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageCodec;

import java.nio.charset.StandardCharsets;
import java.util.List;

/**
 * @author yury757
 * 以Message为消息体的编解码
 */
public class RpcCodec extends ByteToMessageCodec<Message> {
    private boolean isServer;
    private ObjectMapper mapper = new ObjectMapper();

    public RpcCodec() {
        this(true);
    }

    public RpcCodec(boolean isServer) {
        super();
        this.isServer = isServer;
    }

    @Override
    protected void encode(ChannelHandlerContext ctx, Message msg, ByteBuf out) throws Exception {
        byte[] bytes = mapper.writeValueAsString(msg).getBytes(StandardCharsets.UTF_8);
        out.writeInt(bytes.length);
        out.writeBytes(bytes);
    }

    @Override
    protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) throws Exception {
        int length = in.readInt();
        CharSequence cs = in.readCharSequence(length, StandardCharsets.UTF_8);
        // 如果是服务端则解析成请求消息,如果是客户端则解析成响应消息
        Class<?> clazz = isServer? RpcRequestMessage.class: RpcResponseMessage.class;
        Object message = mapper.readValue(cs.toString(), clazz);
        out.add(message);
    }
}

2、消息接口 Message

package net.yury.netty.Test10Rpc;

public interface Message {
}

3、请求消息类 RpcRequestMessage

package net.yury.netty.Test10Rpc;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

@NoArgsConstructor
@AllArgsConstructor
@Data
public class RpcRequestMessage implements Message {
    private long sequenceID;
    private String interfaceName;
    private String methodName;
    private Class<?> returnType;
    private Class[] paramsTypes;
    private Object[] paramsValues;
}

4、响应消息类 RpcResponseMessage

package net.yury.netty.Test10Rpc;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

@NoArgsConstructor
@AllArgsConstructor
@Data
public class RpcResponseMessage implements Message {
    private long sequenceID;
    private Object returnValue;
    private Throwable cause;
}

5、自定义实体类

package net.yury.netty.Test10Rpc;

import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.io.Serializable;
import java.util.Map;

@AllArgsConstructor
@NoArgsConstructor
@Data
public class ComplicateClass implements Serializable {
    @AllArgsConstructor
    @NoArgsConstructor
    @Data
    public static class TestInnerClass {
        private Class<?> clazz = TestInnerClass.class;
    }

    private int id;
    private String[] array;
    private TestInnerClass obj;
    private Map<Long, Long> map;
}

6、类处理工具类

package net.yury;

import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.net.URL;
import java.net.URLDecoder;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;

public class ClassUtils {

    /**
     * 取得某个包下实现了某个接口所有类
     * @param packageName
     * @param c
     * @param <T>
     * @return
     */
    public static <T> List<Class<? extends T>> getAllClassByInterface(String packageName, Class<T> c) {
        if (!c.isInterface()) {
            throw new RuntimeException("clazz c is not a interface");
        }
        List<Class<? extends T>> res = new ArrayList<>();
        // 获取当前包下以及子包下所以的类
        getClasses(packageName, ((pkgName, classSimpleName, clazz) -> {
            if (c.isAssignableFrom(clazz) && (!c.equals(clazz))) {
                @SuppressWarnings("unchecked")
                Class<T> var1 = (Class<T>) clazz;
                res.add(var1);
            }
        }));
        return res;
    }

    public static interface ClassVisitor {
        /**
         * @param packageName 包名
         * @param classSimpleName 类名
         * @param clazz 类
         */
        public void visit(String packageName, String classSimpleName, Class<?> clazz);
    }

    public static void getClasses(String packageName, ClassVisitor visitor) {
        // 获取包的名字 并进行替换
        String packageDirName = packageName.replace('.', '/');
        // 定义一个枚举的集合 并进行循环来处理这个目录下的things
        Enumeration<URL> dirs = null;
        try {
            dirs = Thread.currentThread().getContextClassLoader().getResources(packageDirName);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
        // 循环迭代下去
        while (dirs.hasMoreElements()) {
            // 获取下一个元素
            URL url = dirs.nextElement();
            // 得到协议的名称
            String protocol = url.getProtocol();
            // 如果是以文件的形式保存在服务器上
            if ("file".equals(protocol)) {
                // 获取包的物理路径
                String filePath = null;
                try {
                    filePath = URLDecoder.decode(url.getFile(), "UTF-8");
                } catch (UnsupportedEncodingException e) {
                    throw new RuntimeException(e);
                }
                // 以文件的方式扫描整个包下的文件 并添加到集合中
                findAndAddClassesInPackageByFile(packageName, filePath, true, visitor);
            }
        }
    }

    public static void findAndAddClassesInPackageByFile(String packageName, String packagePath, final boolean recursive, ClassVisitor visitor){
        // 获取此包的目录 建立一个File
        File dir = new File(packagePath);
        // 如果不存在或者 也不是目录就直接返回
        if (!dir.exists() || !dir.isDirectory()) {
            return;
        }
        // 如果存在 就获取包下的所有文件 包括目录
        File[] dirFiles = dir.listFiles(new FileFilter() {
            // 自定义过滤规则 如果可以循环(包含子目录) 或则是以.class结尾的文件(编译好的java类文件)
            public boolean accept(File file) {
                return (recursive && file.isDirectory()) || (file.getName().endsWith(".class"));
            }
        });
        // 循环所有文件
        for (File file : dirFiles) {
            // 如果是目录 则继续扫描
            if (file.isDirectory()) {
                findAndAddClassesInPackageByFile(packageName + "." + file.getName(),
                        file.getAbsolutePath(),
                        recursive,
                        visitor);
            } else {
                // 如果是java类文件 去掉后面的.class 只留下类名
                String className = file.getName().substring(0, file.getName().length() - 6);
                try {
                    //添加到集合中去
                    visitor.visit(packageName, className, Class.forName(packageName + '.' + className));
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }
            }
        }
    }
}

四、业务服务类及业务注册架构

1、业务类注册接口

package net.yury.netty.Test10Rpc.service;

public interface RpcRegister {
    public Object getInstance();
}

2、服务注册初始化类

package net.yury.netty.Test10Rpc.service;

import lombok.extern.slf4j.Slf4j;
import net.yury.ClassUtils;

import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;

@Slf4j
public class RpcRegisterProcessor {
    /**
     * 接口 => 对象
     */
    public static final ConcurrentHashMap<Class<?>, CopyOnWriteArrayList<Object>> RPC_SERVICE = new ConcurrentHashMap<>();

    /**
     * 往 {@link RpcRegisterProcessor#RPC_SERVICE } 注册一个类及其对象,会将这个类的所有实现接口抽取出来,并进行注册
     * @param clazz
     * @param obj
     */
    public static void register(Class<?> clazz, Object obj) {
        Class<?> objClass = obj.getClass();
        // 验证obj是否是clazz的实例,或者子类实例
        objClass.asSubclass(clazz);
        Class<?>[] interfaces = clazz.getInterfaces();
        for (Class<?> var1 : interfaces) {
            if (RpcRegister.class.equals(var1)) {
                continue;
            }
            RPC_SERVICE.putIfAbsent(var1, new CopyOnWriteArrayList<>());
            RPC_SERVICE.get(var1).add(obj);
            log.debug("register success. interface: " + var1 + ", impl: " + clazz);
        }
    }

    /**
     * 初始化,将指定包下实现了 {@link RpcRegister} 接口的所有类初始化,并注册到 {@link RpcRegisterProcessor#RPC_SERVICE } 中
     */
    public static void init() {
        List<Class<? extends RpcRegister>> clazzList = ClassUtils.getAllClassByInterface("net.yury.netty.Test10Rpc.service", RpcRegister.class);
        for (Class<? extends RpcRegister> clazz : clazzList) {
            RpcRegister obj;
            try {
                obj = clazz.getDeclaredConstructor().newInstance();
            } catch (Exception e) {
                throw new RuntimeException(e);
            }
            register(clazz, obj);
        }
    }
}

3、业务接口

package net.yury.netty.Test10Rpc.service;

import net.yury.netty.Test10Rpc.ComplicateClass;

public interface HelloService {
    /**
     * 普通方法测试
     * @param msg
     * @return
     */
    public String sayHello(String msg);

    /**
     * 返回值为复杂类型方法测试
     * @return
     */
    public ComplicateClass testComplicateReturnType();

    /**
     * 异常测试
     */
    public void testError();

    /**
     * 新增一个方法后不需要对架构进行调整,只需要:
     * 1、新增接口方法
     * 2、实现类新增方法实现
     * 3、重启客户端
     * 4、客户端调用新方法,重启客户端
     */
    public String addNewMethod();
}

4、业务类实现

package net.yury.netty.Test10Rpc.service;

import net.yury.netty.Test10Rpc.ComplicateClass;

import java.util.HashMap;
import java.util.concurrent.atomic.AtomicInteger;

public class HelloServiceImpl implements HelloService, RpcRegister {
    public static final AtomicInteger COUNT = new AtomicInteger(0);

    @Override
    public String sayHello(String msg) {
        return "hello world, " + msg + ", your count is " + COUNT.getAndIncrement();
    }

    @Override
    public ComplicateClass testComplicateReturnType() {
        return new ComplicateClass(1, new String[] {"123"}, new ComplicateClass.TestInnerClass(), new HashMap<>());
    }

    @Override
    public void testError() {
        int i = 1 / 0;
    }

    @Override
    public String addNewMethod() {
        return "123";
    }

    @Override
    public Object getInstance() {
        return new HelloServiceImpl();
    }
}