package projects.tals;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.alphabets.DiscreteAlphabet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.ReferenceSequenceAnnotation;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.text.NumberFormat;
import java.util.Arrays;
import projects.dispom.DispomParameterSet;

/* loaded from: input_file:projects/tals/TALgetterMixture.class */
public class TALgetterMixture extends AbstractDifferentiableStatisticalModel {
    private double ess;
    private boolean isInitialized;
    private double[] params;
    private double[] probs;
    private double[] HyperParams;
    private double[] HyperSum;
    private int p_anz;
    private boolean p_gesamte_seq;

    public TALgetterMixture(AlphabetContainer alphabetContainer, int i, double d, double[] dArr) throws Exception {
        super(alphabetContainer, 1);
        this.isInitialized = false;
        if (alphabetContainer.getAlphabetLengthAt(0) <= 0.0d) {
            throw new Exception("Alphabet wrong");
        }
        this.ess = d;
        this.HyperParams = new double[((int) alphabetContainer.getAlphabetLengthAt(0)) + 1];
        this.HyperSum = new double[((int) alphabetContainer.getAlphabetLengthAt(0)) + 1];
        Arrays.fill(this.HyperSum, (d * i) / alphabetContainer.getAlphabetLengthAt(0));
        for (int i2 = 0; i2 < dArr.length; i2++) {
            this.HyperParams[i2] = this.HyperSum[i2] * dArr[i2];
        }
        this.HyperParams[this.HyperParams.length - 1] = 0.5d * this.HyperSum[this.HyperParams.length - 1];
        this.params = new double[((int) alphabetContainer.getAlphabetLengthAt(0)) + 1];
        this.probs = new double[((int) alphabetContainer.getAlphabetLengthAt(0)) + 1];
    }

