package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models;

import de.jstacs.algorithms.optimization.termination.AbstractTerminationCondition;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.WrongLengthException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.results.Result;
import de.jstacs.results.ResultSet;
import de.jstacs.results.StorableResult;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.SimpleState;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.TrainableState;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.Emission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.BaumWelchParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.HMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.MaxHMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.ViterbiParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.BasicHigherOrderTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.HigherOrderTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.TrainableTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.Transition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.TransitionElement;
import de.jstacs.sequenceScores.statisticalModels.trainable.mixture.AbstractMixtureTrainSM;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.Pair;
import de.jstacs.utils.Time;
import de.jstacs.utils.ToolBox;
import java.util.Arrays;
import javax.naming.OperationNotSupportedException;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/trainable/hmm/models/HigherOrderHMM.class */
public class HigherOrderHMM extends AbstractHMM {
    protected int[] container;
    protected double[] logEmission;
    private double[][][] forwardIntermediate;
    protected double[] backwardIntermediate;
    protected int[][] numberOfSummands;
    protected IntList stateList;
    protected boolean skipInit;
    private static final String XML_TAG = "HigherOrderHMM";

    /* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/trainable/hmm/models/HigherOrderHMM$Compute.class */
    private static class Compute {
        WorkerThread[] workers;
        TrainableTransition[] transition;
        Emission[] emission;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/trainable/hmm/models/HigherOrderHMM$Compute$WorkerState.class */
        public enum WorkerState {
            TRAIN,
            WAIT,
            STOP;

