package net.jpountz.lz4;

/*
 * Copyright 2020 Adrien Grand and the lz4-java contributors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

import java.nio.ByteBuffer;
import java.util.Arrays;

import static net.jpountz.lz4.LZ4Constants.HASH_LOG;
import static net.jpountz.lz4.LZ4Constants.HASH_LOG_64K;
import static net.jpountz.lz4.LZ4Constants.HASH_LOG_HC;
import static net.jpountz.lz4.LZ4Constants.MIN_MATCH;
import static net.jpountz.lz4.LZ4Constants.ML_MASK;
import static net.jpountz.lz4.LZ4Constants.RUN_MASK;

enum LZ4Utils {
  ;

  private static final int MAX_INPUT_SIZE = 0x7E000000;

  static int maxCompressedLength(int length) {
    if (length < 0) {
      throw new IllegalArgumentException("length must be >= 0, got " + length);
    } else if (length >= MAX_INPUT_SIZE) {
      throw new IllegalArgumentException("length must be < " + MAX_INPUT_SIZE);
    }
    return length + length / 255 + 16;
  }

  /**
   * Returns {@code true} if {@code available < required}.
   * <p>
   * Should be used like this:
   * <pre>
   * if (notEnoughSpace(end - off, <i>required</i>)) ...
   * </pre>
   */
  static boolean notEnoughSpace(int available, int required) {
    if (required < 0) {
      // Overflow; so not enough space
      return true;
    }
    return available < required;
  }

  static {
    // `writeLen` is used for runLen and matchLen; ensure that both masks have the same value, otherwise
    // the logic of `lengthOfEncodedInteger` below is incorrect
    assert RUN_MASK == ML_MASK;
  }
  /**
   * The LZ4 format uses two integers per sequence, encoded in a special format: 4 bits in a shared "token" byte, and
   * then possibly multiple additional bytes. This method returns the number of bytes used to encode a particular
   * value, excluding the 4 shared bits. This is the exact length of the encoding {@link LZ4SafeUtils#writeLen} and
   * equivalent methods implement.
   */
  static int lengthOfEncodedInteger(int value) {
    if (value >= RUN_MASK) {
      return (value - RUN_MASK) / 0xff + 1;
    } else {
      return 0;
    }
  }

  /**
   * Get the length of an encoded LZ4 sequence. An LZ4 sequence consists of a <i>run</i>, containing bytes that are
   * copied from the compressed input as-is, and a <i>match</i> which is a reference to previously decompressed bytes.
   * <p>
   * Encoding:
   *
   * <ul>
   *   <li>1 byte: Token containing 4 bits of the run length and match length each</li>
   *   <li>Possibly more bytes to encode the run length</li>
   *   <li>The run bytes</li>
   *   <li>2 bytes: Match offset</li>
   *   <li>Possibly more bytes to encode the match length</li>
   * </ul>
   */
  static int sequenceLength(int runLen, int matchLen) {
    long len = 1 + (long) lengthOfEncodedInteger(runLen) + (long) runLen + 2 + (long) lengthOfEncodedInteger(matchLen);
    if (len > Integer.MAX_VALUE) {
      throw new LZ4Exception("Sequence length too large");
    }
    return (int) len;
  }

  static int hash(int i) {
    return (i * -1640531535) >>> ((MIN_MATCH * 8) - HASH_LOG);
  }

  static int hash64k(int i) {
    return (i * -1640531535) >>> ((MIN_MATCH * 8) - HASH_LOG_64K);
  }

  static int hashHC(int i) {
    return (i * -1640531535) >>> ((MIN_MATCH * 8) - HASH_LOG_HC);
  }

  /**
   * Zero out a buffer.
   *
   * @param array The input array
   * @param start The start index
   * @param end   The end index (exclusive)
   */
  static void zero(byte[] array, int start, int end) {
    Arrays.fill(array, start, end, (byte) 0);
  }

  /**
   * Zero out a buffer.
   *
   * @param bb    The input buffer
   * @param start The start index
   * @param end   The end index (exclusive)
   */
  static void zero(ByteBuffer bb, int start, int end) {
    for (int i = start; i < end; i++) {
      bb.put(i, (byte) 0);
    }
  }

  static class Match {
    int start, ref, len;

    void fix(int correction) {
      start += correction;
      ref += correction;
      len -= correction;
    }

    int end() {
      return start + len;
    }
  }

  static void copyTo(Match m1, Match m2) {
    m2.len = m1.len;
    m2.start = m1.start;
    m2.ref = m1.ref;
  }

}
