作为一个分布式系统,Flink 内部不同组件之间通信依赖于 RPC 机制。这篇文章将对 Flink 的 RPC 框架加以分析。

例子
先来看一个简单的例子,了解 Flink 内部的 RPC 框架是如何使用的。

public class RpcTest {
    private static final Time TIMEOUT = Time.seconds(10L);
    private static ActorSystem actorSystem = null;
    private static RpcService rpcService = null;

    // 定义通信协议
    public interface HelloGateway extends RpcGateway {
        String hello();
    }

    public interface HiGateway extends RpcGateway {
        String hi();
    }

    // 具体实现
    public static class HelloRpcEndpoint extends RpcEndpoint implements HelloGateway {
        protected HelloRpcEndpoint(RpcService rpcService) {
            super(rpcService);
        }

        @Override
        public String hello() {
            return "hello";
        }
    }

    public static class HiRpcEndpoint extends RpcEndpoint implements HiGateway {
        protected HiRpcEndpoint(RpcService rpcService) {
            super(rpcService);
        }

        @Override
        public String hi() {
            return "hi";
        }
    }

    @BeforeClass
    public static void setup() {
        actorSystem = AkkaUtils.createDefaultActorSystem();
        // 创建 RpcService, 基于 AKKA 的实现
        rpcService = new AkkaRpcService(actorSystem, AkkaRpcServiceConfiguration.defaultConfiguration());
    }

    @AfterClass
    public static void teardown() throws Exception {

        final CompletableFuture<Void> rpcTerminationFuture = rpcService.stopService();
        final CompletableFuture<Terminated> actorSystemTerminationFuture = FutureUtils.toJava(actorSystem.terminate());

        FutureUtils
            .waitForAll(Arrays.asList(rpcTerminationFuture, actorSystemTerminationFuture))
            .get(TIMEOUT.toMilliseconds(), TimeUnit.MILLISECONDS);
    }

    @Test
    public void test() throws Exception {
        HelloRpcEndpoint helloEndpoint = new HelloRpcEndpoint(rpcService);
        HiRpcEndpoint hiEndpoint = new HiRpcEndpoint(rpcService);

        helloEndpoint.start();
        //获取 endpoint 的 self gateway
        HelloGateway helloGateway = helloEndpoint.getSelfGateway(HelloGateway.class);
        String hello = helloGateway.hello();
        assertEquals("hello", hello);

        hiEndpoint.start();
        // 通过 endpoint 的地址获得代理
        HiGateway hiGateway = rpcService.connect(hiEndpoint.getAddress(),HiGateway.class).get();
        String hi = hiGateway.hi();
        assertEquals("hi", hi);
    }
}

基本的使用流程就是1)定义协议,提供 RPC 方法的实现;2)获得服务对象的代理对象,调用 RPC 方法。

主要抽象
RpcEndpoint 是对 RPC 框架中提供具体服务的实体的抽象,所有提供远程调用方法的组件都需要继承该抽象类。另外,对于同一个 RpcEndpoint 的所有 RPC 调用都会在同一个线程(RpcEndpoint 的“主线程”)中执行,因此无需担心并发执行的线程安全问题。

RpcGateway 接口是用于远程调用的代理接口。 RpcGateway 提供了获取其所代理的 RpcEndpoint 的地址的方法。在实现一个提供 RPC 调用的组件时,通常需要先定一个接口,该接口继承 RpcGateway 并约定好提供的远程调用的方法。

RpcService 是 RpcEndpoint 的运行时环境, RpcService 提供了启动 RpcEndpoint, 连接到远端 RpcEndpoint 并返回远端 RpcEndpoint 的代理对象等方法。此外, RpcService 还提供了某些异步任务或者周期性调度任务的方法。

RpcServer 相当于 RpcEndpoint 自身的的代理对象(self gateway)。RpcServer 是 RpcService 在启动了 RpcEndpoint 之后返回的对象,每一个 RpcEndpoint 对象内部都有一个 RpcServer 的成员变量,通过 getSelfGateway 方法就可以获得自身的代理,然后调用该Endpoint 提供的服务。

