/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.Blocks;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.Dropout;
import ai.djl.nn.pooling.Pool;

public final class VGG {
    private VGG() {
    }

    public static Block vgg(Builder builder) {
        SequentialBlock block = new SequentialBlock();
        VGG vgg = new VGG();
        for (int[] arr : builder.convArch) {
            block.add((Block)vgg.vggBlock(arr[0], arr[1]));
        }
        block.add(Blocks.batchFlattenBlock()).add((Block)Linear.builder().setUnits(4096L).build()).add(Activation::relu).add((Block)Dropout.builder().optRate(0.5f).build()).add((Block)Linear.builder().setUnits(4096L).build()).add(Activation::relu).add((Block)Dropout.builder().optRate(0.5f).build()).add((Block)Linear.builder().setUnits(builder.outSize).build());
        return block;
    }

    public SequentialBlock vggBlock(int numConvs, int numChannels) {
        SequentialBlock tempBlock = new SequentialBlock();
        for (int i = 0; i < numConvs; ++i) {
            tempBlock.add((Block)((Conv2d.Builder)((Conv2d.Builder)((Conv2d.Builder)Conv2d.builder().setFilters(numChannels)).setKernelShape(new Shape(new long[]{3L, 3L}))).optPadding(new Shape(new long[]{1L, 1L}))).build()).add(Activation::relu);
        }
        tempBlock.add(Pool.maxPool2dBlock((Shape)new Shape(new long[]{2L, 2L}), (Shape)new Shape(new long[]{2L, 2L})));
        return tempBlock;
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        int numLayers = 11;
        int[][] convArch = new int[][]{{1, 64}, {1, 128}, {2, 256}, {2, 512}, {2, 512}};
        long outSize = 10L;

        public Builder setNumLayers(int numLayers) {
            this.numLayers = numLayers;
            return this;
        }

        public Builder setConvArch(int[][] convArch) {
            int numConvs = 0;
            for (int[] layer : convArch) {
                numConvs += layer[0];
            }
            if (numConvs != this.numLayers - 3) {
                throw new IllegalArgumentException("total sum of channels in the array should be equal to the ( numLayers - 3 )");
            }
            this.convArch = convArch;
            return this;
        }

        public Builder setOutSize(long outSize) {
            this.outSize = outSize;
            return this;
        }

        public Block build() {
            return VGG.vgg(this);
        }
    }
}

