/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.math.optimisers;

import com.oracle.labs.mlrg.olcut.config.Config;
import com.oracle.labs.mlrg.olcut.config.Configurable;
import com.oracle.labs.mlrg.olcut.provenance.ConfiguredObjectProvenance;
import com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl;
import java.util.Arrays;
import java.util.function.DoubleUnaryOperator;
import java.util.logging.Logger;
import org.tribuo.math.Parameters;
import org.tribuo.math.StochasticGradientOptimiser;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Matrix;
import org.tribuo.math.la.MatrixIterator;
import org.tribuo.math.la.MatrixTuple;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorIterator;
import org.tribuo.math.la.VectorTuple;

public class AdaGradRDA
implements StochasticGradientOptimiser {
    private static final Logger logger = Logger.getLogger(AdaGradRDA.class.getName());
    @Config(mandatory=true, description="Initial learning rate used to scale the gradients.")
    private double initialLearningRate;
    @Config(description="Epsilon for numerical stability around zero.")
    private double epsilon = 1.0E-6;
    @Config(description="l1 regularization penalty.")
    private double l1 = 0.0;
    @Config(description="l2 regularization penalty.")
    private double l2 = 0.0;
    @Config(description="Number of examples to scale the l1 and l2 penalties by.")
    private int numExamples = 1;
    private Parameters parameters = null;

    public AdaGradRDA(double initialLearningRate, double epsilon, double l1, double l2, int numExamples) {
        this.initialLearningRate = initialLearningRate;
        this.epsilon = epsilon;
        this.l1 = l1;
        this.l2 = l2;
        this.numExamples = numExamples;
    }

    public AdaGradRDA(double initialLearningRate, double epsilon) {
        this(initialLearningRate, epsilon, 0.0, 0.0, 1);
    }

    private AdaGradRDA() {
    }

    @Override
    public void initialise(Parameters parameters) {
        this.parameters = parameters;
        Tensor[] curParams = parameters.get();
        Tensor[] newParams = new Tensor[curParams.length];
        for (int i = 0; i < newParams.length; ++i) {
            if (curParams[i] instanceof DenseVector) {
                newParams[i] = new AdaGradRDAVector((DenseVector)curParams[i], this.initialLearningRate, this.epsilon, this.l1 / (double)this.numExamples, this.l2 / (double)this.numExamples);
                continue;
            }
            if (curParams[i] instanceof DenseMatrix) {
                newParams[i] = new AdaGradRDAMatrix((DenseMatrix)curParams[i], this.initialLearningRate, this.epsilon, this.l1 / (double)this.numExamples, this.l2 / (double)this.numExamples);
                continue;
            }
            throw new IllegalStateException("Unknown Tensor subclass");
        }
        parameters.set(newParams);
    }

    @Override
    public Tensor[] step(Tensor[] updates, double weight) {
        for (Tensor update : updates) {
            update.scaleInPlace(weight);
        }
        return updates;
    }

    @Override
    public void finalise() {
        Tensor[] curParams = this.parameters.get();
        Tensor[] newParams = new Tensor[curParams.length];
        for (int i = 0; i < newParams.length; ++i) {
            if (!(curParams[i] instanceof AdaGradRDATensor)) {
                throw new IllegalStateException("Finalising a Parameters which wasn't initialised with AdaGradRDA");
            }
            newParams[i] = ((AdaGradRDATensor)((Object)curParams[i])).convertToDense();
        }
        this.parameters.set(newParams);
    }

    public String toString() {
        return "AdaGradRDA(initialLearningRate=" + this.initialLearningRate + ",epsilon=" + this.epsilon + ",l1=" + this.l1 + ",l2=" + this.l2 + ")";
    }

    @Override
    public void reset() {
        this.parameters = null;
    }

    @Override
    public AdaGradRDA copy() {
        return new AdaGradRDA(this.initialLearningRate, this.epsilon, this.l1, this.l2, this.numExamples);
    }

    public ConfiguredObjectProvenance getProvenance() {
        return new ConfiguredObjectProvenanceImpl((Configurable)this, "StochasticGradientOptimiser");
    }

    private static class AdaGradRDAMatrix
    extends DenseMatrix
    implements AdaGradRDATensor {
        private final double learningRate;
        private final double epsilon;
        private final double l1;
        private final double l2;
        private final double[][] gradSquares;
        private int iteration;

        public AdaGradRDAMatrix(DenseMatrix v, double learningRate, double epsilon, double l1, double l2) {
            super(v);
            this.learningRate = learningRate;
            this.epsilon = epsilon;
            this.l1 = l1;
            this.l2 = l2;
            this.gradSquares = new double[v.getDimension1Size()][v.getDimension2Size()];
            this.iteration = 0;
        }

        @Override
        public DenseMatrix convertToDense() {
            return new DenseMatrix(this);
        }

        @Override
        public DenseVector leftMultiply(SGDVector input) {
            if (input.size() == this.dim2) {
                double[] output = new double[this.dim1];
                for (VectorTuple tuple : input) {
                    for (int i = 0; i < output.length; ++i) {
                        int n = i;
                        output[n] = output[n] + this.get(i, tuple.index) * tuple.value;
                    }
                }
                return DenseVector.createDenseVector(output);
            }
            throw new IllegalArgumentException("input.size() != dim2");
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        @Override
        public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) {
            if (!(other instanceof Matrix)) throw new IllegalStateException("Adding a non-Matrix to a Matrix");
            Matrix otherMat = (Matrix)other;
            if (this.dim1 != otherMat.getDimension1Size() || this.dim2 != otherMat.getDimension2Size()) throw new IllegalStateException("Matrices are not the same size, this(" + this.dim1 + "," + this.dim2 + "), other(" + otherMat.getDimension1Size() + "," + otherMat.getDimension2Size() + ")");
            for (MatrixTuple tuple : otherMat) {
                double update = f.applyAsDouble(tuple.value);
                double[] dArray = this.values[tuple.i];
                int n = tuple.j;
                dArray[n] = dArray[n] + update;
                double[] dArray2 = this.gradSquares[tuple.i];
                int n2 = tuple.j;
                dArray2[n2] = dArray2[n2] + update * update;
            }
        }

        @Override
        public double get(int i, int j) {
            if (this.gradSquares[i][j] == 0.0) {
                return this.values[i][j];
            }
            double h = (Math.sqrt(this.gradSquares[i][j]) + this.epsilon) / this.learningRate + (double)this.iteration * this.l2;
            double rate = 1.0 / h;
            return rate * AdaGradRDATensor.truncate(this.values[i][j], (double)this.iteration * this.l1);
        }

        @Override
        public MatrixIterator iterator() {
            return new RDAMatrixIterator(this);
        }

        private static class RDAMatrixIterator
        implements MatrixIterator {
            private final AdaGradRDAMatrix matrix;
            private final MatrixTuple tuple;
            private final int dim2;
            private int i;
            private int j;

            public RDAMatrixIterator(AdaGradRDAMatrix matrix) {
                this.matrix = matrix;
                this.tuple = new MatrixTuple();
                this.dim2 = matrix.dim2;
                this.i = 0;
                this.j = 0;
            }

            @Override
            public MatrixTuple getReference() {
                return this.tuple;
            }

            @Override
            public boolean hasNext() {
                return this.i < this.matrix.dim1 && this.j < this.matrix.dim2;
            }

            @Override
            public MatrixTuple next() {
                this.tuple.i = this.i;
                this.tuple.j = this.j;
                this.tuple.value = this.matrix.get(this.i, this.j);
                if (this.j < this.dim2 - 1) {
                    ++this.j;
                } else {
                    ++this.i;
                    this.j = 0;
                }
                return this.tuple;
            }
        }
    }

    private static class AdaGradRDAVector
    extends DenseVector
    implements AdaGradRDATensor {
        private final double learningRate;
        private final double epsilon;
        private final double l1;
        private final double l2;
        private final double[] gradSquares;
        private int iteration;

        public AdaGradRDAVector(DenseVector v, double learningRate, double epsilon, double l1, double l2) {
            super(v);
            this.learningRate = learningRate;
            this.epsilon = epsilon;
            this.l1 = l1;
            this.l2 = l2;
            this.gradSquares = new double[v.size()];
            this.iteration = 0;
        }

        private AdaGradRDAVector(double[] values, double learningRate, double epsilon, double l1, double l2, double[] gradSquares, int iteration) {
            super(values);
            this.learningRate = learningRate;
            this.epsilon = epsilon;
            this.l1 = l1;
            this.l2 = l2;
            this.gradSquares = gradSquares;
            this.iteration = iteration;
        }

        @Override
        public DenseVector convertToDense() {
            return DenseVector.createDenseVector(this.toArray());
        }

        @Override
        public AdaGradRDAVector copy() {
            return new AdaGradRDAVector(Arrays.copyOf(this.elements, this.elements.length), this.learningRate, this.epsilon, this.l1, this.l2, Arrays.copyOf(this.gradSquares, this.gradSquares.length), this.iteration);
        }

        @Override
        public double[] toArray() {
            double[] newValues = new double[this.elements.length];
            for (int i = 0; i < newValues.length; ++i) {
                newValues[i] = this.get(i);
            }
            return newValues;
        }

        @Override
        public double get(int index) {
            if (this.gradSquares[index] == 0.0) {
                return this.elements[index];
            }
            double h = (Math.sqrt(this.gradSquares[index]) + this.epsilon) / this.learningRate + (double)this.iteration * this.l2;
            double rate = 1.0 / h;
            return rate * AdaGradRDATensor.truncate(this.elements[index], (double)this.iteration * this.l1);
        }

        @Override
        public double sum() {
            double sum = 0.0;
            for (int i = 0; i < this.elements.length; ++i) {
                sum += this.get(i);
            }
            return sum;
        }

        @Override
        public void intersectAndAddInPlace(Tensor other, DoubleUnaryOperator f) {
            ++this.iteration;
            SGDVector otherVec = (SGDVector)other;
            for (VectorTuple tuple : otherVec) {
                double update = f.applyAsDouble(tuple.value);
                int n = tuple.index;
                this.elements[n] = this.elements[n] + update;
                int n2 = tuple.index;
                this.gradSquares[n2] = this.gradSquares[n2] + update * update;
            }
        }

        @Override
        public int indexOfMax() {
            int index = 0;
            double value = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < this.elements.length; ++i) {
                double tmp = this.get(i);
                if (!(tmp > value)) continue;
                index = i;
                value = tmp;
            }
            return index;
        }

        @Override
        public double maxValue() {
            double value = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < this.elements.length; ++i) {
                double tmp = this.get(i);
                if (!(tmp > value)) continue;
                value = tmp;
            }
            return value;
        }

        @Override
        public double minValue() {
            double value = Double.POSITIVE_INFINITY;
            for (int i = 0; i < this.elements.length; ++i) {
                double tmp = this.get(i);
                if (!(tmp < value)) continue;
                value = tmp;
            }
            return value;
        }

        @Override
        public double dot(SGDVector other) {
            double score = 0.0;
            for (VectorTuple tuple : other) {
                score += this.get(tuple.index) * tuple.value;
            }
            return score;
        }

        @Override
        public VectorIterator iterator() {
            return new RDAVectorIterator(this);
        }

        private static class RDAVectorIterator
        implements VectorIterator {
            private final AdaGradRDAVector vector;
            private final VectorTuple tuple;
            private int index;

            public RDAVectorIterator(AdaGradRDAVector vector) {
                this.vector = vector;
                this.tuple = new VectorTuple();
                this.index = 0;
            }

            @Override
            public boolean hasNext() {
                return this.index < this.vector.size();
            }

            @Override
            public VectorTuple next() {
                this.tuple.index = this.index;
                this.tuple.value = this.vector.get(this.index);
                ++this.index;
                return this.tuple;
            }

            @Override
            public VectorTuple getReference() {
                return this.tuple;
            }
        }
    }

    private static interface AdaGradRDATensor {
        public Tensor convertToDense();

        public static double truncate(double input, double threshold) {
            if (input > threshold) {
                return input - threshold;
            }
            if (input < -threshold) {
                return input + threshold;
            }
            return 0.0;
        }
    }
}

