手寫一個RPC框架


一、前言

前段時間看到一篇不錯的文章《看了這篇你就會手寫RPC框架了》,於是便來了興趣對着實現了一遍,后面覺得還有很多優化的地方便對其進行了改進。

主要改動點如下:

  1. 除了Java序列化協議,增加了protobuf和kryo序列化協議,配置即用。
  2. 增加多種負載均衡算法(隨機、輪詢、加權輪詢、平滑加權輪詢),配置即用。
  3. 客戶端增加本地服務列表緩存,提高性能。
  4. 修復高並發情況下,netty導致的內存泄漏問題
  5. 由原來的每個請求建立一次連接,改為建立TCP長連接,並多次復用。
  6. 服務端增加線程池提高消息處理能力

二、介紹

RPC,即 Remote Procedure Call(遠程過程調用),調用遠程計算機上的服務,就像調用本地服務一樣。RPC可以很好的解耦系統,如WebService就是一種基於Http協議的RPC。

調用示意圖
調用示意圖

總的來說,就如下幾個步驟:

  1. 客戶端(ServerA)執行遠程方法時就調用client stub傳遞類名、方法名和參數等信息。
  2. client stub會將參數等信息序列化為二進制流的形式,然后通過Sockect發送給服務端(ServerB)
  3. 服務端收到數據包后,server stub 需要進行解析反序列化為類名、方法名和參數等信息。
  4. server stub調用對應的本地方法,並把執行結果返回給客戶端

所以一個RPC框架有如下角色:

服務消費者

遠程方法的調用方,即客戶端。一個服務既可以是消費者也可以是提供者。

服務提供者

遠程服務的提供方,即服務端。一個服務既可以是消費者也可以是提供者。

注冊中心

保存服務提供者的服務地址等信息,一般由zookeeper、redis等實現。

監控運維(可選)

監控接口的響應時間、統計請求數量等,及時發現系統問題並發出告警通知。

三、實現

本RPC框架rpc-spring-boot-starter涉及技術棧如下:

  • 使用zookeeper作為注冊中心
  • 使用netty作為通信框架
  • 消息編解碼:protostuff、kryo、java
  • spring
  • 使用SPI來根據配置動態選擇負載均衡算法等

由於代碼過多,這里只講幾處改動點。

3.1動態負載均衡算法

1.編寫LoadBalance的實現類

負載均衡算法實現類
負載均衡算法實現類

2.自定義注解 @LoadBalanceAno

  1. /** 
  2. * 負載均衡注解 
  3. */ 
  4. @Target(ElementType.TYPE) 
  5. @Retention(RetentionPolicy.RUNTIME) 
  6. @Documented 
  7. public @interface LoadBalanceAno { 
  8.  
  9. String value() default ""; 
  10. } 
  11.  
  12. /** 
  13. * 輪詢算法 
  14. */ 
  15. @LoadBalanceAno(RpcConstant.BALANCE_ROUND) 
  16. public class FullRoundBalance implements LoadBalance { 
  17.  
  18. private static Logger logger = LoggerFactory.getLogger(FullRoundBalance.class); 
  19.  
  20. private volatile int index; 
  21.  
  22. @Override 
  23. public synchronized Service chooseOne(List<Service> services) { 
  24. // 加鎖防止多線程情況下,index超出services.size() 
  25. if (index == services.size()) { 
  26. index = 0; 
  27. } 
  28. return services.get(index++); 
  29. } 
  30. } 

3.新建在resource目錄下META-INF/servers文件夾並創建文件

enter description here
enter description here

4.RpcConfig增加配置項loadBalance

  1. /** 
  2. * @author 2YSP 
  3. * @date 2020/7/26 15:13 
  4. */ 
  5. @ConfigurationProperties(prefix = "sp.rpc") 
  6. public class RpcConfig { 
  7.  
  8. /** 
  9. * 服務注冊中心地址 
  10. */ 
  11. private String registerAddress = "127.0.0.1:2181"; 
  12.  
  13. /** 
  14. * 服務暴露端口 
  15. */ 
  16. private Integer serverPort = 9999; 
  17. /** 
  18. * 服務協議 
  19. */ 
  20. private String protocol = "java"; 
  21. /** 
  22. * 負載均衡算法 
  23. */ 
  24. private String loadBalance = "random"; 
  25. /** 
  26. * 權重,默認為1 
  27. */ 
  28. private Integer weight = 1; 
  29.  
  30. // 省略getter setter 
  31. } 