FencedRpcEndpoint 和 FencedRpcGate 要求在调用 RPC 方法时携带 token 信息,只有当调用方提供了 token 和 endpoint 的 token 一致时才允许调用。

基于 Akka 的 RPC 实现
前面介绍了 Flink 内部 RPC 框架的基本抽象,主要就是 RpcService, RpcEndpoint, RpcGateway, RpcServer 等接口。至于具体的实现,则可以有多种不同的方式,如 Akka, Netty 等。Flink 目前提供了一套基于 Akka 的实现。

启动 RpcEndpoint
AkkaRpcService 实现了 RpcService 接口, AkkaRpcService 会启动 Akka actor 来接收来自 RpcGateway 的 RPC 调用。

首先,在 RpcEndpoint 的构造函数中,会调用 AkkaRpcService#startServer 方法来初始化服务,AkkaRpcService#startServer 的主要工作包括: - 创建一个 Akka actor (AkkaRpcActor 或 FencedAkkaRpcActor) - 通过动态代理创建代理对象

class AkkaRpcService {
    @Override
    public <C extends RpcEndpoint & RpcGateway> RpcServer startServer(C rpcEndpoint) {
        checkNotNull(rpcEndpoint, "rpc endpoint");

        CompletableFuture<Void> terminationFuture = new CompletableFuture<>();
        final Props akkaRpcActorProps;

        if (rpcEndpoint instanceof FencedRpcEndpoint) {
            akkaRpcActorProps = Props.create(
                FencedAkkaRpcActor.class,
                rpcEndpoint,
                terminationFuture,
                getVersion(),
                configuration.getMaximumFramesize());
        } else {
            akkaRpcActorProps = Props.create(
                AkkaRpcActor.class,
                rpcEndpoint,
                terminationFuture,
                getVersion(),
                configuration.getMaximumFramesize());
        }

        ActorRef actorRef;

        // 创建 Akka actor
        synchronized (lock) {
            checkState(!stopped, "RpcService is stopped");
            actorRef = actorSystem.actorOf(akkaRpcActorProps, rpcEndpoint.getEndpointId());
            actors.put(actorRef, rpcEndpoint);
        }

        LOG.info("Starting RPC endpoint for {} at {} .", rpcEndpoint.getClass().getName(), actorRef.path());

        final String akkaAddress = AkkaUtils.getAkkaURL(actorSystem, actorRef);
        final String hostname;
        Option<String> host = actorRef.path().address().host();
        if (host.isEmpty()) {
            hostname = "localhost";
        } else {
            hostname = host.get();
        }

        // 代理的接口
        Set<Class<?>> implementedRpcGateways = new HashSet<>(RpcUtils.extractImplementedRpcGateways(rpcEndpoint.getClass()));

        implementedRpcGateways.add(RpcServer.class);
        implementedRpcGateways.add(AkkaBasedEndpoint.class);

        final InvocationHandler akkaInvocationHandler;

        //创建 InvocationHandler
        if (rpcEndpoint instanceof FencedRpcEndpoint) {
            // a FencedRpcEndpoint needs a FencedAkkaInvocationHandler
            akkaInvocationHandler = new FencedAkkaInvocationHandler<>(
                akkaAddress,
                hostname,
                actorRef,
                configuration.getTimeout(),
                configuration.getMaximumFramesize(),
                terminationFuture,
                ((FencedRpcEndpoint<?>) rpcEndpoint)::getFencingToken);

            implementedRpcGateways.add(FencedMainThreadExecutable.class);
        } else {
            akkaInvocationHandler = new AkkaInvocationHandler(
                akkaAddress,
                hostname,
                actorRef,
                configuration.getTimeout(),
                configuration.getMaximumFramesize(),
                terminationFuture);
        }

        // Rather than using the System ClassLoader directly, we derive the ClassLoader
        // from this class . That works better in cases where Flink runs embedded and all Flink
        // code is loaded dynamically (for example from an OSGI bundle) through a custom ClassLoader
        ClassLoader classLoader = getClass().getClassLoader();

        //通过动态代理创建代理对象
        @SuppressWarnings("unchecked")
        RpcServer server = (RpcServer) Proxy.newProxyInstance(
            classLoader,
            implementedRpcGateways.toArray(new Class<?>[implementedRpcGateways.size()]),
            akkaInvocationHandler);

        return server;
    }
}

