package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models;

import de.jstacs.NotTrainedException;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sampling.BurnInTest;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.SamplingState;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.SimpleSamplingState;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.SamplingEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.SamplingHMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.SamplingTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.TransitionWithSufficientStatistic;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.TransitionElement;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.Pair;
import java.io.IOException;
import java.util.Arrays;
import org.biojavax.bio.seq.io.RichSequenceBuilderFactory;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/trainable/hmm/models/SamplingHigherOrderHMM.class */
public class SamplingHigherOrderHMM extends HigherOrderHMM {
    protected BurnInTest burnInTest;
    protected boolean hasSampled;
    private int numberOfStarts;
    private IntList path;
    private static final String XML_TAG = "SamplingHigherOrderHMM";

    /* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/trainable/hmm/models/SamplingHigherOrderHMM$ViterbiComputation.class */
    public enum ViterbiComputation {
        MAX,
        MAX_GAMMA,
        SAMPLING,
        SAMPLING_GAMMA,
        MAX_AND_SAMPLING,
        MAX_AND_SAMPLING_GAMMA;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static ViterbiComputation[] valuesCustom() {
            ViterbiComputation[] valuesCustom = values();
            int length = valuesCustom.length;
            ViterbiComputation[] viterbiComputationArr = new ViterbiComputation[length];
            System.arraycopy(valuesCustom, 0, viterbiComputationArr, 0, length);
            return viterbiComputationArr;
        }
    }

    public SamplingHigherOrderHMM(SamplingHMMTrainingParameterSet samplingHMMTrainingParameterSet, String[] strArr, int[] iArr, boolean[] zArr, SamplingEmission[] samplingEmissionArr, TransitionElement... transitionElementArr) throws Exception {
        super(samplingHMMTrainingParameterSet, strArr, iArr, zArr, samplingEmissionArr, transitionElementArr);
        this.hasSampled = false;
        this.path = new IntList();
    }