5.在自動配置類RpcAutoConfiguration根據配置選擇對應的算法實現類

  1. /** 
  2. * 使用spi匹配符合配置的負載均衡算法 
  3. * 
  4. * @param name 
  5. * @return 
  6. */ 
  7. private LoadBalance getLoadBalance(String name) { 
  8. ServiceLoader<LoadBalance> loader = ServiceLoader.load(LoadBalance.class); 
  9. Iterator<LoadBalance> iterator = loader.iterator(); 
  10. while (iterator.hasNext()) { 
  11. LoadBalance loadBalance = iterator.next(); 
  12. LoadBalanceAno ano = loadBalance.getClass().getAnnotation(LoadBalanceAno.class); 
  13. Assert.notNull(ano, "load balance name can not be empty!"); 
  14. if (name.equals(ano.value())) { 
  15. return loadBalance; 
  16. } 
  17. } 
  18. throw new RpcException("invalid load balance config"); 
  19. } 
  20.  
  21. @Bean 
  22. public ClientProxyFactory proxyFactory(@Autowired RpcConfig rpcConfig) { 
  23. ClientProxyFactory clientProxyFactory = new ClientProxyFactory(); 
  24. // 設置服務發現着 
  25. clientProxyFactory.setServerDiscovery(new ZookeeperServerDiscovery(rpcConfig.getRegisterAddress())); 
  26.  
  27. // 設置支持的協議 
  28. Map<String, MessageProtocol> supportMessageProtocols = buildSupportMessageProtocols(); 
  29. clientProxyFactory.setSupportMessageProtocols(supportMessageProtocols); 
  30. // 設置負載均衡算法 
  31. LoadBalance loadBalance = getLoadBalance(rpcConfig.getLoadBalance()); 
  32. clientProxyFactory.setLoadBalance(loadBalance); 
  33. // 設置網絡層實現 
  34. clientProxyFactory.setNetClient(new NettyNetClient()); 
  35.  
  36. return clientProxyFactory; 
  37. } 

3.2本地服務列表緩存

使用Map來緩存數據

  1. /** 
  2. * 服務發現本地緩存 
  3. */ 
  4. public class ServerDiscoveryCache { 
  5. /** 
  6. * key: serviceName 
  7. */ 
  8. private static final Map<String, List<Service>> SERVER_MAP = new ConcurrentHashMap<>(); 
  9. /** 
  10. * 客戶端注入的遠程服務service class 
  11. */ 
  12. public static final List<String> SERVICE_CLASS_NAMES = new ArrayList<>(); 
  13.  
  14. public static void put(String serviceName, List<Service> serviceList) { 
  15. SERVER_MAP.put(serviceName, serviceList); 
  16. } 
  17.  
  18. /** 
  19. * 去除指定的值 
  20. * @param serviceName 
  21. * @param service 
  22. */ 
  23. public static void remove(String serviceName, Service service) { 
  24. SERVER_MAP.computeIfPresent(serviceName, (key, value) -> 
  25. value.stream().filter(o -> !o.toString().equals(service.toString())).collect(Collectors.toList()) 
  26. ); 
  27. } 
  28.  
  29. public static void removeAll(String serviceName) { 
  30. SERVER_MAP.remove(serviceName); 
  31. } 
  32.  
  33.  
  34. public static boolean isEmpty(String serviceName) { 
  35. return SERVER_MAP.get(serviceName) == null || SERVER_MAP.get(serviceName).size() == 0; 
  36. } 
  37.  
  38. public static List<Service> get(String serviceName) { 
  39. return SERVER_MAP.get(serviceName); 
  40. } 
  41. } 