    public TALgetterMixture(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
        this.isInitialized = false;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel, de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore, de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore, de.jstacs.sequenceScores.SequenceScore
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public TALgetterMixture mo130clone() throws CloneNotSupportedException {
        TALgetterMixture tALgetterMixture = (TALgetterMixture) super.mo130clone();
        tALgetterMixture.params = (double[]) this.params.clone();
        tALgetterMixture.probs = (double[]) this.probs.clone();
        tALgetterMixture.HyperSum = (double[]) this.HyperSum.clone();
        tALgetterMixture.HyperParams = (double[]) this.HyperParams.clone();
        return tALgetterMixture;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public void addGradientOfLogPriorTerm(double[] dArr, int i) throws Exception {
        int i2 = 0;
        while (i2 < getNumberOfParameters() - 1) {
            int i3 = i;
            dArr[i3] = dArr[i3] + (this.HyperParams[i2] - (this.HyperSum[i2] * this.probs[i2]));
            i2++;
            i++;
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getESS() {
        return this.ess;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogNormalizationConstant() {
        return 0.0d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogPartialNormalizationConstant(int i) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogPriorTerm() {
        double[] dArr = new double[this.params.length];
        double d = 0.0d;
        for (int i = 0; i < this.params.length - 1; i++) {
            dArr[i] = (Gamma.logOfGamma(this.HyperParams[i]) + Gamma.logOfGamma(this.HyperSum[i] - this.HyperParams[i])) - Gamma.logOfGamma(this.HyperSum[i]);
            d = (d + ((this.HyperParams[i] * Math.log(this.probs[i])) + ((this.HyperSum[i] - this.HyperParams[i]) * Math.log1p(-this.probs[i])))) - dArr[i];
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int i) {
        return 0;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double[] getCurrentParameterValues() throws Exception {
        return (double[]) this.params.clone();
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        return "TALgetterMixture";
    }

    public double getImportance(Sequence sequence, int i) {
        return this.probs[sequence.discreteVal(i)];
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public double getLogScoreFor(Sequence sequence, int i) {
        Sequence referenceSequence = ((ReferenceSequenceAnnotation) sequence.getSequenceAnnotationByType(ReferenceSequenceAnnotation.TYPE, 0)).getReferenceSequence();
        return (i >= sequence.getLength() - this.p_anz || this.p_gesamte_seq) ? this.p_gesamte_seq ? Math.log(this.probs[referenceSequence.discreteVal(i - 1)]) + ((i - 1) * Math.log(this.probs[this.probs.length - 1])) : Math.log(this.probs[referenceSequence.discreteVal(i - 1)]) + ((i - ((sequence.getLength() - 1) - this.p_anz)) * Math.log(this.probs[this.probs.length - 1])) : Math.log(this.probs[referenceSequence.discreteVal(i - 1)]);
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double getLogScoreAndPartialDerivation(Sequence sequence, int i, IntList intList, DoubleList doubleList) {
        double log;
        int discreteVal = ((ReferenceSequenceAnnotation) sequence.getSequenceAnnotationByType(ReferenceSequenceAnnotation.TYPE, 0)).getReferenceSequence().discreteVal(i - 1);
        if (i < sequence.getLength() - this.p_anz && !this.p_gesamte_seq) {
            log = Math.log(this.probs[discreteVal]);
        } else if (this.p_gesamte_seq) {
            log = Math.log(this.probs[discreteVal]) + ((i - 1) * Math.log(this.probs[this.probs.length - 1]));
            intList.add(this.probs.length - 1);
            doubleList.add((i - 1) * (1.0d - this.probs[this.probs.length - 1]));
        } else {
            log = Math.log(this.probs[discreteVal]) + ((i - ((sequence.getLength() - 1) - this.p_anz)) * Math.log(this.probs[this.probs.length - 1]));
            intList.add(this.probs.length - 1);
            doubleList.add((i - ((sequence.getLength() - 1) - this.p_anz)) * (1.0d - this.probs[this.probs.length - 1]));
        }
        intList.add(discreteVal);
        doubleList.add(1.0d - this.probs[discreteVal]);
        return log;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public int getNumberOfParameters() {
        return this.params.length;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunction(int i, boolean z, DataSet[] dataSetArr, double[][] dArr) throws Exception {
        initializeFunctionRandomly(z);
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunctionRandomly(boolean z) throws Exception {
        for (int i = 0; i < this.params.length; i++) {
            double random = Math.random();
            this.params[i] = Math.log(random);
            this.probs[i] = random / (random + 1.0d);
        }
        this.isInitialized = true;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public boolean isInitialized() {
        return this.isInitialized;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void setParameters(double[] dArr, int i) {
        int i2 = 0;
        while (i2 < getNumberOfParameters()) {
            this.params[i2] = dArr[i];
            this.probs[i2] = Math.exp(this.params[i2]) / (1.0d + Math.exp(this.params[i2]));
            i2++;
            i++;
        }
    }

    @Override // de.jstacs.Storable
    public StringBuffer toXML() {
        StringBuffer stringBuffer = new StringBuffer();
        XMLParser.appendObjectWithTags(stringBuffer, this.alphabets, "alphabets");
        XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.length), DispomParameterSet.LENGTH);
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.ess), "ess");
        XMLParser.appendObjectWithTags(stringBuffer, this.HyperParams, "hyperParams");
        XMLParser.appendObjectWithTags(stringBuffer, this.HyperSum, "hyperSum");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.isInitialized), "isInitialized");
        XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.p_anz), "panz");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.p_gesamte_seq), "pgesseq");
        XMLParser.appendObjectWithTags(stringBuffer, this.params, "params");
        XMLParser.appendObjectWithTags(stringBuffer, this.probs, "probs");
        XMLParser.addTags(stringBuffer, "TALMSF");
        return stringBuffer;
    }

    @Override // de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore
    protected void fromXML(StringBuffer stringBuffer) throws NonParsableException {
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, "TALMSF");
        this.alphabets = (AlphabetContainer) XMLParser.extractObjectForTags(extractForTag, "alphabets");
        this.length = ((Integer) XMLParser.extractObjectForTags(extractForTag, DispomParameterSet.LENGTH, Integer.TYPE)).intValue();
        this.ess = ((Double) XMLParser.extractObjectForTags(extractForTag, "ess", Double.TYPE)).doubleValue();
        this.HyperParams = (double[]) XMLParser.extractObjectForTags(extractForTag, "hyperParams");
        this.HyperSum = (double[]) XMLParser.extractObjectForTags(extractForTag, "hyperSum");
        this.isInitialized = ((Boolean) XMLParser.extractObjectForTags(extractForTag, "isInitialized", Boolean.TYPE)).booleanValue();
        this.p_anz = ((Integer) XMLParser.extractObjectForTags(extractForTag, "panz", Integer.TYPE)).intValue();
        this.p_gesamte_seq = ((Boolean) XMLParser.extractObjectForTags(extractForTag, "pgesseq", Boolean.TYPE)).booleanValue();
        this.params = (double[]) XMLParser.extractObjectForTags(extractForTag, "params");
        this.probs = (double[]) XMLParser.extractObjectForTags(extractForTag, "probs");
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String toString(NumberFormat numberFormat) {
        StringBuffer stringBuffer = new StringBuffer();
        for (int i = 0; i < this.probs.length - 1; i++) {
            stringBuffer.append(String.valueOf(this.alphabets.getSymbol(0, i)) + "\t" + numberFormat.format(this.probs[i]) + "\n");
        }
        return stringBuffer.toString();
    }

    public void addAndSet(AlphabetContainer alphabetContainer, String[] strArr) throws WrongAlphabetException {
        DiscreteAlphabet discreteAlphabet = (DiscreteAlphabet) alphabetContainer.getAlphabetAt(0);
        double[] dArr = new double[((int) discreteAlphabet.length()) + 1];
        double[] dArr2 = new double[((int) discreteAlphabet.length()) + 1];
        Arrays.fill(dArr, 1.0d);
        Arrays.fill(dArr2, Double.POSITIVE_INFINITY);
        System.arraycopy(this.probs, 0, dArr, 0, this.probs.length - 1);
        dArr[dArr.length - 1] = this.probs[this.probs.length - 1];
        System.arraycopy(this.params, 0, dArr2, 0, this.params.length - 1);
        dArr2[dArr2.length - 1] = this.params[this.params.length - 1];
        double[] dArr3 = new double[((int) alphabetContainer.getAlphabetLengthAt(0)) + 1];
        double[] dArr4 = new double[((int) alphabetContainer.getAlphabetLengthAt(0)) + 1];
        Arrays.fill(dArr3, 1.0d);
        Arrays.fill(dArr4, 1.0d);
        for (int i = 0; i < this.HyperParams.length - 1; i++) {
            dArr3[i] = this.HyperParams[i];
            dArr4[i] = this.HyperSum[i];
        }
        dArr3[dArr3.length - 1] = this.HyperParams[this.HyperParams.length - 1];
        dArr4[dArr4.length - 1] = this.HyperSum[this.HyperSum.length - 1];
        for (String str : strArr) {
            int code = discreteAlphabet.getCode(str);
            dArr[code] = 1.0d;
            dArr2[code] = Double.POSITIVE_INFINITY;
        }
        this.params = dArr2;
        this.probs = dArr;
        this.HyperParams = dArr3;
        this.HyperSum = dArr4;
        this.alphabets = alphabetContainer;
    }
}