    public SamplingHigherOrderHMM(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM, de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel, de.jstacs.sequenceScores.SequenceScore
    /* renamed from: clone */
    public SamplingHigherOrderHMM mo112clone() throws CloneNotSupportedException {
        SamplingHigherOrderHMM samplingHigherOrderHMM = (SamplingHigherOrderHMM) super.mo112clone();
        samplingHigherOrderHMM.path = this.path.m161clone();
        if (this.burnInTest != null) {
            samplingHigherOrderHMM.burnInTest = this.burnInTest.m109clone();
        }
        return samplingHigherOrderHMM;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    protected void createStates() {
        this.states = new SimpleSamplingState[this.emissionIdx.length];
        for (int i = 0; i < this.emissionIdx.length; i++) {
            this.states[i] = new SimpleSamplingState((SamplingEmission) this.emission[this.emissionIdx[i]], this.name[i], this.forward[i]);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void acceptParameters() throws IOException {
        ((SamplingTransition) this.transition).acceptParameters();
        for (int i = 0; i < this.emission.length; i++) {
            ((SamplingEmission) this.emission[i]).acceptParameters();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void drawFromStatistics() throws Exception {
        ((SamplingTransition) this.transition).drawParametersFromStatistic();
        for (int i = 0; i < this.emission.length; i++) {
            ((SamplingEmission) this.emission[i]).drawParametersFromStatistic();
        }
    }

    protected double gibbsSampling(int i, int i2, double d, Sequence sequence) throws Exception {
        samplePath(this.path, i, i2, sequence);
        addToStatistics(i, d, sequence, this.path);
        return this.bwdMatrix[0][0];
    }

    private void addToStatistics(int i, double d, Sequence sequence, IntList intList) throws Exception {
        int i2 = 0;
        this.container[1] = 0;
        for (int i3 = 0; i3 < intList.length(); i3++) {
            int i4 = intList.get(i3);
            int childIdx = this.transition.getChildIdx(i2, this.container[1], i4);
            if (childIdx < 0) {
                throw new IllegalArgumentException("Impossible path");
            }
            ((SamplingState) this.states[i4]).addToStatistic(i, i, d, sequence);
            ((SamplingTransition) this.transition).addToStatistic(i2, this.container[1], childIdx, d, sequence, i);
            this.transition.fillTransitionInformation(i2, this.container[1], childIdx, this.container);
            if (this.container[2] == 1) {
                i++;
                i2++;
            }
        }
    }

    private double getLogGammaScoreForCurrentStatistics() {
        double logGammaScoreFromStatistic = ((TransitionWithSufficientStatistic) this.transition).getLogGammaScoreFromStatistic();
        for (int i = 0; i < this.states.length; i++) {
            logGammaScoreFromStatistic += ((SamplingState) this.states[i]).getLogGammaScoreForCurrentStatistic();
        }
        return logGammaScoreFromStatistic;
    }

    protected void gibbsSamplingStep(int i, int i2, boolean z, DataSet dataSet, double[] dArr) throws Exception {
        double d = 1.0d;
        int numberOfElements = dataSet.getNumberOfElements();
        this.burnInTest.setCurrentSamplingIndex(i);
        ((SamplingTransition) this.transition).extendSampling(i, z);
        for (int i3 = 0; i3 < this.emission.length; i3++) {
            ((SamplingEmission) this.emission[i3]).extendSampling(i, z);
        }
        this.sostream.writeln(String.valueOf(i) + " ----------------------------------------");
        for (int i4 = 0; i4 < i2; i4++) {
            resetStatistics();
            double logPriorTerm = getLogPriorTerm();
            for (int i5 = 0; i5 < numberOfElements; i5++) {
                Sequence elementAt = dataSet.getElementAt(i5);
                if (dArr != null) {
                    d = dArr[i5];
                }
                logPriorTerm += gibbsSampling(0, elementAt.getLength() - 1, d, elementAt);
            }
            this.sostream.writeln(String.valueOf(i4) + "\t" + logPriorTerm);
            this.burnInTest.setValue(logPriorTerm);
            getNewParameters();
            acceptParameters();
        }
    }

    protected void getNewParameters() throws Exception {
        drawFromStatistics();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel
    public void train(DataSet dataSet, double[] dArr) throws Exception {
        int numberOfStepsInStationaryPhase = ((SamplingHMMTrainingParameterSet) this.trainingParameter).getNumberOfStepsInStationaryPhase();
        int numberOfStepsPerIteration = ((SamplingHMMTrainingParameterSet) this.trainingParameter).getNumberOfStepsPerIteration();
        int i = 0;
        boolean z = false;
        initTraining(dataSet, dArr);
        this.sostream.writeln("GIBBS-SAMPLING - Burn-In ==============================");
        do {
            for (int i2 = 0; i2 < this.numberOfStarts; i2++) {
                if (i == 0) {
                    initializeRandomly();
                }
                gibbsSamplingStep(i2, numberOfStepsPerIteration, z, dataSet, dArr);
            }
            z = true;
            i += numberOfStepsPerIteration;
        } while (i - this.burnInTest.getLengthOfBurnIn() <= 0);
        this.sostream.writeln("GIBBS-SAMPLING - Final Sampling =======================");
        for (int i3 = 0; i3 < this.numberOfStarts; i3++) {
            gibbsSamplingStep(i3, numberOfStepsInStationaryPhase, true, dataSet, dArr);
        }
        ((SamplingTransition) this.transition).samplingStopped();
        for (int i4 = 0; i4 < this.emission.length; i4++) {
            ((SamplingEmission) this.emission[i4]).samplingStopped();
        }
        this.hasSampled = true;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        return "Sampling HMM(" + this.transition.getMaximalMarkovOrder() + ")";
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.SequenceScore
    public boolean isInitialized() {
        return this.hasSampled;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public double logProb(int i, int i2, Sequence sequence) throws Exception {
        double d = Double.NEGATIVE_INFINITY;
        int i3 = 0;
        if (!this.hasSampled) {
            throw new NotTrainedException();
        }
        for (int i4 = 0; i4 < this.numberOfStarts; i4++) {
            boolean parseParameterSet = parseParameterSet(i4, this.burnInTest.getLengthOfBurnIn());
            while (parseParameterSet) {
                d = Normalisation.getLogSum(d, super.logProb(i, i2, sequence));
                parseParameterSet = parseNextParameterSet();
                i3++;
            }
        }
        return d - Math.log(i3);
    }

    protected boolean parseParameterSet(int i, int i2) throws Exception {
        boolean parseParameterSet = ((SamplingTransition) this.transition).parseParameterSet(i, i2);
        for (int i3 = 0; i3 < this.emission.length; i3++) {
            parseParameterSet &= ((SamplingEmission) this.emission[i3]).parseParameterSet(i, i2);
        }
        return parseParameterSet;
    }

    protected boolean parseNextParameterSet() throws Exception {
        boolean parseNextParameterSet = ((SamplingTransition) this.transition).parseNextParameterSet();
        for (int i = 0; i < this.emission.length; i++) {
            parseNextParameterSet &= ((SamplingEmission) this.emission[i]).parseNextParameterSet();
        }
        return parseNextParameterSet;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    protected String getXMLTag() {
        return XML_TAG;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public void appendFurtherInformation(StringBuffer stringBuffer) {
        super.appendFurtherInformation(stringBuffer);
        XMLParser.appendObjectWithTags(stringBuffer, this.burnInTest, "burnInTest");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.hasSampled), "hasSampled");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public void extractFurtherInformation(StringBuffer stringBuffer) throws NonParsableException {
        super.extractFurtherInformation(stringBuffer);
        this.numberOfStarts = this.trainingParameter.getNumberOfStarts();
        this.burnInTest = (BurnInTest) XMLParser.extractObjectForTags(stringBuffer, "burnInTest", BurnInTest.class);
        this.hasSampled = ((Boolean) XMLParser.extractObjectForTags(stringBuffer, "hasSampled", Boolean.TYPE)).booleanValue();
        this.path = new IntList();
    }

    protected void initTraining(DataSet dataSet, double[] dArr) throws Exception {
        this.numberOfStarts = this.trainingParameter.getNumberOfStarts();
        this.burnInTest = ((SamplingHMMTrainingParameterSet) this.trainingParameter).getBurnInTest();
        this.burnInTest.resetAllValues();
        ((SamplingTransition) this.transition).initForSampling(this.numberOfStarts);
        for (int i = 0; i < this.emission.length; i++) {
            ((SamplingEmission) this.emission[i]).initForSampling(this.numberOfStarts);
        }
        furtherInits(dataSet, dArr);
    }

    protected void furtherInits(DataSet dataSet, double[] dArr) throws Exception {
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public double[][] getLogStatePosteriorMatrixFor(int i, int i2, Sequence sequence) throws Exception {
        double[][] createMatrixForStatePosterior = createMatrixForStatePosterior(i, i2);
        double[][] createMatrixForStatePosterior2 = createMatrixForStatePosterior(i, i2);
        int i3 = 0;
        for (int i4 = 0; i4 < this.states.length; i4++) {
            Arrays.fill(createMatrixForStatePosterior[i4], Double.NEGATIVE_INFINITY);
        }
        if (!this.hasSampled) {
            throw new NotTrainedException();
        }
        for (int i5 = 0; i5 < this.numberOfStarts; i5++) {
            boolean parseParameterSet = parseParameterSet(i5, this.burnInTest.getLengthOfBurnIn());
            while (parseParameterSet) {
                fillLogStatePosteriorMatrix(createMatrixForStatePosterior2, i, i2, sequence, true);
                for (int i6 = 0; i6 < this.states.length; i6++) {
                    for (int i7 = 0; i7 < createMatrixForStatePosterior[i6].length; i7++) {
                        createMatrixForStatePosterior[i6][i7] = Normalisation.getLogSum(createMatrixForStatePosterior[i6][i7], createMatrixForStatePosterior2[i6][i7]);
                    }
                }
                parseParameterSet = parseNextParameterSet();
                i3++;
            }
        }
        double log = Math.log(i3);
        for (int i8 = 0; i8 < this.states.length; i8++) {
            for (int i9 = 0; i9 < createMatrixForStatePosterior[i8].length; i9++) {
                double[] dArr = createMatrixForStatePosterior[i8];
                int i10 = i9;
                dArr[i10] = dArr[i10] - log;
            }
        }
        return getFinalStatePosterioriMatrix(createMatrixForStatePosterior);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public double getLogProbForPath(IntList intList, int i, Sequence sequence) throws Exception {
        int i2 = 0;
        DoubleList doubleList = new DoubleList(RichSequenceBuilderFactory.THRESHOLD_VALUE);
        if (!this.hasSampled) {
            throw new NotTrainedException();
        }
        for (int i3 = 0; i3 < this.numberOfStarts; i3++) {
            boolean parseParameterSet = parseParameterSet(i3, this.burnInTest.getLengthOfBurnIn());
            while (parseParameterSet) {
                doubleList.add(super.getLogProbForPath(intList, i, sequence));
                parseParameterSet = parseNextParameterSet();
                i2++;
            }
        }
        return Normalisation.getLogSum(doubleList.toArray()) - Math.log(i2);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM, de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public Pair<IntList, Double> getViterbiPathFor(int i, int i2, Sequence sequence) throws Exception {
        return getViterbiPath(i, i2, sequence, ViterbiComputation.MAX);
    }

    public Pair<IntList, Double> getViterbiPath(int i, int i2, Sequence sequence, ViterbiComputation viterbiComputation) throws Exception {
        IntList intList = new IntList();
        double d = Double.NEGATIVE_INFINITY;
        if (!this.hasSampled) {
            throw new NotTrainedException();
        }
        for (int i3 = 0; i3 < this.numberOfStarts; i3++) {
            boolean parseParameterSet = parseParameterSet(i3, this.burnInTest.getLengthOfBurnIn());
            while (parseParameterSet) {
                if (viterbiComputation == ViterbiComputation.SAMPLING || viterbiComputation == ViterbiComputation.SAMPLING_GAMMA || viterbiComputation == ViterbiComputation.MAX_AND_SAMPLING || viterbiComputation == ViterbiComputation.MAX_AND_SAMPLING_GAMMA) {
                    resetStatistics();
                    this.path.clear();
                    gibbsSampling(i, i2, 1.0d, sequence);
                    double logProbForPath = (viterbiComputation == ViterbiComputation.SAMPLING || viterbiComputation == ViterbiComputation.MAX_AND_SAMPLING) ? super.getLogProbForPath(this.path, i, sequence) : getLogGammaScoreForCurrentStatistics();
                    if (logProbForPath > d) {
                        d = logProbForPath;
                        intList.clear();
                        for (int i4 = 0; i4 < this.path.length(); i4++) {
                            intList.add(this.path.get(i4));
                        }
                    }
                }
                if (viterbiComputation == ViterbiComputation.MAX || viterbiComputation == ViterbiComputation.MAX_GAMMA || viterbiComputation == ViterbiComputation.MAX_AND_SAMPLING || viterbiComputation == ViterbiComputation.MAX_AND_SAMPLING_GAMMA) {
                    resetStatistics();
                    this.path.clear();
                    double viterbi = viterbi(this.path, i, i2, 0.0d, sequence);
                    addToStatistics(i, 1.0d, sequence, this.path);
                    double logGammaScoreForCurrentStatistics = (viterbiComputation == ViterbiComputation.MAX || viterbiComputation == ViterbiComputation.MAX_AND_SAMPLING) ? viterbi : getLogGammaScoreForCurrentStatistics();
                    if (logGammaScoreForCurrentStatistics > d) {
                        d = logGammaScoreForCurrentStatistics;
                        intList.clear();
                        for (int i5 = 0; i5 < this.path.length(); i5++) {
                            intList.add(this.path.get(i5));
                        }
                    }
                }
                parseParameterSet = parseNextParameterSet();
            }
        }
        return new Pair<>(intList, Double.valueOf(d));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getLogPosteriorFromStatistic() {
        double logPosteriorFromStatistic = ((SamplingTransition) this.transition).getLogPosteriorFromStatistic();
        for (int i = 0; i < this.states.length; i++) {
            logPosteriorFromStatistic += ((SimpleSamplingState) this.states[i]).getLogPosteriorFromStatistic();
        }
        return logPosteriorFromStatistic;
    }
}
