/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models;

import de.jstacs.NotTrainedException;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sampling.BurnInTest;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.HigherOrderHMM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.SamplingState;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.SimpleSamplingState;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.SamplingEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.HMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.SamplingHMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.BasicHigherOrderTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.SamplingTransition;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.TransitionWithSufficientStatistic;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.TransitionElement;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.Pair;
import java.io.IOException;
import java.util.Arrays;

public class SamplingHigherOrderHMM
extends HigherOrderHMM {
    protected BurnInTest burnInTest;
    protected boolean hasSampled;
    private int numberOfStarts;
    private IntList path;
    private static final String XML_TAG = "SamplingHigherOrderHMM";

    public SamplingHigherOrderHMM(SamplingHMMTrainingParameterSet trainingParameterSet, String[] name, int[] emissionIdx, boolean[] forward, SamplingEmission[] emission, TransitionElement ... te) throws Exception {
        super((HMMTrainingParameterSet)trainingParameterSet, name, emissionIdx, forward, emission, (BasicHigherOrderTransition.AbstractTransitionElement[])te);
        this.hasSampled = false;
        this.path = new IntList();
    }

    public SamplingHigherOrderHMM(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public SamplingHigherOrderHMM clone() throws CloneNotSupportedException {
        SamplingHigherOrderHMM clone = (SamplingHigherOrderHMM)super.clone();
        clone.path = this.path.clone();
        if (this.burnInTest != null) {
            clone.burnInTest = this.burnInTest.clone();
        }
        return clone;
    }

    @Override
    protected void createStates() {
        this.states = new SimpleSamplingState[this.emissionIdx.length];
        int i = 0;
        while (i < this.emissionIdx.length) {
            this.states[i] = new SimpleSamplingState((SamplingEmission)this.emission[this.emissionIdx[i]], this.name[i], this.forward[i]);
            ++i;
        }
    }

    protected void acceptParameters() throws IOException {
        ((SamplingTransition)this.transition).acceptParameters();
        int e = 0;
        while (e < this.emission.length) {
            ((SamplingEmission)this.emission[e]).acceptParameters();
            ++e;
        }
    }

    protected void drawFromStatistics() throws Exception {
        ((SamplingTransition)this.transition).drawParametersFromStatistic();
        int e = 0;
        while (e < this.emission.length) {
            ((SamplingEmission)this.emission[e]).drawParametersFromStatistic();
            ++e;
        }
    }

    protected double gibbsSampling(int startPos, int endPos, double weight, Sequence seq) throws Exception {
        this.samplePath(this.path, startPos, endPos, seq);
        this.addToStatistics(startPos, weight, seq, this.path);
        return this.bwdMatrix[0][0];
    }

    private void addToStatistics(int startPos, double weight, Sequence seq, IntList p) throws Exception {
        int l = 0;
        int layer = 0;
        this.container[1] = 0;
        while (l < p.length()) {
            int state = p.get(l);
            int childIdx = this.transition.getChildIdx(layer, this.container[1], state);
            if (childIdx < 0) {
                throw new IllegalArgumentException("Impossible path");
            }
            ((SamplingState)this.states[state]).addToStatistic(startPos, startPos, weight, seq);
            ((SamplingTransition)this.transition).addToStatistic(layer, this.container[1], childIdx, weight, seq, startPos);
            this.transition.fillTransitionInformation(layer, this.container[1], childIdx, this.container);
            if (this.container[2] == 1) {
                ++startPos;
                ++layer;
            }
            ++l;
        }
    }

    private double getLogGammaScoreForCurrentStatistics() {
        double score = ((TransitionWithSufficientStatistic)this.transition).getLogGammaScoreFromStatistic();
        int state = 0;
        while (state < this.states.length) {
            score += ((SamplingState)this.states[state]).getLogGammaScoreForCurrentStatistic();
            ++state;
        }
        return score;
    }

    protected void gibbsSamplingStep(int sampling, int steps, boolean append, DataSet data, double[] weights) throws Exception {
        double score = 0.0;
        double weight = 1.0;
        int N = data.getNumberOfElements();
        this.burnInTest.setCurrentSamplingIndex(sampling);
        ((SamplingTransition)this.transition).extendSampling(sampling, append);
        int e = 0;
        while (e < this.emission.length) {
            ((SamplingEmission)this.emission[e]).extendSampling(sampling, append);
            ++e;
        }
        this.sostream.writeln(String.valueOf(sampling) + " ----------------------------------------");
        int s = 0;
        while (s < steps) {
            this.resetStatistics();
            score = this.getLogPriorTerm();
            int n = 0;
            while (n < N) {
                Sequence seq = data.getElementAt(n);
                if (weights != null) {
                    weight = weights[n];
                }
                score += this.gibbsSampling(0, seq.getLength() - 1, weight, seq);
                ++n;
            }
            this.sostream.writeln(String.valueOf(s) + "\t" + score);
            this.burnInTest.setValue(score);
            this.getNewParameters();
            this.acceptParameters();
            ++s;
        }
    }

    protected void getNewParameters() throws Exception {
        this.drawFromStatistics();
    }

    @Override
    public void train(DataSet data, double[] weights) throws Exception {
        int start;
        int numberOfBurnInSteps;
        int numberOfSteps = ((SamplingHMMTrainingParameterSet)this.trainingParameter).getNumberOfStepsInStationaryPhase();
        int steps = ((SamplingHMMTrainingParameterSet)this.trainingParameter).getNumberOfStepsPerIteration();
        int samplingCounter = 0;
        boolean append = false;
        this.initTraining(data, weights);
        this.sostream.writeln("GIBBS-SAMPLING - Burn-In ==============================");
        do {
            start = 0;
            while (start < this.numberOfStarts) {
                if (samplingCounter == 0) {
                    this.initializeRandomly();
                }
                this.gibbsSamplingStep(start, steps, append, data, weights);
                ++start;
            }
            append = true;
        } while ((samplingCounter += steps) - (numberOfBurnInSteps = this.burnInTest.getLengthOfBurnIn()) <= 0);
        this.sostream.writeln("GIBBS-SAMPLING - Final Sampling =======================");
        start = 0;
        while (start < this.numberOfStarts) {
            this.gibbsSamplingStep(start, numberOfSteps, append, data, weights);
            ++start;
        }
        ((SamplingTransition)this.transition).samplingStopped();
        int e = 0;
        while (e < this.emission.length) {
            ((SamplingEmission)this.emission[e]).samplingStopped();
            ++e;
        }
        this.hasSampled = true;
    }

    @Override
    public String getInstanceName() {
        return "Sampling HMM(" + this.transition.getMaximalMarkovOrder() + ")";
    }

    @Override
    public boolean isInitialized() {
        return this.hasSampled;
    }

    @Override
    protected double logProb(int startpos, int endpos, Sequence sequence) throws Exception {
        double logProb = Double.NEGATIVE_INFINITY;
        int numSamples = 0;
        if (this.hasSampled) {
            int start = 0;
            while (start < this.numberOfStarts) {
                boolean furtherParam = this.parseParameterSet(start, this.burnInTest.getLengthOfBurnIn());
                while (furtherParam) {
                    logProb = Normalisation.getLogSum(logProb, super.logProb(startpos, endpos, sequence));
                    furtherParam = this.parseNextParameterSet();
                    ++numSamples;
                }
                ++start;
            }
        } else {
            throw new NotTrainedException();
        }
        return logProb - Math.log(numSamples);
    }

    protected boolean parseParameterSet(int sampling, int idx) throws Exception {
        boolean parsed = ((SamplingTransition)this.transition).parseParameterSet(sampling, idx);
        int e = 0;
        while (e < this.emission.length) {
            parsed &= ((SamplingEmission)this.emission[e]).parseParameterSet(sampling, idx);
            ++e;
        }
        return parsed;
    }

    protected boolean parseNextParameterSet() throws Exception {
        boolean parsed = ((SamplingTransition)this.transition).parseNextParameterSet();
        int e = 0;
        while (e < this.emission.length) {
            parsed &= ((SamplingEmission)this.emission[e]).parseNextParameterSet();
            ++e;
        }
        return parsed;
    }

    @Override
    protected String getXMLTag() {
        return XML_TAG;
    }

    @Override
    protected void appendFurtherInformation(StringBuffer xml) {
        super.appendFurtherInformation(xml);
        XMLParser.appendObjectWithTags(xml, this.burnInTest, "burnInTest");
        XMLParser.appendObjectWithTags(xml, this.hasSampled, "hasSampled");
    }

    @Override
    protected void extractFurtherInformation(StringBuffer xml) throws NonParsableException {
        super.extractFurtherInformation(xml);
        this.numberOfStarts = this.trainingParameter.getNumberOfStarts();
        this.burnInTest = XMLParser.extractObjectForTags(xml, "burnInTest", BurnInTest.class);
        this.hasSampled = XMLParser.extractObjectForTags(xml, "hasSampled", Boolean.TYPE);
        this.path = new IntList();
    }

    protected void initTraining(DataSet data, double[] weights) throws Exception {
        this.numberOfStarts = this.trainingParameter.getNumberOfStarts();
        this.burnInTest = ((SamplingHMMTrainingParameterSet)this.trainingParameter).getBurnInTest();
        this.burnInTest.resetAllValues();
        ((SamplingTransition)this.transition).initForSampling(this.numberOfStarts);
        int e = 0;
        while (e < this.emission.length) {
            ((SamplingEmission)this.emission[e]).initForSampling(this.numberOfStarts);
            ++e;
        }
        this.furtherInits(data, weights);
    }

    protected void furtherInits(DataSet data, double[] weights) throws Exception {
    }

    @Override
    public double[][] getLogStatePosteriorMatrixFor(int startPos, int endPos, Sequence seq) throws Exception {
        double[][] statePosterior = this.createMatrixForStatePosterior(startPos, endPos);
        double[][] tmp = this.createMatrixForStatePosterior(startPos, endPos);
        int numSamples = 0;
        int s = 0;
        while (s < this.states.length) {
            Arrays.fill(statePosterior[s], Double.NEGATIVE_INFINITY);
            ++s;
        }
        if (this.hasSampled) {
            int start = 0;
            while (start < this.numberOfStarts) {
                boolean furtherParam = this.parseParameterSet(start, this.burnInTest.getLengthOfBurnIn());
                while (furtherParam) {
                    this.fillLogStatePosteriorMatrix(tmp, startPos, endPos, seq, true);
                    int s2 = 0;
                    while (s2 < this.states.length) {
                        int l = 0;
                        while (l < statePosterior[s2].length) {
                            statePosterior[s2][l] = Normalisation.getLogSum(statePosterior[s2][l], tmp[s2][l]);
                            ++l;
                        }
                        ++s2;
                    }
                    furtherParam = this.parseNextParameterSet();
                    ++numSamples;
                }
                ++start;
            }
            double d = Math.log(numSamples);
            int s3 = 0;
            while (s3 < this.states.length) {
                int l = 0;
                while (l < statePosterior[s3].length) {
                    double[] dArray = statePosterior[s3];
                    int n = l++;
                    dArray[n] = dArray[n] - d;
                }
                ++s3;
            }
        } else {
            throw new NotTrainedException();
        }
        return this.getFinalStatePosterioriMatrix(statePosterior);
    }

    @Override
    public double getLogProbForPath(IntList path, int startPos, Sequence seq) throws Exception {
        int numSamples = 0;
        DoubleList d = new DoubleList(5000);
        if (this.hasSampled) {
            int start = 0;
            while (start < this.numberOfStarts) {
                boolean furtherParam = this.parseParameterSet(start, this.burnInTest.getLengthOfBurnIn());
                while (furtherParam) {
                    d.add(super.getLogProbForPath(path, startPos, seq));
                    furtherParam = this.parseNextParameterSet();
                    ++numSamples;
                }
                ++start;
            }
            return Normalisation.getLogSum(d.toArray()) - Math.log(numSamples);
        }
        throw new NotTrainedException();
    }

    @Override
    public Pair<IntList, Double> getViterbiPathFor(int startPos, int endPos, Sequence seq) throws Exception {
        return this.getViterbiPath(startPos, endPos, seq, ViterbiComputation.MAX);
    }

    public Pair<IntList, Double> getViterbiPath(int startPos, int endPos, Sequence seq, ViterbiComputation compute) throws Exception {
        IntList bestPath = new IntList();
        double bestScore = Double.NEGATIVE_INFINITY;
        if (this.hasSampled) {
            int start = 0;
            while (start < this.numberOfStarts) {
                boolean furtherParam = this.parseParameterSet(start, this.burnInTest.getLengthOfBurnIn());
                while (furtherParam) {
                    int i;
                    double score;
                    if (compute == ViterbiComputation.SAMPLING || compute == ViterbiComputation.SAMPLING_GAMMA || compute == ViterbiComputation.MAX_AND_SAMPLING || compute == ViterbiComputation.MAX_AND_SAMPLING_GAMMA) {
                        this.resetStatistics();
                        this.path.clear();
                        this.gibbsSampling(startPos, endPos, 1.0, seq);
                        double d = score = compute == ViterbiComputation.SAMPLING || compute == ViterbiComputation.MAX_AND_SAMPLING ? super.getLogProbForPath(this.path, startPos, seq) : this.getLogGammaScoreForCurrentStatistics();
                        if (score > bestScore) {
                            bestScore = score;
                            bestPath.clear();
                            i = 0;
                            while (i < this.path.length()) {
                                bestPath.add(this.path.get(i));
                                ++i;
                            }
                        }
                    }
                    if (compute == ViterbiComputation.MAX || compute == ViterbiComputation.MAX_GAMMA || compute == ViterbiComputation.MAX_AND_SAMPLING || compute == ViterbiComputation.MAX_AND_SAMPLING_GAMMA) {
                        this.resetStatistics();
                        this.path.clear();
                        score = this.viterbi(this.path, startPos, endPos, 0.0, seq, null);
                        this.addToStatistics(startPos, 1.0, seq, this.path);
                        double d = score = compute == ViterbiComputation.MAX || compute == ViterbiComputation.MAX_AND_SAMPLING ? score : this.getLogGammaScoreForCurrentStatistics();
                        if (score > bestScore) {
                            bestScore = score;
                            bestPath.clear();
                            i = 0;
                            while (i < this.path.length()) {
                                bestPath.add(this.path.get(i));
                                ++i;
                            }
                        }
                    }
                    furtherParam = this.parseNextParameterSet();
                }
                ++start;
            }
        } else {
            throw new NotTrainedException();
        }
        return new Pair<IntList, Double>(bestPath, bestScore);
    }

    protected double getLogPosteriorFromStatistic() {
        double logPosterior = ((SamplingTransition)this.transition).getLogPosteriorFromStatistic();
        int state = 0;
        while (state < this.states.length) {
            logPosterior += ((SimpleSamplingState)this.states[state]).getLogPosteriorFromStatistic();
            ++state;
        }
        return logPosterior;
    }

    public static enum ViterbiComputation {
        MAX,
        MAX_GAMMA,
        SAMPLING,
        SAMPLING_GAMMA,
        MAX_AND_SAMPLING,
        MAX_AND_SAMPLING_GAMMA;

    }
}