ClientProxyFactory,先查本地緩存,緩存沒有再查詢zookeeper。

  1. /** 
  2. * 根據服務名獲取可用的服務地址列表 
  3. * @param serviceName 
  4. * @return 
  5. */ 
  6. private List<Service> getServiceList(String serviceName) { 
  7. List<Service> services; 
  8. synchronized (serviceName){ 
  9. if (ServerDiscoveryCache.isEmpty(serviceName)) { 
  10. services = serverDiscovery.findServiceList(serviceName); 
  11. if (services == null || services.size() == 0) { 
  12. throw new RpcException("No provider available!"); 
  13. } 
  14. ServerDiscoveryCache.put(serviceName, services); 
  15. } else { 
  16. services = ServerDiscoveryCache.get(serviceName); 
  17. } 
  18. } 
  19. return services; 
  20. } 

問題: 如果服務端因為宕機或網絡問題下線了,緩存卻還在就會導致客戶端請求已經不可用的服務端,增加請求失敗率。
解決方案:由於服務端注冊的是臨時節點,所以如果服務端下線節點會被移除。只要監聽zookeeper的子節點,如果新增或刪除子節點就直接清空本地緩存即可。
DefaultRpcProcessor

  1. /** 
  2. * Rpc處理者,支持服務啟動暴露,自動注入Service 
  3. * @author 2YSP 
  4. * @date 2020/7/26 14:46 
  5. */ 
  6. public class DefaultRpcProcessor implements ApplicationListener<ContextRefreshedEvent> { 
  7.  
  8.  
  9.  
  10. @Override 
  11. public void onApplicationEvent(ContextRefreshedEvent event) { 
  12. // Spring啟動完畢過后會收到一個事件通知 
  13. if (Objects.isNull(event.getApplicationContext().getParent())){ 
  14. ApplicationContext context = event.getApplicationContext(); 
  15. // 開啟服務 
  16. startServer(context); 
  17. // 注入Service 
  18. injectService(context); 
  19. } 
  20. } 
  21.  
  22. private void injectService(ApplicationContext context) { 
  23. String[] names = context.getBeanDefinitionNames(); 
  24. for(String name : names){ 
  25. Class<?> clazz = context.getType(name); 
  26. if (Objects.isNull(clazz)){ 
  27. continue; 
  28. } 
  29.  
  30. Field[] declaredFields = clazz.getDeclaredFields(); 
  31. for(Field field : declaredFields){ 
  32. // 找出標記了InjectService注解的屬性 
  33. InjectService injectService = field.getAnnotation(InjectService.class); 
  34. if (injectService == null){ 
  35. continue; 
  36. } 
  37.  
  38. Class<?> fieldClass = field.getType(); 
  39. Object object = context.getBean(name); 
  40. field.setAccessible(true); 
  41. try { 
  42. field.set(object,clientProxyFactory.getProxy(fieldClass)); 
  43. } catch (IllegalAccessException e) { 
  44. e.printStackTrace(); 
  45. } 
  46. // 添加本地服務緩存 
  47. ServerDiscoveryCache.SERVICE_CLASS_NAMES.add(fieldClass.getName()); 
  48. } 
  49. } 
  50. // 注冊子節點監聽 
  51. if (clientProxyFactory.getServerDiscovery() instanceof ZookeeperServerDiscovery){ 
  52. ZookeeperServerDiscovery serverDiscovery = (ZookeeperServerDiscovery) clientProxyFactory.getServerDiscovery(); 
  53. ZkClient zkClient = serverDiscovery.getZkClient(); 
  54. ServerDiscoveryCache.SERVICE_CLASS_NAMES.forEach(name ->{ 
  55. String servicePath = RpcConstant.ZK_SERVICE_PATH + RpcConstant.PATH_DELIMITER + name + "/service"; 
  56. zkClient.subscribeChildChanges(servicePath, new ZkChildListenerImpl()); 
  57. }); 
  58. logger.info("subscribe service zk node successfully"); 
  59. } 
  60.  
  61. } 
  62.  
  63. private void startServer(ApplicationContext context) { 
  64. ... 
  65.  
  66. } 
  67. } 
  68.  

