/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.transforms.custom;

import java.lang.reflect.Field;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.descriptors.properties.adapters.IntArrayIntIndexAdapter;
import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class Dilation2D
extends DynamicCustomOp {
    protected boolean isSameMode;
    protected int r0;
    protected int r1;
    protected int r2;
    protected int r3;
    protected int s0;
    protected int s1;
    protected int s2;
    protected int s3;

    public Dilation2D() {
    }

    public Dilation2D(SameDiff sameDiff, SDVariable df, SDVariable weights, int[] strides, int[] rates, boolean isSameMode) {
        this(sameDiff, new SDVariable[]{df, weights}, strides, rates, isSameMode, false);
    }

    public Dilation2D(SameDiff sameDiff, SDVariable[] inputAndWeights, int[] strides, int[] rates, boolean isSameMode, boolean inPlace) {
        super(null, sameDiff, inputAndWeights, inPlace);
        Preconditions.checkArgument((rates.length == 4 ? 1 : 0) != 0, (String)"Dilation rate length must be 4, got an array with length %s with values %s", (Object)rates.length, (Object)rates);
        Preconditions.checkArgument((strides.length == 4 ? 1 : 0) != 0, (String)"Dilation strides length must be 4, got an array with length %s with values %s", (Object)strides.length, (Object)strides);
        this.r0 = rates[0];
        this.r1 = rates[1];
        this.r2 = rates[2];
        this.r3 = rates[3];
        this.s0 = strides[0];
        this.s1 = strides[1];
        this.s2 = strides[2];
        this.s3 = strides[3];
        this.isSameMode = isSameMode;
        this.addArgs();
    }

    public Dilation2D(INDArray[] inputArrays, INDArray[] outputs) {
        super(null, inputArrays, outputs);
    }

    public Dilation2D(INDArray df, INDArray weights, int[] strides, int[] rates, boolean isSameMode) {
        this.addInputArgument(df, weights);
        if (rates.length < 4) {
            throw new IllegalArgumentException("Dilation rate length must be 4.");
        }
        if (strides.length < 4) {
            throw new IllegalArgumentException("Strides length must be 4.");
        }
        this.r0 = rates[0];
        this.r1 = rates[1];
        this.r2 = rates[2];
        this.r3 = rates[3];
        this.s0 = strides[0];
        this.s1 = strides[1];
        this.s2 = strides[2];
        this.s3 = strides[3];
        this.isSameMode = isSameMode;
        this.addArgs();
    }

    protected void addArgs() {
        this.addIArgument(this.isSameMode ? 1 : 0, this.r0, this.r1, this.r2, this.r3, this.s0, this.s1, this.s2, this.s3);
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
        this.addArgs();
    }

    @Override
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap<String, Map<String, PropertyMapping>> ret = new HashMap<String, Map<String, PropertyMapping>>();
        HashMap<String, PropertyMapping> map = new HashMap<String, PropertyMapping>();
        PropertyMapping sameMode = PropertyMapping.builder().tfAttrName("padding").propertyNames(new String[]{"isSameMode"}).build();
        PropertyMapping ratesMapping = PropertyMapping.builder().tfAttrName("rates").propertyNames(new String[]{"r0", "r1", "r2", "r3"}).build();
        PropertyMapping stridesMapping = PropertyMapping.builder().tfAttrName("strides").propertyNames(new String[]{"s0", "s1", "s2", "s3"}).build();
        map.put("isSameMode", sameMode);
        map.put("r0", ratesMapping);
        map.put("r1", ratesMapping);
        map.put("r2", ratesMapping);
        map.put("r3", ratesMapping);
        map.put("s0", stridesMapping);
        map.put("s1", stridesMapping);
        map.put("s2", stridesMapping);
        map.put("s3", stridesMapping);
        try {
            ret.put(this.onnxName(), map);
        }
        catch (NoOpNameFoundException noOpNameFoundException) {
            // empty catch block
        }
        try {
            ret.put(this.tensorflowName(), map);
        }
        catch (NoOpNameFoundException e) {
            throw new RuntimeException(e);
        }
        return ret;
    }

    @Override
    public Map<String, Map<String, AttributeAdapter>> attributeAdaptersForFunction() {
        HashMap<String, Map<String, AttributeAdapter>> ret = new HashMap<String, Map<String, AttributeAdapter>>();
        LinkedHashMap<String, AttributeAdapter> tfMappings = new LinkedHashMap<String, AttributeAdapter>();
        Map<String, Field> fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this);
        tfMappings.put("r0", new IntArrayIntIndexAdapter(0));
        tfMappings.put("r1", new IntArrayIntIndexAdapter(1));
        tfMappings.put("r2", new IntArrayIntIndexAdapter(2));
        tfMappings.put("r3", new IntArrayIntIndexAdapter(3));
        tfMappings.put("s0", new IntArrayIntIndexAdapter(0));
        tfMappings.put("s1", new IntArrayIntIndexAdapter(1));
        tfMappings.put("s2", new IntArrayIntIndexAdapter(2));
        tfMappings.put("s3", new IntArrayIntIndexAdapter(3));
        tfMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
        HashMap<String, StringEqualsAdapter> onnxMappings = new HashMap<String, StringEqualsAdapter>();
        onnxMappings.put("isSameMode", new StringEqualsAdapter("SAME"));
        ret.put(this.tensorflowName(), tfMappings);
        ret.put(this.onnxName(), onnxMappings);
        return ret;
    }

    @Override
    public String opName() {
        return "dilation2d";
    }

    @Override
    public String onnxName() {
        return "Dilation_2D";
    }

    @Override
    public String tensorflowName() {
        return "Dilation2D";
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
        Preconditions.checkState((inputDataTypes != null && inputDataTypes.size() >= 2 && inputDataTypes.size() <= 4 ? 1 : 0) != 0, (String)"Expected 2 to 4 input datatypes for %s, got %s", this.getClass(), inputDataTypes);
        return Collections.singletonList(inputDataTypes.get(0));
    }
}

