手写一个RPC框架

时间:2023-12-06 11:10:26

一、前言

前段时间看到一篇不错的文章《看了这篇你就会手写RPC框架了》,于是便来了兴趣对着实现了一遍,后面觉得还有很多优化的地方便对其进行了改进。

主要改动点如下:

  1. 除了Java序列化协议,增加了protobuf和kryo序列化协议,配置即用。
  2. 增加多种负载均衡算法(随机、轮询、加权轮询、平滑加权轮询),配置即用。
  3. 客户端增加本地服务列表缓存,提高性能。
  4. 修复高并发情况下,netty导致的内存泄漏问题
  5. 由原来的每个请求建立一次连接,改为建立TCP长连接,并多次复用。
  6. 服务端增加线程池提高消息处理能力

二、介绍

RPC,即 Remote Procedure Call(远程过程调用),调用远程计算机上的服务,就像调用本地服务一样。RPC可以很好的解耦系统,如WebService就是一种基于Http协议的RPC。

手写一个RPC框架
调用示意图

总的来说,就如下几个步骤:

  1. 客户端(ServerA)执行远程方法时就调用client stub传递类名、方法名和参数等信息。
  2. client stub会将参数等信息序列化为二进制流的形式,然后通过Sockect发送给服务端(ServerB)
  3. 服务端收到数据包后,server stub 需要进行解析反序列化为类名、方法名和参数等信息。
  4. server stub调用对应的本地方法,并把执行结果返回给客户端

所以一个RPC框架有如下角色:

服务消费者

远程方法的调用方,即客户端。一个服务既可以是消费者也可以是提供者。

服务提供者

远程服务的提供方,即服务端。一个服务既可以是消费者也可以是提供者。

注册中心

保存服务提供者的服务地址等信息,一般由zookeeper、redis等实现。

监控运维(可选)

监控接口的响应时间、统计请求数量等,及时发现系统问题并发出告警通知。

三、实现

本RPC框架rpc-spring-boot-starter涉及技术栈如下:

  • 使用zookeeper作为注册中心
  • 使用netty作为通信框架
  • 消息编解码:protostuff、kryo、java
  • spring
  • 使用SPI来根据配置动态选择负载均衡算法等

由于代码过多,这里只讲几处改动点。

3.1动态负载均衡算法

1.编写LoadBalance的实现类

手写一个RPC框架
负载均衡算法实现类