在 RpcEndpoint 对象创建后,下一步操作是启动它,实际上调用的是 RpcServer.start() 方法。RpcServer 是通过 AkkaInvocationHandler 创建的动态代理对象:

class AkkaInvocationHandler {
    private final ActorRef rpcEndpoint;

    public void start() {
        //向 Akka actor 发送 START 消息
        rpcEndpoint.tell(ControlMessages.START, ActorRef.noSender());
    }
}

所以启动 RpcEndpoint 实际上就是向当前 endpoint 绑定的 Actor 发送一条 START 消息,通知服务启动。

获取 RpcEndpoint 的代理对象
在 RpcEndpoint 创建的过程中,实际上已经通过动态代理生成了一个可供本地使用的代理对象,通过 RpcEndpoint#getSelfGateway 方法可以直接获取。

class RpcEndpoint {
    public <C extends RpcGateway> C getSelfGateway(Class<C> selfGatewayType) {
        //rpcServer 是通过动态代理创建的
        if (selfGatewayType.isInstance(rpcServer)) {
            @SuppressWarnings("unchecked")
            C selfGateway = ((C) rpcServer);

            return selfGateway;
        } else {
            throw new RuntimeException("RpcEndpoint does not implement the RpcGateway interface of type " + selfGatewayType + '.');
        }
    }
}

如果需要获取一个远程 RpcEndpoint 的代理,就需要通过 RpcService#connect 方法,需要提供远程 endpoint 的地址:

class AkkaRpcService {
    private <C extends RpcGateway> CompletableFuture<C> connectInternal(
            final String address,
            final Class<C> clazz,
            Function<ActorRef, InvocationHandler> invocationHandlerFactory) {
        checkState(!stopped, "RpcService is stopped");

        LOG.debug("Try to connect to remote RPC endpoint with address {}. Returning a {} gateway.",
            address, clazz.getName());

        final ActorSelection actorSel = actorSystem.actorSelection(address);

        final Future<ActorIdentity> identify = Patterns
            .ask(actorSel, new Identify(42), configuration.getTimeout().toMilliseconds())
            .<ActorIdentity>mapTo(ClassTag$.MODULE$.<ActorIdentity>apply(ActorIdentity.class));

        final CompletableFuture<ActorIdentity> identifyFuture = FutureUtils.toJava(identify);

        //获取 actor 的引用 ActorRef
        final CompletableFuture<ActorRef> actorRefFuture = identifyFuture.thenApply(
            (ActorIdentity actorIdentity) -> {
                if (actorIdentity.getRef() == null) {
                    throw new CompletionException(new RpcConnectionException("Could not connect to rpc endpoint under address " + address + '.'));
                } else {
                    return actorIdentity.getRef();
                }
            });

        //发送握手消息
        final CompletableFuture<HandshakeSuccessMessage> handshakeFuture = actorRefFuture.thenCompose(
            (ActorRef actorRef) -> FutureUtils.toJava(
                Patterns
                    .ask(actorRef, new RemoteHandshakeMessage(clazz, getVersion()), configuration.getTimeout().toMilliseconds())
                    .<HandshakeSuccessMessage>mapTo(ClassTag$.MODULE$.<HandshakeSuccessMessage>apply(HandshakeSuccessMessage.class))));

        // 创建 InvocationHandler,并通过动态代理生成代理对象
        return actorRefFuture.thenCombineAsync(
            handshakeFuture,
            (ActorRef actorRef, HandshakeSuccessMessage ignored) -> {
                InvocationHandler invocationHandler = invocationHandlerFactory.apply(actorRef);

                // Rather than using the System ClassLoader directly, we derive the ClassLoader
                // from this class . That works better in cases where Flink runs embedded and all Flink
                // code is loaded dynamically (for example from an OSGI bundle) through a custom ClassLoader
                ClassLoader classLoader = getClass().getClassLoader();

                @SuppressWarnings("unchecked")
                C proxy = (C) Proxy.newProxyInstance(
                    classLoader,
                    new Class<?>[]{clazz},
                    invocationHandler);

                return proxy;
            },
            actorSystem.dispatcher());
    }
}

