/*
 * Copyright 2024 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.seppiko.commons.utils.crypto;

import java.nio.ByteBuffer;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.security.Provider;
import java.util.Objects;
import javax.crypto.Mac;
import javax.crypto.SecretKey;
import org.seppiko.commons.utils.crypto.spec.KeySpecUtil;

/**
 * A standards-compliant implementation of RFC 5869
 * for HMAC-based Key Derivation Function.
 *
 * <p>
 * HKDF follows the "extract-then-expand" paradigm, where the KDF
 * logically consists of two modules.  The first stage takes the input
 * keying material and "extracts" from it a fixed-length pseudorandom
 * key K.  The second stage "expands" the key K into several additional
 * pseudorandom keys (the output of the KDF).
 * </p>
 *
 * HKDF was first described by Hugo Krawczyk.
 *
 * @see <a href="https://tools.ietf.org/html/rfc5869">RFC 5869</a>
 * @see <a href="https://en.wikipedia.org/wiki/HKDF">Wikipedia: HKDF</a>
 * @see <a href="https://github.com/patrickfav/hkdf/blob/main/src/main/java/at/favre/lib/hkdf/HKDF.java">HKDF.java</a>
 * @author Leonard Woo
 */
public class HKDF {

  private final String algorithm;
  private final Provider provider;

  private HKDF(String algorithm, Provider provider) {
    this.algorithm = algorithm;
    this.provider = provider;
  }

  /**
   * Create a new HKDF instance for given macFactory.
   *
   * @param algorithm used algorithm for HKDF.
   * @return a new instance of HKDF.
   */
  public static HKDF from(String algorithm) {
    return from(algorithm, CryptoUtil.NONPROVIDER);
  }

  /**
   * Create a new HKDF instance for given macFactory.
   *
   * @param algorithm used algorithm for HKDF.
   * @param provider used provider for HKDF.
   * @return a new instance of HKDF.
   */
  public static HKDF from(String algorithm, Provider provider) {
    return new HKDF(algorithm, provider);
  }

  /**
   * <strong>Step 1 of RFC 5869 (Section 2.2)</strong>
   * <p>
   * The first stage takes the input keying material and "extracts" from it a fixed-length
   * pseudorandom key K. The goal of the "extract" stage is to "concentrate" and provide a
   * more uniformly unbiased and higher entropy but smaller output.
   * This is done by utilising the diffusion properties of cryptographic MACs.
   * <p>
   * <strong>About Salts (from RFC 5869):</strong>
   * <blockquote>
   * HKDF is defined to operate with and without random salt.  This is
   * done to accommodate applications where a salt value is not available.
   * We stress, however, that the use of salt adds significantly to the
   * strength of HKDF, ensuring independence between different uses of the
   * hash function, supporting "source-independent" extraction, and
   * strengthening the analytical results that back the HKDF design.
   * </blockquote>
   *
   * @param salt optional salt value (a non-secret random value) (can be null)
   *   if not provided, it is set to an array of hash length of zeros.
   * @param ikm data to be extracted (Input Keying Material)
   *
   * @return a new byte array pseudo random key (of hash length in bytes) (PRK) which can be used to expand
   * @see <a href="https://tools.ietf.org/html/rfc5869#section-2.2">RFC 5869 Section 2.2</a>
   */
  public byte[] extract(byte[] salt, byte[] ikm) {
    return extract(createSecretKey(salt), ikm);
  }

  /**
   * Use this if you require {@link SecretKey} types by your security framework.
   * <p>
   * See {@link #extract(byte[], byte[])} for description.
   *
   * @param salt optional salt value (a non-secret random value) (can be null)
   * @param ikm data to be extracted (Input Keying Material)
   * @return a new byte array pseudo random key (of hash length in bytes) (PRK) which can be used to expand
   */
  public byte[] extract(SecretKey salt, byte[] ikm) {
    return new Extractor(algorithm, provider).execute(salt, ikm);
  }

  /**
   * <strong>Step 2 of RFC 5869 (Section 2.3)</strong>
   * <p>
   * To "expand" the generated output of an already reasonably random input such as an existing
   * shared key into a larger cryptographically independent output, thereby producing multiple keys
   * deterministically from that initial shared key, so that the same process may produce those same
   * secret keys safely on multiple devices, as long as the same inputs are used.
   * <p>
   * <strong>About Info (from RFC 5869):</strong>
   * <blockquote>
   * While the 'info' value is optional in the definition of HKDF, it is
   * often of great importance in applications.  Its main objective is to
   * bind the derived key material to application- and context-specific
   * information.  For example, 'info' may contain a protocol number,
   * algorithm identifiers, user identities, etc.  In particular, it may
   * prevent the derivation of the same keying material for different
   * contexts (when the same input key material (IKM) is used in such
   * different contexts).
   * </blockquote>
   *
   * @param pseudoRandomKey a pseudo random key of at least hmac hash length in bytes (usually, the output from the extract step)
   * @param info            optional context and application specific information; may be null
   * @param outLengthBytes  length of output keying material in bytes
   * @return new byte array of output keying material (OKM)
   * @see <a href="https://tools.ietf.org/html/rfc5869#section-2.3">RFC 5869 Section 2.3</a>
   */
  public byte[] expand(byte[] pseudoRandomKey, byte[] info, int outLengthBytes) {
    return expand(createSecretKey(pseudoRandomKey), info, outLengthBytes);
  }

