/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.nn.Activation;
import ai.djl.training.loss.Loss;

public class SigmoidBinaryCrossEntropyLoss
extends Loss {
    private float weight;
    private boolean fromSigmoid;

    public SigmoidBinaryCrossEntropyLoss() {
        this("SigmoidBinaryCrossEntropyLoss");
    }

    public SigmoidBinaryCrossEntropyLoss(String name) {
        this(name, 1.0f, false);
    }

    public SigmoidBinaryCrossEntropyLoss(String name, float weight, boolean fromSigmoid) {
        super(name);
        this.weight = weight;
        this.fromSigmoid = fromSigmoid;
    }

    @Override
    public NDArray evaluate(NDList label, NDList prediction) {
        NDArray pred = prediction.singletonOrThrow();
        NDArray lab = label.singletonOrThrow();
        lab = lab.reshape(pred.getShape());
        NDArray loss = !this.fromSigmoid ? Activation.relu(pred).sub(pred.mul(lab)).add(Activation.softPlus(pred.abs().neg())) : this.epsLog(pred).mul(lab).add(this.epsLog(NDArrays.sub(1.0, pred)).mul(NDArrays.sub(1.0, lab)));
        if (this.weight != 1.0f) {
            loss = loss.mul(Float.valueOf(this.weight));
        }
        return loss.mean();
    }

    private NDArray epsLog(NDArray a) {
        double eps = 1.0E-12;
        return a.add(eps).log();
    }
}

