package de.jstacs.sequenceScores.statisticalModels.differentiable.continuous;

import cern.colt.matrix.impl.AbstractFormatter;
import de.jstacs.algorithms.optimization.termination.SmallDifferenceOfFunctionEvaluationsCondition;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.ArbitrarySequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.trainable.DifferentiableStatisticalModelWrapperTrainSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.ToolBox;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Random;
import projects.dispom.DispomParameterSet;
import umontreal.iro.lecuyer.util.Num;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/differentiable/continuous/DirichletDiffSM.class */
public class DirichletDiffSM extends AbstractDifferentiableStatisticalModel {
    private static Random r = new Random();
    private double logNorm;
    private double[] partNorm;
    private double ess;
    private double[] pars;
    private double[] expPars;
    private double[] chi;
    private boolean isInitialized;
    private int numStarts;

    public static void main(String[] strArr) throws Exception {
        DirichletMRGParams dirichletMRGParams = new DirichletMRGParams(2.0d, 0.1d, 4.0d);
        LinkedList linkedList = new LinkedList();
        AlphabetContainer alphabetContainer = new AlphabetContainer(new ContinuousAlphabet());
        for (int i = 0; i < 10000; i++) {
            linkedList.add(new ArbitrarySequence(alphabetContainer, DirichletMRG.DEFAULT_INSTANCE.generate(3, dirichletMRGParams)).getSubSequence(0, 2));
        }
        DataSet dataSet = new DataSet("", linkedList);
        DifferentiableStatisticalModelWrapperTrainSM differentiableStatisticalModelWrapperTrainSM = new DifferentiableStatisticalModelWrapperTrainSM(new DirichletDiffSM(alphabetContainer, 2, new double[]{10.0d, 0.001d, 0.001d}, 1.0d, 10), 1, (byte) 20, new SmallDifferenceOfFunctionEvaluationsCondition(1.0E-6d), 1.0E-6d, 1.0E-4d);
        differentiableStatisticalModelWrapperTrainSM.train(dataSet);
        System.out.println(differentiableStatisticalModelWrapperTrainSM);
    }

