/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.pooling.Pool;

public final class MobileNetV1 {
    static final int[] FILTERS = new int[]{32, 64, 128, 128, 256, 256, 512, 512, 1024, 1024};

    private MobileNetV1() {
    }

    public static Block depthSeparableConv2d(int inputChannels, int outputChannels, int stride, Builder builder) {
        SequentialBlock depthWise = new SequentialBlock();
        depthWise.add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{3L, 3L}))).optBias(false)).optPadding(new Shape(new long[]{1L, 1L}))).optStride(new Shape(new long[]{stride, stride}))).optGroups(inputChannels)).setFilters(inputChannels)).build()).add((Block)BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(builder.batchNormMomentum).build()).add(Activation.reluBlock());
        SequentialBlock pointWise = new SequentialBlock();
        pointWise.add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{1L, 1L}))).setFilters(outputChannels)).optBias(false)).build()).add((Block)BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(builder.batchNormMomentum).build()).add(Activation.reluBlock());
        return depthWise.add((Block)pointWise);
    }

    public static Block mobilenet(Builder builder) {
        SequentialBlock mobileNet = new SequentialBlock();
        mobileNet.add((Block)new SequentialBlock().add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setKernelShape(new Shape(new long[]{3L, 3L}))).optBias(false)).optStride(new Shape(new long[]{2L, 2L}))).optPadding(new Shape(new long[]{1L, 1L}))).setFilters((int)((float)FILTERS[0] * builder.widthMultiplier))).build()).add((Block)BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(builder.batchNormMomentum).build()).add(Activation.reluBlock())).add(MobileNetV1.depthSeparableConv2d((int)((float)FILTERS[0] * builder.widthMultiplier), (int)((float)FILTERS[1] * builder.widthMultiplier), 1, builder)).add(MobileNetV1.depthSeparableConv2d((int)((float)FILTERS[1] * builder.widthMultiplier), (int)((float)FILTERS[2] * builder.widthMultiplier), 2, builder)).add(MobileNetV1.depthSeparableConv2d((int)((float)FILTERS[2] * builder.widthMultiplier), (int)((float)FILTERS[3] * builder.widthMultiplier), 1, builder)).add(MobileNetV1.depthSeparableConv2d((int)((float)FILTERS[3] * builder.widthMultiplier), (int)((float)FILTERS[4] * builder.widthMultiplier), 2, builder)).add(MobileNetV1.depthSeparableConv2d((int)((float)FILTERS[4] * builder.widthMultiplier), (int)((float)FILTERS[5] * builder.widthMultiplier), 1, builder)).add(MobileNetV1.depthSeparableConv2d((int)((float)FILTERS[5] * builder.widthMultiplier), (int)((float)FILTERS[6] * builder.widthMultiplier), 2, builder)).add(MobileNetV1.depthSeparableConv2d((int)((float)FILTERS[6] * builder.widthMultiplier), (int)((float)FILTERS[7] * builder.widthMultiplier), 1, builder)).add(MobileNetV1.depthSeparableConv2d((int)((float)FILTERS[6] * builder.widthMultiplier), (int)((float)FILTERS[7] * builder.widthMultiplier), 1, builder)).add(MobileNetV1.depthSeparableConv2d((int)((float)FILTERS[6] * builder.widthMultiplier), (int)((float)FILTERS[7] * builder.widthMultiplier), 1, builder)).add(MobileNetV1.depthSeparableConv2d((int)((float)FILTERS[6] * builder.widthMultiplier), (int)((float)FILTERS[7] * builder.widthMultiplier), 1, builder)).add(MobileNetV1.depthSeparableConv2d((int)((float)FILTERS[6] * builder.widthMultiplier), (int)((float)FILTERS[7] * builder.widthMultiplier), 1, builder)).add(MobileNetV1.depthSeparableConv2d((int)((float)FILTERS[7] * builder.widthMultiplier), (int)((float)FILTERS[8] * builder.widthMultiplier), 2, builder)).add(MobileNetV1.depthSeparableConv2d((int)((float)FILTERS[8] * builder.widthMultiplier), (int)((float)FILTERS[9] * builder.widthMultiplier), 1, builder)).add(Pool.globalAvgPool2dBlock()).add((Block)Linear.builder().setUnits(builder.outSize).build());
        return mobileNet;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        float batchNormMomentum = 0.9f;
        float widthMultiplier = 1.0f;
        long outSize = 10L;

        Builder() {
        }

        public Builder optWidthMultiplier(float widthMultiplier) {
            this.widthMultiplier = widthMultiplier;
            return this;
        }

        public Builder optBatchNormMomentum(float batchNormMomentum) {
            this.batchNormMomentum = batchNormMomentum;
            return this;
        }

        public Builder setOutSize(long outSize) {
            this.outSize = outSize;
            return this;
        }

        public Block build() {
            return MobileNetV1.mobilenet(this);
        }
    }
}

