/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.common;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.tez.common.IdUtils;
import org.apache.tez.common.RssTezConfig;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.runtime.library.common.InputAttemptIdentifier;
import org.apache.tez.runtime.library.input.ConcatenatedMergedKeyValueInput;
import org.apache.tez.runtime.library.input.ConcatenatedMergedKeyValuesInput;
import org.apache.tez.runtime.library.input.OrderedGroupedInputLegacy;
import org.apache.tez.runtime.library.input.OrderedGroupedKVInput;
import org.apache.tez.runtime.library.input.OrderedGroupedMergedKVInput;
import org.apache.tez.runtime.library.input.RssConcatenatedMergedKeyValueInput;
import org.apache.tez.runtime.library.input.RssConcatenatedMergedKeyValuesInput;
import org.apache.tez.runtime.library.input.RssOrderedGroupedInputLegacy;
import org.apache.tez.runtime.library.input.RssOrderedGroupedKVInput;
import org.apache.tez.runtime.library.input.RssOrderedGroupedMergedKVInput;
import org.apache.tez.runtime.library.input.RssUnorderedKVInput;
import org.apache.tez.runtime.library.input.UnorderedKVInput;
import org.apache.tez.runtime.library.output.OrderedPartitionedKVOutput;
import org.apache.tez.runtime.library.output.RssOrderedPartitionedKVOutput;
import org.apache.tez.runtime.library.output.RssUnorderedKVOutput;
import org.apache.tez.runtime.library.output.RssUnorderedPartitionedKVOutput;
import org.apache.tez.runtime.library.output.UnorderedKVOutput;
import org.apache.tez.runtime.library.output.UnorderedPartitionedKVOutput;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.ShuffleClientFactory;
import org.apache.uniffle.com.google.common.base.Preconditions;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RssTezUtils {
    private static final Logger LOG = LoggerFactory.getLogger(RssTezUtils.class);
    private static final BlockIdLayout LAYOUT = BlockIdLayout.DEFAULT;
    private static final int MAX_ATTEMPT_LENGTH = 6;
    private static final int MAX_ATTEMPT_ID = 63;
    private static final int MAX_SEQUENCE_NO = (1 << RssTezUtils.LAYOUT.sequenceNoBits - 6) - 1;
    public static final String HOST_NAME = "hostname";
    public static final String UNDERLINE_DELIMITER = "_";
    private static final int VERTEX_ID_MAPPING_MAX_ID = 500;
    private static final String VERTEX_ID_MAPPING_MAP = "Map";
    private static final String VERTEX_ID_MAPPING_REDUCER = "Reducer";
    private static final int VERTEX_ID_MAPPING_MAGIC = 600;
    private static final int SHUFFLE_ID_MAGIC = 1000;

    private RssTezUtils() {
    }

    public static ShuffleWriteClient createShuffleClient(Configuration conf) {
        int heartBeatThreadNum = conf.getInt("tez.rss.client.heartBeat.threadNum", 4);
        int retryMax = conf.getInt("tez.rss.client.retry.max", 50);
        long retryIntervalMax = conf.getLong("tez.rss.client.retry.interval.max", 10000L);
        String clientType = conf.get("tez.rss.client.type", "GRPC");
        int replicaWrite = conf.getInt("tez.rss.data.replica.write", 1);
        int replicaRead = conf.getInt("tez.rss.data.replica.read", 1);
        int replica = conf.getInt("tez.rss.data.replica", 1);
        boolean replicaSkipEnabled = conf.getBoolean("tez.rss.data.replica.skip.enabled", true);
        int dataTransferPoolSize = conf.getInt("tez.rss.client.data.transfer.pool.size", RssTezConfig.RSS_DATA_TRANSFER_POOL_SIZE_DEFAULT_VALUE);
        int dataCommitPoolSize = conf.getInt("tez.rss.client.data.commit.pool.size", -1);
        ShuffleWriteClient client = ShuffleClientFactory.getInstance().createShuffleWriteClient(ShuffleClientFactory.newWriteBuilder().clientType(clientType).retryMax(retryMax).retryIntervalMax(retryIntervalMax).heartBeatThreadNum(heartBeatThreadNum).replica(replica).replicaWrite(replicaWrite).replicaRead(replicaRead).replicaSkipEnabled(replicaSkipEnabled).dataTransferPoolSize(dataTransferPoolSize).dataCommitPoolSize(dataCommitPoolSize));
        return client;
    }

    public static long getInitialMemoryRequirement(Configuration conf, long maxAvailableTaskMemory) {
        long initialMemRequestMb = conf.getLong("tez.rss.runtime.io.sort.mb", 100L);
        LOG.info("InitialMemRequestMb is {}", (Object)initialMemRequestMb);
        LOG.info("MaxAvailableTaskMemory is {}", (Object)maxAvailableTaskMemory);
        long reqBytes = initialMemRequestMb << 20;
        Preconditions.checkArgument(initialMemRequestMb > 0L && reqBytes < maxAvailableTaskMemory, "tez.rss.runtime.io.sort.mb" + initialMemRequestMb + " should be larger than 0 and should be less than the available task memory (MB):" + (maxAvailableTaskMemory >> 20));
        LOG.info("Requested BufferSize (tez.runtime.unordered.output.buffer.size-mb) : " + initialMemRequestMb);
        return reqBytes;
    }

    public static String uniqueIdentifierToAttemptId(String uniqueIdentifier) {
        if (uniqueIdentifier == null) {
            throw new RssException("uniqueIdentifier should not be null");
        }
        Object[] ids = uniqueIdentifier.split(UNDERLINE_DELIMITER);
        return StringUtils.join((Object[])ids, (String)UNDERLINE_DELIMITER, (int)0, (int)7);
    }

    public static long getBlockId(int partitionId, long taskAttemptId, int nextSeqNo) {
        LOG.info("GetBlockId, partitionId:{}, taskAttemptId:{}, nextSeqNo:{}", new Object[]{partitionId, taskAttemptId, nextSeqNo});
        long attemptId = taskAttemptId >> RssTezUtils.LAYOUT.partitionIdBits + RssTezUtils.LAYOUT.taskAttemptIdBits;
        if (attemptId < 0L || attemptId > 63L) {
            throw new RssException("Can't support attemptId [" + attemptId + "], the max value should be " + 63);
        }
        if (nextSeqNo < 0 || nextSeqNo > MAX_SEQUENCE_NO) {
            throw new RssException("Can't support sequence [" + nextSeqNo + "], the max value should be " + MAX_SEQUENCE_NO);
        }
        int atomicInt = (int)((long)(nextSeqNo << 6) + attemptId);
        long taskId = taskAttemptId - (attemptId << RssTezUtils.LAYOUT.partitionIdBits + RssTezUtils.LAYOUT.taskAttemptIdBits);
        return LAYOUT.getBlockId(atomicInt, partitionId, taskId);
    }

    public static long getTaskAttemptId(long blockId) {
        int mapId = LAYOUT.getTaskAttemptId(blockId);
        int attemptId = LAYOUT.getSequenceNo(blockId) & 0x3F;
        return LAYOUT.getBlockId(attemptId, 0, mapId);
    }

    public static int estimateTaskConcurrency(Configuration jobConf, int mapNum, int reduceNum) {
        int estimateReduceNum;
        double dynamicFactor = jobConf.getDouble("tez.rss.estimate.task.concurrency.dynamic.factor", 1.0);
        double slowStart = jobConf.getDouble("mapreduce.job.reduce.slowstart.completedmaps", 0.05);
        int mapLimit = jobConf.getInt("mapreduce.job.running.map.limit", 0);
        int reduceLimit = jobConf.getInt("mapreduce.job.running.reduce.limit", 0);
        int estimateMapNum = mapLimit > 0 ? Math.min(mapNum, mapLimit) : mapNum;
        int n = estimateReduceNum = reduceLimit > 0 ? Math.min(reduceNum, reduceLimit) : reduceNum;
        if (slowStart == 1.0) {
            return (int)((double)Math.max(estimateMapNum, estimateReduceNum) * dynamicFactor);
        }
        return (int)(((1.0 - slowStart) * (double)estimateMapNum + (double)estimateReduceNum) * dynamicFactor);
    }

    public static int getRequiredShuffleServerNumber(Configuration jobConf, int mapNum, int reduceNum) {
        int requiredShuffleServerNumber = jobConf.getInt("rss.client.assignment.shuffle.nodes.max", -1);
        boolean enabledEstimateServer = jobConf.getBoolean("tez.rss.estimate.server.assignment.enabled", false);
        if (!enabledEstimateServer || requiredShuffleServerNumber > 0) {
            return requiredShuffleServerNumber;
        }
        int taskConcurrency = RssTezUtils.estimateTaskConcurrency(jobConf, mapNum, reduceNum);
        int taskConcurrencyPerServer = jobConf.getInt("tez.rss.estimate.task.concurrency.per.server", 80);
        return (int)Math.ceil((double)taskConcurrency * 1.0 / (double)taskConcurrencyPerServer);
    }

    public static int computeShuffleId(int tezDagID, int upVertexId, int downVertexId) {
        int shuffleId = tezDagID * 1000000 + upVertexId * 1000 + downVertexId;
        LOG.info("Compute Shuffle Id:{}, up vertex id:{}, down vertex id:{}", new Object[]{shuffleId, upVertexId, downVertexId});
        return shuffleId;
    }

    public static int parseDagId(int shuffleId) {
        Preconditions.checkArgument(shuffleId > 0, "shuffleId should be positive.");
        int dagId = shuffleId / 1000000;
        if (dagId == 0) {
            throw new RssException("Illegal shuffleId: " + shuffleId);
        }
        return dagId;
    }

    private static int mapVertexId(String vertexName) {
        String[] ss = vertexName.split("\\s+");
        if (Integer.parseInt(ss[1]) > 500) {
            throw new RssException("Too large vertex name to id mapping, vertexName:" + vertexName);
        }
        if (VERTEX_ID_MAPPING_MAP.equals(ss[0])) {
            return Integer.parseInt(ss[1]);
        }
        if (VERTEX_ID_MAPPING_REDUCER.equals(ss[0])) {
            return 600 + Integer.parseInt(ss[1]);
        }
        throw new RssException("Wrong vertex name to id mapping, vertexName:" + vertexName);
    }

    public static long convertTaskAttemptIdToLong(TezTaskAttemptID taskAttemptID) {
        int lowBytes = taskAttemptID.getTaskID().getId();
        if (lowBytes > RssTezUtils.LAYOUT.maxTaskAttemptId) {
            throw new RssException("TaskAttempt " + taskAttemptID + " low bytes " + lowBytes + " exceed");
        }
        int highBytes = taskAttemptID.getId();
        if (highBytes > 63 || highBytes < 0) {
            throw new RssException("TaskAttempt " + taskAttemptID + " high bytes " + highBytes + " exceed.");
        }
        long id = LAYOUT.getBlockId(highBytes, 0, lowBytes);
        LOG.info("ConvertTaskAttemptIdToLong taskAttemptID:{}, id is {}, .", (Object)taskAttemptID, (Object)id);
        return id;
    }

    public static Roaring64NavigableMap fetchAllRssTaskIds(Set<InputAttemptIdentifier> successMapTaskAttempts, int totalMapsCount, int appAttemptId) {
        String errMsg = "TaskAttemptIDs are inconsistent with map tasks";
        Roaring64NavigableMap rssTaskIdBitmap = Roaring64NavigableMap.bitmapOf(new long[0]);
        Roaring64NavigableMap mapTaskIdBitmap = Roaring64NavigableMap.bitmapOf(new long[0]);
        LOG.info("FetchAllRssTaskIds successMapTaskAttempts size:{}", (Object)successMapTaskAttempts.size());
        LOG.info("FetchAllRssTaskIds totalMapsCount:{}, appAttemptId:{}", (Object)totalMapsCount, (Object)appAttemptId);
        for (InputAttemptIdentifier inputAttemptIdentifier : successMapTaskAttempts) {
            String pathComponent = inputAttemptIdentifier.getPathComponent();
            TezTaskAttemptID mapTaskAttemptID = IdUtils.convertTezTaskAttemptID(pathComponent);
            long rssTaskId = RssTezUtils.convertTaskAttemptIdToLong(mapTaskAttemptID);
            long mapTaskId = mapTaskAttemptID.getTaskID().getId();
            LOG.info("FetchAllRssTaskIds, pathComponent: {}, mapTaskId:{}, rssTaskId:{}, is contains:{}", new Object[]{pathComponent, mapTaskId, rssTaskId, mapTaskIdBitmap.contains(mapTaskId)});
            if (!mapTaskIdBitmap.contains(mapTaskId)) {
                rssTaskIdBitmap.addLong(rssTaskId);
                mapTaskIdBitmap.addLong(mapTaskId);
                if (mapTaskId < (long)totalMapsCount) continue;
                LOG.warn(inputAttemptIdentifier + " has overflowed mapIndex, pathComponent: " + pathComponent + ",totalMapsCount: " + totalMapsCount);
                continue;
            }
            LOG.warn(inputAttemptIdentifier + " is redundant on index: " + mapTaskId);
        }
        if (mapTaskIdBitmap.getLongCardinality() != rssTaskIdBitmap.getLongCardinality()) {
            throw new IllegalStateException(errMsg);
        }
        return rssTaskIdBitmap;
    }

    public static int taskIdStrToTaskId(String taskIdStr) {
        try {
            int pos1 = taskIdStr.indexOf(UNDERLINE_DELIMITER);
            int pos2 = taskIdStr.indexOf(UNDERLINE_DELIMITER, pos1 + 1);
            int pos3 = taskIdStr.indexOf(UNDERLINE_DELIMITER, pos2 + 1);
            int pos4 = taskIdStr.indexOf(UNDERLINE_DELIMITER, pos3 + 1);
            int pos5 = taskIdStr.indexOf(UNDERLINE_DELIMITER, pos4 + 1);
            int pos6 = taskIdStr.indexOf(UNDERLINE_DELIMITER, pos5 + 1);
            return Integer.parseInt(taskIdStr.substring(pos5 + 1, pos6));
        }
        catch (Exception e) {
            e.printStackTrace();
            LOG.error("Failed to get VertexId, taskId:{}.", (Object)taskIdStr, (Object)e);
            throw e;
        }
    }

    private static void parseRssWorkerFromHostInfo(Map<Integer, Set<ShuffleServerInfo>> rssWorker, String multiHostInfo) {
        for (String hostInfo : multiHostInfo.split(",")) {
            String[] info = hostInfo.split("\\+");
            ShuffleServerInfo serverInfo = new ShuffleServerInfo(info[0].split(":")[0], Integer.parseInt(info[0].split(":")[1]));
            String[] partitions = info[1].split(UNDERLINE_DELIMITER);
            assert (partitions.length > 0);
            for (String partitionId : partitions) {
                rssWorker.computeIfAbsent(Integer.parseInt(partitionId), k -> new HashSet());
                rssWorker.get(Integer.parseInt(partitionId)).add(serverInfo);
            }
        }
    }

    public static void parseRssWorker(Map<Integer, Set<ShuffleServerInfo>> rssWorker, int shuffleId, String hostnameInfo) {
        LOG.info("ParseRssWorker, hostnameInfo length:{}", (Object)hostnameInfo.length());
        for (String toVertex : hostnameInfo.split(";")) {
            String[] splits = toVertex.split("=");
            if (splits.length != 2 || !String.valueOf(shuffleId).equals(splits[0])) continue;
            String workerStr = splits[1];
            RssTezUtils.parseRssWorkerFromHostInfo(rssWorker, workerStr);
        }
    }

    public static String replaceRssOutputClassName(String className) {
        if (className.equals(OrderedPartitionedKVOutput.class.getName())) {
            LOG.info("Output class name will transient from {} to {}", (Object)className, (Object)RssOrderedPartitionedKVOutput.class.getName());
            return RssOrderedPartitionedKVOutput.class.getName();
        }
        if (className.equals(UnorderedKVOutput.class.getName())) {
            LOG.info("Output class name will transient from {} to {}", (Object)className, (Object)RssUnorderedKVOutput.class.getName());
            return RssUnorderedKVOutput.class.getName();
        }
        if (className.equals(UnorderedPartitionedKVOutput.class.getName())) {
            LOG.info("Output class name will transient from {} to {}", (Object)className, (Object)RssUnorderedPartitionedKVOutput.class.getName());
            return RssUnorderedPartitionedKVOutput.class.getName();
        }
        LOG.info("Unexpected kv output class name {}.", (Object)className);
        return className;
    }

    public static String replaceRssInputClassName(String className) {
        if (className.equals(OrderedGroupedKVInput.class.getName())) {
            LOG.info("Input class name will transient from {} to {}", (Object)className, (Object)RssOrderedGroupedKVInput.class.getName());
            return RssOrderedGroupedKVInput.class.getName();
        }
        if (className.equals(OrderedGroupedMergedKVInput.class.getName())) {
            LOG.info("Input class name will transient from {} to {}", (Object)className, (Object)RssOrderedGroupedMergedKVInput.class.getName());
            return RssOrderedGroupedMergedKVInput.class.getName();
        }
        if (className.equals(OrderedGroupedInputLegacy.class.getName())) {
            LOG.info("Input class name will transient from {} to {}", (Object)className, (Object)RssOrderedGroupedInputLegacy.class.getName());
            return RssOrderedGroupedInputLegacy.class.getName();
        }
        if (className.equals(UnorderedKVInput.class.getName())) {
            LOG.info("Input class name will transient from {} to {}", (Object)className, (Object)RssUnorderedKVInput.class.getName());
            return RssUnorderedKVInput.class.getName();
        }
        if (className.equals(ConcatenatedMergedKeyValueInput.class.getName())) {
            LOG.info("Input class name will transient from {} to {}", (Object)className, (Object)RssConcatenatedMergedKeyValueInput.class.getName());
            return RssConcatenatedMergedKeyValueInput.class.getName();
        }
        if (className.equals(ConcatenatedMergedKeyValuesInput.class.getName())) {
            LOG.info("Input class name will transient from {} to {}", (Object)className, (Object)RssConcatenatedMergedKeyValuesInput.class.getName());
            return RssConcatenatedMergedKeyValuesInput.class.getName();
        }
        LOG.info("Unexpected kv input class name {}.", (Object)className);
        return className;
    }

    public static void applyDynamicClientConf(Configuration conf, Map<String, String> confItems) {
        if (conf == null) {
            LOG.warn("Tez conf is null");
            return;
        }
        if (confItems == null || confItems.isEmpty()) {
            LOG.warn("Empty conf items");
            return;
        }
        for (Map.Entry<String, String> kv : confItems.entrySet()) {
            String tezConfKey = kv.getKey();
            if (!tezConfKey.startsWith("tez.")) {
                tezConfKey = "tez." + tezConfKey;
            }
            String tezConfVal = kv.getValue();
            if (!StringUtils.isEmpty((CharSequence)conf.get(tezConfKey, "")) && !RssTezConfig.RSS_MANDATORY_CLUSTER_CONF.contains(tezConfKey)) continue;
            LOG.warn("Use conf dynamic conf {} = {}", (Object)tezConfKey, (Object)tezConfVal);
            conf.set(tezConfKey, tezConfVal);
        }
    }

    public static Configuration filterRssConf(Configuration extraConf) {
        Configuration conf = new Configuration(false);
        for (Map.Entry entry : extraConf) {
            String key = (String)entry.getKey();
            if (!key.startsWith("tez.")) continue;
            conf.set((String)entry.getKey(), (String)entry.getValue());
        }
        return conf;
    }
}

