/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.ParallelBlock;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2D;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.pooling.Pool;
import ai.djl.nn.pooling.PoolingConvention;
import java.util.Arrays;

public final class SqueezeNet {
    private SqueezeNet() {
    }

    static Block fire(int squeezePlanes, int expand1x1Planes, int expand3x3Planes) {
        SequentialBlock squeezeWithActivation = new SequentialBlock().add((Block)((Conv2D.Builder)((Conv2D.Builder)Conv2D.builder().setNumFilters(squeezePlanes)).setKernel(new Shape(new long[]{1L, 1L}))).build()).add(Activation::relu);
        SequentialBlock expand1x1 = new SequentialBlock().add((Block)((Conv2D.Builder)((Conv2D.Builder)Conv2D.builder().setNumFilters(expand1x1Planes)).setKernel(new Shape(new long[]{1L, 1L}))).build()).add(Activation::relu);
        SequentialBlock expand3x3 = new SequentialBlock().add((Block)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)Conv2D.builder().setNumFilters(expand3x3Planes)).setKernel(new Shape(new long[]{3L, 3L}))).optPad(new Shape(new long[]{1L, 1L}))).build()).add(Activation::relu);
        return new SequentialBlock().add((Block)squeezeWithActivation).add((Block)new ParallelBlock(list -> new NDList(new NDArray[]{NDArrays.concat((NDList)((NDList)list.get(0)).addAll((NDList)list.get(1)), (int)1)}), Arrays.asList(expand1x1, expand3x3)));
    }

    public static Block squeezenet(int outSize) {
        return new SequentialBlock().add((Block)((Conv2D.Builder)((Conv2D.Builder)((Conv2D.Builder)Conv2D.builder().setNumFilters(64)).setKernel(new Shape(new long[]{3L, 3L}))).optStride(new Shape(new long[]{2L, 2L}))).build()).add(Activation::relu).add(Pool.maxPool2DBlock((Shape)new Shape(new long[]{3L, 3L}), (Shape)new Shape(new long[]{2L, 2L}), (Shape)new Shape(new long[]{0L, 0L}), (PoolingConvention)PoolingConvention.FULL)).add(SqueezeNet.fire(16, 64, 64)).add(SqueezeNet.fire(16, 64, 64)).add(Pool.maxPool2DBlock((Shape)new Shape(new long[]{3L, 3L}), (Shape)new Shape(new long[]{2L, 2L}), (Shape)new Shape(new long[]{0L, 0L}), (PoolingConvention)PoolingConvention.FULL)).add(SqueezeNet.fire(32, 128, 128)).add(SqueezeNet.fire(32, 128, 128)).add(Pool.maxPool2DBlock((Shape)new Shape(new long[]{3L, 3L}), (Shape)new Shape(new long[]{2L, 2L}), (Shape)new Shape(new long[]{0L, 0L}), (PoolingConvention)PoolingConvention.FULL)).add(SqueezeNet.fire(48, 192, 192)).add(SqueezeNet.fire(48, 192, 192)).add(SqueezeNet.fire(64, 256, 256)).add(SqueezeNet.fire(64, 256, 256)).add((Block)Dropout.builder().optProbability(0.5f).build()).add((Block)((Conv2D.Builder)((Conv2D.Builder)Conv2D.builder().setNumFilters(outSize)).setKernel(new Shape(new long[]{1L, 1L}))).build()).add(Activation::relu).add(Pool.globalAvgPool2DBlock()).add(Blocks.batchFlattenBlock());
    }
}

