/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.loss.Loss;

public class IndexLoss
extends Loss {
    private Loss loss;
    private Integer predictionsIndex;
    private Integer labelsIndex;

    public IndexLoss(Loss loss, int index) {
        this(loss, index, index);
    }

    public IndexLoss(Loss loss, Integer predictionsIndex, Integer labelsIndex) {
        super(loss.getName());
        this.loss = loss;
        this.predictionsIndex = predictionsIndex;
        this.labelsIndex = labelsIndex;
    }

    @Override
    public NDArray evaluate(NDList labels, NDList predictions) {
        return this.loss.evaluate(this.getLabels(labels), this.getPredictions(predictions));
    }

    private NDList getPredictions(NDList predictions) {
        if (this.predictionsIndex == null) {
            return predictions;
        }
        return new NDList((NDArray)predictions.get(this.predictionsIndex));
    }

    private NDList getLabels(NDList labels) {
        if (this.labelsIndex == null) {
            return labels;
        }
        return new NDList((NDArray)labels.get(this.labelsIndex));
    }
}

