/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow.op.nn;

import java.util.Arrays;
import java.util.List;
import org.tensorflow.Operand;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.op.Scope;
import org.tensorflow.op.core.Concat;
import org.tensorflow.op.core.Constant;
import org.tensorflow.op.core.Range;
import org.tensorflow.op.core.Rank;
import org.tensorflow.op.core.Reshape;
import org.tensorflow.op.core.Slice;
import org.tensorflow.op.dtypes.Cast;
import org.tensorflow.op.linalg.Transpose;
import org.tensorflow.op.math.Sub;
import org.tensorflow.types.TBfloat16;
import org.tensorflow.types.TFloat16;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TNumber;

public class SoftmaxCrossEntropyWithLogits {
    public static <T extends TNumber, U extends TNumber> Operand<T> softmaxCrossEntropyWithLogits(Scope scope, Operand<U> labels, Operand<T> logits, int axis) {
        scope = scope.withSubScope("SoftmaxCrossEntropyWithLogits");
        if ((axis %= logits.shape().numDimensions()) < 0) {
            axis += logits.shape().numDimensions();
        }
        if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) {
            Operand<TFloat32> result = SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, Cast.create(scope, labels, TFloat32.class, new Cast.Options[0]), Cast.create(scope, logits, TFloat32.class, new Cast.Options[0]), axis);
            return Cast.create(scope, result, logits.asOutput().type(), new Cast.Options[0]);
        }
        if (logits.asOutput().type() != labels.asOutput().type()) {
            return SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, Cast.create(scope, labels, logits.asOutput().type(), new Cast.Options[0]), logits, axis);
        }
        Cast<TInt64> inputRank = Cast.create(scope, Rank.create(scope, logits), TInt64.class, new Cast.Options[0]);
        Shape shape = logits.shape();
        if (axis != -1 && axis != logits.shape().numDimensions() - 1) {
            logits = SoftmaxCrossEntropyWithLogits.moveDimToEnd(scope, logits, axis, inputRank);
            labels = SoftmaxCrossEntropyWithLogits.moveDimToEnd(scope, labels, axis, inputRank);
        }
        Shape inputShape = logits.shape();
        logits = SoftmaxCrossEntropyWithLogits.flattenOuterDims(scope, logits);
        labels = SoftmaxCrossEntropyWithLogits.flattenOuterDims(scope, labels);
        org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits<T> smax = org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits.create(scope, logits, labels);
        Operand<T> cost = smax.loss();
        Slice<TInt64> outputShape = Slice.create(scope, Constant.tensorOf(scope, inputShape), Constant.arrayOf(scope, 0L), Constant.arrayOf(scope, (long)inputShape.numDimensions() - 1L));
        cost = Reshape.create(scope, cost, outputShape);
        if (scope.env().isGraph() && !shape.hasUnknownDimension()) {
            int i;
            long[] array = shape.asArray();
            long[] newArray = new long[array.length - 1];
            if (axis < 0) {
                axis = shape.numDimensions() + axis;
            }
            for (i = 0; i < axis; ++i) {
                newArray[i] = shape.size(i);
            }
            for (i = axis + 1; i < shape.numDimensions(); ++i) {
                newArray[i - 1] = shape.size(i);
            }
            cost = Reshape.create(scope, cost, Constant.vectorOf(scope, newArray));
        }
        return cost;
    }

    private static <T extends TNumber> Operand<T> flattenOuterDims(Scope scope, Operand<T> logits) {
        Constant<TInt64> one = Constant.scalarOf(scope, 1L);
        Shape shape = logits.shape();
        int ndims = shape.numDimensions();
        if (!shape.hasUnknownDimension()) {
            long product = 1L;
            boolean productValid = true;
            for (int i = ndims - 2; i >= 0; --i) {
                long d = shape.size(i);
                if (d == Shape.UNKNOWN_SIZE) {
                    productValid = false;
                    break;
                }
                product *= d;
            }
            if (productValid) {
                return Reshape.create(scope, logits, Constant.arrayOf(scope, product, shape.size(-1)));
            }
        }
        Cast<TInt64> rank = Cast.create(scope, Rank.create(scope, logits), TInt64.class, new Cast.Options[0]);
        Sub<TInt64> rankMinusOne = Sub.create(scope, rank, one);
        Slice<TInt64> lastDimSize = Slice.create(scope, org.tensorflow.op.core.Shape.create(scope, logits, TInt64.class), rankMinusOne, one);
        Concat concat = Concat.create(scope, Arrays.asList(Constant.arrayOf(scope, -1L), lastDimSize), Constant.scalarOf(scope, 0));
        return Reshape.create(scope, logits, concat);
    }

    private static <T extends TNumber, U extends TNumber> Operand<T> moveDimToEnd(Scope scope, Operand<T> input, int dimIndex, Operand<U> rank) {
        Class<U> rankType = rank.asOutput().type();
        Cast<U> one = Cast.create(scope, Constant.scalarOf(scope, 1), rankType, new Cast.Options[0]);
        List concatList = Arrays.asList(Range.create(scope, Cast.create(scope, Constant.scalarOf(scope, dimIndex), rankType, new Cast.Options[0]), one, one), Range.create(scope, Cast.create(scope, Constant.scalarOf(scope, dimIndex + 1), rankType, new Cast.Options[0]), one, one));
        return Transpose.create(scope, input, Concat.create(scope, concatList, Constant.scalarOf(scope, 0)));
    }
}

