/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models;

import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.SamplingDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.trainable.DifferentiableStatisticalModelWrapperTrainSM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.DifferentiableState;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.SimpleDifferentiableState;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.DifferentiableEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.HMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.MaxHMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.NumericalHMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.BasicHigherOrderTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.DifferentiableTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.TransitionElement;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.ToolBox;
import java.util.Arrays;
import java.util.LinkedList;

public class DifferentiableHigherOrderHMM
extends HigherOrderHMM
implements SamplingDifferentiableStatisticalModel {
    protected int numberOfParameters;
    protected double ess;
    protected HigherOrderHMM.Type score;
    protected int[][] index;
    protected double[][][] gradient;
    protected IntList[] indicesState;
    protected IntList[] indicesTransition;
    protected DoubleList[] partDerState;
    protected DoubleList[] partDerTransition;

    public DifferentiableHigherOrderHMM(MaxHMMTrainingParameterSet trainingParameterSet, String[] name, int[] emissionIdx, boolean[] forward, DifferentiableEmission[] emission, boolean likelihood, double ess, TransitionElement ... te) throws Exception {
        super((HMMTrainingParameterSet)trainingParameterSet, name, emissionIdx, forward, emission, (BasicHigherOrderTransition.AbstractTransitionElement[])te);
        this.getOffsets();
        HigherOrderHMM.Type type = this.score = likelihood ? HigherOrderHMM.Type.LIKELIHOOD : HigherOrderHMM.Type.VITERBI;
        if (ess < 0.0) {
            throw new IllegalArgumentException();
        }
        this.ess = ess;
    }

    public DifferentiableHigherOrderHMM(StringBuffer xml) throws NonParsableException {
        super(xml);
        this.getOffsets();
    }

    @Override
    protected void appendFurtherInformation(StringBuffer xml) {
        super.appendFurtherInformation(xml);
        XMLParser.appendObjectWithTags(xml, this.ess, "ess");
        XMLParser.appendObjectWithTags(xml, (Object)this.score, "score");
    }

    @Override
    protected void extractFurtherInformation(StringBuffer xml) throws NonParsableException {
        super.extractFurtherInformation(xml);
        this.ess = XMLParser.extractObjectForTags(xml, "ess", Double.TYPE);
        this.score = XMLParser.extractObjectForTags(xml, "score", HigherOrderHMM.Type.class);
    }

    @Override
    protected void createHelperVariables() {
        if (this.container == null) {
            int maxOrder = this.transition.getMaximalMarkovOrder();
            int anz = 0;
            int i = 0;
            while (i <= maxOrder) {
                anz = Math.max(anz, this.transition.getNumberOfIndexes(i));
                ++i;
            }
            if (this.gradient == null || this.gradient[0].length != anz || this.gradient[0][0].length != this.numberOfParameters) {
                this.gradient = new double[2][anz][this.numberOfParameters];
                this.index = new int[3][anz];
            }
            if (this.indicesState == null) {
                anz = this.transition.getMaximalNumberOfChildren();
                try {
                    this.indicesState = (IntList[])ArrayHandler.createArrayOf((Cloneable)new IntList(), (int)this.states.length);
                    this.partDerState = (DoubleList[])ArrayHandler.createArrayOf((Cloneable)new DoubleList(), (int)this.states.length);
                    this.indicesTransition = (IntList[])ArrayHandler.createArrayOf((Cloneable)new IntList(), (int)anz);
                    this.partDerTransition = (DoubleList[])ArrayHandler.createArrayOf((Cloneable)new DoubleList(), (int)anz);
                }
                catch (CloneNotSupportedException cnse) {
                    throw DifferentiableHigherOrderHMM.getRunTimeException(cnse);
                }
            }
        }
        super.createHelperVariables();
    }

    @Override
    protected void createStates() {
        this.states = new SimpleDifferentiableState[this.emissionIdx.length];
        int i = 0;
        while (i < this.emissionIdx.length) {
            this.states[i] = new SimpleDifferentiableState((DifferentiableEmission)this.emission[this.emissionIdx[i]], this.name[i], this.forward[i]);
            ++i;
        }
    }

    @Override
    public DifferentiableHigherOrderHMM clone() throws CloneNotSupportedException {
        double[][][] grad = this.gradient;
        this.gradient = null;
        IntList[] ind = this.indicesState;
        this.indicesState = null;
        DifferentiableHigherOrderHMM clone = (DifferentiableHigherOrderHMM)super.clone();
        this.gradient = grad;
        this.indicesState = ind;
        return clone;
    }

    @Override
    public double getESS() {
        return this.ess;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        int e = 0;
        while (e < this.emission.length) {
            ((DifferentiableEmission)this.emission[e]).addGradientOfLogPriorTerm(grad, start);
            ++e;
        }
        ((DifferentiableTransition)this.transition).addGradientForLogPriorTerm(grad, start);
    }

    private void getOffsets() {
        this.numberOfParameters = 0;
        int e = 0;
        while (e < this.emission.length) {
            this.numberOfParameters = ((DifferentiableEmission)this.emission[e]).setParameterOffset(this.numberOfParameters);
            if (this.numberOfParameters == -1) {
                return;
            }
            ++e;
        }
        this.numberOfParameters = ((DifferentiableTransition)this.transition).setParameterOffset(this.numberOfParameters);
        if (this.numberOfParameters == -1) {
            return;
        }
        this.createHelperVariables();
    }

    @Override
    public int getNumberOfParameters() {
        return this.numberOfParameters;
    }

    @Override
    public int getNumberOfRecommendedStarts() {
        return this.trainingParameter.getNumberOfStarts();
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        int i = 0;
        int n = this.getNumberOfParameters();
        if (n != -1) {
            double[] params = new double[n];
            int e = 0;
            while (e < this.emission.length) {
                ((DifferentiableEmission)this.emission[e]).fillCurrentParameter(params);
                ++e;
                ++i;
            }
            ((DifferentiableTransition)this.transition).fillParameters(params);
            return params;
        }
        throw new IllegalArgumentException();
    }

    @Override
    public boolean isInitialized() {
        return true;
    }

    @Override
    public void setParameters(double[] params, int start) {
        int e = 0;
        while (e < this.emission.length) {
            ((DifferentiableEmission)this.emission[e]).setParameter(params, start);
            ++e;
        }
        ((DifferentiableTransition)this.transition).setParameters(params, start);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        if (this.skipInit) {
            return;
        }
        this.initializeRandomly();
        this.getOffsets();
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        if (this.skipInit) {
            return;
        }
        if (this.trainingParameter instanceof NumericalHMMTrainingParameterSet) {
            this.initializeFunctionRandomly(freeParams);
        } else {
            this.train(data[index], weights == null ? null : weights[index]);
            this.getOffsets();
        }
    }

    @Override
    public void train(DataSet data, double[] weights) throws Exception {
        if (this.trainingParameter instanceof NumericalHMMTrainingParameterSet) {
            NumericalHMMTrainingParameterSet params = (NumericalHMMTrainingParameterSet)this.trainingParameter;
            DifferentiableStatisticalModelWrapperTrainSM model = new DifferentiableStatisticalModelWrapperTrainSM(this, params.getNumberOfThreads(), params.getAlgorithm(), params.getTerminationCondition(), params.getLineEps(), params.getStartDistance());
            model.setOutputStream(this.sostream);
            model.train(data, weights);
            DifferentiableHigherOrderHMM hmm = (DifferentiableHigherOrderHMM)model.getFunction();
            this.emission = hmm.emission;
            this.createStates();
            this.transition = hmm.transition;
        } else {
            super.train(data, weights);
        }
    }

    @Override
    public boolean isNormalized() {
        return true;
    }

    @Override
    public double getLogNormalizationConstant() {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double getInitialClassParam(double classProb) {
        return Math.log(classProb);
    }

    @Override
    public double getLogScoreFor(Sequence seq) {
        return this.getLogScoreFor(seq, 0);
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        return this.getLogScoreFor(seq, start, seq.getLength() - 1);
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start, int end) {
        return this.logProb(start, seq.getLength() - 1, seq);
    }

    @Override
    protected double logProb(int startpos, int endpos, Sequence sequence) {
        try {
            this.fillBwdOrViterbiMatrix(this.score, startpos, endpos, 0.0, sequence);
        }
        catch (Exception e) {
            throw DifferentiableHigherOrderHMM.getRunTimeException(e);
        }
        return this.bwdMatrix[0][0];
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, IntList indices, DoubleList partialDer) {
        return this.getLogScoreAndPartialDerivation(seq, 0, indices, partialDer);
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int startPos, IntList indices, DoubleList partialDer) {
        return this.getLogScoreAndPartialDerivation(seq, startPos, seq.getLength() - 1, indices, partialDer);
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int startPos, int endPos, IntList indices, DoubleList partialDer) {
        try {
            int children;
            int n;
            boolean zero = this.transition.getMaximalMarkovOrder() == 0;
            int l = endPos - startPos + 1;
            this.provideMatrix(1, endPos - startPos + 1);
            int idx2 = 0;
            while (idx2 < this.gradient[1].length) {
                Arrays.fill(this.gradient[0][idx2], 0.0);
                Arrays.fill(this.gradient[1][idx2], 0.0);
                ++idx2;
            }
            DifferentiableTransition diffTransition = (DifferentiableTransition)this.transition;
            int stateID = 0;
            while (stateID < this.states.length) {
                this.indicesState[stateID].clear();
                this.partDerState[stateID].clear();
                ++stateID;
            }
            int context = this.bwdMatrix[l].length - 1;
            while (context >= 0) {
                n = this.transition.getNumberOfChildren(l, context);
                children = 0;
                double val = zero || this.finalState[this.transition.getLastContextState(l, context)] ? 0.0 : Double.NEGATIVE_INFINITY;
                stateID = 0;
                while (stateID < n) {
                    this.transition.fillTransitionInformation(l, context, stateID, this.container);
                    if (this.states[this.container[0]].isSilent()) {
                        this.indicesTransition[children].clear();
                        this.partDerTransition[children].clear();
                        this.backwardIntermediate[children] = this.bwdMatrix[l][this.container[1]] + diffTransition.getLogScoreAndPartialDerivation(l, context, stateID, this.indicesTransition[children], this.partDerTransition[children], seq, endPos);
                        if (this.backwardIntermediate[children] != Double.NEGATIVE_INFINITY) {
                            this.index[0][children] = this.container[0];
                            this.index[1][children] = this.container[1];
                            this.index[2][children] = this.container[2];
                            ++children;
                        }
                    }
                    ++stateID;
                }
                if (children == 0) {
                    this.bwdMatrix[l][context] = val;
                    this.resetGradient(l, context, 0.0);
                } else {
                    this.merge(children, l, context, val);
                }
                --context;
            }
            while (--l >= 0) {
                stateID = 0;
                while (stateID < this.states.length) {
                    this.indicesState[stateID].clear();
                    this.partDerState[stateID].clear();
                    this.logEmission[stateID] = ((DifferentiableState)this.states[stateID]).getLogScoreAndPartialDerivation(endPos, endPos, this.indicesState[stateID], this.partDerState[stateID], seq);
                    ++stateID;
                }
                context = this.bwdMatrix[l].length - 1;
                while (context >= 0) {
                    n = this.transition.getNumberOfChildren(l, context);
                    children = 0;
                    stateID = 0;
                    while (stateID < n) {
                        this.indicesTransition[children].clear();
                        this.partDerTransition[children].clear();
                        this.transition.fillTransitionInformation(l, context, stateID, this.container);
                        this.backwardIntermediate[children] = this.bwdMatrix[l + this.container[2]][this.container[1]] + this.logEmission[this.container[0]] + diffTransition.getLogScoreAndPartialDerivation(l, context, stateID, this.indicesTransition[children], this.partDerTransition[children], seq, endPos);
                        if (this.backwardIntermediate[children] != Double.NEGATIVE_INFINITY) {
                            this.index[0][children] = this.container[0];
                            this.index[1][children] = this.container[1];
                            this.index[2][children] = this.container[2];
                            ++children;
                        }
                        ++stateID;
                    }
                    if (children == 0) {
                        this.bwdMatrix[l][context] = Double.NEGATIVE_INFINITY;
                        this.resetGradient(l, context, 0.0);
                    } else {
                        this.merge(children, l, context, Double.NEGATIVE_INFINITY);
                    }
                    --context;
                }
                --endPos;
            }
            int p = 0;
            while (p < this.numberOfParameters) {
                if (this.gradient[0][0][p] != 0.0) {
                    indices.add(p);
                    partialDer.add(this.gradient[0][0][p]);
                }
                ++p;
            }
            return this.bwdMatrix[0][0];
        }
        catch (Exception e) {
            throw DifferentiableHigherOrderHMM.getRunTimeException(e);
        }
    }

    private void merge(int anz, int layer, int context, double extra) {
        int h = layer % 2;
        if (this.score == HigherOrderHMM.Type.VITERBI) {
            int idx = ToolBox.getMaxIndex(0, anz, this.backwardIntermediate);
            if (this.backwardIntermediate[idx] > extra) {
                System.arraycopy(this.gradient[(layer + this.index[2][idx]) % 2][this.index[1][idx]], 0, this.gradient[h][context], 0, this.numberOfParameters);
                this.miniMerge(idx, 1.0, h, context);
                this.bwdMatrix[layer][context] = this.backwardIntermediate[idx];
            } else {
                this.bwdMatrix[layer][context] = extra;
            }
        } else {
            this.bwdMatrix[layer][context] = extra != Double.NEGATIVE_INFINITY ? Normalisation.logSumNormalisation(this.backwardIntermediate, 0, anz, new double[]{extra}, this.backwardIntermediate, 0) : Normalisation.logSumNormalisation(this.backwardIntermediate, 0, anz, this.backwardIntermediate, 0);
            Arrays.fill(this.gradient[h][context], 0.0);
            int i = 0;
            while (i < anz) {
                int x = (layer + this.index[2][i]) % 2;
                int p = 0;
                while (p < this.numberOfParameters) {
                    double[] dArray = this.gradient[h][context];
                    int n = p;
                    dArray[n] = dArray[n] + this.backwardIntermediate[i] * this.gradient[x][this.index[1][i]][p];
                    ++p;
                }
                this.miniMerge(i, this.backwardIntermediate[i], h, context);
                ++i;
            }
        }
    }

    private void miniMerge(int i, double weight, int h, int context) {
        int p = 0;
        while (p < this.indicesTransition[i].length()) {
            double[] dArray = this.gradient[h][context];
            int n = this.indicesTransition[i].get(p);
            dArray[n] = dArray[n] + weight * this.partDerTransition[i].get(p);
            ++p;
        }
        p = 0;
        while (p < this.indicesState[this.index[0][i]].length()) {
            double[] dArray = this.gradient[h][context];
            int n = this.indicesState[this.index[0][i]].get(p);
            dArray[n] = dArray[n] + weight * this.partDerState[this.index[0][i]].get(p);
            ++p;
        }
    }

    private void resetGradient(int layer, int context, double val) {
        Arrays.fill(this.gradient[layer % 2][context], val);
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        int off = 0;
        int i = 0;
        while (i < this.emission.length) {
            int num = ((DifferentiableEmission)this.emission[i]).getNumberOfParameters();
            if (num > 0 && index >= off && index < off + num) {
                return ((DifferentiableEmission)this.emission[i]).getSizeOfEventSpace();
            }
            off += num;
            ++i;
        }
        return ((DifferentiableTransition)this.transition).getSizeOfEventSpace(index);
    }

    @Override
    public int[][] getSamplingGroups(int parameterOffset) {
        LinkedList<int[]> list = new LinkedList<int[]>();
        int i = 0;
        while (i < this.emission.length) {
            ((DifferentiableEmission)this.emission[i]).fillSamplingGroups(parameterOffset, list);
            ++i;
        }
        ((DifferentiableTransition)this.transition).fillSamplingGroups(parameterOffset, list);
        return (int[][])list.toArray((T[])new int[0][0]);
    }

    @Override
    public String getInstanceName() {
        return "differentiable HMM(" + this.transition.getMaximalMarkovOrder() + ", " + (Object)((Object)this.score) + ")";
    }
}