上述方法主要的功能包括:

通过地址获取 RpcEndpoint 绑定的 actor 的引用 ActorRef
向对应的 AkkaRpcActor 发送握手消息
握手成功之后,创建 AkkaInvocationHandler 对象,并通过动态代理生成代理对象

Rpc 调用
在获取了本地或者远端 RpcEndpoint 的代理对象后,就可以通过代理对象发起 RPC 调用了。由于代理对象是通过动态代理创建的,因而所以的方法都会转化为 AkkaInvocationHandler#invoke 方法,并传入 RPC 调用的方法以及参数信息。

class AkkaInvocationHandler {
    @Override
    public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
        Class<?> declaringClass = method.getDeclaringClass();

        Object result;

        if (declaringClass.equals(AkkaBasedEndpoint.class) ||
            declaringClass.equals(Object.class) ||
            declaringClass.equals(RpcGateway.class) ||
            declaringClass.equals(StartStoppable.class) ||
            declaringClass.equals(MainThreadExecutable.class) ||
            declaringClass.equals(RpcServer.class)) {
            result = method.invoke(this, args);
        } else if (declaringClass.equals(FencedRpcGateway.class)) {
            throw new UnsupportedOperationException("AkkaInvocationHandler does not support the call FencedRpcGateway#" +
                method.getName() + ". This indicates that you retrieved a FencedRpcGateway without specifying a " +
                "fencing token. Please use RpcService#connect(RpcService, F, Time) with F being the fencing token to " +
                "retrieve a properly FencedRpcGateway.");
        } else {
            result = invokeRpc(method, args);
        }

        return result;
    }

    private Object invokeRpc(Method method, Object[] args) throws Exception {
        String methodName = method.getName();
        Class<?>[] parameterTypes = method.getParameterTypes();
        Annotation[][] parameterAnnotations = method.getParameterAnnotations();
        Time futureTimeout = extractRpcTimeout(parameterAnnotations, args, timeout);

        //将 RPC 调用封装为 RpcInvocation(会根据RpcEndpoint是本地还是远程的)
        final RpcInvocation rpcInvocation = createRpcInvocationMessage(methodName, parameterTypes, args);

        Class<?> returnType = method.getReturnType();

        final Object result;

        //根据RPC方法是否有返回值决定调用 tell 还是 ask
        if (Objects.equals(returnType, Void.TYPE)) {
            //akka actor tell
            tell(rpcInvocation);

            result = null;
        } else {
            // execute an asynchronous call
            //akka actor ask
            CompletableFuture<?> resultFuture = ask(rpcInvocation, futureTimeout);

            CompletableFuture<?> completableFuture = resultFuture.thenApply((Object o) -> {
                if (o instanceof SerializedValue) {
                    try {
                        return  ((SerializedValue<?>) o).deserializeValue(getClass().getClassLoader());
                    } catch (IOException | ClassNotFoundException e) {
                        throw new CompletionException(
                            new RpcException("Could not deserialize the serialized payload of RPC method : "
                                + methodName, e));
                    }
                } else {
                    return o;
                }
            });

            if (Objects.equals(returnType, CompletableFuture.class)) {
                result = completableFuture;
            } else {
                try {
                    result = completableFuture.get(futureTimeout.getSize(), futureTimeout.getUnit());
                } catch (ExecutionException ee) {
                    throw new RpcException("Failure while obtaining synchronous RPC result.", ExceptionUtils.stripExecutionException(ee));
                }
            }
        }

        return result;
    }
}

对于 RPC 调用,需要将 RPC 调用的方法名、参数类型和参数值封装为一个 RpcInvocation 对象,根据 RpcEndpoint 是本地的还是远端,具体的 有 LocalRpcInvocation 和 RemoteRpcInvocation 两类,它们的区别在于是否需要序列化。

然后根据 RPC 方法是否有返回值,决定调用 tell 或 ask 方法,然后通过 Akka 的 ActorRef 向对应的 AkkaRpcActor 发送请求,如果带有返回值,则等待 actor 的响应。

