/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.tensorflow.Operand;
import org.tensorflow.Output;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.proto.framework.SignatureDef;
import org.tensorflow.proto.framework.TensorInfo;
import org.tensorflow.proto.framework.TensorShapeProto;

public class Signature {
    public static final String DEFAULT_KEY = "serving_default";
    private final String key;
    private final SignatureDef signatureDef;

    public static Builder builder() {
        return new Builder();
    }

    public String key() {
        return this.key;
    }

    public String methodName() {
        return this.signatureDef.getMethodName().isEmpty() ? null : this.signatureDef.getMethodName();
    }

    public Set<String> inputNames() {
        return this.signatureDef.getInputsMap().keySet();
    }

    public Set<String> outputNames() {
        return this.signatureDef.getOutputsMap().keySet();
    }

    public String toString() {
        StringBuilder strBuilder = new StringBuilder("Signature for \"" + this.key + "\":\n");
        if (!this.methodName().isEmpty()) {
            strBuilder.append("\tMethod: \"").append(this.methodName()).append("\"\n");
        }
        if (this.signatureDef.getInputsCount() > 0) {
            strBuilder.append("\tInputs:\n");
            Signature.printTensorInfo(this.signatureDef.getInputsMap(), strBuilder);
        }
        if (this.signatureDef.getOutputsCount() > 0) {
            strBuilder.append("\tOutputs:\n");
            Signature.printTensorInfo(this.signatureDef.getOutputsMap(), strBuilder);
        }
        return strBuilder.toString();
    }

    private Map<String, TensorDescription> buildTensorDescriptionMap(Map<String, TensorInfo> dataMapIn) {
        HashMap<String, TensorDescription> dataTypeMap = new HashMap<String, TensorDescription>();
        dataMapIn.forEach((a, b) -> {
            long[] tensorDims = b.getTensorShape().getDimList().stream().mapToLong(d -> d.getSize()).toArray();
            Shape tensorShape = Shape.of(tensorDims);
            dataTypeMap.put((String)a, new TensorDescription(b.getDtype(), tensorShape));
        });
        return dataTypeMap;
    }

    public Map<String, TensorDescription> getInputs() {
        return this.buildTensorDescriptionMap(this.signatureDef.getInputsMap());
    }

    public Map<String, TensorDescription> getOutputs() {
        return this.buildTensorDescriptionMap(this.signatureDef.getOutputsMap());
    }

    Signature(String key, SignatureDef signatureDef) {
        this.key = key;
        this.signatureDef = signatureDef;
    }

    SignatureDef asSignatureDef() {
        return this.signatureDef;
    }

    private static void printTensorInfo(Map<String, TensorInfo> tensorMap, StringBuilder strBuilder) {
        tensorMap.forEach((key, tensorInfo) -> {
            strBuilder.append("\t\t\"").append((String)key).append("\": dtype=").append(tensorInfo.getDtype().name()).append(", shape=(");
            for (int i = 0; i < tensorInfo.getTensorShape().getDimCount(); ++i) {
                strBuilder.append(tensorInfo.getTensorShape().getDim(i).getSize());
                if (i >= tensorInfo.getTensorShape().getDimCount() - 1) continue;
                strBuilder.append(", ");
            }
            strBuilder.append(")\n");
        });
    }

    public static class Builder {
        private String key = "serving_default";
        private final SignatureDef.Builder signatureBuilder = SignatureDef.newBuilder();

        public Builder key(String key) {
            if (key == null || key.isEmpty()) {
                throw new IllegalArgumentException("Invalid key: " + key);
            }
            this.key = key;
            return this;
        }

        public Builder input(String inputName, Operand<?> input) {
            if (this.signatureBuilder.containsInputs(inputName)) {
                throw new IllegalArgumentException("\"" + inputName + "\" is already being mapped to another input");
            }
            this.signatureBuilder.putInputs(inputName, Builder.toTensorInfo(input.asOutput()));
            return this;
        }

        public Builder output(String outputName, Operand<?> output) {
            if (this.signatureBuilder.containsOutputs(outputName)) {
                throw new IllegalArgumentException("\"" + outputName + "\" is already being mapped to another output");
            }
            this.signatureBuilder.putOutputs(outputName, Builder.toTensorInfo(output.asOutput()));
            return this;
        }

        public Builder methodName(String methodName) {
            this.signatureBuilder.setMethodName(methodName == null ? "" : methodName);
            return this;
        }

        public Signature build() {
            return new Signature(this.key, this.signatureBuilder.build());
        }

        private static TensorInfo toTensorInfo(Output<?> operand) {
            Shape shape = operand.shape();
            TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
            for (int i = 0; i < shape.numDimensions(); ++i) {
                tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(shape.size(i)));
            }
            return TensorInfo.newBuilder().setDtype(operand.dataType()).setTensorShape(tensorShapeBuilder).setName(operand.op().name() + ":" + operand.index()).build();
        }
    }

    public static class TensorDescription {
        public final DataType dataType;
        public final Shape shape;

        public TensorDescription(DataType dataType, Shape shape) {
            this.dataType = dataType;
            this.shape = shape;
        }
    }
}