2.自定义注解 @LoadBalanceAno

  1. /** 


  2. * 负载均衡注解 


  3. */ 


  4. @Target(ElementType.TYPE) 


  5. @Retention(RetentionPolicy.RUNTIME) 


  6. @Documented 


  7. public @interface LoadBalanceAno { 



  8. String value() default ""; 






  9. /** 


  10. * 轮询算法 


  11. */ 


  12. @LoadBalanceAno(RpcConstant.BALANCE_ROUND) 


  13. public class FullRoundBalance implements LoadBalance { 



  14. private static Logger logger = LoggerFactory.getLogger(FullRoundBalance.class); 



  15. private volatile int index; 



  16. @Override 


  17. public synchronized Service chooseOne(List<Service> services) { 


  18. // 加锁防止多线程情况下,index超出services.size() 


  19. if (index == services.size()) { 


  20. index = 0; 





  21. return services.get(index++); 








3.新建在resource目录下META-INF/servers文件夹并创建文件

手写一个RPC框架
enter description here

4.RpcConfig增加配置项loadBalance

  1. /** 


  2. * @author 2YSP 


  3. * @date 2020/7/26 15:13 


  4. */ 


  5. @ConfigurationProperties(prefix = "sp.rpc") 


  6. public class RpcConfig { 



  7. /** 


  8. * 服务注册中心地址 


  9. */ 


  10. private String registerAddress = "127.0.0.1:2181"; 



  11. /** 


  12. * 服务暴露端口 


  13. */ 


  14. private Integer serverPort = 9999; 


  15. /** 


  16. * 服务协议 


  17. */ 


  18. private String protocol = "java"; 


  19. /** 


  20. * 负载均衡算法 


  21. */ 


  22. private String loadBalance = "random"; 


  23. /** 


  24. * 权重,默认为1 


  25. */ 


  26. private Integer weight = 1; 



  27. // 省略getter setter 





5.在自动配置类RpcAutoConfiguration根据配置选择对应的算法实现类

  1. /** 


  2. * 使用spi匹配符合配置的负载均衡算法 





  3. * @param name 


  4. * @return 


  5. */ 


  6. private LoadBalance getLoadBalance(String name) { 


  7. ServiceLoader<LoadBalance> loader = ServiceLoader.load(LoadBalance.class); 


  8. Iterator<LoadBalance> iterator = loader.iterator(); 


  9. while (iterator.hasNext()) { 


  10. LoadBalance loadBalance = iterator.next(); 


  11. LoadBalanceAno ano = loadBalance.getClass().getAnnotation(LoadBalanceAno.class); 


  12. Assert.notNull(ano, "load balance name can not be empty!"); 


  13. if (name.equals(ano.value())) { 


  14. return loadBalance; 








  15. throw new RpcException("invalid load balance config"); 






  16. @Bean 


  17. public ClientProxyFactory proxyFactory(@Autowired RpcConfig rpcConfig) { 


  18. ClientProxyFactory clientProxyFactory = new ClientProxyFactory(); 


  19. // 设置服务发现着 


  20. clientProxyFactory.setServerDiscovery(new ZookeeperServerDiscovery(rpcConfig.getRegisterAddress())); 



  21. // 设置支持的协议 


  22. Map<String, MessageProtocol> supportMessageProtocols = buildSupportMessageProtocols(); 


  23. clientProxyFactory.setSupportMessageProtocols(supportMessageProtocols); 


  24. // 设置负载均衡算法 


  25. LoadBalance loadBalance = getLoadBalance(rpcConfig.getLoadBalance()); 


  26. clientProxyFactory.setLoadBalance(loadBalance); 


  27. // 设置网络层实现 


  28. clientProxyFactory.setNetClient(new NettyNetClient()); 



  29. return clientProxyFactory; 





3.2本地服务列表缓存

使用Map来缓存数据

  1. /** 


  2. * 服务发现本地缓存 


  3. */ 


  4. public class ServerDiscoveryCache { 


  5. /** 


  6. * key: serviceName 


  7. */ 


  8. private static final Map<String, List<Service>> SERVER_MAP = new ConcurrentHashMap<>(); 


  9. /** 


  10. * 客户端注入的远程服务service class 


  11. */ 


  12. public static final List<String> SERVICE_CLASS_NAMES = new ArrayList<>(); 



  13. public static void put(String serviceName, List<Service> serviceList) { 


  14. SERVER_MAP.put(serviceName, serviceList); 






  15. /** 


  16. * 去除指定的值 


  17. * @param serviceName 


  18. * @param service 


  19. */ 


  20. public static void remove(String serviceName, Service service) { 


  21. SERVER_MAP.computeIfPresent(serviceName, (key, value) -> 


  22. value.stream().filter(o -> !o.toString().equals(service.toString())).collect(Collectors.toList()) 


  23. ); 






  24. public static void removeAll(String serviceName) { 


  25. SERVER_MAP.remove(serviceName); 







  26. public static boolean isEmpty(String serviceName) { 


  27. return SERVER_MAP.get(serviceName) == null || SERVER_MAP.get(serviceName).size() == 0; 






  28. public static List<Service> get(String serviceName) { 


  29. return SERVER_MAP.get(serviceName); 








ClientProxyFactory,先查本地缓存,缓存没有再查询zookeeper。

  1. /** 


  2. * 根据服务名获取可用的服务地址列表 


  3. * @param serviceName 


  4. * @return 


  5. */ 


  6. private List<Service> getServiceList(String serviceName) { 


  7. List<Service> services; 


  8. synchronized (serviceName){ 


  9. if (ServerDiscoveryCache.isEmpty(serviceName)) { 


  10. services = serverDiscovery.findServiceList(serviceName); 


  11. if (services == null || services.size() == 0) { 


  12. throw new RpcException("No provider available!"); 





  13. ServerDiscoveryCache.put(serviceName, services); 


  14. } else { 


  15. services = ServerDiscoveryCache.get(serviceName); 








  16. return services; 





问题: 如果服务端因为宕机或网络问题下线了,缓存却还在就会导致客户端请求已经不可用的服务端,增加请求失败率。

解决方案:由于服务端注册的是临时节点,所以如果服务端下线节点会被移除。只要监听zookeeper的子节点,如果新增或删除子节点就直接清空本地缓存即可。

DefaultRpcProcessor

  1. /** 


  2. * Rpc处理者,支持服务启动暴露,自动注入Service 


  3. * @author 2YSP 


  4. * @date 2020/7/26 14:46 


  5. */ 


  6. public class DefaultRpcProcessor implements ApplicationListener<ContextRefreshedEvent> { 





  7. @Override 


  8. public void onApplicationEvent(ContextRefreshedEvent event) { 


  9. // Spring启动完毕过后会收到一个事件通知 


  10. if (Objects.isNull(event.getApplicationContext().getParent())){ 


  11. ApplicationContext context = event.getApplicationContext(); 


  12. // 开启服务 


  13. startServer(context); 


  14. // 注入Service 


  15. injectService(context); 









  16. private void injectService(ApplicationContext context) { 


  17. String[] names = context.getBeanDefinitionNames(); 


  18. for(String name : names){ 


  19. Class<?> clazz = context.getType(name); 


  20. if (Objects.isNull(clazz)){ 


  21. continue; 






  22. Field[] declaredFields = clazz.getDeclaredFields(); 


  23. for(Field field : declaredFields){ 


  24. // 找出标记了InjectService注解的属性 


  25. InjectService injectService = field.getAnnotation(InjectService.class); 


  26. if (injectService == null){ 


  27. continue; 






  28. Class<?> fieldClass = field.getType(); 


  29. Object object = context.getBean(name); 


  30. field.setAccessible(true); 


  31. try { 


  32. field.set(object,clientProxyFactory.getProxy(fieldClass)); 


  33. } catch (IllegalAccessException e) { 


  34. e.printStackTrace(); 





  35. // 添加本地服务缓存 


  36. ServerDiscoveryCache.SERVICE_CLASS_NAMES.add(fieldClass.getName()); 








  37. // 注册子节点监听 


  38. if (clientProxyFactory.getServerDiscovery() instanceof ZookeeperServerDiscovery){ 


  39. ZookeeperServerDiscovery serverDiscovery = (ZookeeperServerDiscovery) clientProxyFactory.getServerDiscovery(); 


  40. ZkClient zkClient = serverDiscovery.getZkClient(); 


  41. ServerDiscoveryCache.SERVICE_CLASS_NAMES.forEach(name ->{ 


  42. String servicePath = RpcConstant.ZK_SERVICE_PATH + RpcConstant.PATH_DELIMITER + name + "/service"; 


  43. zkClient.subscribeChildChanges(servicePath, new ZkChildListenerImpl()); 


  44. }); 


  45. logger.info("subscribe service zk node successfully"); 










  46. private void startServer(ApplicationContext context) { 


  47. ... 










ZkChildListenerImpl

  1. /** 


  2. * 子节点事件监听处理类 


  3. */ 


  4. public class ZkChildListenerImpl implements IZkChildListener { 



  5. private static Logger logger = LoggerFactory.getLogger(ZkChildListenerImpl.class); 



  6. /** 


  7. * 监听子节点的删除和新增事件 


  8. * @param parentPath /rpc/serviceName/service 


  9. * @param childList 


  10. * @throws Exception 


  11. */ 


  12. @Override 


  13. public void handleChildChange(String parentPath, List<String> childList) throws Exception { 


  14. logger.debug("Child change parentPath:[{}] -- childList:[{}]", parentPath, childList); 


  15. // 只要子节点有改动就清空缓存 


  16. String[] arr = parentPath.split("/"); 


  17. ServerDiscoveryCache.removeAll(arr[2]); 








3.3nettyClient支持TCP长连接

这部分的改动最多,先增加新的sendRequest接口。

手写一个RPC框架
添加接口

实现类NettyNetClient

  1. /** 


  2. * @author 2YSP 


  3. * @date 2020/7/25 20:12 


  4. */ 


  5. public class NettyNetClient implements NetClient { 



  6. private static Logger logger = LoggerFactory.getLogger(NettyNetClient.class); 



  7. private static ExecutorService threadPool = new ThreadPoolExecutor(4, 10, 200, 


  8. TimeUnit.SECONDS, new LinkedBlockingQueue<>(1000), new ThreadFactoryBuilder() 


  9. .setNameFormat("rpcClient-%d") 


  10. .build()); 



  11. private EventLoopGroup loopGroup = new NioEventLoopGroup(4); 



  12. /** 


  13. * 已连接的服务缓存 


  14. * key: 服务地址,格式:ip:port 


  15. */ 


  16. public static Map<String, SendHandlerV2> connectedServerNodes = new ConcurrentHashMap<>(); 



  17. @Override 


  18. public byte[] sendRequest(byte[] data, Service service) throws InterruptedException { 


  19. .... 


  20. return respData; 






  21. @Override 


  22. public RpcResponse sendRequest(RpcRequest rpcRequest, Service service, MessageProtocol messageProtocol) { 



  23. String address = service.getAddress(); 


  24. synchronized (address) { 


  25. if (connectedServerNodes.containsKey(address)) { 


  26. SendHandlerV2 handler = connectedServerNodes.get(address); 


  27. logger.info("使用现有的连接"); 


  28. return handler.sendRequest(rpcRequest); 






  29. String[] addrInfo = address.split(":"); 


  30. final String serverAddress = addrInfo[0]; 


  31. final String serverPort = addrInfo[1]; 


  32. final SendHandlerV2 handler = new SendHandlerV2(messageProtocol, address); 


  33. threadPool.submit(() -> { 


  34. // 配置客户端 


  35. Bootstrap b = new Bootstrap(); 


  36. b.group(loopGroup).channel(NioSocketChannel.class) 


  37. .option(ChannelOption.TCP_NODELAY, true) 


  38. .handler(new ChannelInitializer<SocketChannel>() { 


  39. @Override 


  40. protected void initChannel(SocketChannel socketChannel) throws Exception { 


  41. ChannelPipeline pipeline = socketChannel.pipeline(); 


  42. pipeline 


  43. .addLast(handler); 





  44. }); 


  45. // 启用客户端连接 


  46. ChannelFuture channelFuture = b.connect(serverAddress, Integer.parseInt(serverPort)); 


  47. channelFuture.addListener(new ChannelFutureListener() { 


  48. @Override 


  49. public void operationComplete(ChannelFuture channelFuture) throws Exception { 


  50. connectedServerNodes.put(address, handler); 





  51. }); 





  52. ); 


  53. logger.info("使用新的连接。。。"); 


  54. return handler.sendRequest(rpcRequest); 












每次请求都会调用sendRequest()方法,用线程池异步和服务端创建TCP长连接,连接成功后将SendHandlerV2缓存到ConcurrentHashMap中方便复用,后续请求的请求地址(ip+port)如果在connectedServerNodes中存在则使用connectedServerNodes中的handler处理不再重新建立连接。

SendHandlerV2

  1. /** 


  2. * @author 2YSP 


  3. * @date 2020/8/19 20:06 


  4. */ 


  5. public class SendHandlerV2 extends ChannelInboundHandlerAdapter { 



  6. private static Logger logger = LoggerFactory.getLogger(SendHandlerV2.class); 



  7. /** 


  8. * 等待通道建立最大时间 


  9. */ 


  10. static final int CHANNEL_WAIT_TIME = 4; 


  11. /** 


  12. * 等待响应最大时间 


  13. */ 


  14. static final int RESPONSE_WAIT_TIME = 8; 



  15. private volatile Channel channel; 



  16. private String remoteAddress; 



  17. private static Map<String, RpcFuture<RpcResponse>> requestMap = new ConcurrentHashMap<>(); 



  18. private MessageProtocol messageProtocol; 



  19. private CountDownLatch latch = new CountDownLatch(1); 



  20. public SendHandlerV2(MessageProtocol messageProtocol,String remoteAddress) { 


  21. this.messageProtocol = messageProtocol; 


  22. this.remoteAddress = remoteAddress; 






  23. @Override 


  24. public void channelRegistered(ChannelHandlerContext ctx) throws Exception { 


  25. this.channel = ctx.channel(); 


  26. latch.countDown(); 






  27. @Override 


  28. public void channelActive(ChannelHandlerContext ctx) throws Exception { 


  29. logger.debug("Connect to server successfully:{}", ctx); 






  30. @Override 


  31. public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { 


  32. logger.debug("Client reads message:{}", msg); 


  33. ByteBuf byteBuf = (ByteBuf) msg; 


  34. byte[] resp = new byte[byteBuf.readableBytes()]; 


  35. byteBuf.readBytes(resp); 


  36. // 手动回收 


  37. ReferenceCountUtil.release(byteBuf); 


  38. RpcResponse response = messageProtocol.unmarshallingResponse(resp); 


  39. RpcFuture<RpcResponse> future = requestMap.get(response.getRequestId()); 


  40. future.setResponse(response); 






  41. @Override 


  42. public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { 


  43. cause.printStackTrace(); 


  44. logger.error("Exception occurred:{}", cause.getMessage()); 


  45. ctx.close(); 






  46. @Override 


  47. public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { 


  48. ctx.flush(); 






  49. @Override 


  50. public void channelInactive(ChannelHandlerContext ctx) throws Exception { 


  51. super.channelInactive(ctx); 


  52. logger.error("channel inactive with remoteAddress:[{}]",remoteAddress); 


  53. NettyNetClient.connectedServerNodes.remove(remoteAddress); 







  54. @Override 


  55. public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { 


  56. super.userEventTriggered(ctx, evt); 






  57. public RpcResponse sendRequest(RpcRequest request) { 


  58. RpcResponse response; 


  59. RpcFuture<RpcResponse> future = new RpcFuture<>(); 


  60. requestMap.put(request.getRequestId(), future); 


  61. try { 


  62. byte[] data = messageProtocol.marshallingRequest(request); 


  63. ByteBuf reqBuf = Unpooled.buffer(data.length); 


  64. reqBuf.writeBytes(data); 


  65. if (latch.await(CHANNEL_WAIT_TIME,TimeUnit.SECONDS)){ 


  66. channel.writeAndFlush(reqBuf); 


  67. // 等待响应 


  68. response = future.get(RESPONSE_WAIT_TIME, TimeUnit.SECONDS); 


  69. }else { 


  70. throw new RpcException("establish channel time out"); 





  71. } catch (Exception e) { 


  72. throw new RpcException(e.getMessage()); 


  73. } finally { 


  74. requestMap.remove(request.getRequestId()); 





  75. return response; 









RpcFuture

  1. package cn.sp.rpc.client.net; 



  2. import java.util.concurrent.*; 



  3. /** 


  4. * @author 2YSP 


  5. * @date 2020/8/19 22:31 


  6. */ 


  7. public class RpcFuture<T> implements Future<T> { 



  8. private T response; 


  9. /** 


  10. * 因为请求和响应是一一对应的,所以这里是1 


  11. */ 


  12. private CountDownLatch countDownLatch = new CountDownLatch(1); 


  13. /** 


  14. * Future的请求时间,用于计算Future是否超时 


  15. */ 


  16. private long beginTime = System.currentTimeMillis(); 



  17. @Override 


  18. public boolean cancel(boolean mayInterruptIfRunning) { 


  19. return false; 






  20. @Override 


  21. public boolean isCancelled() { 


  22. return false; 






  23. @Override 


  24. public boolean isDone() { 


  25. if (response != null) { 


  26. return true; 





  27. return false; 






  28. /** 


  29. * 获取响应,直到有结果才返回 


  30. * @return 


  31. * @throws InterruptedException 


  32. * @throws ExecutionException 


  33. */ 


  34. @Override 


  35. public T get() throws InterruptedException, ExecutionException { 


  36. countDownLatch.await(); 


  37. return response; 






  38. @Override 


  39. public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { 


  40. if (countDownLatch.await(timeout,unit)){ 


  41. return response; 





  42. return null; 






  43. public void setResponse(T response) { 


  44. this.response = response; 


  45. countDownLatch.countDown(); 






  46. public long getBeginTime() { 


  47. return beginTime; 









此处逻辑,第一次执行 SendHandlerV2#sendRequest() 时channel需要等待通道建立好之后才能发送请求,所以用CountDownLatch来控制,等待通道建立。

自定义Future+requestMap缓存来实现netty的请求和阻塞等待响应,RpcRequest对象在创建时会生成一个请求的唯一标识requestId,发送请求前先将RpcFuture缓存到requestMap中,key为requestId,读取到服务端的响应信息后(channelRead方法),将响应结果放入对应的RpcFuture中。

SendHandlerV2#channelInactive() 方法中,如果连接的服务端异常断开连接了,则及时清理缓存中对应的serverNode。

四、压力测试

测试环境:

(英特尔)Intel(R) Core(TM) i5-6300HQ CPU @ 2.30GHz

4核

windows10家庭版(64位)

16G内存

1.本地启动zookeeper

2.本地启动一个消费者,两个服务端,轮询算法

3.使用ab进行压力测试,4个线程发送10000个请求

ab -c 4 -n 10000 http://localhost:8080/test/user?id=1

测试结果

手写一个RPC框架
测试结果

从图片可以看出,10000个请求只用了11s,比之前的130+秒耗时减少了10倍以上。

代码地址:

https://github.com/2YSP/rpc-spring-boot-starter

https://github.com/2YSP/rpc-example

参考:

看了这篇你就会手写RPC框架了