/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.cpu.nativecpu.compression;

import org.apache.commons.math3.util.FastMath;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.compression.impl.AbstractCompressor;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.DataTypeEx;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.compression.CompressedDataBuffer;
import org.nd4j.linalg.compression.CompressionDescriptor;
import org.nd4j.linalg.compression.CompressionType;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class CpuThreshold
extends AbstractCompressor {
    private static final Logger log = LoggerFactory.getLogger(CpuThreshold.class);
    protected float threshold = 0.001f;

    public String getDescriptor() {
        return "THRESHOLD";
    }

    public void configure(Object ... vars) {
        if (!(vars[0] instanceof Number)) {
            throw new ND4JIllegalStateException("Threshold value should be Number");
        }
        Number t = (Number)vars[0];
        this.threshold = FastMath.abs((float)t.floatValue());
        log.info("Setting threshold to [{}]", (Object)Float.valueOf(this.threshold));
    }

    public INDArray compress(INDArray array) {
        Nd4j.getExecutioner().commit();
        Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.HOST);
        DataBuffer buffer = this.compress(array.data());
        if (buffer == null) {
            return null;
        }
        INDArray dup = Nd4j.createArrayFromShapeBuffer((DataBuffer)buffer, (DataBuffer)array.shapeInfoDataBuffer());
        dup.markAsCompressed(true);
        return dup;
    }

    public CompressionType getCompressionType() {
        return CompressionType.LOSSLESS;
    }

    public DataBuffer decompress(DataBuffer buffer, DataType dataType) {
        DataBuffer result = Nd4j.getNDArrayFactory().convertDataEx(DataTypeEx.THRESHOLD, buffer, this.getGlobalTypeEx());
        return result;
    }

    public DataBuffer compress(DataBuffer buffer) {
        INDArray temp = Nd4j.createArrayFromShapeBuffer((DataBuffer)buffer, (DataBuffer)((DataBuffer)Nd4j.getShapeInfoProvider().createShapeInformation(new long[]{1L, buffer.length()}, buffer.dataType()).getFirst()));
        MatchCondition condition = new MatchCondition(temp, Conditions.absGreaterThanOrEqual((Number)Float.valueOf(this.threshold)), new int[0]);
        int cntAbs = Nd4j.getExecutioner().exec((ReduceOp)condition).getInt(new int[]{0});
        if (cntAbs < 2) {
            return null;
        }
        long originalLength = buffer.length() * (long)Nd4j.sizeOfDataType((DataType)buffer.dataType());
        int compressedLength = cntAbs + 4;
        IntPointer pointer = new IntPointer((long)compressedLength);
        pointer.put(0L, cntAbs);
        pointer.put(1L, (int)buffer.length());
        pointer.put(2L, Float.floatToIntBits(this.threshold));
        pointer.put(3L, 0);
        CompressionDescriptor descriptor = new CompressionDescriptor();
        descriptor.setCompressedLength((long)(compressedLength * 4));
        descriptor.setOriginalLength(originalLength);
        descriptor.setOriginalElementSize((long)Nd4j.sizeOfDataType((DataType)buffer.dataType()));
        descriptor.setNumberOfElements(buffer.length());
        descriptor.setCompressionAlgorithm(this.getDescriptor());
        descriptor.setCompressionType(this.getCompressionType());
        CompressedDataBuffer cbuff = new CompressedDataBuffer((Pointer)pointer, descriptor);
        Nd4j.getNDArrayFactory().convertDataEx(CpuThreshold.getBufferTypeEx((DataBuffer)buffer), buffer.addressPointer(), DataTypeEx.THRESHOLD, (Pointer)pointer, buffer.length());
        Nd4j.getAffinityManager().tagLocation(buffer, AffinityManager.Location.HOST);
        return cbuff;
    }

    protected CompressedDataBuffer compressPointer(DataTypeEx srcType, Pointer srcPointer, int length, int elementSize) {
        throw new UnsupportedOperationException();
    }

    public float getThreshold() {
        return this.threshold;
    }

    public void setThreshold(float threshold) {
        this.threshold = threshold;
    }
}