ZkChildListenerImpl

  1. /** 
  2. * 子節點事件監聽處理類 
  3. */ 
  4. public class ZkChildListenerImpl implements IZkChildListener { 
  5.  
  6. private static Logger logger = LoggerFactory.getLogger(ZkChildListenerImpl.class); 
  7.  
  8. /** 
  9. * 監聽子節點的刪除和新增事件 
  10. * @param parentPath /rpc/serviceName/service 
  11. * @param childList 
  12. * @throws Exception 
  13. */ 
  14. @Override 
  15. public void handleChildChange(String parentPath, List<String> childList) throws Exception { 
  16. logger.debug("Child change parentPath:[{}] -- childList:[{}]", parentPath, childList); 
  17. // 只要子節點有改動就清空緩存 
  18. String[] arr = parentPath.split("/"); 
  19. ServerDiscoveryCache.removeAll(arr[2]); 
  20. } 
  21. } 

3.3nettyClient支持TCP長連接

這部分的改動最多,先增加新的sendRequest接口。

添加接口
添加接口

實現類NettyNetClient

  1. /** 
  2. * @author 2YSP 
  3. * @date 2020/7/25 20:12 
  4. */ 
  5. public class NettyNetClient implements NetClient { 
  6.  
  7. private static Logger logger = LoggerFactory.getLogger(NettyNetClient.class); 
  8.  
  9. private static ExecutorService threadPool = new ThreadPoolExecutor(4, 10, 200, 
  10. TimeUnit.SECONDS, new LinkedBlockingQueue<>(1000), new ThreadFactoryBuilder() 
  11. .setNameFormat("rpcClient-%d") 
  12. .build()); 
  13.  
  14. private EventLoopGroup loopGroup = new NioEventLoopGroup(4); 
  15.  
  16. /** 
  17. * 已連接的服務緩存 
  18. * key: 服務地址,格式:ip:port 
  19. */ 
  20. public static Map<String, SendHandlerV2> connectedServerNodes = new ConcurrentHashMap<>(); 
  21.  
  22. @Override 
  23. public byte[] sendRequest(byte[] data, Service service) throws InterruptedException { 
  24. .... 
  25. return respData; 
  26. } 
  27.  
  28. @Override 
  29. public RpcResponse sendRequest(RpcRequest rpcRequest, Service service, MessageProtocol messageProtocol) { 
  30.  
  31. String address = service.getAddress(); 
  32. synchronized (address) { 
  33. if (connectedServerNodes.containsKey(address)) { 
  34. SendHandlerV2 handler = connectedServerNodes.get(address); 
  35. logger.info("使用現有的連接"); 
  36. return handler.sendRequest(rpcRequest); 
  37. } 
  38.  
  39. String[] addrInfo = address.split(":"); 
  40. final String serverAddress = addrInfo[0]; 
  41. final String serverPort = addrInfo[1]; 
  42. final SendHandlerV2 handler = new SendHandlerV2(messageProtocol, address); 
  43. threadPool.submit(() -> { 
  44. // 配置客戶端 
  45. Bootstrap b = new Bootstrap(); 
  46. b.group(loopGroup).channel(NioSocketChannel.class) 
  47. .option(ChannelOption.TCP_NODELAY, true) 
  48. .handler(new ChannelInitializer<SocketChannel>() { 
  49. @Override 
  50. protected void initChannel(SocketChannel socketChannel) throws Exception { 
  51. ChannelPipeline pipeline = socketChannel.pipeline(); 
  52. pipeline 
  53. .addLast(handler); 
  54. } 
  55. }); 
  56. // 啟用客戶端連接 
  57. ChannelFuture channelFuture = b.connect(serverAddress, Integer.parseInt(serverPort)); 
  58. channelFuture.addListener(new ChannelFutureListener() { 
  59. @Override 
  60. public void operationComplete(ChannelFuture channelFuture) throws Exception { 
  61. connectedServerNodes.put(address, handler); 
  62. } 
  63. }); 
  64. } 
  65. ); 
  66. logger.info("使用新的連接。。。"); 
  67. return handler.sendRequest(rpcRequest); 
  68. } 
  69. } 
  70. } 
  71.  