  /**
   * Use this if you require {@link SecretKey} types by your security framework.
   * <p>
   * See {@link #expand(byte[], byte[], int)} for description.
   *
   * @param pseudoRandomKey a pseudo random key of at least hmac hash length in bytes (usually, the output from the extract step)
   * @param info            optional context and application specific information; may be null
   * @param outLengthBytes  length of output keying material in bytes
   * @return new byte array of output keying material (OKM)
   */
  public byte[] expand(SecretKey pseudoRandomKey, byte[] info, int outLengthBytes) {
    return new Expander(algorithm, provider).execute(pseudoRandomKey, info, outLengthBytes);
  }

  private SecretKey createSecretKey(byte[] salt) {
    return KeySpecUtil.getSecret(salt, algorithm);
  }

  static final class Extractor {
    private final String algorithmName;
    private final Provider provider;
    private Mac mac;

    Extractor(String algorithmName, Provider provider) {
      this.algorithmName = algorithmName;
      this.provider = provider;
      createMac();
    }

    private void createMac() {
      try {
        mac = CryptoUtil.mac(algorithmName, provider);
      } catch (NoSuchAlgorithmException e) {
        throw new IllegalStateException(e);
      }
    }

    /**
     * Step 1 of RFC 5869
     *
     * @param salt optional salt value (a non-secret random value);
     *   if not provided, it is set to an array of hash length of zeros.
     * @param ikm data to be extracted (Input Keying Material)
     * @return a new byte array pseudorandom key (of hash length in bytes) (PRK) which can be used to expand.
     */
    byte[] execute(SecretKey salt, byte[] ikm) {
      try {
        if (salt == null) {
          salt = createEmptySecretKey();
        }

        mac.init(salt);
        return mac.doFinal(ikm);
      } catch (InvalidKeyException e) {
        throw new IllegalArgumentException(e);
      }
    }


    SecretKey createEmptySecretKey() {
      int macLength = mac.getMacLength();
      return KeySpecUtil.getSecret(new byte[macLength], algorithmName);
    }
  }

  static final class Expander {
    private final String algorithmName;
    private final Provider provider;

    Expander(String algorithmName, Provider provider) {
      this.algorithmName = algorithmName;
      this.provider = provider;
    }

    /**
     * Step 2 of RFC 5869.
     *
     * @param pseudoRandomKey a pseudorandom key of at least hmac hash length in bytes (usually, the output from the extract step)
     * @param info            optional context and application specific information; may be null
     * @param outLengthBytes  length of output keying material in bytes (must be <= 255 * mac hash length)
     * @return new byte array of output keying material (OKM)
     */
    byte[] execute(SecretKey pseudoRandomKey, byte[] info, int outLengthBytes) {
      if (Objects.isNull(pseudoRandomKey)) {
        throw new IllegalArgumentException("provided pseudoRandomKey must not be null");
      }
      if (null == info) {
        info = new byte[0];
      }
      if (outLengthBytes < 1) {
        throw new IllegalArgumentException("out length bytes must be positive");
      }

      Mac mac;
      try {
        mac = CryptoUtil.mac(algorithmName, provider);
        mac.init(pseudoRandomKey);
      } catch (NoSuchAlgorithmException | InvalidKeyException e) {
        throw new IllegalStateException(e);
      }

      /*
      The output OKM is calculated as follows:
        N = ceil(L/HashLen)
        T = T(1) | T(2) | T(3) | ... | T(N)
        OKM = first L bytes of T
      where:
        T(0) = empty string (zero length)
        T(1) = HMAC-Hash(PRK, T(0) | info | 0x01)
        T(2) = HMAC-Hash(PRK, T(1) | info | 0x02)
        T(3) = HMAC-Hash(PRK, T(2) | info | 0x03)
        ...
       */

      byte[] blockN = new byte[0];

      int hashLen = mac.getMacLength();
      if (outLengthBytes > (255 * hashLen)) {
        throw new IllegalArgumentException(
            "outLengthBytes must be less than or equal to 255*HashLen");
      }

      int n = (outLengthBytes % hashLen == 0)? (outLengthBytes / hashLen): ((outLengthBytes / hashLen) + 1);

      ByteBuffer generatedBytes = ByteBuffer.allocate(Math.multiplyExact(n, hashLen));
      for (int roundNum = 1; roundNum <= n; roundNum++) {
        mac.reset();
        ByteBuffer t = ByteBuffer.allocate(blockN.length + info.length + 1);
        t.put(blockN);
        t.put(info);
        t.put((byte) roundNum);
        blockN = mac.doFinal(t.array());
        generatedBytes.put(blockN);
      }

      byte[] result = new byte[outLengthBytes];
      generatedBytes.rewind();
      generatedBytes.get(result, 0, outLengthBytes);
      return result;
    }
  }
}