AkkaRpcActor
AkkaRpcActor 负责接受 RPC 调用的请求,并通过反射调用 RpcEndpoint 的对应方法来完成 RPC 调用。

class AkkaRpcActor<T extends RpcEndpoint & RpcGateway> extends AbstractActor {
    protected final T rpcEndpoint;

    @Override
    public Receive createReceive() {
        //不同类型消息的处理方法
        return ReceiveBuilder.create()
            .match(RemoteHandshakeMessage.class, this::handleHandshakeMessage)
            .match(ControlMessages.class, this::handleControlMessage)
            .matchAny(this::handleMessage)
            .build();
    }

    //处理 RPC 调用
    private void handleMessage(final Object message) {
        if (state.isRunning()) {
            mainThreadValidator.enterMainThread();

            try {
                handleRpcMessage(message);
            } finally {
                mainThreadValidator.exitMainThread();
            }
        } else {
            log.info("The rpc endpoint {} has not been started yet. Discarding message {} until processing is started.",
                rpcEndpoint.getClass().getName(),
                message.getClass().getName());

            sendErrorIfSender(new AkkaRpcException(
                String.format("Discard message, because the rpc endpoint %s has not been started yet.", rpcEndpoint.getAddress())));
        }
    }

  private void handleRpcInvocation(RpcInvocation rpcInvocation) {
        Method rpcMethod = null;

        try {
            String methodName = rpcInvocation.getMethodName();
            Class<?>[] parameterTypes = rpcInvocation.getParameterTypes();

            //获去需要调用的方法
            rpcMethod = lookupRpcMethod(methodName, parameterTypes);
        } catch (ClassNotFoundException e) {
            log.error("Could not load method arguments.", e);

            RpcConnectionException rpcException = new RpcConnectionException("Could not load method arguments.", e);
            getSender().tell(new Status.Failure(rpcException), getSelf());
        } catch (IOException e) {
            log.error("Could not deserialize rpc invocation message.", e);

            RpcConnectionException rpcException = new RpcConnectionException("Could not deserialize rpc invocation message.", e);
            getSender().tell(new Status.Failure(rpcException), getSelf());
        } catch (final NoSuchMethodException e) {
            log.error("Could not find rpc method for rpc invocation.", e);

            RpcConnectionException rpcException = new RpcConnectionException("Could not find rpc method for rpc invocation.", e);
            getSender().tell(new Status.Failure(rpcException), getSelf());
        }

        //通过反射执行
        if (rpcMethod != null) {
            try {
                // this supports declaration of anonymous classes
                rpcMethod.setAccessible(true);

                if (rpcMethod.getReturnType().equals(Void.TYPE)) {
                    // No return value to send back
                    rpcMethod.invoke(rpcEndpoint, rpcInvocation.getArgs());
                }
                else {
                    final Object result;
                    try {
                        result = rpcMethod.invoke(rpcEndpoint, rpcInvocation.getArgs());
                    }
                    catch (InvocationTargetException e) {
                        log.debug("Reporting back error thrown in remote procedure {}", rpcMethod, e);

                        // tell the sender about the failure
                        getSender().tell(new Status.Failure(e.getTargetException()), getSelf());
                        return;
                    }

                    final String methodName = rpcMethod.getName();

                    //向调用方发送执行结果
                    if (result instanceof CompletableFuture) {
                        final CompletableFuture<?> responseFuture = (CompletableFuture<?>) result;
                        sendAsyncResponse(responseFuture, methodName);
                    } else {
                        sendSyncResponse(result, methodName);
                    }
                }
            } catch (Throwable e) {
                log.error("Error while executing remote procedure call {}.", rpcMethod, e);
                // tell the sender about the failure
                getSender().tell(new Status.Failure(e), getSelf());
            }
        }
    }

}

小结
这篇文章简单地分析了 Flink 内部的 RPC 框架。首先,通过 RpcService, RpcEndpoint, RpcGateway, RpcServer 等接口和抽象类,确定了 RPC 服务的基本框架;在这套框架的基础上, Flink 借助 Akka 和动态代理等技术提供了 RPC 调用的具体实现。

转发自:https://blog.jrwang.me/2019/flink-source-code-rpc/