/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.pytorch.engine.PtNDArray;
import ai.djl.pytorch.engine.PtNDManager;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.GradientCollector;
import java.util.concurrent.atomic.AtomicBoolean;

public final class PtGradientCollector
implements GradientCollector {
    private boolean gradModel = JniUtils.isGradMode();
    private static AtomicBoolean isCollecting = new AtomicBoolean();

    public PtGradientCollector() {
        JniUtils.setGradMode(true);
        boolean wasCollecting = isCollecting.getAndSet(true);
        if (wasCollecting) {
            throw new IllegalStateException("A PtGradientCollector is already collecting. Only one can be collecting at a time");
        }
    }

    public void backward(NDArray target) {
        NDArray grad = target.getManager().ones(target.getShape(), target.getDataType()).toDevice(target.getDevice(), false);
        this.backward(target, grad, false, false);
    }

    private void backward(NDArray target, NDArray grad, boolean keepGraph, boolean createGraph) {
        JniUtils.backward((PtNDArray)target, (PtNDArray)grad, keepGraph, createGraph);
    }

    public void zeroGradients() {
        PtNDManager systemManager = PtNDManager.getSystemManager();
        for (NDArray array : systemManager.getManagedArrays()) {
            if (!array.hasGradient()) continue;
            array.getGradient().subi(array.getGradient());
        }
    }

    public void close() {
        if (!this.gradModel) {
            JniUtils.setGradMode(false);
        }
        isCollecting.set(false);
    }
}