每次請求都會調用sendRequest()方法,用線程池異步和服務端創建TCP長連接,連接成功后將SendHandlerV2緩存到ConcurrentHashMap中方便復用,后續請求的請求地址(ip+port)如果在connectedServerNodes中存在則使用connectedServerNodes中的handler處理不再重新建立連接。

SendHandlerV2

  1. /** 
  2. * @author 2YSP 
  3. * @date 2020/8/19 20:06 
  4. */ 
  5. public class SendHandlerV2 extends ChannelInboundHandlerAdapter { 
  6.  
  7. private static Logger logger = LoggerFactory.getLogger(SendHandlerV2.class); 
  8.  
  9. /** 
  10. * 等待通道建立最大時間 
  11. */ 
  12. static final int CHANNEL_WAIT_TIME = 4; 
  13. /** 
  14. * 等待響應最大時間 
  15. */ 
  16. static final int RESPONSE_WAIT_TIME = 8; 
  17.  
  18. private volatile Channel channel; 
  19.  
  20. private String remoteAddress; 
  21.  
  22. private static Map<String, RpcFuture<RpcResponse>> requestMap = new ConcurrentHashMap<>(); 
  23.  
  24. private MessageProtocol messageProtocol; 
  25.  
  26. private CountDownLatch latch = new CountDownLatch(1); 
  27.  
  28. public SendHandlerV2(MessageProtocol messageProtocol,String remoteAddress) { 
  29. this.messageProtocol = messageProtocol; 
  30. this.remoteAddress = remoteAddress; 
  31. } 
  32.  
  33. @Override 
  34. public void channelRegistered(ChannelHandlerContext ctx) throws Exception { 
  35. this.channel = ctx.channel(); 
  36. latch.countDown(); 
  37. } 
  38.  
  39. @Override 
  40. public void channelActive(ChannelHandlerContext ctx) throws Exception { 
  41. logger.debug("Connect to server successfully:{}", ctx); 
  42. } 
  43.  
  44. @Override 
  45. public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { 
  46. logger.debug("Client reads message:{}", msg); 
  47. ByteBuf byteBuf = (ByteBuf) msg; 
  48. byte[] resp = new byte[byteBuf.readableBytes()]; 
  49. byteBuf.readBytes(resp); 
  50. // 手動回收 
  51. ReferenceCountUtil.release(byteBuf); 
  52. RpcResponse response = messageProtocol.unmarshallingResponse(resp); 
  53. RpcFuture<RpcResponse> future = requestMap.get(response.getRequestId()); 
  54. future.setResponse(response); 
  55. } 
  56.  
  57. @Override 
  58. public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { 
  59. cause.printStackTrace(); 
  60. logger.error("Exception occurred:{}", cause.getMessage()); 
  61. ctx.close(); 
  62. } 
  63.  
  64. @Override 
  65. public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { 
  66. ctx.flush(); 
  67. } 
  68.  
  69. @Override 
  70. public void channelInactive(ChannelHandlerContext ctx) throws Exception { 
  71. super.channelInactive(ctx); 
  72. logger.error("channel inactive with remoteAddress:[{}]",remoteAddress); 
  73. NettyNetClient.connectedServerNodes.remove(remoteAddress); 
  74.  
  75. } 
  76.  
  77. @Override 
  78. public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { 
  79. super.userEventTriggered(ctx, evt); 
  80. } 
  81.  
  82. public RpcResponse sendRequest(RpcRequest request) { 
  83. RpcResponse response; 
  84. RpcFuture<RpcResponse> future = new RpcFuture<>(); 
  85. requestMap.put(request.getRequestId(), future); 
  86. try { 
  87. byte[] data = messageProtocol.marshallingRequest(request); 
  88. ByteBuf reqBuf = Unpooled.buffer(data.length); 
  89. reqBuf.writeBytes(data); 
  90. if (latch.await(CHANNEL_WAIT_TIME,TimeUnit.SECONDS)){ 
  91. channel.writeAndFlush(reqBuf); 
  92. // 等待響應 
  93. response = future.get(RESPONSE_WAIT_TIME, TimeUnit.SECONDS); 
  94. }else { 
  95. throw new RpcException("establish channel time out"); 
  96. } 
  97. } catch (Exception e) { 
  98. throw new RpcException(e.getMessage()); 
  99. } finally { 
  100. requestMap.remove(request.getRequestId()); 
  101. } 
  102. return response; 
  103. } 
  104. } 
  105.  

