/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.util;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.util.CommonThreadPool;
import org.apache.sysds.runtime.util.DependencyTask;
import org.apache.sysds.runtime.util.DependencyWrapperTask;
import org.apache.sysds.utils.Explain;

public class DependencyThreadPool {
    protected static final Log LOG = LogFactory.getLog((String)DependencyThreadPool.class.getName());
    private final ExecutorService _pool;

    public DependencyThreadPool(int k) {
        this._pool = CommonThreadPool.get(k);
    }

    public void shutdown() {
        this._pool.shutdown();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public List<Future<Future<?>>> submitAll(List<DependencyTask<?>> dtasks) {
        ArrayList futures = new ArrayList();
        ArrayList<Integer> rdyTasks = new ArrayList<Integer>();
        int i = 0;
        Collections.sort(dtasks);
        for (DependencyTask<?> t : dtasks) {
            CompletableFuture f = new CompletableFuture();
            t.addPool(this._pool);
            if (!t.isReady()) {
                t.assignFuture(f);
            } else {
                rdyTasks.add(i);
            }
            futures.add(f);
            ++i;
        }
        LOG.debug((Object)("Initial Starting tasks: \n\t" + rdyTasks.stream().map(index -> ((DependencyTask)dtasks.get((int)index)).toString()).collect(Collectors.joining("\n\t"))));
        for (Integer index2 : rdyTasks) {
            ExecutorService executorService = this._pool;
            synchronized (executorService) {
                ((CompletableFuture)futures.get(index2)).complete(this._pool.submit((Callable)dtasks.get(index2)));
            }
        }
        return futures;
    }

    public List<Future<Future<?>>> submitAll(List<? extends Callable<?>> tasks, List<List<? extends Callable<?>>> dependencies) {
        List<DependencyTask<?>> dtasks = DependencyThreadPool.createDependencyTasks(tasks, dependencies);
        return this.submitAll(dtasks);
    }

    public List<Object> submitAllAndWait(List<DependencyTask<?>> dtasks) throws ExecutionException, InterruptedException {
        ArrayList<Object> res = new ArrayList<Object>();
        if (LOG.isDebugEnabled() && dtasks != null && dtasks.size() > 0) {
            DependencyThreadPool.explainTaskGraph(dtasks);
        }
        List<Future<Future<?>>> futures = this.submitAll(dtasks);
        int i = 0;
        for (Future<Future<?>> ff : futures) {
            if (dtasks.get(i) instanceof DependencyWrapperTask) {
                for (Future<Future<?>> f : ((DependencyWrapperTask)dtasks.get(i)).getWrappedTaskFuture()) {
                    res.add(f.get().get());
                }
            } else {
                res.add(ff.get().get());
            }
            ++i;
        }
        return res;
    }

    public static DependencyTask<?> createDependencyTask(Callable<?> task) {
        return new DependencyTask(task, new ArrayList());
    }

    public static List<List<? extends Callable<?>>> createDependencyList(List<? extends Callable<?>> tasks, Map<Integer[], Integer[]> depMap, List<List<? extends Callable<?>>> dep) {
        if (depMap != null) {
            depMap.forEach((ti, di) -> {
                ti[0] = ti[0] < 0 ? dep.size() + ti[0] + 1 : ti[0];
                ti[1] = ti[1] < 0 ? dep.size() + ti[1] + 1 : ti[1];
                di[0] = di[0] < 0 ? tasks.size() + di[0] + 1 : di[0];
                di[1] = di[1] < 0 ? tasks.size() + di[1] + 1 : di[1];
                for (int r = ti[0].intValue(); r < ti[1]; ++r) {
                    if (dep.get(r) == null) {
                        dep.set(r, tasks.subList(di[0], di[1]));
                        continue;
                    }
                    dep.set(r, Stream.concat(((List)dep.get(r)).stream(), tasks.subList(di[0], di[1]).stream()).collect(Collectors.toList()));
                }
            });
        }
        return dep;
    }

    public static List<DependencyTask<?>> createDependencyTasks(List<? extends Callable<?>> tasks, List<List<? extends Callable<?>>> dependencies) {
        if (dependencies != null && tasks.size() != dependencies.size()) {
            throw new DMLRuntimeException("Could not create DependencyTasks since the input array sizes are mismatching");
        }
        ArrayList ret = new ArrayList();
        HashMap map = new HashMap();
        for (Callable<?> task : tasks) {
            DependencyTask dt = task instanceof DependencyTask ? (DependencyTask)task : new DependencyTask(task, new ArrayList());
            ret.add(dt);
            map.put(task, dt);
        }
        if (dependencies == null) {
            return ret;
        }
        for (int i = 0; i < tasks.size(); ++i) {
            List<Callable<?>> deps = dependencies.get(i);
            if (deps == null) continue;
            DependencyTask t = (DependencyTask)ret.get(i);
            for (Callable<?> dep : deps) {
                DependencyTask dt = (DependencyTask)map.get(dep);
                if (LOG.isDebugEnabled()) {
                    t._dependencyTasks = t._dependencyTasks == null ? new ArrayList() : t._dependencyTasks;
                    t._dependencyTasks.add(dt);
                }
                if (dt == null) continue;
                dt.addDependent(t);
            }
        }
        return ret;
    }

    public static void explainTaskGraph(List<DependencyTask<?>> tasks) {
        HashMap levelMap = new HashMap();
        int depth = 1;
        while (levelMap.size() < tasks.size()) {
            for (int i = 0; i < tasks.size(); ++i) {
                DependencyTask<?> dt = tasks.get(i);
                if (dt._dependencyTasks == null || dt._dependencyTasks.size() == 0) {
                    levelMap.put(dt, 0);
                }
                if (dt._dependencyTasks == null) continue;
                List<DependencyTask<?>> parents = dt._dependencyTasks;
                int[] parentLevels = new int[parents.size()];
                boolean missing = false;
                for (int p = 0; p < parents.size(); ++p) {
                    if (!levelMap.containsKey(parents.get(p))) {
                        missing = true;
                        continue;
                    }
                    parentLevels[p] = (Integer)levelMap.get(parents.get(p));
                }
                if (missing) continue;
                int maxParentLevel = Arrays.stream(parentLevels).max().getAsInt();
                levelMap.put(dt, maxParentLevel + 1);
                if (maxParentLevel + 1 != depth) continue;
                ++depth;
            }
        }
        StringBuilder[] sbs = new StringBuilder[depth];
        String[] offsets = new String[depth];
        for (Map.Entry entry : levelMap.entrySet()) {
            int level = (Integer)entry.getValue();
            if (sbs[level] == null) {
                sbs[level] = new StringBuilder();
                offsets[level] = Explain.createOffset(level);
            }
            sbs[level].append(offsets[level]);
            sbs[level].append(((DependencyTask)entry.getKey()).toString() + "\n");
        }
        StringBuilder sb = new StringBuilder("\n");
        sb.append("EXPlAIN (TASK-GRAPH):");
        for (int i = 0; i < sbs.length; ++i) {
            sb.append(sbs[i].toString());
        }
        LOG.debug((Object)sb.toString());
    }
}

