/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.matrix.data;

import java.lang.ref.SoftReference;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.util.FastMath;
import org.apache.sysds.runtime.compress.utils.Util;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class LibMatrixFourier {
    protected static final Log LOG = LogFactory.getLog((String)LibMatrixFourier.class.getName());
    static SoftReference<HashMap<Double, Double>> sinCacheRef = new SoftReference(new HashMap());
    static SoftReference<HashMap<Double, Double>> cosCacheRef = new SoftReference(new HashMap());

    public static MatrixBlock[] fft(MatrixBlock re, MatrixBlock im, int threads) {
        int rows = re.getNumRows();
        int cols = re.getNumColumns();
        if (!LibMatrixFourier.isPowerOfTwo(rows) || !LibMatrixFourier.isPowerOfTwo(cols)) {
            throw new RuntimeException("false dimensions");
        }
        MatrixBlock re_out = new MatrixBlock();
        re_out.copy(re, false);
        MatrixBlock im_out = new MatrixBlock();
        im_out.copy(im, false);
        LibMatrixFourier.fft(re_out.getDenseBlockValues(), im_out.getDenseBlockValues(), rows, cols, threads, true);
        re_out.recomputeNonZeros(threads);
        im_out.recomputeNonZeros(threads);
        return new MatrixBlock[]{re_out, im_out};
    }

    public static MatrixBlock[] ifft(MatrixBlock re, MatrixBlock im, int threads) {
        int rows = re.getNumRows();
        int cols = re.getNumColumns();
        if (!LibMatrixFourier.isPowerOfTwo(rows) || !LibMatrixFourier.isPowerOfTwo(cols)) {
            throw new RuntimeException("false dimensions");
        }
        MatrixBlock re_out = new MatrixBlock();
        re_out.copy(re, false);
        MatrixBlock im_out = new MatrixBlock();
        im_out.copy(im, false);
        LibMatrixFourier.ifft(re_out.getDenseBlockValues(), im_out.getDenseBlockValues(), rows, cols, threads, true);
        re_out.recomputeNonZeros(threads);
        im_out.recomputeNonZeros(threads);
        return new MatrixBlock[]{re_out, im_out};
    }

    public static MatrixBlock[] fft_linearized(MatrixBlock re, MatrixBlock im, int threads) {
        int rows = re.getNumRows();
        int cols = re.getNumColumns();
        if (!LibMatrixFourier.isPowerOfTwo(cols)) {
            throw new RuntimeException("false dimensions");
        }
        MatrixBlock re_out = new MatrixBlock();
        re_out.copy(re, false);
        MatrixBlock im_out = new MatrixBlock();
        im_out.copy(im, false);
        LibMatrixFourier.fft(re_out.getDenseBlockValues(), im_out.getDenseBlockValues(), rows, cols, threads, false);
        re_out.recomputeNonZeros(threads);
        im_out.recomputeNonZeros(threads);
        return new MatrixBlock[]{re_out, im_out};
    }

    public static MatrixBlock[] ifft_linearized(MatrixBlock re, MatrixBlock im, int threads) {
        int rows = re.getNumRows();
        int cols = re.getNumColumns();
        if (!LibMatrixFourier.isPowerOfTwo(cols)) {
            throw new RuntimeException("false dimensions");
        }
        MatrixBlock re_out = new MatrixBlock();
        re_out.copy(re, false);
        MatrixBlock im_out = new MatrixBlock();
        im_out.copy(im, false);
        LibMatrixFourier.ifft(re_out.getDenseBlockValues(), im_out.getDenseBlockValues(), rows, cols, threads, false);
        re_out.recomputeNonZeros(threads);
        im_out.recomputeNonZeros(threads);
        return new MatrixBlock[]{re_out, im_out};
    }

    public static void fft(double[] re, double[] im, int rows, int cols, int threads, boolean inclColCalc) {
        double[] re_inter = new double[rows * cols];
        double[] im_inter = new double[rows * cols];
        ExecutorService pool = CommonThreadPool.get(threads);
        try {
            int end;
            ArrayList tasks = new ArrayList();
            int rBlz = Math.max(rows / threads, 32);
            int cBlz = Math.max(cols / threads, 32);
            for (int i = 0; i < rows; i += rBlz) {
                int n = i;
                end = Math.min(i + rBlz, rows);
                tasks.add(pool.submit(() -> {
                    for (int j = start; j < end; ++j) {
                        LibMatrixFourier.fft_one_dim(re, im, re_inter, im_inter, j * cols, (j + 1) * cols, cols, 1);
                    }
                }));
            }
            for (Future future : tasks) {
                future.get();
            }
            tasks.clear();
            if (inclColCalc && rows > 1) {
                for (int j = 0; j < cols; j += cBlz) {
                    int n = j;
                    end = Math.min(j + cBlz, cols);
                    tasks.add(pool.submit(() -> {
                        for (int i = start; i < end; ++i) {
                            LibMatrixFourier.fft_one_dim(re, im, re_inter, im_inter, i, i + rows * cols, rows, cols);
                        }
                    }));
                }
                for (Future future : tasks) {
                    future.get();
                }
            }
        }
        catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException(e);
        }
        finally {
            pool.shutdown();
        }
    }

    public static void ifft(double[] re, double[] im, int rows, int cols, int threads, boolean inclColCalc) {
        double[] re_inter = new double[rows * cols];
        double[] im_inter = new double[rows * cols];
        ExecutorService pool = CommonThreadPool.get(threads);
        try {
            int end;
            ArrayList tasks = new ArrayList();
            int rBlz = Math.max(rows / threads, 32);
            int cBlz = Math.max(cols / threads, 32);
            if (inclColCalc && rows > 1) {
                for (int j = 0; j < cols; j += cBlz) {
                    int n = j;
                    end = Math.min(j + cBlz, cols);
                    tasks.add(pool.submit(() -> {
                        for (int i = start; i < end; ++i) {
                            LibMatrixFourier.ifft_one_dim(re, im, re_inter, im_inter, i, i + rows * cols, rows, cols);
                        }
                    }));
                }
                for (Future future : tasks) {
                    future.get();
                }
            }
            tasks.clear();
            for (int i = 0; i < rows; i += rBlz) {
                int n = i;
                end = Math.min(i + rBlz, rows);
                tasks.add(pool.submit(() -> {
                    for (int j = start; j < end; ++j) {
                        LibMatrixFourier.ifft_one_dim(re, im, re_inter, im_inter, j * cols, (j + 1) * cols, cols, 1);
                    }
                }));
            }
            for (Future future : tasks) {
                future.get();
            }
        }
        catch (InterruptedException | ExecutionException e) {
            throw new RuntimeException(e);
        }
        finally {
            pool.shutdown();
        }
    }

    public static void fft_one_dim(double[] re, double[] im, double[] re_inter, double[] im_inter, int start, int stop, int num, int minStep) {
        if (num == 1) {
            return;
        }
        int step = minStep * (num / 2);
        for (int subNum = 2; subNum <= num; subNum *= 2) {
            double angle = Math.PI * -2 / (double)subNum;
            LibMatrixFourier.fft_one_dim_iter_sub(re, im, re_inter, im_inter, start, stop, num, minStep, subNum, step, angle);
            step /= 2;
        }
    }

    private static void fft_one_dim_iter_sub(double[] re, double[] im, double[] re_inter, double[] im_inter, int start, int stop, int num, int minStep, int subNum, int step, double angle) {
        int sub = 0;
        while ((double)sub < FastMath.ceil((double)((double)num / (2.0 * (double)subNum)))) {
            int startSub = start + sub * minStep;
            LibMatrixFourier.fft_one_dim_sub(re, im, re_inter, im_inter, start, stop, startSub, subNum, step, angle);
            if (subNum == num) {
                return;
            }
            startSub = start + sub * minStep + step / 2;
            LibMatrixFourier.fft_one_dim_sub(re, im, re_inter, im_inter, start, stop, startSub, subNum, step, angle);
            ++sub;
        }
    }

    private static void fft_one_dim_sub(double[] re, double[] im, double[] re_inter, double[] im_inter, int start, int stop, int startSub, int subNum, int step, double angle) {
        HashMap<Double, Double> sinCache = sinCacheRef.get();
        HashMap<Double, Double> cosCache = cosCacheRef.get();
        if (sinCache == null) {
            sinCache = new HashMap();
        }
        if (cosCache == null) {
            cosCache = new HashMap();
        }
        int j = startSub;
        for (int cnt = 0; cnt < subNum / 2; ++cnt) {
            double omega_pow_re = LibMatrixFourier.cos((double)cnt * angle, cosCache);
            double omega_pow_im = LibMatrixFourier.sin((double)cnt * angle, sinCache);
            double m_re = omega_pow_re * re[j + step] - omega_pow_im * im[j + step];
            double m_im = omega_pow_re * im[j + step] + omega_pow_im * re[j + step];
            int index = startSub + cnt * step;
            re_inter[index] = re[j] + m_re;
            re_inter[index + (stop - start) / 2] = re[j] - m_re;
            im_inter[index] = im[j] + m_im;
            im_inter[index + (stop - start) / 2] = im[j] - m_im;
            j += 2 * step;
        }
        for (j = startSub; j < startSub + (stop - start); j += step) {
            re[j] = re_inter[j];
            im[j] = im_inter[j];
            re_inter[j] = 0.0;
            im_inter[j] = 0.0;
        }
    }

    private static void ifft_one_dim(double[] re, double[] im, double[] re_inter, double[] im_inter, int start, int stop, int num, int minStep) {
        int i;
        for (i = start; i < start + num * minStep; i += minStep) {
            im[i] = -im[i];
        }
        LibMatrixFourier.fft_one_dim(re, im, re_inter, im_inter, start, stop, num, minStep);
        for (i = start; i < start + num * minStep; i += minStep) {
            re[i] = re[i] / (double)num;
            im[i] = -im[i] / (double)num;
        }
    }

    public static boolean isPowerOfTwo(int n) {
        return n != 0 && (n & n - 1) == 0;
    }

    public static MatrixBlock[] fft(MatrixBlock re, int threads) {
        return LibMatrixFourier.fft(re, new MatrixBlock(re.getNumRows(), re.getNumColumns(), new double[re.getNumRows() * re.getNumColumns()]), threads);
    }

    public static MatrixBlock[] ifft(MatrixBlock re, int threads) {
        return LibMatrixFourier.ifft(re, new MatrixBlock(re.getNumRows(), re.getNumColumns(), new double[re.getNumRows() * re.getNumColumns()]), threads);
    }

    public static MatrixBlock[] fft_linearized(MatrixBlock re, int threads) {
        return LibMatrixFourier.fft_linearized(re, new MatrixBlock(re.getNumRows(), re.getNumColumns(), new double[re.getNumRows() * re.getNumColumns()]), threads);
    }

    public static MatrixBlock[] ifft_linearized(MatrixBlock re, int threads) {
        return LibMatrixFourier.ifft_linearized(re, new MatrixBlock(re.getNumRows(), re.getNumColumns(), new double[re.getNumRows() * re.getNumColumns()]), threads);
    }

    private static double sin(double angle, HashMap<Double, Double> cache) {
        double v = cache.getOrDefault(angle, -100.0);
        if (Util.eq(v, -100.0)) {
            double res = FastMath.sin((double)angle);
            if (cache.size() < 1000) {
                cache.put(angle, res);
            }
            return res;
        }
        return v;
    }

    private static double cos(double angle, HashMap<Double, Double> cache) {
        double v = cache.getOrDefault(angle, -100.0);
        if (Util.eq(v, -100.0)) {
            double res = FastMath.cos((double)angle);
            if (cache.size() < 1000) {
                cache.put(angle, res);
            }
            return res;
        }
        return v;
    }
}