            /* renamed from: values, reason: to resolve conflict with enum method */
            public static WorkerState[] valuesCustom() {
                WorkerState[] valuesCustom = values();
                int length = valuesCustom.length;
                WorkerState[] workerStateArr = new WorkerState[length];
                System.arraycopy(valuesCustom, 0, workerStateArr, 0, length);
                return workerStateArr;
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/trainable/hmm/models/HigherOrderHMM$Compute$WorkerThread.class */
        public class WorkerThread extends Thread {
            private boolean exception;
            private WorkerState state;
            private int idx;
            private int start;
            private int end;
            private double score;
            private DataSet data;
            private double[] weights;
            private HigherOrderHMM hmm;

            public WorkerThread(int i, HigherOrderHMM higherOrderHMM) {
                this.idx = i;
                this.hmm = higherOrderHMM;
                setDaemon(true);
                this.state = WorkerState.WAIT;
                start();
            }

            /* JADX INFO: Access modifiers changed from: private */
            public void set(int i, int i2, DataSet dataSet, double[] dArr) {
                this.start = i;
                this.end = i2;
                this.score = 0.0d;
                this.data = dataSet;
                this.weights = dArr;
                this.state = WorkerState.WAIT;
            }

            public double getScore() {
                return this.score;
            }

            public synchronized void setState(WorkerState workerState) {
                this.state = workerState;
                notify();
            }

            /* JADX WARN: Multi-variable type inference failed */
            /* JADX WARN: Type inference failed for: r0v10, types: [java.lang.Throwable] */
            /* JADX WARN: Type inference failed for: r0v14 */
            /* JADX WARN: Type inference failed for: r0v9, types: [de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM$Compute] */
            @Override // java.lang.Thread, java.lang.Runnable
            public synchronized void run() {
                this.exception = false;
                while (this.state != WorkerState.STOP) {
                    if (this.state == WorkerState.WAIT) {
                        try {
                            wait();
                        } catch (InterruptedException e) {
                        }
                    } else {
                        try {
                            this.score = this.hmm.doOneStep(this.data, this.weights, this.start, this.end);
                        } catch (Exception e2) {
                            this.exception = true;
                            e2.printStackTrace();
                        }
                        ?? r0 = Compute.this;
                        synchronized (r0) {
                            this.state = WorkerState.WAIT;
                            Compute.this.notify();
                            r0 = r0;
                        }
                    }
                }
            }

            public boolean isWaiting() {
                return this.state == WorkerState.WAIT;
            }
        }

        public Compute(int i, HigherOrderHMM higherOrderHMM) throws CloneNotSupportedException {
            this.workers = new WorkerThread[i];
            this.workers[0] = new WorkerThread(0, higherOrderHMM);
            for (int i2 = 1; i2 < i; i2++) {
                this.workers[i2] = new WorkerThread(i2, higherOrderHMM.mo116clone());
            }
            this.transition = new TrainableTransition[i];
            this.emission = new Emission[i];
        }

        private synchronized void waitUntilWorkersFinished() {
            int i = -1;
            boolean z = false;
            while (true) {
                int i2 = 0;
                while (i2 < this.workers.length && this.workers[i2].isWaiting()) {
                    if (this.workers[i2].exception) {
                        i = i2;
                        z = true;
                    }
                    i2++;
                }
                if (i2 == this.workers.length) {
                    break;
                } else {
                    try {
                        wait();
                    } catch (InterruptedException e) {
                    }
                }
            }
            if (z) {
                for (int i3 = 0; i3 < this.workers.length; i3++) {
                    this.workers[i3].interrupt();
                }
                stopThreads();
                throw new RuntimeException("Terminate program, since at least thread " + i + " throws an exception.");
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void stopThreads() {
            for (int i = 0; i < this.workers.length; i++) {
                this.workers[i].setState(WorkerState.STOP);
                this.workers[i] = null;
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void setDataSet(DataSet dataSet, double[] dArr) {
            int i = 0;
            int numberOfElements = dataSet.getNumberOfElements();
            for (int i2 = 0; i2 < this.workers.length - 1; i2++) {
                this.workers[i2].set(i, ((i2 + 1) * numberOfElements) / this.workers.length, dataSet, dArr);
                i = this.workers[i2].end;
            }
            this.workers[this.workers.length - 1].set(i, numberOfElements, dataSet, dArr);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public double oneIteration() {
            for (int i = 0; i < this.workers.length; i++) {
                this.workers[i].hmm.resetStatistics();
                this.workers[i].setState(WorkerState.TRAIN);
            }
            waitUntilWorkersFinished();
            double d = 0.0d;
            for (int i2 = 0; i2 < this.workers.length; i2++) {
                d += this.workers[i2].getScore();
            }
            return d;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void estimateFromStatistics() {
            if (this.workers.length > 1) {
                for (int i = 0; i < this.workers.length; i++) {
                    this.transition[i] = (TrainableTransition) this.workers[i].hmm.transition;
                }
                ((TrainableTransition) this.workers[0].hmm.transition).joinStatistics(this.transition);
                for (int i2 = 0; i2 < this.workers[0].hmm.emission.length; i2++) {
                    for (int i3 = 0; i3 < this.workers.length; i3++) {
                        this.emission[i3] = this.workers[i3].hmm.emission[i2];
                    }
                    this.workers[0].hmm.emission[i2].joinStatistics(this.emission);
                }
            }
            this.workers[0].hmm.estimateFromStatistics();
            setParameters();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void setParameters() {
            for (int i = 1; i < this.workers.length; i++) {
                this.workers[i].hmm.transition.setParameters(this.workers[0].hmm.transition);
                for (int i2 = 0; i2 < this.workers[0].hmm.emission.length; i2++) {
                    this.workers[i].hmm.emission[i2].setParameters(this.workers[0].hmm.emission[i2]);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/trainable/hmm/models/HigherOrderHMM$Type.class */
    public enum Type {
        LIKELIHOOD,
        VITERBI,
        BAUM_WELCH;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static Type[] valuesCustom() {
            Type[] valuesCustom = values();
            int length = valuesCustom.length;
            Type[] typeArr = new Type[length];
            System.arraycopy(valuesCustom, 0, typeArr, 0, length);
            return typeArr;
        }
    }

    public HigherOrderHMM(HMMTrainingParameterSet hMMTrainingParameterSet, String[] strArr, Emission[] emissionArr, BasicHigherOrderTransition.AbstractTransitionElement... abstractTransitionElementArr) throws Exception {
        this(hMMTrainingParameterSet, strArr, null, null, emissionArr, abstractTransitionElementArr);
    }

    public HigherOrderHMM(HMMTrainingParameterSet hMMTrainingParameterSet, String[] strArr, int[] iArr, boolean[] zArr, Emission[] emissionArr, BasicHigherOrderTransition.AbstractTransitionElement... abstractTransitionElementArr) throws Exception {
        super(hMMTrainingParameterSet, strArr, iArr, zArr, emissionArr);
        createStates();
        initTransition(abstractTransitionElementArr);
        determineFinalStates();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public void createHelperVariables() {
        if (this.container == null) {
            this.container = new int[3];
            this.logEmission = new double[this.states.length];
            int i = 0;
            int maximalMarkovOrder = this.transition.getMaximalMarkovOrder();
            for (int i2 = 0; i2 <= maximalMarkovOrder; i2++) {
                i = Math.max(i, this.transition.getNumberOfIndexes(i2));
            }
            this.forwardIntermediate = new double[2][i][this.transition.getMaximalInDegree() + 1];
            this.backwardIntermediate = new double[this.states.length + 1];
            this.numberOfSummands = new int[2][this.forwardIntermediate[0].length];
            this.stateList = new IntList();
        }
    }

    public HigherOrderHMM(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
        createHelperVariables();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    protected String getXMLTag() {
        return XML_TAG;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public void appendFurtherInformation(StringBuffer stringBuffer) {
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.skipInit), "skipInit");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public void extractFurtherInformation(StringBuffer stringBuffer) throws NonParsableException {
        this.skipInit = ((Boolean) XMLParser.extractObjectForTags(stringBuffer, "skipInit", Boolean.TYPE)).booleanValue();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM, de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel, de.jstacs.sequenceScores.SequenceScore
    /* renamed from: clone */
    public HigherOrderHMM mo116clone() throws CloneNotSupportedException {
        HigherOrderHMM higherOrderHMM = (HigherOrderHMM) super.mo116clone();
        higherOrderHMM.container = null;
        higherOrderHMM.createHelperVariables();
        return higherOrderHMM;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    protected void createStates() {
        this.states = new SimpleState[this.emissionIdx.length];
        for (int i = 0; i < this.emissionIdx.length; i++) {
            this.states[i] = new SimpleState(this.emission[this.emissionIdx[i]], this.name[i], this.forward[i]);
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogPriorTerm() {
        double logPriorTerm = this.transition.getLogPriorTerm();
        for (int i = 0; i < this.emission.length; i++) {
            logPriorTerm += this.emission[i].getLogPriorTerm();
        }
        return logPriorTerm;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public double getLogProbForPath(IntList intList, int i, Sequence sequence) throws Exception {
        if (!this.finalState[intList.get(intList.length() - 1)]) {
            throw new IllegalArgumentException("The last state of the path is no final state. Hence the path is not valid.");
        }
        double d = 0.0d;
        int i2 = 0;
        this.container[1] = 0;
        for (int i3 = 0; i3 < intList.length(); i3++) {
            int i4 = intList.get(i3);
            int childIdx = this.transition.getChildIdx(i2, this.container[1], i4);
            if (childIdx < 0) {
                throw new IllegalArgumentException("Impossible path");
            }
            d += this.transition.getLogScoreFor(i2, this.container[1], childIdx, sequence, i) + this.states[i4].getLogScoreFor(i, i, sequence);
            this.transition.fillTransitionInformation(i2, this.container[1], childIdx, this.container);
            if (this.container[2] == 1) {
                i++;
                i2++;
            }
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public void fillLogStatePosteriorMatrix(double[][] dArr, int i, int i2, Sequence sequence, boolean z) throws Exception {
        int i3 = (i2 - i) + 1;
        if (this.transition.getMaximalMarkovOrder() == 0) {
            for (int i4 = 0; i4 < i3; i4++) {
                for (int i5 = 0; i5 < this.states.length; i5++) {
                    this.transition.fillTransitionInformation(0, 0, i5, this.container);
                    double logScoreFor = this.transition.getLogScoreFor(0, 0, i5, sequence, i + i4) + this.states[this.container[0]].getLogScoreFor(i + i4, i + i4, sequence);
                    dArr[this.container[0]][i4 + 1] = logScoreFor;
                    this.logEmission[i5] = logScoreFor;
                }
                double logSum = Normalisation.getLogSum(this.logEmission);
                for (int i6 = 0; i6 < this.states.length; i6++) {
                    double[] dArr2 = dArr[i6];
                    int i7 = i4 + 1;
                    dArr2[i7] = dArr2[i7] - logSum;
                }
            }
            return;
        }
        fillFwdMatrix(i, i2, sequence);
        fillBwdMatrix(i, i2, sequence);
        double d = this.bwdMatrix[0][0];
        for (int i8 = 0; i8 <= i3; i8++) {
            for (int i9 = 0; i9 < this.states.length; i9++) {
                dArr[i9][i8] = Double.NEGATIVE_INFINITY;
            }
            for (int i10 = 0; i10 < this.fwdMatrix[i8].length; i10++) {
                int lastContextState = this.transition.getLastContextState(i8, i10);
                if (lastContextState >= 0 && (!z || !this.states[lastContextState].isSilent())) {
                    dArr[lastContextState][i8] = Normalisation.getLogSum(dArr[lastContextState][i8], this.fwdMatrix[i8][i10] + this.bwdMatrix[i8][i10]);
                }
            }
            for (int i11 = 0; i11 < this.states.length; i11++) {
                double[] dArr3 = dArr[i11];
                int i12 = i8;
                dArr3[i12] = dArr3[i12] - d;
            }
            i++;
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    protected void fillFwdMatrix(int i, int i2, Sequence sequence) throws OperationNotSupportedException, WrongLengthException {
        int i3 = 0;
        provideMatrix(0, (i2 - i) + 1);
        Arrays.fill(this.numberOfSummands[0], 0);
        this.numberOfSummands[0][0] = 1;
        this.forwardIntermediate[0][0][0] = 0.0d;
        while (i <= i2) {
            for (int i4 = 0; i4 < this.states.length; i4++) {
                this.logEmission[i4] = this.states[i4].getLogScoreFor(i, i, sequence);
            }
            int i5 = i3 % 2;
            Arrays.fill(this.numberOfSummands[1 - i5], 0);
            for (int i6 = 0; i6 < this.fwdMatrix[i3].length; i6++) {
                int numberOfChildren = this.transition.getNumberOfChildren(i3, i6);
                if (this.numberOfSummands[i5][i6] > 0) {
                    this.fwdMatrix[i3][i6] = Normalisation.getLogSum(0, this.numberOfSummands[i5][i6], this.forwardIntermediate[i5][i6]);
                    for (int i7 = 0; i7 < numberOfChildren; i7++) {
                        this.transition.fillTransitionInformation(i3, i6, i7, this.container);
                        double logScoreFor = this.transition.getLogScoreFor(i3, i6, i7, sequence, i);
                        int i8 = (i5 + this.container[2]) % 2;
                        this.forwardIntermediate[i8][this.container[1]][this.numberOfSummands[i8][this.container[1]]] = this.fwdMatrix[i3][i6] + this.logEmission[this.container[0]] + logScoreFor;
                        int[] iArr = this.numberOfSummands[i8];
                        int i9 = this.container[1];
                        iArr[i9] = iArr[i9] + 1;
                    }
                } else {
                    this.fwdMatrix[i3][i6] = Double.NEGATIVE_INFINITY;
                }
            }
            i3++;
            i++;
        }
        int i10 = i3 % 2;
        for (int i11 = 0; i11 < this.fwdMatrix[i3].length; i11++) {
            int numberOfChildren2 = this.transition.getNumberOfChildren(i3, i11);
            if (this.numberOfSummands[i10][i11] > 0) {
                this.fwdMatrix[i3][i11] = Normalisation.getLogSum(0, this.numberOfSummands[i10][i11], this.forwardIntermediate[i10][i11]);
                for (int i12 = 0; i12 < numberOfChildren2; i12++) {
                    this.transition.fillTransitionInformation(i3, i11, i12, this.container);
                    if (this.states[this.container[0]].isSilent()) {
                        double logScoreFor2 = this.transition.getLogScoreFor(i3, i11, i12, sequence, i);
                        double[] dArr = this.forwardIntermediate[i10][this.container[1]];
                        int[] iArr2 = this.numberOfSummands[i10];
                        int i13 = this.container[1];
                        int i14 = iArr2[i13];
                        iArr2[i13] = i14 + 1;
                        dArr[i14] = this.fwdMatrix[i3][i11] + logScoreFor2;
                    }
                }
            } else {
                this.fwdMatrix[i3][i11] = Double.NEGATIVE_INFINITY;
            }
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    protected void fillBwdMatrix(int i, int i2, Sequence sequence) throws Exception {
        fillBwdOrViterbiMatrix(Type.LIKELIHOOD, i, i2, 1.0d, sequence);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void fillBwdOrViterbiMatrix(Type type, int i, int i2, double d, Sequence sequence) throws Exception {
        int i3 = (i2 - i) + 1;
        boolean z = this.transition.getMaximalMarkovOrder() == 0;
        provideMatrix(1, (i2 - i) + 1);
        double computeLogScoreFromForward = type != Type.BAUM_WELCH ? Double.NaN : computeLogScoreFromForward(i3);
        for (int length = this.bwdMatrix[i3].length - 1; length >= 0; length--) {
            int numberOfChildren = this.transition.getNumberOfChildren(i3, length);
            this.numberOfSummands[0][0] = 0;
            double d2 = (z || this.finalState[this.transition.getLastContextState(i3, length)]) ? 0.0d : Double.NEGATIVE_INFINITY;
            for (int i4 = 0; i4 < numberOfChildren; i4++) {
                this.transition.fillTransitionInformation(i3, length, i4, this.container);
                if (this.states[this.container[0]].isSilent()) {
                    this.backwardIntermediate[this.numberOfSummands[0][0]] = this.bwdMatrix[i3][this.container[1]] + this.transition.getLogScoreFor(i3, length, i4, sequence, i2);
                    if (type == Type.BAUM_WELCH) {
                        ((TrainableTransition) this.transition).addToStatistic(i3, length, i4, d * Math.exp((this.fwdMatrix[i3][length] + this.backwardIntermediate[this.numberOfSummands[0][0]]) - computeLogScoreFromForward), sequence, i2);
                    }
                    int[] iArr = this.numberOfSummands[0];
                    iArr[0] = iArr[0] + 1;
                }
            }
            if (this.numberOfSummands[0][0] == 0) {
                this.bwdMatrix[i3][length] = d2;
            } else {
                this.bwdMatrix[i3][length] = type == Type.VITERBI ? Math.max(d2, ToolBox.max(0, this.numberOfSummands[0][0], this.backwardIntermediate)) : Normalisation.getLogSum(d2, Normalisation.getLogSum(0, this.numberOfSummands[0][0], this.backwardIntermediate));
            }
        }
        while (true) {
            i3--;
            if (i3 < 0) {
                return;
            }
            for (int i5 = 0; i5 < this.states.length; i5++) {
                this.logEmission[i5] = this.states[i5].getLogScoreFor(i2, i2, sequence);
            }
            for (int length2 = this.bwdMatrix[i3].length - 1; length2 >= 0; length2--) {
                int numberOfChildren2 = this.transition.getNumberOfChildren(i3, length2);
                for (int i6 = 0; i6 < numberOfChildren2; i6++) {
                    this.transition.fillTransitionInformation(i3, length2, i6, this.container);
                    this.backwardIntermediate[i6] = this.bwdMatrix[i3 + this.container[2]][this.container[1]] + this.logEmission[this.container[0]] + this.transition.getLogScoreFor(i3, length2, i6, sequence, i2);
                    if (type == Type.BAUM_WELCH) {
                        double exp = d * Math.exp((this.fwdMatrix[i3][length2] + this.backwardIntermediate[i6]) - computeLogScoreFromForward);
                        ((TrainableState) this.states[this.container[0]]).addToStatistic(i2, i2, exp, sequence);
                        ((TrainableTransition) this.transition).addToStatistic(i3, length2, i6, exp, sequence, i2);
                    }
                }
                if (numberOfChildren2 > 0) {
                    this.bwdMatrix[i3][length2] = type == Type.VITERBI ? ToolBox.max(0, numberOfChildren2, this.backwardIntermediate) : Normalisation.getLogSum(0, numberOfChildren2, this.backwardIntermediate);
                } else {
                    this.bwdMatrix[i3][length2] = Double.NEGATIVE_INFINITY;
                }
            }
            i2--;
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public Pair<IntList, Double> getViterbiPathFor(int i, int i2, Sequence sequence) throws Exception {
        IntList intList = new IntList((i2 - i) + 1);
        return new Pair<>(intList, Double.valueOf(viterbi(intList, i, i2, 0.0d, sequence)));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double viterbi(IntList intList, int i, int i2, double d, Sequence sequence) throws Exception {
        fillBwdOrViterbiMatrix(Type.VITERBI, i, i2, 0.0d, sequence);
        int i3 = (i2 - i) + 1;
        int i4 = 0;
        int i5 = 0;
        if (intList != null) {
            intList.clear();
        }
        while (i4 < i3) {
            int numberOfChildren = this.transition.getNumberOfChildren(i4, i5);
            double d2 = Double.POSITIVE_INFINITY;
            int i6 = -1000;
            int i7 = -1000;
            int i8 = -1000;
            int i9 = -1000;
            for (int i10 = 0; i10 < numberOfChildren; i10++) {
                this.transition.fillTransitionInformation(i4, i5, i10, this.container);
                double logScoreFor = ((this.bwdMatrix[i4 + this.container[2]][this.container[1]] + this.states[this.container[0]].getLogScoreFor(i, i, sequence)) + this.transition.getLogScoreFor(i4, i5, i10, sequence, i)) - this.bwdMatrix[i4][i5];
                double d3 = logScoreFor * logScoreFor;
                if (d3 < d2) {
                    i9 = i10;
                    i8 = this.container[0];
                    i7 = this.container[1];
                    i6 = this.container[2];
                    d2 = d3;
                }
            }
            if (intList == null) {
                ((TrainableTransition) this.transition).addToStatistic(i4, i5, i9, d, sequence, i);
                ((TrainableState) this.states[i8]).addToStatistic(i, i, d, sequence);
            } else {
                intList.add(i8);
            }
            i += i6;
            i4 += i6;
            i5 = i7;
        }
        while (true) {
            int numberOfChildren2 = this.transition.getNumberOfChildren(i4, i5);
            double d4 = this.finalState[this.transition.getLastContextState(i4, i5)] ? 0.0d - this.bwdMatrix[i4][i5] : Double.NEGATIVE_INFINITY;
            double d5 = d4 * d4;
            int i11 = -1000;
            int i12 = -1000;
            int i13 = -1000;
            for (int i14 = 0; i14 < numberOfChildren2; i14++) {
                this.transition.fillTransitionInformation(i4, i5, i14, this.container);
                if (this.container[2] == 0) {
                    double logScoreFor2 = ((this.bwdMatrix[i4][this.container[1]] + this.states[this.container[0]].getLogScoreFor(i, i, sequence)) + this.transition.getLogScoreFor(i4, i5, i14, sequence, i)) - this.bwdMatrix[i4][i5];
                    double d6 = logScoreFor2 * logScoreFor2;
                    if (d6 < d5) {
                        i13 = i14;
                        i12 = this.container[0];
                        i11 = this.container[1];
                        d5 = d6;
                    }
                }
            }
            if (i12 < 0) {
                return this.bwdMatrix[0][0];
            }
            if (intList == null) {
                ((TrainableTransition) this.transition).addToStatistic(i4, i5, i13, d, sequence, i);
                ((TrainableState) this.states[i12]).addToStatistic(i, i, d, sequence);
            } else {
                intList.add(i12);
            }
            i5 = i11;
        }
    }

    protected double baumWelch(int i, int i2, double d, Sequence sequence) throws Exception {
        fillFwdMatrix(i, i2, sequence);
        fillBwdOrViterbiMatrix(Type.BAUM_WELCH, i, i2, d, sequence);
        return this.bwdMatrix[0][0];
    }

    public void train(DataSet dataSet, double[] dArr) throws Exception {
        if (!(this.trainingParameter instanceof MaxHMMTrainingParameterSet)) {
            throw new IllegalArgumentException("This kind of training is currently not supported.");
        }
        Transition transition = null;
        Emission[] emissionArr = null;
        double d = Double.NEGATIVE_INFINITY;
        int numberOfStarts = this.trainingParameter.getNumberOfStarts();
        AbstractTerminationCondition terminationCondition = ((MaxHMMTrainingParameterSet) this.trainingParameter).getTerminationCondition();
        Compute compute = new Compute(this.threads, this);
        compute.setDataSet(dataSet, dArr);
        Time timeInstance = Time.getTimeInstance(this.sostream);
        for (int i = 0; i < numberOfStarts; i++) {
            this.sostream.writeln("start " + i + " ============================");
            if (!this.skipInit) {
                initialize(dataSet, dArr);
                compute.setParameters();
            }
            double d2 = Double.NEGATIVE_INFINITY;
            int i2 = 0;
            timeInstance.reset();
            while (true) {
                double d3 = d2;
                d2 = getLogPriorTerm() + compute.oneIteration();
                int i3 = i2;
                i2++;
                this.sostream.writeln(String.valueOf(i3) + "\t" + timeInstance.getElapsedTime() + "\t" + d2 + "\t" + (d2 - d3));
                if (!terminationCondition.doNextIteration(i2, d3, d2, null, null, Double.NaN, timeInstance)) {
                    break;
                } else {
                    compute.estimateFromStatistics();
                }
            }
            if (d2 > d) {
                d = d2;
                if (numberOfStarts > 1) {
                    emissionArr = (Emission[]) ArrayHandler.clone(this.emission);
                    transition = this.transition.m164clone();
                }
            }
        }
        this.sostream.writeln("best result: " + d);
        if (emissionArr != null) {
            this.emission = emissionArr;
            this.transition = transition;
            createStates();
        }
        compute.stopThreads();
    }

    /* JADX INFO: Access modifiers changed from: private */
    public double doOneStep(DataSet dataSet, double[] dArr, int i, int i2) throws Exception {
        double d;
        double baumWelch;
        double d2 = 1.0d;
        double d3 = 0.0d;
        for (int i3 = i; i3 < i2; i3++) {
            Sequence elementAt = dataSet.getElementAt(i3);
            if (dArr != null) {
                d2 = dArr[i3];
            }
            if (this.trainingParameter instanceof ViterbiParameterSet) {
                d = d3;
                baumWelch = viterbi(null, 0, elementAt.getLength() - 1, d2, elementAt);
            } else {
                if (!(this.trainingParameter instanceof BaumWelchParameterSet)) {
                    throw new IllegalArgumentException("Training mode not available.");
                }
                d = d3;
                baumWelch = baumWelch(0, elementAt.getLength() - 1, d2, elementAt);
            }
            d3 = d + baumWelch;
        }
        return d3;
    }

    protected void initialize(DataSet dataSet, double[] dArr) throws Exception {
        initializeRandomly();
    }

    public void setSkiptInit(boolean z) {
        this.skipInit = z;
    }

    public void initializeRandomly() {
        this.transition.initializeRandomly();
        for (int i = 0; i < this.emission.length; i++) {
            this.emission[i].initializeFunctionRandomly();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void resetStatistics() {
        ((TrainableTransition) this.transition).resetStatistic();
        for (int i = 0; i < this.emission.length; i++) {
            this.emission[i].resetStatistic();
        }
    }

    protected void estimateFromStatistics() {
        ((TrainableTransition) this.transition).estimateFromStatistic();
        for (int i = 0; i < this.emission.length; i++) {
            this.emission[i].estimateFromStatistic();
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public final byte getMaximalMarkovOrder() throws UnsupportedOperationException {
        return Byte.MAX_VALUE;
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [de.jstacs.results.Result[], de.jstacs.results.Result[][]] */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel, de.jstacs.sequenceScores.SequenceScore
    public ResultSet getCharacteristics() throws Exception {
        return new ResultSet((Result[][]) new Result[]{getNumericalCharacteristics().getResults(), new Result[]{new StorableResult("model", "the xml representation of the model", this)}});
    }

    public String getInstanceName() {
        return "HMM(" + this.transition.getMaximalMarkovOrder() + ") " + this.trainingParameter.getClass().getSimpleName();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel, de.jstacs.sequenceScores.SequenceScore
    public double[] getLogScoreFor(DataSet dataSet) throws Exception {
        double[] dArr = new double[dataSet.getNumberOfElements()];
        getLogScoreFor(dataSet, dArr);
        return dArr;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.AbstractTrainableStatisticalModel, de.jstacs.sequenceScores.SequenceScore
    public void getLogScoreFor(DataSet dataSet, double[] dArr) throws Exception {
        if (!dataSet.getAlphabetContainer().checkConsistency(getAlphabetContainer())) {
            throw new WrongAlphabetException("The AlphabetContainer of the data set and the model do not match.");
        }
        int length = getLength();
        int elementLength = dataSet.getElementLength();
        if (length != 0 && elementLength != length) {
            throw new WrongLengthException("The length of the data set and the model do not match.");
        }
        for (int i = 0; i < dataSet.getNumberOfElements(); i++) {
            Sequence elementAt = dataSet.getElementAt(i);
            dArr[i] = logProb(0, elementAt.getLength() - 1, elementAt);
        }
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public NumericalResultSet getNumericalCharacteristics() throws Exception {
        return null;
    }

    public boolean isInitialized() {
        return true;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM
    public void finalize() throws Throwable {
        this.emission = null;
        this.emissionIdx = null;
        this.forward = null;
        this.name = null;
        this.container = null;
        this.numberOfSummands = null;
        this.logEmission = null;
        this.forwardIntermediate = null;
        this.backwardIntermediate = null;
        super.finalize();
    }

    public void samplePath(IntList intList, int i, int i2, Sequence sequence) throws Exception {
        int i3;
        fillBwdMatrix(i, i2, sequence);
        int i4 = 0;
        int i5 = 0;
        provideMatrix(0, (i2 - i) + 1);
        intList.clear();
        while (i <= i2) {
            int numberOfChildren = this.transition.getNumberOfChildren(i4, i5);
            for (int i6 = 0; i6 < numberOfChildren; i6++) {
                this.transition.fillTransitionInformation(i4, i5, i6, this.container);
                this.backwardIntermediate[i6] = this.bwdMatrix[i4 + this.container[2]][this.container[1]] + this.states[this.container[0]].getLogScoreFor(i, i, sequence) + this.transition.getLogScoreFor(i4, i5, i6, sequence, i);
            }
            Normalisation.logSumNormalisation(this.backwardIntermediate, 0, numberOfChildren);
            this.transition.fillTransitionInformation(i4, i5, AbstractMixtureTrainSM.draw(this.backwardIntermediate, 0), this.container);
            intList.add(this.container[0]);
            i5 = this.container[1];
            i4 += this.container[2];
            i += this.container[2];
        }
        int i7 = intList.get(intList.length() - 1);
        while (true) {
            int i8 = i7;
            int numberOfChildren2 = this.transition.getNumberOfChildren(i4, i5);
            this.stateList.clear();
            for (int i9 = 0; i9 < numberOfChildren2; i9++) {
                this.transition.fillTransitionInformation(i4, i5, i9, this.container);
                if (this.states[this.container[0]].isSilent()) {
                    this.backwardIntermediate[this.stateList.length()] = this.bwdMatrix[i4][this.container[1]] + this.transition.getLogScoreFor(i4, i5, i9, sequence, i);
                    this.stateList.add(i9);
                }
            }
            if (this.finalState[i8]) {
                this.backwardIntermediate[this.stateList.length()] = 0.0d;
                i3 = 1;
            } else {
                i3 = 0;
            }
            Normalisation.logSumNormalisation(this.backwardIntermediate, 0, this.stateList.length() + i3);
            int draw = AbstractMixtureTrainSM.draw(this.backwardIntermediate, 0);
            if (i3 == 1 && draw == this.stateList.length()) {
                return;
            }
            this.transition.fillTransitionInformation(i4, i5, this.stateList.get(draw), this.container);
            intList.add(this.container[0]);
            i5 = this.container[1];
            i4 += this.container[2];
            i += this.container[2];
            i7 = this.container[0];
        }
    }

    private double computeLogScoreFromForward(int i) {
        double d = Double.NEGATIVE_INFINITY;
        if (this.transition.getMaximalMarkovOrder() > 0) {
            for (int i2 = 0; i2 < this.fwdMatrix[i].length; i2++) {
                if (this.finalState[this.transition.getLastContextState(i, i2)]) {
                    d = Normalisation.getLogSum(d, this.fwdMatrix[i][i2]);
                }
            }
        } else {
            d = Normalisation.getLogSum(this.fwdMatrix[i]);
        }
        return d;
    }

    public Emission[] getEmissions() throws CloneNotSupportedException {
        return (Emission[]) ArrayHandler.clone(this.emission);
    }

    public TransitionElement[] getTransisionElements() throws CloneNotSupportedException {
        return ((HigherOrderTransition) this.transition).getTransisionElements();
    }

    public int[] getEmissionIndexes() {
        return (int[]) this.emissionIdx.clone();
    }

    public String[] getNames() {
        return (String[]) this.name.clone();
    }

    public HMMTrainingParameterSet getTrainingParams() throws CloneNotSupportedException {
        return (HMMTrainingParameterSet) this.trainingParameter.m106clone();
    }
}
