/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models;

import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LogGenDisMixFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongLengthException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.SequenceAnnotation;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.results.StorableResult;
import de.jstacs.sequenceScores.statisticalModels.differentiable.SamplingDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.trainable.DifferentiableStatisticalModelWrapperTrainSM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.AbstractHMM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.FastDifferentiableHigherOrderHMM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.DifferentiableState;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.SimpleDifferentiableState;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.DifferentiableEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.DifferentiableSMWrapperEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.UniformEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.filter.Filter;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.BaumWelchParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.HMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.MaxHMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.NumericalHMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.ViterbiParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.BasicHigherOrderTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.DifferentiableTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.TransitionElement;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.ToolBox;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.Locale;
import java.util.Random;
import javax.naming.OperationNotSupportedException;

public class DifferentiableHigherOrderHMM
extends HigherOrderHMM
implements SamplingDifferentiableStatisticalModel {
    protected int numberOfParameters;
    protected double ess;
    protected int[][] index;
    protected double[][][] gradient;
    protected double[][][] gradient2;
    protected double[] logScore;
    protected double[] prop;
    protected IntList[] indicesState;
    protected IntList[] indicesTransition;
    protected DoubleList[] partDerState;
    protected DoubleList[] partDerTransition;
    private NumericalHMMTrainingParameterSet.TrainingType training;
    private HigherOrderHMM.Type score = HigherOrderHMM.Type.LIKELIHOOD;
    private boolean train;
    private double[] forwardIntermediate;
    private IntList childrenBW;
    private IntList childrenFW;

    public DifferentiableHigherOrderHMM(MaxHMMTrainingParameterSet trainingParameterSet, String[] name, int[] emissionIdx, boolean[] forward, DifferentiableEmission[] emission, double ess, TransitionElement ... te) throws Exception {
        this(null, null, trainingParameterSet, name, null, emissionIdx, forward, emission, ess, null, te);
    }

    public DifferentiableHigherOrderHMM(String type, int[][] statesGroups, MaxHMMTrainingParameterSet trainingParameterSet, String[] name, Filter[] filter, int[] emissionIdx, boolean[] forward, DifferentiableEmission[] emission, double ess, int[] transIndex, TransitionElement ... te) throws Exception {
        super(type, statesGroups, trainingParameterSet, name, filter, emissionIdx, forward, emission, transIndex, (BasicHigherOrderTransition.AbstractTransitionElement[])te);
        this.getOffsets();
        if (ess < 0.0) {
            throw new IllegalArgumentException();
        }
        this.ess = ess;
        this.childrenBW = new IntList();
        this.childrenFW = new IntList();
        this.forwardIntermediate = new double[this.getNumberOfStates()];
    }

    @Override
    protected void setTrainingParameter(HMMTrainingParameterSet trainingParameterSet) throws CloneNotSupportedException {
        super.setTrainingParameter(trainingParameterSet);
        if (this.trainingParameter instanceof NumericalHMMTrainingParameterSet) {
            this.training = ((NumericalHMMTrainingParameterSet)this.trainingParameter).getTrainingType();
        } else if (this.trainingParameter instanceof ViterbiParameterSet) {
            this.training = NumericalHMMTrainingParameterSet.TrainingType.VITERBI;
        } else if (this.trainingParameter instanceof BaumWelchParameterSet) {
            this.training = NumericalHMMTrainingParameterSet.TrainingType.LIKELIHOOD;
        }
        this.setTrain(false);
    }

    public DifferentiableHigherOrderHMM(StringBuffer xml) throws NonParsableException {
        super(xml);
        this.getOffsets();
        this.childrenBW = new IntList();
        this.childrenFW = new IntList();
        this.forwardIntermediate = new double[this.getNumberOfStates()];
    }

    @Override
    protected void appendFurtherInformation(StringBuffer xml) {
        super.appendFurtherInformation(xml);
        XMLParser.appendObjectWithTags(xml, this.ess, "ess");
    }

    @Override
    protected void extractFurtherInformation(StringBuffer xml) throws NonParsableException {
        super.extractFurtherInformation(xml);
        this.ess = XMLParser.extractObjectForTags(xml, "ess", Double.TYPE);
    }

    @Override
    protected void createHelperVariables() {
        if (this.container == null) {
            int maxOrder = this.transition.getMaximalMarkovOrder();
            int anz = 0;
            int i = 0;
            while (i <= maxOrder) {
                anz = Math.max(anz, this.transition.getNumberOfIndexes(i));
                ++i;
            }
            if (this.gradient == null || this.gradient[0].length != anz || this.gradient[0][0].length != this.numberOfParameters) {
                this.gradient = new double[2][anz][this.numberOfParameters];
                this.gradient2 = new double[2][anz][this.numberOfParameters];
                this.index = new int[4][anz];
            }
            if (this.indicesState == null) {
                anz = this.transition.getMaximalNumberOfChildren();
                try {
                    this.indicesState = (IntList[])ArrayHandler.createArrayOf((Cloneable)new IntList(), (int)this.states.length);
                    this.partDerState = (DoubleList[])ArrayHandler.createArrayOf((Cloneable)new DoubleList(), (int)this.states.length);
                    this.indicesTransition = (IntList[])ArrayHandler.createArrayOf((Cloneable)new IntList(), (int)anz);
                    this.partDerTransition = (DoubleList[])ArrayHandler.createArrayOf((Cloneable)new DoubleList(), (int)anz);
                }
                catch (CloneNotSupportedException cnse) {
                    throw DifferentiableHigherOrderHMM.getRunTimeException(cnse);
                }
            }
            this.logScore = new double[2];
            this.prop = new double[2];
        }
        super.createHelperVariables();
    }

    @Override
    protected void createStates() {
        this.states = new SimpleDifferentiableState[this.emissionIdx.length];
        int i = 0;
        while (i < this.emissionIdx.length) {
            this.states[i] = new SimpleDifferentiableState((DifferentiableEmission)this.emission[this.emissionIdx[i]], this.name[i], this.forward[i]);
            ++i;
        }
    }

    @Override
    public DifferentiableHigherOrderHMM clone() throws CloneNotSupportedException {
        double[][][] grad = this.gradient;
        this.gradient = null;
        IntList[] ind = this.indicesState;
        this.indicesState = null;
        DifferentiableHigherOrderHMM clone = (DifferentiableHigherOrderHMM)super.clone();
        clone.forwardIntermediate = (double[])this.forwardIntermediate.clone();
        clone.childrenBW = this.childrenBW.clone();
        clone.childrenFW = this.childrenFW.clone();
        this.gradient = grad;
        this.indicesState = ind;
        return clone;
    }

    @Override
    public double getESS() {
        return this.ess;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        int e = 0;
        while (e < this.emission.length) {
            ((DifferentiableEmission)this.emission[e]).addGradientOfLogPriorTerm(grad, start);
            ++e;
        }
        ((DifferentiableTransition)this.transition).addGradientForLogPriorTerm(grad, start);
    }

    private void getOffsets() {
        this.numberOfParameters = 0;
        int e = 0;
        while (e < this.emission.length) {
            this.numberOfParameters = ((DifferentiableEmission)this.emission[e]).setParameterOffset(this.numberOfParameters);
            if (this.numberOfParameters == -1) {
                return;
            }
            ++e;
        }
        this.numberOfParameters = ((DifferentiableTransition)this.transition).setParameterOffset(this.numberOfParameters);
        if (this.numberOfParameters == -1) {
            return;
        }
        this.createHelperVariables();
    }

    @Override
    public int getNumberOfParameters() {
        return this.numberOfParameters;
    }

    @Override
    public int getNumberOfRecommendedStarts() {
        return this.trainingParameter.getNumberOfStarts();
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        int n = this.getNumberOfParameters();
        if (n != -1) {
            double[] params = new double[n];
            int e = 0;
            while (e < this.emission.length) {
                ((DifferentiableEmission)this.emission[e]).fillCurrentParameter(params);
                ++e;
            }
            ((DifferentiableTransition)this.transition).fillParameters(params);
            return params;
        }
        throw new IllegalArgumentException();
    }

    @Override
    public boolean isInitialized() {
        return true;
    }

    @Override
    public void setParameters(double[] params, int start) {
        int e = 0;
        while (e < this.emission.length) {
            ((DifferentiableEmission)this.emission[e]).setParameter(params, start);
            ++e;
        }
        ((DifferentiableTransition)this.transition).setParameters(params, start);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        if (this.skipInit) {
            return;
        }
        this.initializeRandomly();
        this.getOffsets();
    }

    public void initializeTransitionRandomly() throws Exception {
        this.transition.initializeRandomly();
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        if (this.skipInit) {
            return;
        }
        if (this.trainingParameter instanceof NumericalHMMTrainingParameterSet) {
            boolean[] diffEM = new boolean[this.emission.length];
            boolean sub = false;
            int i = 0;
            while (i < this.emission.length) {
                diffEM[i] = this.emission[i] instanceof DifferentiableSMWrapperEmission;
                sub |= diffEM[i];
                ++i;
            }
            try {
                NumericalHMMTrainingParameterSet trainingParameterSet = (NumericalHMMTrainingParameterSet)this.trainingParameter;
                NumericalHMMTrainingParameterSet.TrainingType tType = trainingParameterSet.getTrainingType();
                if (tType.isViterbiLike()) {
                    this.trainingParameter = new ViterbiParameterSet(1, ((MaxHMMTrainingParameterSet)this.trainingParameter).getTerminationCondition(), ((NumericalHMMTrainingParameterSet)this.trainingParameter).getNumberOfThreads());
                    this.training = NumericalHMMTrainingParameterSet.TrainingType.VITERBI;
                } else {
                    this.trainingParameter = new BaumWelchParameterSet(1, ((MaxHMMTrainingParameterSet)this.trainingParameter).getTerminationCondition(), ((NumericalHMMTrainingParameterSet)this.trainingParameter).getNumberOfThreads());
                    this.training = NumericalHMMTrainingParameterSet.TrainingType.LIKELIHOOD;
                }
                if (!sub) {
                    super.train(data[index], weights == null ? null : weights[index]);
                } else {
                    int i2;
                    DifferentiableEmission[] dEmission = new DifferentiableEmission[this.emission.length];
                    ArrayList[] seqs = new ArrayList[this.emission.length];
                    DoubleList[] subW = new DoubleList[this.emission.length];
                    int[] l = new int[this.emission.length];
                    int[] offset = new int[l.length];
                    int i3 = 0;
                    while (i3 < this.emission.length) {
                        if (diffEM[i3]) {
                            dEmission[i3] = new UniformEmission(this.getAlphabetContainer());
                            seqs[i3] = new ArrayList();
                            subW[i3] = new DoubleList();
                            DifferentiableSMWrapperEmission help = (DifferentiableSMWrapperEmission)this.emission[i3];
                            l[i3] = help.getLength();
                            offset[i3] = help.getOffset();
                        } else {
                            dEmission[i3] = (DifferentiableEmission)this.emission[i3];
                        }
                        ++i3;
                    }
                    DifferentiableHigherOrderHMM simple = this instanceof FastDifferentiableHigherOrderHMM ? new FastDifferentiableHigherOrderHMM(this.type, null, (MaxHMMTrainingParameterSet)this.trainingParameter, this.name, this.filter, this.emissionIdx, dEmission, this.ess, this.transIndex, this.getTransitionElements()) : new DifferentiableHigherOrderHMM(this.type, null, (MaxHMMTrainingParameterSet)this.trainingParameter, this.name, this.filter, this.emissionIdx, this.forward, dEmission, this.ess, this.transIndex, this.getTransitionElements());
                    simple.defContext = this.defContext;
                    simple.preComputedContext = this.preComputedContext;
                    simple.train(data[index], weights == null ? null : weights[index]);
                    if (this.type != null) {
                        Random random = new Random();
                        i2 = 0;
                        while (i2 < data[index].getNumberOfElements()) {
                            Sequence seq = data[index].getElementAt(i2);
                            double w = weights == null || weights[index] == null ? 1.0 : weights[index][i2];
                            SequenceAnnotation sa = seq.getSequenceAnnotationByType(this.type, 0);
                            if (sa != null) {
                                StorableResult res = (StorableResult)sa.getResultAt(random.nextInt(sa.getNumberOfResults()));
                                int[] allowedStatesGroup = ((AbstractHMM.AllowedStatesGroups)res.getResultInstance()).groups;
                                int j = 0;
                                while (j < allowedStatesGroup.length) {
                                    int[] states = this.statesGroups[allowedStatesGroup[j]];
                                    if (states.length == 1) {
                                        int e = this.emissionIdx[states[0]];
                                        int st = j + offset[e];
                                        if (seqs[e] != null && st >= 0 && st + l[e] < seq.getLength()) {
                                            seqs[e].add(seq.getSubSequence(st, l[e]));
                                            subW[e].add(w);
                                        }
                                    }
                                    ++j;
                                }
                            }
                            ++i2;
                        }
                    }
                    this.transition.setParameters(simple.transition);
                    NumberFormat nf = NumberFormat.getInstance(Locale.US);
                    nf.setMaximumFractionDigits(3);
                    i2 = 0;
                    while (i2 < this.emission.length) {
                        if (diffEM[i2]) {
                            ((DifferentiableSMWrapperEmission)this.emission[i2]).initializeUniformly();
                        } else {
                            this.emission[i2].setParameters(simple.emission[i2]);
                        }
                        ++i2;
                    }
                }
                this.trainingParameter = trainingParameterSet;
                this.training = tType;
            }
            catch (Exception e) {
                e.printStackTrace();
                this.sostream.writeln("Problem while initialization from data. " + e.getClass().getSimpleName() + ": " + e.getCause());
                this.initializeFunctionRandomly(freeParams);
            }
        } else {
            this.initializeFunctionRandomly(freeParams);
        }
        System.out.println("initialization:\n" + this);
    }

    @Override
    public void train(DataSet data, double[] weights) throws Exception {
        if (this.trainingParameter instanceof NumericalHMMTrainingParameterSet) {
            this.setTrain(true);
            NumericalHMMTrainingParameterSet params = (NumericalHMMTrainingParameterSet)this.trainingParameter;
            DoesNothingLogPrior p = DoesNothingLogPrior.defaultInstance;
            DifferentiableStatisticalModelWrapperTrainSM model = new DifferentiableStatisticalModelWrapperTrainSM(this, params.getNumberOfThreads(), params.getAlgorithm(), params.getTerminationCondition(), params.getLineEps(), params.getStartDistance(), params.randomInitialization(), p);
            model.setOutputStream(this.sostream);
            model.train(data, weights);
            DifferentiableHigherOrderHMM hmm = (DifferentiableHigherOrderHMM)model.getFunction();
            this.emission = hmm.emission;
            this.createStates();
            this.transition = hmm.transition;
            this.setTrain(false);
        } else {
            super.train(data, weights);
        }
    }

    @Override
    public boolean isNormalized() {
        boolean isNormalized = ((DifferentiableTransition)this.transition).isNormalized();
        System.out.println("trans\t" + isNormalized);
        if (isNormalized) {
            int i = 0;
            while (i < this.emission.length && (isNormalized = ((DifferentiableEmission)this.emission[i]).isNormalized())) {
                System.out.println("em\t" + i + "/" + this.emission.length + "\t" + isNormalized);
                ++i;
            }
        }
        return isNormalized;
    }

    @Override
    public double getLogNormalizationConstant() {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double getInitialClassParam(double classProb) {
        return Math.log(classProb);
    }

    @Override
    public double getLogScoreFor(Sequence seq) {
        return this.getLogScoreFor(seq, 0);
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        return this.getLogScoreFor(seq, start, seq.getLength() - 1);
    }

    /*
     * Enabled aggressive block sorting
     */
    @Override
    public double getLogScoreFor(Sequence seq, int start, int end) {
        int[] allowedStatesGroup = null;
        double s = Double.NEGATIVE_INFINITY;
        if (this.train && this.type != null) {
            SequenceAnnotation sa = seq.getSequenceAnnotationByType(this.type, 0);
            if (sa != null) {
                switch (this.training) {
                    case VITERBI: 
                    case DISCRIMINATIVE_VITERBI: {
                        int i = 0;
                        while (i < sa.getNumberOfResults()) {
                            StorableResult res = (StorableResult)sa.getResultAt(i);
                            AbstractHMM.AllowedStatesGroups asg = (AbstractHMM.AllowedStatesGroups)res.getResultInstance();
                            allowedStatesGroup = asg.groups;
                            double current = this.logProb(start, end, seq, allowedStatesGroup, this.score);
                            if (current > s) {
                                s = current;
                            }
                            ++i;
                        }
                        break;
                    }
                    case DISCRIMINATIVE_VITERBI2: {
                        int i = 0;
                        while (true) {
                            if (i >= sa.getNumberOfResults()) {
                                return s;
                            }
                            StorableResult res = (StorableResult)sa.getResultAt(i);
                            AbstractHMM.AllowedStatesGroups asg = (AbstractHMM.AllowedStatesGroups)res.getResultInstance();
                            allowedStatesGroup = asg.groups;
                            double current = this.fill(seq, allowedStatesGroup, start, end, null, null);
                            if (current > s) {
                                s = current;
                            }
                            ++i;
                        }
                    }
                    case LIKELIHOOD: 
                    case DISCRIMINATIVE_LIKELIHOOD: {
                        StorableResult res = (StorableResult)sa.getResultAt(0);
                        AbstractHMM.AllowedStatesGroups asg = (AbstractHMM.AllowedStatesGroups)res.getResultInstance();
                        allowedStatesGroup = asg.groups;
                        s = this.logProb(start, end, seq, allowedStatesGroup, this.score);
                    }
                    default: {
                        break;
                    }
                }
            }
        } else {
            s = this.logProb(start, end, seq, allowedStatesGroup, this.score);
        }
        if (this.train && this.training.isDiscrimnative()) {
            s -= this.logProb(start, end, seq, null, HigherOrderHMM.Type.LIKELIHOOD);
        }
        return s;
    }

    public double getLogScoreFor(Sequence seq, int start, int end, int[] allowedStatesGroup) {
        return this.logProb(start, end, seq, allowedStatesGroup);
    }

    @Override
    protected double logProb(int startpos, int endpos, Sequence sequence) {
        return this.logProb(startpos, endpos, sequence, null);
    }

    @Override
    protected double logProb(int startpos, int endpos, Sequence sequence, int[] allowedStatesGroups) {
        return this.logProb(startpos, endpos, sequence, allowedStatesGroups, this.score);
    }

    protected double logProb(int startpos, int endpos, Sequence sequence, int[] allowedStatesGroups, HigherOrderHMM.Type score) {
        try {
            this.fillBwdOrViterbiMatrix(score, startpos, endpos, 0.0, sequence, allowedStatesGroups, false);
        }
        catch (Exception e) {
            throw DifferentiableHigherOrderHMM.getRunTimeException(e);
        }
        return this.bwdMatrix[0][0];
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, IntList indices, DoubleList partialDer) {
        return this.getLogScoreAndPartialDerivation(seq, 0, indices, partialDer);
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int startPos, IntList indices, DoubleList partialDer) {
        return this.getLogScoreAndPartialDerivation(seq, startPos, seq.getLength() - 1, indices, partialDer);
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int startPos, int endPos, IntList indices, DoubleList partialDer) {
        SequenceAnnotation sa;
        int[] allowedStatesGroup = null;
        if (this.train && this.type != null && (sa = seq.getSequenceAnnotationByType(this.type, 0)) != null) {
            int best = 0;
            if (this.training.isViterbiLike()) {
                double b = Double.NEGATIVE_INFINITY;
                int i = 0;
                while (i < sa.getNumberOfResults()) {
                    double current;
                    StorableResult res = (StorableResult)sa.getResultAt(i);
                    AbstractHMM.AllowedStatesGroups asg = (AbstractHMM.AllowedStatesGroups)res.getResultInstance();
                    allowedStatesGroup = asg.groups;
                    double d = current = this.training == NumericalHMMTrainingParameterSet.TrainingType.DISCRIMINATIVE_VITERBI2 ? this.fill(seq, allowedStatesGroup, startPos, endPos, null, null) : this.logProb(startPos, endPos, seq, allowedStatesGroup, this.score);
                    if (current > b) {
                        b = current;
                        best = i;
                    }
                    ++i;
                }
            }
            StorableResult res = (StorableResult)sa.getResultAt(best);
            AbstractHMM.AllowedStatesGroups asg = (AbstractHMM.AllowedStatesGroups)res.getResultInstance();
            allowedStatesGroup = asg.groups;
            if (this.training == NumericalHMMTrainingParameterSet.TrainingType.DISCRIMINATIVE_VITERBI2) {
                return this.fill(seq, allowedStatesGroup, startPos, endPos, indices, partialDer);
            }
        }
        double s = this.logScoreAndPartialDerivation(seq, startPos, endPos, indices, partialDer, allowedStatesGroup, this.score, 1.0);
        if (this.train && this.training.isDiscrimnative()) {
            s -= this.logScoreAndPartialDerivation(seq, startPos, endPos, indices, partialDer, null, HigherOrderHMM.Type.LIKELIHOOD, -1.0);
        }
        return s;
    }

    protected double logScoreAndPartialDerivation(Sequence seq, int startPos, int endPos, IntList indices, DoubleList partialDer, int[] allowedStatesGroup, HigherOrderHMM.Type score, double factor) {
        try {
            int children;
            int n;
            int context;
            int maxOrder = this.transition.getMaximalMarkovOrder();
            boolean zero = maxOrder == 0;
            int l = endPos - startPos + 1;
            this.provideMatrix(1, endPos - startPos + 1);
            int idx2 = 0;
            while (idx2 < this.gradient[1].length) {
                Arrays.fill(this.gradient[0][idx2], 0.0);
                Arrays.fill(this.gradient[1][idx2], 0.0);
                ++idx2;
            }
            DifferentiableTransition diffTransition = (DifferentiableTransition)this.transition;
            int stateID = 0;
            while (stateID < this.states.length) {
                this.indicesState[stateID].clear();
                this.partDerState[stateID].clear();
                ++stateID;
            }
            this.fillFilter(endPos, seq);
            int[] allContext = this.getAllowedContext(endPos + 1, startPos, allowedStatesGroup, maxOrder);
            Arrays.fill(this.bwdMatrix[l], zero ? 0.0 : Double.NEGATIVE_INFINITY);
            int x = allContext.length - 1;
            while (x >= 0) {
                context = allContext[x];
                n = this.transition.getNumberOfChildren(l, context);
                children = 0;
                double val = zero || this.finalState[this.transition.getLastContextState(l, context)] ? 0.0 : Double.NEGATIVE_INFINITY;
                stateID = 0;
                while (stateID < n) {
                    this.transition.fillTransitionInformation(l, context, stateID, this.container);
                    if (this.filterRes[this.container[0]] && this.states[this.container[0]].isSilent()) {
                        this.indicesTransition[children].clear();
                        this.partDerTransition[children].clear();
                        this.backwardIntermediate[children] = this.bwdMatrix[l][this.container[1]] + diffTransition.getLogScoreAndPartialDerivation(l, context, stateID, this.indicesTransition[children], this.partDerTransition[children], seq, endPos);
                        if (this.backwardIntermediate[children] != Double.NEGATIVE_INFINITY) {
                            this.index[0][children] = this.container[0];
                            this.index[1][children] = this.container[1];
                            this.index[2][children] = this.container[2];
                            ++children;
                        }
                    }
                    ++stateID;
                }
                this.merge(this.bwdMatrix, this.backwardIntermediate, children, l, context, val, score);
                --x;
            }
            while (--l >= 0) {
                this.fillLogEmissionAndPartialDer(endPos, seq, true);
                this.fillFilter(endPos, seq);
                allContext = this.getAllowedContext(endPos, startPos, allowedStatesGroup, maxOrder);
                x = allContext.length - 1;
                while (x >= 0) {
                    context = allContext[x];
                    n = this.transition.getNumberOfChildren(l, context);
                    children = 0;
                    stateID = 0;
                    while (stateID < n) {
                        this.indicesTransition[children].clear();
                        this.partDerTransition[children].clear();
                        this.transition.fillTransitionInformation(l, context, stateID, this.container);
                        if (this.filterRes[this.container[0]]) {
                            this.backwardIntermediate[children] = this.bwdMatrix[l + this.container[2]][this.container[1]] + this.logEmission[this.getIndex(this.container[0])] + diffTransition.getLogScoreAndPartialDerivation(l, context, stateID, this.indicesTransition[children], this.partDerTransition[children], seq, endPos);
                            if (this.backwardIntermediate[children] != Double.NEGATIVE_INFINITY) {
                                this.index[0][children] = this.container[0];
                                this.index[1][children] = this.container[1];
                                this.index[2][children] = this.container[2];
                                ++children;
                            }
                        }
                        ++stateID;
                    }
                    this.merge(this.bwdMatrix, this.backwardIntermediate, children, l, context, Double.NEGATIVE_INFINITY, score);
                    --x;
                }
                --endPos;
                if (l - 1 < 0) continue;
                Arrays.fill(this.bwdMatrix[l - 1], Double.NEGATIVE_INFINITY);
            }
            int p = 0;
            while (p < this.numberOfParameters) {
                if (this.gradient[0][0][p] != 0.0) {
                    indices.add(p);
                    partialDer.add(factor * this.gradient[0][0][p]);
                }
                ++p;
            }
            return this.bwdMatrix[0][0];
        }
        catch (Exception e) {
            throw DifferentiableHigherOrderHMM.getRunTimeException(e);
        }
    }

    public double test(Sequence seq, int[] allowedStatesGroup, int startPos, int endPos) {
        return this.fill(seq, allowedStatesGroup, startPos, endPos, null, null);
    }

    public double fill(Sequence seq, int[] allowedStatesGroup, int startPos, int endPos, IntList indices, DoubleList partialDer) {
        try {
            double emTrans;
            boolean allowedState;
            int n;
            int context;
            int stateID;
            int maxOrder = this.transition.getMaximalMarkovOrder();
            boolean zero = maxOrder == 0;
            int l = endPos - startPos + 1;
            this.provideMatrix(0, endPos - startPos + 1);
            this.provideMatrix(1, endPos - startPos + 1);
            DifferentiableTransition diffTransition = (DifferentiableTransition)this.transition;
            if (indices != null) {
                int idx2 = 0;
                while (idx2 < this.gradient[1].length) {
                    Arrays.fill(this.gradient[0][idx2], 0.0);
                    Arrays.fill(this.gradient[1][idx2], 0.0);
                    Arrays.fill(this.gradient2[0][idx2], 0.0);
                    Arrays.fill(this.gradient2[1][idx2], 0.0);
                    ++idx2;
                }
                stateID = 0;
                while (stateID < this.states.length) {
                    this.indicesState[stateID].clear();
                    this.partDerState[stateID].clear();
                    ++stateID;
                }
            }
            this.fillFilter(endPos, seq);
            int[] all = this.getAllowedContext(endPos + 1, startPos, null, maxOrder);
            int[] allowed = this.getAllowedContext(endPos + 1, startPos, allowedStatesGroup, maxOrder);
            int a = allowed.length - 1;
            Arrays.fill(this.bwdMatrix[l], zero ? 0.0 : Double.NEGATIVE_INFINITY);
            Arrays.fill(this.fwdMatrix[l], zero ? 0.0 : Double.NEGATIVE_INFINITY);
            int x = all.length - 1;
            while (x >= 0) {
                context = all[x];
                n = this.transition.getNumberOfChildren(l, context);
                allowedState = a >= 0 ? context == allowed[a] : false;
                this.childrenBW.clear();
                this.childrenFW.clear();
                double val = zero || this.finalState[this.transition.getLastContextState(l, context)] ? 0.0 : Double.NEGATIVE_INFINITY;
                stateID = 0;
                while (stateID < n) {
                    this.transition.fillTransitionInformation(l, context, stateID, this.container);
                    if (this.filterRes[this.container[0]] && this.states[this.container[0]].isSilent()) {
                        emTrans = 0.0;
                        if (indices != null) {
                            this.indicesTransition[stateID].clear();
                            this.partDerTransition[stateID].clear();
                            emTrans += diffTransition.getLogScoreAndPartialDerivation(l, context, stateID, this.indicesTransition[stateID], this.partDerTransition[stateID], seq, endPos);
                        } else {
                            emTrans += diffTransition.getLogScoreFor(l, context, stateID, seq, endPos);
                        }
                        if (allowedState) {
                            this.backwardIntermediate[stateID] = emTrans + this.bwdMatrix[l][this.container[1]];
                            this.childrenBW.add(stateID);
                            this.forwardIntermediate[stateID] = Double.NEGATIVE_INFINITY;
                        } else {
                            this.forwardIntermediate[stateID] = emTrans + this.fwdMatrix[l][this.container[1]];
                        }
                        this.childrenFW.add(stateID);
                        this.index[0][stateID] = this.container[0];
                        this.index[1][stateID] = this.container[1];
                        this.index[2][stateID] = this.container[2];
                        this.index[3][stateID] = 0;
                    }
                    ++stateID;
                }
                this.bwdMatrix[l][context] = this.max(this.backwardIntermediate, this.childrenBW, l, context, val, indices == null ? null : this.gradient, null);
                this.fwdMatrix[l][context] = this.max(this.forwardIntermediate, this.childrenFW, l, context, Double.NEGATIVE_INFINITY, indices == null ? null : this.gradient2, this.gradient);
                if (allowedState) {
                    --a;
                }
                --x;
            }
            while (--l >= 0) {
                this.fillLogEmissionAndPartialDer(endPos, seq, indices != null);
                this.fillFilter(endPos, seq);
                all = this.getAllowedContext(endPos, startPos, null, maxOrder);
                allowed = this.getAllowedContext(endPos, startPos, allowedStatesGroup, maxOrder);
                a = allowed.length - 1;
                x = all.length - 1;
                while (x >= 0) {
                    context = all[x];
                    n = this.transition.getNumberOfChildren(l, context);
                    allowedState = a >= 0 ? context == allowed[a] : false;
                    this.childrenBW.clear();
                    this.childrenFW.clear();
                    stateID = 0;
                    while (stateID < n) {
                        this.transition.fillTransitionInformation(l, context, stateID, this.container);
                        if (this.filterRes[this.container[0]]) {
                            emTrans = this.logEmission[this.getIndex(this.container[0])];
                            if (indices != null) {
                                this.indicesTransition[stateID].clear();
                                this.partDerTransition[stateID].clear();
                                emTrans += diffTransition.getLogScoreAndPartialDerivation(l, context, stateID, this.indicesTransition[stateID], this.partDerTransition[stateID], seq, endPos);
                            } else {
                                emTrans += diffTransition.getLogScoreFor(l, context, stateID, seq, endPos);
                            }
                            this.index[0][stateID] = this.container[0];
                            this.index[1][stateID] = this.container[1];
                            this.index[2][stateID] = this.container[2];
                            this.index[3][stateID] = 0;
                            if (allowedState) {
                                this.backwardIntermediate[stateID] = emTrans + this.bwdMatrix[l + this.container[2]][this.container[1]];
                                this.childrenBW.add(stateID);
                                this.forwardIntermediate[stateID] = emTrans + this.fwdMatrix[l + this.container[2]][this.container[1]];
                            } else if (this.fwdMatrix[l + this.container[2]][this.container[1]] > this.bwdMatrix[l + this.container[2]][this.container[1]]) {
                                this.forwardIntermediate[stateID] = emTrans + this.fwdMatrix[l + this.container[2]][this.container[1]];
                            } else {
                                this.index[3][stateID] = 1;
                                this.forwardIntermediate[stateID] = emTrans + this.bwdMatrix[l + this.container[2]][this.container[1]];
                            }
                            this.childrenFW.add(stateID);
                        }
                        ++stateID;
                    }
                    this.bwdMatrix[l][context] = this.max(this.backwardIntermediate, this.childrenBW, l, context, Double.NEGATIVE_INFINITY, indices == null ? null : this.gradient, null);
                    this.fwdMatrix[l][context] = this.max(this.forwardIntermediate, this.childrenFW, l, context, Double.NEGATIVE_INFINITY, indices == null ? null : this.gradient2, this.gradient);
                    if (allowedState) {
                        --a;
                    }
                    --x;
                }
                --endPos;
            }
            this.logScore[0] = this.bwdMatrix[0][0];
            this.logScore[1] = this.fwdMatrix[0][0];
            double ls = Normalisation.logSumNormalisation(this.logScore, 0, 2, this.prop, 0);
            if (indices != null) {
                int p = 0;
                while (p < this.numberOfParameters) {
                    double v = this.prop[1] * (this.gradient[0][0][p] - this.gradient2[0][0][p]);
                    if (v != 0.0) {
                        indices.add(p);
                        partialDer.add(v);
                    }
                    ++p;
                }
            }
            return this.logScore[0] - ls;
        }
        catch (Exception e) {
            throw DifferentiableHigherOrderHMM.getRunTimeException(e);
        }
    }

    protected void fillLogEmissionAndPartialDer(int endPos, Sequence seq, boolean grad) throws OperationNotSupportedException, WrongLengthException {
        if (grad) {
            int stateID = 0;
            while (stateID < this.states.length) {
                this.indicesState[stateID].clear();
                this.partDerState[stateID].clear();
                this.logEmission[stateID] = ((DifferentiableState)this.states[stateID]).getLogScoreAndPartialDerivation(endPos, endPos, this.indicesState[stateID], this.partDerState[stateID], seq);
                ++stateID;
            }
        } else {
            int stateID = 0;
            while (stateID < this.states.length) {
                this.logEmission[stateID] = this.states[stateID].getLogScoreFor(endPos, endPos, seq);
                ++stateID;
            }
        }
    }

    protected double max(double[] intermediate, IntList children, int layer, int context, double val, double[][][] gradient, double[][][] altGradient) {
        if (children.length() == 0) {
            if (gradient != null) {
                this.resetGradient(gradient, layer, context, 0.0);
            }
            return val;
        }
        int idx = ToolBox.getMaxIndex(children, intermediate);
        if (gradient != null) {
            int h = layer % 2;
            int x = (layer + this.index[2][idx]) % 2;
            double[] old = altGradient == null || this.index[3][idx] == 0 ? gradient[x][this.index[1][idx]] : altGradient[x][this.index[1][idx]];
            int p = 0;
            while (p < this.numberOfParameters) {
                gradient[h][context][p] = old[p];
                ++p;
            }
            this.miniMerge(idx, 1.0, h, context, gradient);
        }
        return intermediate[idx];
    }

    protected void merge(double[][] matrix, double[] intermediate, int anz, int layer, int context, double val, HigherOrderHMM.Type score) {
        if (anz == 0) {
            matrix[layer][context] = val;
            this.resetGradient(this.gradient, layer, context, 0.0);
        } else {
            int h = layer % 2;
            if (score == HigherOrderHMM.Type.VITERBI) {
                int idx = ToolBox.getMaxIndex(0, anz, intermediate);
                System.arraycopy(this.gradient[(layer + this.index[2][idx]) % 2][this.index[1][idx]], 0, this.gradient[h][context], 0, this.numberOfParameters);
                this.miniMerge(idx, 1.0, h, context, this.gradient);
                matrix[layer][context] = intermediate[idx];
            } else {
                matrix[layer][context] = Normalisation.logSumNormalisation(intermediate, 0, anz, intermediate, 0);
                Arrays.fill(this.gradient[h][context], 0.0);
                int i = 0;
                while (i < anz) {
                    int x = (layer + this.index[2][i]) % 2;
                    int p = 0;
                    while (p < this.numberOfParameters) {
                        double[] dArray = this.gradient[h][context];
                        int n = p;
                        dArray[n] = dArray[n] + intermediate[i] * this.gradient[x][this.index[1][i]][p];
                        ++p;
                    }
                    this.miniMerge(i, intermediate[i], h, context, this.gradient);
                    ++i;
                }
            }
        }
    }

    private void miniMerge(int i, double weight, int h, int context, double[][][] gradient) {
        int p = 0;
        while (p < this.indicesTransition[i].length()) {
            double[] dArray = gradient[h][context];
            int n = this.indicesTransition[i].get(p);
            dArray[n] = dArray[n] + weight * this.partDerTransition[i].get(p);
            ++p;
        }
        int j = this.getIndex(this.index[0][i]);
        int p2 = 0;
        while (p2 < this.indicesState[j].length()) {
            double[] dArray = gradient[h][context];
            int n = this.indicesState[j].get(p2);
            dArray[n] = dArray[n] + weight * this.partDerState[j].get(p2);
            ++p2;
        }
    }

    private void resetGradient(double[][][] gradient, int layer, int context, double val) {
        Arrays.fill(gradient[layer % 2][context], val);
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        int off = 0;
        int i = 0;
        while (i < this.emission.length) {
            int num = ((DifferentiableEmission)this.emission[i]).getNumberOfParameters();
            if (num > 0 && index >= off && index < off + num) {
                return ((DifferentiableEmission)this.emission[i]).getSizeOfEventSpace();
            }
            off += num;
            ++i;
        }
        return ((DifferentiableTransition)this.transition).getSizeOfEventSpace(index);
    }

    @Override
    public int[][] getSamplingGroups(int parameterOffset) {
        LinkedList<int[]> list = new LinkedList<int[]>();
        int i = 0;
        while (i < this.emission.length) {
            ((DifferentiableEmission)this.emission[i]).fillSamplingGroups(parameterOffset, list);
            ++i;
        }
        ((DifferentiableTransition)this.transition).fillSamplingGroups(parameterOffset, list);
        return (int[][])list.toArray((T[])new int[0][0]);
    }

    @Override
    public String getInstanceName() {
        return "differentiable HMM(" + this.transition.getMaximalMarkovOrder() + ", " + (Object)((Object)this.training) + ")";
    }

    public void setTrain(boolean train) {
        this.train = train;
        this.score = train ? (this.training == NumericalHMMTrainingParameterSet.TrainingType.LIKELIHOOD || this.training == NumericalHMMTrainingParameterSet.TrainingType.DISCRIMINATIVE_LIKELIHOOD ? HigherOrderHMM.Type.LIKELIHOOD : HigherOrderHMM.Type.VITERBI) : HigherOrderHMM.Type.LIKELIHOOD;
    }

    public void check(DataSet data) throws Exception {
        double[] params = new double[1 + this.getNumberOfParameters()];
        System.arraycopy(this.getCurrentParameterValues(), 0, params, 1, params.length - 1);
        this.setTrain(true);
        double[][] w = new double[1][data.getNumberOfElements()];
        Arrays.fill(w[0], 1.0);
        LogGenDisMixFunction log = new LogGenDisMixFunction(((NumericalHMMTrainingParameterSet)this.trainingParameter).getNumberOfThreads(), new DifferentiableHigherOrderHMM[]{this}, new DataSet[]{data}, w, null, LearningPrinciple.getBeta(this.ess == 0.0 ? LearningPrinciple.ML : LearningPrinciple.MAP), true, false);
        log.reset();
        double[] grad = log.evaluateGradientOfFunction(params);
        int min = ToolBox.getMinIndex(grad);
        int max = ToolBox.getMaxIndex(grad);
        System.out.println("grad\t" + grad.length + "\t" + grad[min] + "\t" + grad[max] + "\t" + DifferentiableHigherOrderHMM.dist(grad, null) + "\t" + Arrays.toString(grad));
        this.setTrain(false);
    }

    private static double dist(double[] a, double[] b) {
        double dist = 0.0;
        int i = 0;
        while (i < a.length) {
            double diff = a[i] - (b == null ? 0.0 : b[i]);
            dist += diff * diff;
            ++i;
        }
        return Math.sqrt(dist) / (double)a.length;
    }
}

