/*
 * Decompiled with CFR 0.152.
 */
package org.apache.shiro.crypto.support.hashes.argon2;

import java.security.SecureRandom;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashSet;
import java.util.Locale;
import java.util.Objects;
import java.util.Set;
import java.util.StringJoiner;
import java.util.regex.Pattern;
import org.apache.shiro.crypto.hash.AbstractCryptHash;
import org.apache.shiro.lang.util.ByteSource;
import org.apache.shiro.lang.util.SimpleByteSource;
import org.bouncycastle.crypto.generators.Argon2BytesGenerator;
import org.bouncycastle.crypto.params.Argon2Parameters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

class Argon2Hash
extends AbstractCryptHash {
    public static final int DEFAULT_PARALLELISM = 4;
    public static final int DEFAULT_OUTPUT_LENGTH_BITS = 256;
    public static final String DEFAULT_ALGORITHM_NAME = "argon2id";
    public static final int DEFAULT_ALGORITHM_VERSION = 19;
    public static final int DEFAULT_ITERATIONS = 1;
    public static final int DEFAULT_MEMORY_KIB = 65536;
    private static final long serialVersionUID = 2647354947284558921L;
    private static final Logger LOG = LoggerFactory.getLogger(Argon2Hash.class);
    private static final Set<String> ALGORITHMS_ARGON2 = new HashSet<String>(Arrays.asList("argon2id", "argon2i", "argon2d"));
    private static final Pattern DELIMITER_COMMA = Pattern.compile(",");
    private static final int SALT_LENGTH_BITS = 128;
    private final int argonVersion;
    private final int iterations;
    private final int memoryKiB;
    private final int parallelism;

    Argon2Hash(String algorithmName, int argonVersion, byte[] hashedData, ByteSource salt, int iterations, int memoryAsKB, int parallelism) {
        super(algorithmName, hashedData, salt);
        this.argonVersion = argonVersion;
        this.iterations = iterations;
        this.memoryKiB = memoryAsKB;
        this.parallelism = parallelism;
        this.checkValidIterations();
    }

    public static Set<String> getAlgorithmsArgon2() {
        return Collections.unmodifiableSet(ALGORITHMS_ARGON2);
    }

    protected static ByteSource createSalt() {
        return Argon2Hash.createSalt(new SecureRandom());
    }

    public static ByteSource createSalt(SecureRandom random) {
        return new SimpleByteSource(random.generateSeed(16));
    }

    public static Argon2Hash fromString(String input) {
        if (!input.startsWith("$")) {
            throw new UnsupportedOperationException("Unsupported input: " + input);
        }
        String[] parts = AbstractCryptHash.DELIMITER.split(input.substring(1));
        String algorithmName = parts[0].trim();
        if (!ALGORITHMS_ARGON2.contains(algorithmName)) {
            throw new UnsupportedOperationException("Unsupported algorithm: " + algorithmName + ". Expected one of: " + String.valueOf(ALGORITHMS_ARGON2));
        }
        int version2 = Argon2Hash.parseVersion(parts[1]);
        String parameters = parts[2];
        int memoryPowTwo = Argon2Hash.parseMemory(parameters);
        int iterations = Argon2Hash.parseIterations(parameters);
        int parallelism = Argon2Hash.parseParallelism(parameters);
        SimpleByteSource salt = new SimpleByteSource(org.apache.shiro.lang.codec.Base64.decode(parts[3]));
        byte[] hashedData = org.apache.shiro.lang.codec.Base64.decode(parts[4]);
        return new Argon2Hash(algorithmName, version2, hashedData, salt, iterations, memoryPowTwo, parallelism);
    }

    private static int parseParallelism(String parameters) {
        String parameter = DELIMITER_COMMA.splitAsStream(parameters).filter(parm -> parm.startsWith("p=")).findAny().orElseThrow(() -> new IllegalArgumentException("Did not found memory parameter 'p='. Got: [" + parameters + "]."));
        return Integer.parseInt(parameter.substring(2));
    }

    private static int parseIterations(String parameters) {
        String parameter = DELIMITER_COMMA.splitAsStream(parameters).filter(parm -> parm.startsWith("t=")).findAny().orElseThrow(() -> new IllegalArgumentException("Did not found memory parameter 't='. Got: [" + parameters + "]."));
        return Integer.parseInt(parameter.substring(2));
    }

    private static int parseMemory(String parameters) {
        String parameter = DELIMITER_COMMA.splitAsStream(parameters).filter(parm -> parm.startsWith("m=")).findAny().orElseThrow(() -> new IllegalArgumentException("Did not found memory parameter 'm='. Got: [" + parameters + "]."));
        return Integer.parseInt(parameter.substring(2));
    }

    private static int parseVersion(String part) {
        if (!part.startsWith("v=")) {
            throw new IllegalArgumentException("Did not find version parameter 'v='. Got: [" + part + "].");
        }
        return Integer.parseInt(part.substring(2));
    }

    public static Argon2Hash generate(char[] source) {
        return Argon2Hash.generate(new SimpleByteSource(source), Argon2Hash.createSalt(), 1);
    }

    public static Argon2Hash generate(ByteSource source, ByteSource salt, int iterations) {
        return Argon2Hash.generate(DEFAULT_ALGORITHM_NAME, source, Objects.requireNonNull(salt, "salt"), iterations);
    }

    public static Argon2Hash generate(String algorithmName, ByteSource source, ByteSource salt, int iterations) {
        return Argon2Hash.generate(algorithmName, 19, source, salt, iterations, 65536, 4, 256);
    }

    public static Argon2Hash generate(String algorithmName, int argonVersion, ByteSource source, ByteSource salt, int iterations, int memoryAsKB, int parallelism, int outputLengthBits) {
        int type;
        switch (Objects.requireNonNull(algorithmName, "algorithmName")) {
            case "argon2i": {
                type = 1;
                break;
            }
            case "argon2d": {
                type = 0;
                break;
            }
            case "argon2": 
            case "argon2id": {
                type = 2;
                break;
            }
            default: {
                throw new IllegalArgumentException("Unknown argon2 algorithm: " + algorithmName);
            }
        }
        Argon2Parameters parameters = new Argon2Parameters.Builder(type).withVersion(argonVersion).withIterations(iterations).withParallelism(parallelism).withSalt(Objects.requireNonNull(salt, "salt").getBytes()).withMemoryAsKB(memoryAsKB).build();
        Argon2BytesGenerator gen = new Argon2BytesGenerator();
        gen.init(parameters);
        byte[] hash = new byte[outputLengthBits / 8];
        gen.generateBytes(source.getBytes(), hash);
        return new Argon2Hash(algorithmName, argonVersion, hash, new SimpleByteSource(salt), iterations, memoryAsKB, parallelism);
    }

    @Override
    protected void checkValidAlgorithm() {
        if (!ALGORITHMS_ARGON2.contains(this.getAlgorithmName())) {
            String message = String.format(Locale.ENGLISH, "Given algorithm name [%s] not valid for argon2. Valid algorithms: [%s].", this.getAlgorithmName(), ALGORITHMS_ARGON2);
            throw new IllegalArgumentException(message);
        }
    }

    protected void checkValidIterations() {
        int iterations = this.getIterations();
        if (iterations < 1) {
            String message = String.format(Locale.ENGLISH, "Expected argon2 iterations >= 1, but was [%d].", iterations);
            throw new IllegalArgumentException(message);
        }
    }

    @Override
    public int getIterations() {
        return this.iterations;
    }

    @Override
    public boolean matchesPassword(ByteSource plaintextBytes) {
        try {
            Argon2Hash compare = Argon2Hash.generate(this.getAlgorithmName(), this.argonVersion, plaintextBytes, this.getSalt(), this.getIterations(), this.memoryKiB, this.parallelism, this.getBytes().length * 8);
            return this.equals(compare);
        }
        catch (IllegalArgumentException illegalArgumentException) {
            LOG.warn("Cannot recreate a hash using the same parameters.", illegalArgumentException);
            return false;
        }
    }

    @Override
    public int getSaltLength() {
        return 16;
    }

    @Override
    public String formatToCryptString() {
        Base64.Encoder encoder = Base64.getEncoder().withoutPadding();
        String saltBase64 = encoder.encodeToString(this.getSalt().getBytes());
        String dataBase64 = encoder.encodeToString(this.getBytes());
        return new StringJoiner("$", "$", "").add(this.getAlgorithmName()).add("v=" + this.argonVersion).add(this.formatParameters()).add(saltBase64).add(dataBase64).toString();
    }

    private CharSequence formatParameters() {
        return String.format(Locale.ENGLISH, "t=%d,m=%d,p=%d", this.getIterations(), this.getMemoryKiB(), this.getParallelism());
    }

    public int getMemoryKiB() {
        return this.memoryKiB;
    }

    public int getParallelism() {
        return this.parallelism;
    }

    public int getArgonVersion() {
        return this.argonVersion;
    }

    @Override
    public boolean equals(Object other) {
        if (this == other) {
            return true;
        }
        if (other == null || this.getClass() != other.getClass()) {
            return false;
        }
        if (!super.equals(other)) {
            return false;
        }
        Argon2Hash that = (Argon2Hash)other;
        return this.argonVersion == that.argonVersion && this.iterations == that.iterations && this.memoryKiB == that.memoryKiB && this.parallelism == that.parallelism;
    }

    @Override
    public int hashCode() {
        return Objects.hash(super.hashCode(), this.argonVersion, this.iterations, this.memoryKiB, this.parallelism);
    }

    @Override
    public String toString() {
        return new StringJoiner(", ", Argon2Hash.class.getSimpleName() + "[", "]").add("super=" + super.toString()).add("version=" + this.argonVersion).add("iterations=" + this.iterations).add("memoryKiB=" + this.memoryKiB).add("parallelism=" + this.parallelism).toString();
    }
}

