/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.statefun.flink.core.reqreply;

import java.time.Duration;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.statefun.flink.core.backpressure.InternalContext;
import org.apache.flink.statefun.flink.core.common.PolyglotUtil;
import org.apache.flink.statefun.flink.core.metrics.FunctionTypeMetrics;
import org.apache.flink.statefun.flink.core.reqreply.PersistedRemoteFunctionValues;
import org.apache.flink.statefun.flink.core.reqreply.RequestReplyClient;
import org.apache.flink.statefun.flink.core.reqreply.ToFunctionRequestSummary;
import org.apache.flink.statefun.sdk.Address;
import org.apache.flink.statefun.sdk.AsyncOperationResult;
import org.apache.flink.statefun.sdk.Context;
import org.apache.flink.statefun.sdk.FunctionType;
import org.apache.flink.statefun.sdk.StatefulFunction;
import org.apache.flink.statefun.sdk.annotations.Persisted;
import org.apache.flink.statefun.sdk.io.EgressIdentifier;
import org.apache.flink.statefun.sdk.reqreply.generated.FromFunction;
import org.apache.flink.statefun.sdk.reqreply.generated.ToFunction;
import org.apache.flink.statefun.sdk.reqreply.generated.TypedValue;
import org.apache.flink.statefun.sdk.state.PersistedAppendingBuffer;
import org.apache.flink.statefun.sdk.state.PersistedValue;
import org.apache.flink.types.Either;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class RequestReplyFunction
implements StatefulFunction {
    public static final Logger LOG = LoggerFactory.getLogger(RequestReplyFunction.class);
    private final FunctionType functionType;
    private final RequestReplyClient client;
    private final int maxNumBatchRequests;
    private boolean isFirstRequestSent;
    @Persisted
    private final PersistedValue<Integer> requestState = PersistedValue.of((String)"request-state", Integer.class);
    @Persisted
    private final PersistedAppendingBuffer<ToFunction.Invocation> batch = PersistedAppendingBuffer.of((String)"batch", ToFunction.Invocation.class);
    @Persisted
    private final PersistedRemoteFunctionValues managedStates;

    public RequestReplyFunction(FunctionType functionType, int maxNumBatchRequests, RequestReplyClient client) {
        this(functionType, new PersistedRemoteFunctionValues(), maxNumBatchRequests, client, false);
    }

    @VisibleForTesting
    RequestReplyFunction(FunctionType functionType, PersistedRemoteFunctionValues states, int maxNumBatchRequests, RequestReplyClient client, boolean isFirstRequestSent) {
        this.functionType = Objects.requireNonNull(functionType);
        this.managedStates = Objects.requireNonNull(states);
        this.maxNumBatchRequests = maxNumBatchRequests;
        this.client = Objects.requireNonNull(client);
        this.isFirstRequestSent = isFirstRequestSent;
    }

    public void invoke(Context context, Object input) {
        InternalContext castedContext = (InternalContext)context;
        if (!(input instanceof AsyncOperationResult)) {
            this.onRequest(castedContext, (TypedValue)input);
            return;
        }
        AsyncOperationResult result = (AsyncOperationResult)input;
        this.onAsyncResult(castedContext, (AsyncOperationResult<ToFunction, FromFunction>)result);
    }

    private void onRequest(InternalContext context, TypedValue message) {
        ToFunction.Invocation.Builder invocationBuilder = RequestReplyFunction.singeInvocationBuilder(context, message);
        int inflightOrBatched = (Integer)this.requestState.getOrDefault((Object)-1);
        if (inflightOrBatched < 0) {
            this.requestState.set((Object)0);
            this.sendToFunction(context, invocationBuilder);
            return;
        }
        this.batch.append((Object)invocationBuilder.build());
        this.requestState.set((Object)(++inflightOrBatched));
        context.functionTypeMetrics().appendBacklogMessages(1);
        if (this.isMaxNumBatchRequestsExceeded(inflightOrBatched)) {
            context.awaitAsyncOperationComplete();
        }
    }

    private void onAsyncResult(InternalContext context, AsyncOperationResult<ToFunction, FromFunction> asyncResult) {
        if (asyncResult.unknown()) {
            ToFunction batch = (ToFunction)asyncResult.metadata();
            this.sendToFunction(context, this.createRetryBatch(batch));
            return;
        }
        if (asyncResult.failure()) {
            throw new IllegalStateException("Failure forwarding a message to a remote function " + context.self(), asyncResult.throwable());
        }
        Either<FromFunction.InvocationResponse, FromFunction.IncompleteInvocationContext> response = RequestReplyFunction.unpackResponse((FromFunction)asyncResult.value());
        if (response.isRight()) {
            this.handleIncompleteInvocationContextResponse(context, (FromFunction.IncompleteInvocationContext)response.right(), (ToFunction)asyncResult.metadata());
        } else {
            this.handleInvocationResultResponse(context, (FromFunction.InvocationResponse)response.left());
        }
    }

    private static Either<FromFunction.InvocationResponse, FromFunction.IncompleteInvocationContext> unpackResponse(FromFunction fromFunction) {
        if (fromFunction.hasIncompleteInvocationContext()) {
            return Either.Right((Object)fromFunction.getIncompleteInvocationContext());
        }
        if (fromFunction.hasInvocationResult()) {
            return Either.Left((Object)fromFunction.getInvocationResult());
        }
        return Either.Left((Object)FromFunction.InvocationResponse.getDefaultInstance());
    }

    private void handleIncompleteInvocationContextResponse(InternalContext context, FromFunction.IncompleteInvocationContext incompleteContext, ToFunction originalBatch) {
        this.managedStates.registerStates(incompleteContext.getMissingValuesList());
        ToFunction.InvocationBatchRequest.Builder retryBatch = this.createRetryBatch(originalBatch);
        this.sendToFunction(context, retryBatch);
    }

    private void handleInvocationResultResponse(InternalContext context, FromFunction.InvocationResponse result) {
        this.handleOutgoingMessages(context, result);
        this.handleOutgoingDelayedMessages(context, result);
        this.handleEgressMessages(context, result);
        this.managedStates.updateStateValues(result.getStateMutationsList());
        int numBatched = (Integer)this.requestState.getOrDefault((Object)-1);
        if (numBatched < 0) {
            throw new IllegalStateException("Got an unexpected async result");
        }
        if (numBatched == 0) {
            this.requestState.clear();
        } else {
            ToFunction.InvocationBatchRequest.Builder nextBatch = this.getNextBatch();
            this.requestState.set((Object)0);
            this.batch.clear();
            context.functionTypeMetrics().consumeBacklogMessages(numBatched);
            this.sendToFunction(context, nextBatch);
        }
    }

    private ToFunction.InvocationBatchRequest.Builder getNextBatch() {
        ToFunction.InvocationBatchRequest.Builder builder = ToFunction.InvocationBatchRequest.newBuilder();
        Iterable view = this.batch.view();
        builder.addAllInvocations(view);
        return builder;
    }

    private ToFunction.InvocationBatchRequest.Builder createRetryBatch(ToFunction toFunction) {
        ToFunction.InvocationBatchRequest.Builder builder = ToFunction.InvocationBatchRequest.newBuilder();
        builder.addAllInvocations(toFunction.getInvocation().getInvocationsList());
        return builder;
    }

    private void handleEgressMessages(Context context, FromFunction.InvocationResponse invocationResult) {
        for (FromFunction.EgressMessage egressMessage : invocationResult.getOutgoingEgressesList()) {
            EgressIdentifier id = new EgressIdentifier(egressMessage.getEgressNamespace(), egressMessage.getEgressType(), TypedValue.class);
            context.send(id, (Object)egressMessage.getArgument());
        }
    }

    private void handleOutgoingMessages(Context context, FromFunction.InvocationResponse invocationResult) {
        for (FromFunction.Invocation invokeCommand : invocationResult.getOutgoingMessagesList()) {
            Address to = PolyglotUtil.polyglotAddressToSdkAddress(invokeCommand.getTarget());
            TypedValue message = invokeCommand.getArgument();
            context.send(to, (Object)message);
        }
    }

    private void handleOutgoingDelayedMessages(Context context, FromFunction.InvocationResponse invocationResult) {
        for (FromFunction.DelayedInvocation delayedInvokeCommand : invocationResult.getDelayedInvocationsList()) {
            if (delayedInvokeCommand.getIsCancellationRequest()) {
                this.handleDelayedMessageCancellation(context, delayedInvokeCommand);
                continue;
            }
            this.handleDelayedMessageSending(context, delayedInvokeCommand);
        }
    }

    private void handleDelayedMessageSending(Context context, FromFunction.DelayedInvocation delayedInvokeCommand) {
        Address to = PolyglotUtil.polyglotAddressToSdkAddress(delayedInvokeCommand.getTarget());
        TypedValue message = delayedInvokeCommand.getArgument();
        long delay = delayedInvokeCommand.getDelayInMs();
        String token = delayedInvokeCommand.getCancellationToken();
        Duration duration = Duration.ofMillis(delay);
        if (token.isEmpty()) {
            context.sendAfter(duration, to, (Object)message);
        } else {
            context.sendAfter(duration, to, (Object)message, token);
        }
    }

    private void handleDelayedMessageCancellation(Context context, FromFunction.DelayedInvocation delayedInvokeCommand) {
        String token = delayedInvokeCommand.getCancellationToken();
        if (token.isEmpty()) {
            throw new IllegalArgumentException("Can not handle a cancellation request without a cancellation token.");
        }
        context.cancelDelayedMessage(token);
    }

    private static ToFunction.Invocation.Builder singeInvocationBuilder(Context context, TypedValue message) {
        ToFunction.Invocation.Builder invocationBuilder = ToFunction.Invocation.newBuilder();
        if (context.caller() != null) {
            invocationBuilder.setCaller(PolyglotUtil.sdkAddressToPolyglotAddress(context.caller()));
        }
        invocationBuilder.setArgument(message);
        return invocationBuilder;
    }

    private void sendToFunction(InternalContext context, ToFunction.Invocation.Builder invocationBuilder) {
        ToFunction.InvocationBatchRequest.Builder batchBuilder = ToFunction.InvocationBatchRequest.newBuilder();
        batchBuilder.addInvocations(invocationBuilder);
        this.sendToFunction(context, batchBuilder);
    }

    private void sendToFunction(InternalContext context, ToFunction.InvocationBatchRequest.Builder batchBuilder) {
        batchBuilder.setTarget(PolyglotUtil.sdkAddressToPolyglotAddress(context.self()));
        this.managedStates.attachStateValues(batchBuilder);
        ToFunction toFunction = ToFunction.newBuilder().setInvocation(batchBuilder).build();
        this.sendToFunction(context, toFunction);
    }

    private void sendToFunction(InternalContext context, ToFunction toFunction) {
        ToFunctionRequestSummary requestSummary = new ToFunctionRequestSummary(context.self(), toFunction.getSerializedSize(), toFunction.getInvocation().getStateCount(), toFunction.getInvocation().getInvocationsCount());
        FunctionTypeMetrics metrics = context.functionTypeMetrics();
        CompletableFuture<FromFunction> responseFuture = this.client.call(requestSummary, metrics, toFunction);
        if (this.isFirstRequestSent) {
            context.registerAsyncOperation(toFunction, responseFuture);
        } else {
            LOG.info("Bootstrapping function {}. Blocking processing until first request is completed. Successive requests will be performed asynchronously.", (Object)this.functionType);
            this.isFirstRequestSent = true;
            this.onAsyncResult(context, this.joinResponse(responseFuture, toFunction));
        }
    }

    private boolean isMaxNumBatchRequestsExceeded(int currentNumBatchRequests) {
        return this.maxNumBatchRequests > 0 && currentNumBatchRequests >= this.maxNumBatchRequests;
    }

    private AsyncOperationResult<ToFunction, FromFunction> joinResponse(CompletableFuture<FromFunction> responseFuture, ToFunction originalRequest) {
        FromFunction response;
        try {
            response = responseFuture.join();
        }
        catch (Exception e) {
            return new AsyncOperationResult((Object)originalRequest, AsyncOperationResult.Status.FAILURE, null, e.getCause());
        }
        return new AsyncOperationResult((Object)originalRequest, AsyncOperationResult.Status.SUCCESS, (Object)response, null);
    }
}

