/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.ml.feature.hashingtf;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.Transformer;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.feature.hashingtf.HashingTFParams;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.shaded.guava30.com.google.common.hash.HashFunction;
import org.apache.flink.shaded.guava30.com.google.common.hash.Hashing;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.table.catalog.ResolvedSchema;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;

public class HashingTF
implements Transformer<HashingTF>,
HashingTFParams<HashingTF> {
    private final Map<Param<?>, Object> paramMap = new HashMap();
    private static final HashFunction HASH_FUNC = Hashing.murmur3_32((int)0);

    public HashingTF() {
        ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
    }

    @Override
    public Table[] transform(Table ... inputs) {
        Preconditions.checkArgument((inputs.length == 1 ? 1 : 0) != 0);
        StreamTableEnvironment tEnv = (StreamTableEnvironment)((TableImpl)inputs[0]).getTableEnvironment();
        ResolvedSchema tableSchema = inputs[0].getResolvedSchema();
        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(tableSchema);
        RowTypeInfo outputTypeInfo = new RowTypeInfo((TypeInformation[])ArrayUtils.addAll((Object[])inputTypeInfo.getFieldTypes(), (Object[])new TypeInformation[]{SparseVectorTypeInfo.INSTANCE}), (String[])ArrayUtils.addAll((Object[])inputTypeInfo.getFieldNames(), (Object[])new String[]{this.getOutputCol()}));
        SingleOutputStreamOperator output = tEnv.toDataStream(inputs[0]).map((MapFunction)new HashTFFunction(this.getInputCol(), this.getBinary(), this.getNumFeatures()), (TypeInformation)outputTypeInfo);
        return new Table[]{tEnv.fromDataStream((DataStream)output)};
    }

    @Override
    public void save(String path) throws IOException {
        ReadWriteUtils.saveMetadata(this, path);
    }

    @Override
    public Map<Param<?>, Object> getParamMap() {
        return this.paramMap;
    }

    public static HashingTF load(StreamTableEnvironment tEnv, String path) throws IOException {
        return (HashingTF)ReadWriteUtils.loadStageParam(path);
    }

    private static int hash(Object obj) {
        if (obj == null) {
            return 0;
        }
        if (obj instanceof Boolean) {
            int value = (Boolean)obj != false ? 1 : 0;
            return HASH_FUNC.hashInt(value).asInt();
        }
        if (obj instanceof Byte) {
            byte value = (Byte)obj;
            return HASH_FUNC.hashInt((int)value).asInt();
        }
        if (obj instanceof Short) {
            short value = (Short)obj;
            return HASH_FUNC.hashInt((int)value).asInt();
        }
        if (obj instanceof Integer) {
            int value = (Integer)obj;
            return HASH_FUNC.hashInt(value).asInt();
        }
        if (obj instanceof Long) {
            long value = (Long)obj;
            return HASH_FUNC.hashLong(value).asInt();
        }
        if (obj instanceof Float) {
            float value = ((Float)obj).floatValue();
            return HASH_FUNC.hashInt(Float.floatToIntBits(value)).asInt();
        }
        if (obj instanceof Double) {
            double value = (Double)obj;
            return HASH_FUNC.hashLong(Double.doubleToLongBits(value)).asInt();
        }
        if (obj instanceof String) {
            return HASH_FUNC.hashUnencodedChars((CharSequence)((String)obj)).asInt();
        }
        throw new UnsupportedOperationException("HashingTF does not support type " + obj.getClass().getCanonicalName() + " of input data.");
    }

    private static int nonNegativeMod(int x, int mod) {
        int rawMod = x % mod;
        return rawMod < 0 ? rawMod + mod : rawMod;
    }

    public static class HashTFFunction
    implements MapFunction<Row, Row> {
        private final String inputCol;
        private final boolean binary;
        private final int numFeatures;

        public HashTFFunction(String inputCol, boolean binary, int numFeatures) {
            this.inputCol = inputCol;
            this.binary = binary;
            this.numFeatures = numFeatures;
        }

        public Row map(Row row) throws Exception {
            Iterable<Object> inputList;
            Object inputObj = row.getField(this.inputCol);
            if (inputObj.getClass().isArray()) {
                inputList = Arrays.asList((Object[])inputObj);
            } else if (inputObj instanceof Iterable) {
                inputList = (Iterable)inputObj;
            } else {
                throw new IllegalArgumentException("Input format " + inputObj.getClass().getCanonicalName() + " is not supported for input column " + this.inputCol + ". Supported options are Array and Iterable.");
            }
            HashMap<Integer, Integer> map = new HashMap<Integer, Integer>();
            for (Object obj : inputList) {
                int hashValue = HashingTF.hash(obj);
                int index = HashingTF.nonNegativeMod(hashValue, this.numFeatures);
                if (map.containsKey(index)) {
                    if (this.binary) continue;
                    map.put(index, (Integer)map.get(index) + 1);
                    continue;
                }
                map.put(index, 1);
            }
            int[] indices = new int[map.size()];
            double[] values = new double[map.size()];
            int idx = 0;
            for (Map.Entry entry : map.entrySet()) {
                indices[idx] = (Integer)entry.getKey();
                values[idx] = ((Integer)entry.getValue()).intValue();
                ++idx;
            }
            return Row.join((Row)row, (Row[])new Row[]{Row.of((Object[])new Object[]{Vectors.sparse(this.numFeatures, indices, values)})});
        }
    }
}

