/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tez.runtime.library.common.sort.impl;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.yarn.api.records.ApplicationAttemptId;
import org.apache.tez.common.RssTezUtils;
import org.apache.tez.dag.records.TezTaskAttemptID;
import org.apache.tez.runtime.api.OutputContext;
import org.apache.tez.runtime.library.common.sort.buffer.WriteBufferManager;
import org.apache.tez.runtime.library.common.sort.impl.ExternalSorter;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.ByteUnit;
import org.apache.uniffle.shaded.com.google.common.collect.Sets;
import org.apache.uniffle.storage.util.StorageType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RssSorter
extends ExternalSorter {
    private static final Logger LOG = LoggerFactory.getLogger(RssSorter.class);
    private WriteBufferManager bufferManager;
    private Set<Long> successBlockIds = Sets.newConcurrentHashSet();
    private Set<Long> failedBlockIds = Sets.newConcurrentHashSet();
    private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
    private int[] numRecordsPerPartition;

    public RssSorter(TezTaskAttemptID tezTaskAttemptID, OutputContext outputContext, Configuration conf, int numMaps, int numOutputs, long initialMemoryAvailable, int shuffleId, ApplicationAttemptId applicationAttemptId, Map<Integer, List<ShuffleServerInfo>> partitionToServers, long taskAttemptId) throws IOException {
        super(outputContext, conf, numOutputs, initialMemoryAvailable);
        this.partitionToServers = partitionToServers;
        this.numRecordsPerPartition = new int[numOutputs];
        long sortmb = conf.getLong("tez.rss.runtime.io.sort.mb", 100L);
        LOG.info("conf.sortmb is {}", (Object)sortmb);
        sortmb = this.availableMemoryMb;
        LOG.info("sortmb, availableMemoryMb is {}, {}", (Object)sortmb, (Object)this.availableMemoryMb);
        if ((sortmb & 0x7FFL) != sortmb) {
            throw new IOException("Invalid \"tez.rss.runtime.io.sort.mb\": " + sortmb);
        }
        double sortThreshold = conf.getDouble("tez.rss.client.sort.memory.use.threshold", (double)0.9f);
        long maxSegmentSize = conf.getLong("tez.rss.client.max.buffer.size", 3072L);
        long maxBufferSize = conf.getLong("tez.rss.writer.buffer.size", 0xE00000L);
        double memoryThreshold = conf.getDouble("tez.rss.client.memory.threshold", (double)0.8f);
        int sendThreadNum = conf.getInt("tez.rss.client.send.thread.num", 5);
        double sendThreshold = conf.getDouble("tez.rss.client.send.threshold", (double)0.2f);
        int batch = conf.getInt("tez.rss.client.batch.trigger.num", 50);
        String storageType = conf.get("tez.rss.storage.type", "MEMORY");
        if (StringUtils.isEmpty((CharSequence)storageType)) {
            throw new RssException("storage type mustn't be empty");
        }
        long sendCheckInterval = conf.getLong("tez.rss.client.send.check.interval.ms", 500L);
        long sendCheckTimeout = conf.getLong("tez.rss.client.send.check.timeout.ms", 600000L);
        int bitmapSplitNum = conf.getInt("tez.rss.client.bitmap.num", 1);
        if (conf.get("hive.tez.log.level", "INFO").equalsIgnoreCase("debug")) {
            LOG.info("sortmb is {}", (Object)sortmb);
            LOG.info("sortThreshold is {}", (Object)sortThreshold);
            LOG.info("taskAttemptId is {}", (Object)taskAttemptId);
            LOG.info("maxSegmentSize is {}", (Object)maxSegmentSize);
            LOG.info("maxBufferSize is {}", (Object)maxBufferSize);
            LOG.info("memoryThreshold is {}", (Object)memoryThreshold);
            LOG.info("sendThreadNum is {}", (Object)sendThreadNum);
            LOG.info("sendThreshold is {}", (Object)sendThreshold);
            LOG.info("batch is {}", (Object)batch);
            LOG.info("storageType is {}", (Object)storageType);
            LOG.info("sendCheckInterval is {}", (Object)sendCheckInterval);
            LOG.info("sendCheckTimeout is {}", (Object)sendCheckTimeout);
            LOG.info("bitmapSplitNum is {}", (Object)bitmapSplitNum);
        }
        LOG.info("applicationAttemptId is {}", (Object)applicationAttemptId.toString());
        this.bufferManager = new WriteBufferManager(tezTaskAttemptID, (long)(ByteUnit.MiB.toBytes(sortmb) * sortThreshold), applicationAttemptId.toString(), taskAttemptId, this.successBlockIds, this.failedBlockIds, RssTezUtils.createShuffleClient(conf), this.comparator, maxSegmentSize, this.keySerializer, this.valSerializer, maxBufferSize, memoryThreshold, sendThreadNum, sendThreshold, batch, new RssConf(), partitionToServers, numMaps, this.isMemoryShuffleEnabled(storageType), sendCheckInterval, sendCheckTimeout, bitmapSplitNum, shuffleId, true, this.mapOutputByteCounter, this.mapOutputRecordCounter);
        LOG.info("Initialized WriteBufferManager.");
    }

    public void flush() throws IOException {
        this.bufferManager.waitSendFinished();
    }

    public final void close() throws IOException {
        super.close();
        this.bufferManager.freeAllResources();
    }

    public void write(Object key, Object value) throws IOException {
        try {
            this.collect(key, value, this.partitioner.getPartition(key, value, this.partitions));
        }
        catch (InterruptedException e) {
            throw new RssException(e);
        }
    }

    synchronized void collect(Object key, Object value, int partition) throws IOException, InterruptedException {
        if (key.getClass() != this.keyClass) {
            throw new IOException("Type mismatch in key from map: expected " + this.keyClass.getName() + ", received " + key.getClass().getName());
        }
        if (value.getClass() != this.valClass) {
            throw new IOException("Type mismatch in value from map: expected " + this.valClass.getName() + ", received " + value.getClass().getName());
        }
        if (partition < 0 || partition >= this.partitions) {
            throw new IOException("Illegal partition for " + key + " (" + partition + ")");
        }
        this.bufferManager.addRecord(partition, key, value);
        int n = partition;
        this.numRecordsPerPartition[n] = this.numRecordsPerPartition[n] + 1;
    }

    public int[] getNumRecordsPerPartition() {
        return this.numRecordsPerPartition;
    }

    private boolean isMemoryShuffleEnabled(String storageType) {
        return StorageType.withMemory(StorageType.valueOf(storageType));
    }
}