    public DirichletDiffSM(AlphabetContainer alphabetContainer, int i, double[] dArr, double d, int i2) throws IllegalArgumentException {
        super(alphabetContainer, i);
        if (dArr.length != i + 1) {
            throw new IllegalArgumentException();
        }
        this.pars = new double[i + 1];
        this.ess = d;
        this.expPars = new double[i + 1];
        Arrays.fill(this.expPars, 1.0d);
        this.chi = (double[]) dArr.clone();
        this.numStarts = i2;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel, de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore, de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore, de.jstacs.sequenceScores.SequenceScore
    /* renamed from: clone */
    public DirichletDiffSM mo106clone() throws CloneNotSupportedException {
        DirichletDiffSM dirichletDiffSM = (DirichletDiffSM) super.mo106clone();
        dirichletDiffSM.chi = (double[]) this.chi.clone();
        dirichletDiffSM.expPars = (double[]) this.expPars.clone();
        dirichletDiffSM.pars = (double[]) this.pars.clone();
        if (this.partNorm != null) {
            dirichletDiffSM.partNorm = (double[]) this.partNorm.clone();
        }
        return dirichletDiffSM;
    }

    public DirichletDiffSM(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int i) {
        return 1;
    }

    @Override // de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore, de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public int getNumberOfRecommendedStarts() {
        return this.numStarts;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogNormalizationConstant() {
        return 0.0d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogPartialNormalizationConstant(int i) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getESS() {
        return this.ess;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogPriorTerm() {
        double d = this.logNorm * this.ess;
        for (int i = 0; i < this.pars.length; i++) {
            d += (this.ess * this.chi[i] * this.expPars[i]) + this.pars[i];
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public void addGradientOfLogPriorTerm(double[] dArr, int i) throws Exception {
        for (int i2 = 0; i2 < this.pars.length; i2++) {
            int i3 = i2 + i;
            dArr[i3] = dArr[i3] + (this.ess * (this.partNorm[i2] + (this.chi[i2] * this.expPars[i2]))) + 1.0d;
        }
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunction(int i, boolean z, DataSet[] dataSetArr, double[][] dArr) throws Exception {
        double d;
        double d2;
        double[] dArr2 = new double[this.length + 1];
        for (int i2 = 0; i2 < dataSetArr[i].getNumberOfElements(); i2++) {
            double d3 = 1.0d;
            for (int i3 = 0; i3 < dArr2.length; i3++) {
                if (i3 < dArr2.length - 1) {
                    double continuousVal = dataSetArr[i].getElementAt(i2).continuousVal(i3);
                    d3 -= continuousVal;
                    d = continuousVal;
                    d2 = dArr[i][i2];
                } else {
                    d = d3;
                    d2 = dArr[i][i2];
                }
                double d4 = d * d2;
                int i4 = i3;
                dArr2[i4] = dArr2[i4] + d4;
            }
        }
        Normalisation.sumNormalisation(dArr2);
        for (int i5 = 0; i5 < dArr2.length; i5++) {
            dArr2[i5] = Math.log((dArr2[i5] * 2.0d) + 2.0d);
        }
        setParameters(dArr2, 0);
        this.isInitialized = true;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunctionRandomly(boolean z) throws Exception {
        for (int i = 0; i < this.pars.length; i++) {
            this.pars[i] = 1.0d + Math.abs(r.nextGaussian());
        }
        precompute();
        this.isInitialized = true;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        return getClass().getSimpleName();
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public double getLogScoreFor(Sequence sequence, int i) {
        double d = this.logNorm;
        double d2 = 1.0d;
        for (int i2 = 0; i2 < this.pars.length - 1; i2++) {
            d += Math.log(sequence.continuousVal(i + i2)) * (this.expPars[i2] - 1.0d);
            d2 -= sequence.continuousVal(i + i2);
        }
        double log = d + (Math.log(d2) * (this.expPars[this.expPars.length - 1] - 1.0d));
        if (Double.isInfinite(log)) {
            System.out.println(String.valueOf(Arrays.toString(this.pars)) + AbstractFormatter.DEFAULT_COLUMN_SEPARATOR + Arrays.toString(this.expPars) + AbstractFormatter.DEFAULT_COLUMN_SEPARATOR + sequence);
        }
        return log;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double getLogScoreAndPartialDerivation(Sequence sequence, int i, IntList intList, DoubleList doubleList) {
        double log;
        double d = this.logNorm;
        double d2 = 1.0d;
        for (int i2 = 0; i2 < this.pars.length; i2++) {
            if (i2 < this.pars.length - 1) {
                log = Math.log(sequence.continuousVal(i + i2));
                d2 -= sequence.continuousVal(i + i2);
            } else {
                log = Math.log(d2);
            }
            d += log * (this.expPars[i2] - 1.0d);
            intList.add(i2);
            doubleList.add(this.partNorm[i2] + (log * this.expPars[i2]));
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public int getNumberOfParameters() {
        return this.pars.length;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double[] getCurrentParameterValues() throws Exception {
        return (double[]) this.pars.clone();
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void setParameters(double[] dArr, int i) {
        for (int i2 = 0; i2 < this.pars.length; i2++) {
            this.pars[i2] = dArr[i2 + i];
        }
        precompute();
    }

    private void precompute() {
        try {
            this.logNorm = 0.0d;
            if (this.partNorm == null) {
                this.partNorm = new double[this.pars.length];
            }
            for (int i = 0; i < this.pars.length; i++) {
                this.expPars[i] = Math.exp(this.pars[i]);
            }
            this.logNorm += Num.lnGamma(ToolBox.sum(this.expPars));
            for (int i2 = 0; i2 < this.pars.length; i2++) {
                this.logNorm -= Num.lnGamma(this.expPars[i2]);
                this.partNorm[i2] = (this.expPars[i2] * Num.digamma(ToolBox.sum(this.expPars))) - (this.expPars[i2] * Num.digamma(this.expPars[i2]));
            }
        } catch (Exception e) {
            System.out.println(Arrays.toString(this.pars));
            System.out.println(Arrays.toString(this.expPars));
            e.printStackTrace();
            throw new RuntimeException();
        }
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public boolean isInitialized() {
        return this.isInitialized;
    }

    @Override // de.jstacs.Storable
    public StringBuffer toXML() {
        StringBuffer stringBuffer = new StringBuffer();
        XMLParser.appendObjectWithTags(stringBuffer, this.alphabets, "alphabets");
        XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.length), DispomParameterSet.LENGTH);
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.ess), "ess");
        XMLParser.appendObjectWithTags(stringBuffer, this.pars, "pars");
        XMLParser.appendObjectWithTags(stringBuffer, this.chi, "chi");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.isInitialized), "isInitialized");
        XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.numStarts), "numStarts");
        XMLParser.addTags(stringBuffer, "Dirichlet");
        return stringBuffer;
    }

    @Override // de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore
    protected void fromXML(StringBuffer stringBuffer) throws NonParsableException {
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, "Dirichlet");
        this.alphabets = (AlphabetContainer) XMLParser.extractObjectForTags(extractForTag, "alphabets");
        this.length = ((Integer) XMLParser.extractObjectForTags(extractForTag, DispomParameterSet.LENGTH, Integer.TYPE)).intValue();
        this.ess = ((Double) XMLParser.extractObjectForTags(extractForTag, "ess", Double.TYPE)).doubleValue();
        this.pars = (double[]) XMLParser.extractObjectForTags(extractForTag, "pars");
        this.chi = (double[]) XMLParser.extractObjectForTags(extractForTag, "chi");
        this.isInitialized = ((Boolean) XMLParser.extractObjectForTags(extractForTag, "isInitialized", Boolean.TYPE)).booleanValue();
        this.numStarts = ((Integer) XMLParser.extractObjectForTags(extractForTag, "numStarts", Integer.TYPE)).intValue();
        this.expPars = new double[this.pars.length];
        precompute();
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String toString(NumberFormat numberFormat) {
        return Arrays.toString(this.expPars);
    }
}
