package de.jstacs.sequenceScores.statisticalModels.differentiable.continuous;

import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.ArbitrarySequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.SparseStringExtractor;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import java.text.NumberFormat;
import org.apache.batik.svggen.SVGSyntax;
import umontreal.iro.lecuyer.util.Num;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/differentiable/continuous/NegativeBinomial.class */
public class NegativeBinomial extends AbstractDifferentiableStatisticalModel {
    private double r;
    private double lambda;
    private double p;
    private double beta;

    public NegativeBinomial() throws IllegalArgumentException {
        super(new AlphabetContainer(new ContinuousAlphabet()), 1);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel, de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore, de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore, de.jstacs.sequenceScores.SequenceScore
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public NegativeBinomial mo112clone() throws CloneNotSupportedException {
        return (NegativeBinomial) super.mo112clone();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int i) {
        return 1;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogNormalizationConstant() {
        return 0.0d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogPartialNormalizationConstant(int i) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogPriorTerm() {
        return 0.0d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public void addGradientOfLogPriorTerm(double[] dArr, int i) throws Exception {
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getESS() {
        return 0.0d;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunction(int i, boolean z, DataSet[] dataSetArr, double[][] dArr) throws Exception {
        initializeFunctionRandomly(z);
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunctionRandomly(boolean z) throws Exception {
        this.p = 0.9d;
        this.beta = Math.log(this.p / (1.0d - this.p));
        this.r = 1.0d;
        this.lambda = Math.log(this.r);
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double getLogScoreAndPartialDerivation(Sequence sequence, int i, IntList intList, DoubleList doubleList) {
        double continuousVal = sequence.continuousVal(i);
        intList.add(0);
        doubleList.add(this.r * ((Num.digamma(continuousVal + this.r) - Num.digamma(this.r)) + Math.log(1.0d - this.p)));
        intList.add(1);
        doubleList.add((continuousVal * (1.0d - this.p)) - (this.r * this.p));
        return getLogScoreFor(sequence, i);
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public int getNumberOfParameters() {
        return 2;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double[] getCurrentParameterValues() throws Exception {
        return new double[]{this.lambda, this.beta};
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void setParameters(double[] dArr, int i) {
        this.lambda = dArr[i];
        this.r = Math.exp(this.lambda);
        this.beta = dArr[i + 1];
        this.p = Math.exp(this.beta) / (1.0d + Math.exp(this.beta));
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        return "NB";
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public double getLogScoreFor(Sequence sequence, int i) {
        double continuousVal = sequence.continuousVal(i);
        return ((Num.lnGamma(continuousVal + this.r) - Num.lnGamma(continuousVal + 1.0d)) - Num.lnGamma(this.r)) + (continuousVal * Math.log(this.p)) + (this.r * Math.log1p(-this.p));
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public boolean isInitialized() {
        return true;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String toString(NumberFormat numberFormat) {
        return "NB(" + this.r + SVGSyntax.COMMA + this.p + ")";
    }

    @Override // de.jstacs.Storable
    public StringBuffer toXML() {
        return null;
    }

    @Override // de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore
    protected void fromXML(StringBuffer stringBuffer) throws NonParsableException {
    }

    public static void main(String[] strArr) throws Exception {
        DataSet dataSet = new DataSet(new AlphabetContainer(new ContinuousAlphabet()), new SparseStringExtractor("/Users/dev/Downloads/allnb.txt", '#'));
        NegativeBinomial negativeBinomial = new NegativeBinomial();
        negativeBinomial.initializeFunctionRandomly(false);
        System.out.println(negativeBinomial);
        System.out.println(negativeBinomial.getLogScoreFor(new ArbitrarySequence(new AlphabetContainer(new ContinuousAlphabet()), 2.0d)));
        GenDisMixClassifier genDisMixClassifier = new GenDisMixClassifier(new GenDisMixClassifierParameterSet(new AlphabetContainer(new ContinuousAlphabet()), 1, (byte) 18, 1.0E-10d, 1.0E-10d, 1.0E-4d, false, OptimizableFunction.KindOfParameter.LAST, true, 1), DoesNothingLogPrior.defaultInstance, LearningPrinciple.ML, negativeBinomial, new ConstantDiffSM(1));
        genDisMixClassifier.train(dataSet, dataSet);
        System.out.println(genDisMixClassifier);
        NegativeBinomial negativeBinomial2 = (NegativeBinomial) genDisMixClassifier.getDifferentiableSequenceScore(0);
        System.out.println(negativeBinomial2);
        System.out.println(negativeBinomial2.getLogScoreFor(new ArbitrarySequence(new AlphabetContainer(new ContinuousAlphabet()), 2.0d)));
    }
}