RpcFuture

  1. package cn.sp.rpc.client.net; 
  2.  
  3. import java.util.concurrent.*; 
  4.  
  5. /** 
  6. * @author 2YSP 
  7. * @date 2020/8/19 22:31 
  8. */ 
  9. public class RpcFuture<T> implements Future<T> { 
  10.  
  11. private T response; 
  12. /** 
  13. * 因為請求和響應是一一對應的,所以這里是1 
  14. */ 
  15. private CountDownLatch countDownLatch = new CountDownLatch(1); 
  16. /** 
  17. * Future的請求時間,用於計算Future是否超時 
  18. */ 
  19. private long beginTime = System.currentTimeMillis(); 
  20.  
  21. @Override 
  22. public boolean cancel(boolean mayInterruptIfRunning) { 
  23. return false; 
  24. } 
  25.  
  26. @Override 
  27. public boolean isCancelled() { 
  28. return false; 
  29. } 
  30.  
  31. @Override 
  32. public boolean isDone() { 
  33. if (response != null) { 
  34. return true; 
  35. } 
  36. return false; 
  37. } 
  38.  
  39. /** 
  40. * 獲取響應,直到有結果才返回 
  41. * @return 
  42. * @throws InterruptedException 
  43. * @throws ExecutionException 
  44. */ 
  45. @Override 
  46. public T get() throws InterruptedException, ExecutionException { 
  47. countDownLatch.await(); 
  48. return response; 
  49. } 
  50.  
  51. @Override 
  52. public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { 
  53. if (countDownLatch.await(timeout,unit)){ 
  54. return response; 
  55. } 
  56. return null; 
  57. } 
  58.  
  59. public void setResponse(T response) { 
  60. this.response = response; 
  61. countDownLatch.countDown(); 
  62. } 
  63.  
  64. public long getBeginTime() { 
  65. return beginTime; 
  66. } 
  67. } 
  68.  

此處邏輯,第一次執行 SendHandlerV2#sendRequest() 時channel需要等待通道建立好之后才能發送請求,所以用CountDownLatch來控制,等待通道建立。
自定義Future+requestMap緩存來實現netty的請求和阻塞等待響應,RpcRequest對象在創建時會生成一個請求的唯一標識requestId,發送請求前先將RpcFuture緩存到requestMap中,key為requestId,讀取到服務端的響應信息后(channelRead方法),將響應結果放入對應的RpcFuture中。
SendHandlerV2#channelInactive() 方法中,如果連接的服務端異常斷開連接了,則及時清理緩存中對應的serverNode。

四、壓力測試

測試環境:
(英特爾)Intel(R) Core(TM) i5-6300HQ CPU @ 2.30GHz
4核
windows10家庭版(64位)
16G內存

1.本地啟動zookeeper
2.本地啟動一個消費者,兩個服務端,輪詢算法
3.使用ab進行壓力測試,4個線程發送10000個請求

ab -c 4 -n 10000 http://localhost:8080/test/user?id=1

測試結果

測試結果
測試結果

從圖片可以看出,10000個請求只用了11s,比之前的130+秒耗時減少了10倍以上。

代碼地址:
https://github.com/2YSP/rpc-spring-boot-starter
https://github.com/2YSP/rpc-example

參考:
看了這篇你就會手寫RPC框架了


免責聲明!

本站轉載的文章為個人學習借鑒使用,本站對版權不負任何法律責任。如果侵犯了您的隱私權益,請聯系本站郵箱yoyou2525@163.com刪除。



 
粵ICP備18138465號   © 2018-2025 CODEPRJ.COM