/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.models.hmm.models;

import de.jstacs.NonParsableException;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.XMLParser;
import de.jstacs.models.NormalizableScoringFunctionModel;
import de.jstacs.models.hmm.HMMTrainingParameterSet;
import de.jstacs.models.hmm.models.HigherOrderHMM;
import de.jstacs.models.hmm.states.DifferentiableState;
import de.jstacs.models.hmm.states.SimpleDifferentiableState;
import de.jstacs.models.hmm.states.emissions.DifferentiableEmission;
import de.jstacs.models.hmm.training.MaxHMMTrainingParameterSet;
import de.jstacs.models.hmm.training.NumericalHMMTrainingParameterSet;
import de.jstacs.models.hmm.transitions.BasicHigherOrderTransition;
import de.jstacs.models.hmm.transitions.DifferentiableTransition;
import de.jstacs.models.hmm.transitions.elements.TransitionElement;
import de.jstacs.scoringFunctions.SamplingScoringFunction;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.ToolBox;
import java.util.Arrays;
import java.util.LinkedList;

public class DifferentiableHigherOrderHMM
extends HigherOrderHMM
implements SamplingScoringFunction {
    protected int numberOfParameters;
    protected double ess;
    protected HigherOrderHMM.Type score;
    protected int[][] index;
    protected double[][][] gradient;
    protected IntList[] indicesState;
    protected IntList[] indicesTransition;
    protected DoubleList[] partDerState;
    protected DoubleList[] partDerTransition;

    public DifferentiableHigherOrderHMM(MaxHMMTrainingParameterSet trainingParameterSet, String[] name, int[] emissionIdx, boolean[] forward, DifferentiableEmission[] emission, boolean likelihood, double ess, TransitionElement ... te) throws Exception {
        super((HMMTrainingParameterSet)trainingParameterSet, name, emissionIdx, forward, emission, (BasicHigherOrderTransition.AbstractTransitionElement[])te);
        this.getOffsets();
        HigherOrderHMM.Type type = this.score = likelihood ? HigherOrderHMM.Type.LIKELIHOOD : HigherOrderHMM.Type.VITERBI;
        if (ess < 0.0) {
            throw new IllegalArgumentException();
        }
        this.ess = ess;
    }

    @Override
    protected void appendFurtherInformation(StringBuffer xml) {
        super.appendFurtherInformation(xml);
        XMLParser.appendObjectWithTags(xml, this.ess, "ess");
        XMLParser.appendObjectWithTags(xml, (Object)this.score, "score");
    }

    @Override
    protected void extractFurtherInformation(StringBuffer xml) throws NonParsableException {
        super.extractFurtherInformation(xml);
        try {
            this.ess = XMLParser.extractObjectForTags(xml, "ess", Double.TYPE);
            this.score = XMLParser.extractObjectForTags(xml, "score", HigherOrderHMM.Type.class);
        }
        catch (NonParsableException e) {
            this.ess = 16.0;
            this.score = HigherOrderHMM.Type.LIKELIHOOD;
        }
    }

    @Override
    protected void createHelperVariables() {
        if (this.container == null) {
            int maxOrder = this.transition.getMaximalMarkovOrder();
            int anz = 0;
            for (int i = 0; i <= maxOrder; ++i) {
                anz = Math.max(anz, this.transition.getNumberOfIndexes(i));
            }
            if (this.gradient == null || this.gradient[0].length != anz || this.gradient[0][0].length != this.numberOfParameters) {
                this.gradient = new double[2][anz][this.numberOfParameters];
                this.index = new int[3][anz];
            }
            if (this.indicesState == null) {
                anz = this.transition.getMaximalNumberOfChildren();
                try {
                    this.indicesState = (IntList[])ArrayHandler.createArrayOf((Cloneable)new IntList(), (int)this.states.length);
                    this.partDerState = (DoubleList[])ArrayHandler.createArrayOf((Cloneable)new DoubleList(), (int)this.states.length);
                    this.indicesTransition = (IntList[])ArrayHandler.createArrayOf((Cloneable)new IntList(), (int)anz);
                    this.partDerTransition = (DoubleList[])ArrayHandler.createArrayOf((Cloneable)new DoubleList(), (int)anz);
                }
                catch (CloneNotSupportedException cnse) {
                    throw DifferentiableHigherOrderHMM.getRunTimeException(cnse);
                }
            }
        }
        super.createHelperVariables();
    }

    public DifferentiableHigherOrderHMM(StringBuffer xml) throws NonParsableException {
        super(xml);
        this.getOffsets();
    }

    @Override
    protected void createStates() {
        this.states = new SimpleDifferentiableState[this.emissionIdx.length];
        for (int i = 0; i < this.emissionIdx.length; ++i) {
            this.states[i] = new SimpleDifferentiableState((DifferentiableEmission)this.emission[this.emissionIdx[i]], this.name[i], this.forward[i]);
        }
    }

    @Override
    public DifferentiableHigherOrderHMM clone() throws CloneNotSupportedException {
        double[][][] grad = this.gradient;
        this.gradient = null;
        IntList[] ind = this.indicesState;
        this.indicesState = null;
        DifferentiableHigherOrderHMM clone = (DifferentiableHigherOrderHMM)super.clone();
        this.gradient = grad;
        this.indicesState = ind;
        return clone;
    }

    @Override
    public double getEss() {
        return this.ess;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        for (int e = 0; e < this.emission.length; ++e) {
            ((DifferentiableEmission)this.emission[e]).addGradientOfLogPriorTerm(grad, start);
        }
        ((DifferentiableTransition)this.transition).addGradientForLogPriorTerm(grad, start);
    }

    private void getOffsets() {
        this.numberOfParameters = 0;
        for (int e = 0; e < this.emission.length; ++e) {
            this.numberOfParameters = ((DifferentiableEmission)this.emission[e]).setParameterOffset(this.numberOfParameters);
            if (this.numberOfParameters != -1) continue;
            return;
        }
        this.numberOfParameters = ((DifferentiableTransition)this.transition).setParameterOffset(this.numberOfParameters);
        if (this.numberOfParameters == -1) {
            return;
        }
        this.createHelperVariables();
    }

    @Override
    public int getNumberOfParameters() {
        return this.numberOfParameters;
    }

    @Override
    public int getNumberOfRecommendedStarts() {
        return this.trainingParameter.getNumberOfStarts();
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        int i = 0;
        int n = this.getNumberOfParameters();
        if (n != -1) {
            double[] params = new double[n];
            int e = 0;
            while (e < this.emission.length) {
                ((DifferentiableEmission)this.emission[e]).fillCurrentParameter(params);
                ++e;
                ++i;
            }
            ((DifferentiableTransition)this.transition).fillParameters(params);
            return params;
        }
        throw new IllegalArgumentException();
    }

    @Override
    public boolean isInitialized() {
        return this.isTrained();
    }

    @Override
    public void setParameters(double[] params, int start) {
        for (int e = 0; e < this.emission.length; ++e) {
            ((DifferentiableEmission)this.emission[e]).setParameter(params, start);
        }
        ((DifferentiableTransition)this.transition).setParameters(params, start);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        this.initializeRandomly();
        this.getOffsets();
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, Sample[] data, double[][] weights) throws Exception {
        if (this.trainingParameter instanceof NumericalHMMTrainingParameterSet) {
            this.initializeFunctionRandomly(freeParams);
        } else {
            this.train(data[index], weights == null ? null : weights[index]);
            this.getOffsets();
        }
    }

    @Override
    public void train(Sample data, double[] weights) throws Exception {
        if (this.trainingParameter instanceof NumericalHMMTrainingParameterSet) {
            NumericalHMMTrainingParameterSet params = (NumericalHMMTrainingParameterSet)this.trainingParameter;
            NormalizableScoringFunctionModel model = new NormalizableScoringFunctionModel(this, params.getNumberOfThreads(), params.getAlgorithm(), params.getTerminantionCondition(), params.getLineEps(), params.getStartDistance());
            model.train(data, weights);
            DifferentiableHigherOrderHMM hmm = (DifferentiableHigherOrderHMM)model.getFunction();
            this.emission = hmm.emission;
            this.createStates();
            this.transition = hmm.transition;
        } else {
            super.train(data, weights);
        }
    }

    @Override
    public boolean isNormalized() {
        return true;
    }

    @Override
    public double getLogNormalizationConstant() {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double getInitialClassParam(double classProb) {
        return Math.log(classProb);
    }

    @Override
    public double getLogScore(Sequence seq) {
        return this.getLogScore(seq, 0);
    }

    @Override
    public double getLogScore(Sequence seq, int start) {
        try {
            int end = seq.getLength() - 1;
            this.fillBwdOrViterbiMatrix(this.score, start, end, 0.0, seq);
            return this.bwdMatrix[0][0];
        }
        catch (Exception e) {
            throw DifferentiableHigherOrderHMM.getRunTimeException(e);
        }
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, IntList indices, DoubleList partialDer) {
        return this.getLogScoreAndPartialDerivation(seq, 0, indices, partialDer);
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int startPos, IntList indices, DoubleList partialDer) {
        try {
            int children;
            int n;
            int context;
            int stateID;
            boolean zero = this.transition.getMaximalMarkovOrder() == 0;
            int endPos = seq.getLength() - 1;
            int l = endPos - startPos + 1;
            this.provideMatrix(1, endPos - startPos + 1);
            for (int idx2 = 0; idx2 < this.gradient[1].length; ++idx2) {
                Arrays.fill(this.gradient[0][idx2], 0.0);
                Arrays.fill(this.gradient[1][idx2], 0.0);
            }
            DifferentiableTransition diffTransition = (DifferentiableTransition)this.transition;
            for (stateID = 0; stateID < this.states.length; ++stateID) {
                this.indicesState[stateID].clear();
                this.partDerState[stateID].clear();
            }
            for (context = this.bwdMatrix[l].length - 1; context >= 0; --context) {
                n = this.transition.getNumberOfChildren(l, context);
                children = 0;
                double val = zero || this.finalState[this.transition.getLastContextState(l, context)] ? 0.0 : Double.NEGATIVE_INFINITY;
                for (stateID = 0; stateID < n; ++stateID) {
                    this.transition.fillTransitionInformation(l, context, stateID, this.container);
                    if (!this.states[this.container[0]].isSilent()) continue;
                    this.indicesTransition[children].clear();
                    this.partDerTransition[children].clear();
                    this.backwardIntermediate[children] = this.bwdMatrix[l][this.container[1]] + diffTransition.getLogScoreAndPartialDerivation(l, context, stateID, this.indicesTransition[children], this.partDerTransition[children], seq, endPos);
                    if (this.backwardIntermediate[children] == Double.NEGATIVE_INFINITY) continue;
                    this.index[0][children] = this.container[0];
                    this.index[1][children] = this.container[1];
                    this.index[2][children] = this.container[2];
                    ++children;
                }
                if (children == 0) {
                    this.bwdMatrix[l][context] = val;
                    this.resetGradient(l, context, 0.0);
                    continue;
                }
                this.merge(children, l, context, val);
            }
            while (--l >= 0) {
                for (stateID = 0; stateID < this.states.length; ++stateID) {
                    this.indicesState[stateID].clear();
                    this.partDerState[stateID].clear();
                    this.logEmission[stateID] = ((DifferentiableState)this.states[stateID]).getLogScoreAndPartialDerivation(endPos, endPos, this.indicesState[stateID], this.partDerState[stateID], seq);
                }
                for (context = this.bwdMatrix[l].length - 1; context >= 0; --context) {
                    n = this.transition.getNumberOfChildren(l, context);
                    children = 0;
                    for (stateID = 0; stateID < n; ++stateID) {
                        this.indicesTransition[children].clear();
                        this.partDerTransition[children].clear();
                        this.transition.fillTransitionInformation(l, context, stateID, this.container);
                        this.backwardIntermediate[children] = this.bwdMatrix[l + this.container[2]][this.container[1]] + this.logEmission[this.container[0]] + diffTransition.getLogScoreAndPartialDerivation(l, context, stateID, this.indicesTransition[children], this.partDerTransition[children], seq, endPos);
                        if (this.backwardIntermediate[children] == Double.NEGATIVE_INFINITY) continue;
                        this.index[0][children] = this.container[0];
                        this.index[1][children] = this.container[1];
                        this.index[2][children] = this.container[2];
                        ++children;
                    }
                    if (children == 0) {
                        this.bwdMatrix[l][context] = Double.NEGATIVE_INFINITY;
                        this.resetGradient(l, context, 0.0);
                        continue;
                    }
                    this.merge(children, l, context, Double.NEGATIVE_INFINITY);
                }
                --endPos;
            }
            for (int p = 0; p < this.numberOfParameters; ++p) {
                if (this.gradient[0][0][p] == 0.0) continue;
                indices.add(p);
                partialDer.add(this.gradient[0][0][p]);
            }
            return this.bwdMatrix[0][0];
        }
        catch (Exception e) {
            throw DifferentiableHigherOrderHMM.getRunTimeException(e);
        }
    }

    private void merge(int anz, int layer, int context, double extra) {
        int h = layer % 2;
        if (this.score == HigherOrderHMM.Type.VITERBI) {
            int idx = ToolBox.getMaxIndex(0, anz, this.backwardIntermediate);
            if (this.backwardIntermediate[idx] > extra) {
                System.arraycopy(this.gradient[(layer + this.index[2][idx]) % 2][this.index[1][idx]], 0, this.gradient[h][context], 0, this.numberOfParameters);
                this.miniMerge(idx, 1.0, h, context);
                this.bwdMatrix[layer][context] = this.backwardIntermediate[idx];
            } else {
                this.bwdMatrix[layer][context] = extra;
            }
        } else {
            this.bwdMatrix[layer][context] = extra != Double.NEGATIVE_INFINITY ? Normalisation.logSumNormalisation(this.backwardIntermediate, 0, anz, new double[]{extra}, this.backwardIntermediate, 0) : Normalisation.logSumNormalisation(this.backwardIntermediate, 0, anz, this.backwardIntermediate, 0);
            Arrays.fill(this.gradient[h][context], 0.0);
            for (int i = 0; i < anz; ++i) {
                int x = (layer + this.index[2][i]) % 2;
                for (int p = 0; p < this.numberOfParameters; ++p) {
                    double[] dArray = this.gradient[h][context];
                    int n = p;
                    dArray[n] = dArray[n] + this.backwardIntermediate[i] * this.gradient[x][this.index[1][i]][p];
                }
                this.miniMerge(i, this.backwardIntermediate[i], h, context);
            }
        }
    }

    private void miniMerge(int i, double weight, int h, int context) {
        int p;
        for (p = 0; p < this.indicesTransition[i].length(); ++p) {
            double[] dArray = this.gradient[h][context];
            int n = this.indicesTransition[i].get(p);
            dArray[n] = dArray[n] + weight * this.partDerTransition[i].get(p);
        }
        for (p = 0; p < this.indicesState[this.index[0][i]].length(); ++p) {
            double[] dArray = this.gradient[h][context];
            int n = this.indicesState[this.index[0][i]].get(p);
            dArray[n] = dArray[n] + weight * this.partDerState[this.index[0][i]].get(p);
        }
    }

    private void resetGradient(int layer, int context, double val) {
        Arrays.fill(this.gradient[layer % 2][context], val);
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        int off = 0;
        for (int i = 0; i < this.emission.length; ++i) {
            int num = ((DifferentiableEmission)this.emission[i]).getNumberOfParameters();
            if (num > 0 && index >= off && index < off + num) {
                return ((DifferentiableEmission)this.emission[i]).getSizeOfEventSpace();
            }
            off += num;
        }
        return ((DifferentiableTransition)this.transition).getSizeOfEventSpace(index);
    }

    @Override
    public int[][] getSamplingGroups(int parameterOffset) {
        LinkedList<int[]> list = new LinkedList<int[]>();
        for (int i = 0; i < this.emission.length; ++i) {
            ((DifferentiableEmission)this.emission[i]).fillSamplingGroups(parameterOffset, list);
        }
        ((DifferentiableTransition)this.transition).fillSamplingGroups(parameterOffset, list);
        return (int[][])list.toArray((T[])new int[0][0]);
    }

    @Override
    public String getInstanceName() {
        return "differentiable HMM(" + this.transition.getMaximalMarkovOrder() + ")";
    }
}

