/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models;

import de.jstacs.data.WrongLengthException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.models.DifferentiableHigherOrderHMM;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.SimpleDifferentiableState;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.emissions.DifferentiableEmission;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.states.filter.Filter;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.training.MaxHMMTrainingParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.TransitionElement;
import javax.naming.OperationNotSupportedException;

public class FastDifferentiableHigherOrderHMM
extends DifferentiableHigherOrderHMM {
    RefSimpleDifferentiableState[] dStates;
    DifferentiableEmission[] dEmission;

    public FastDifferentiableHigherOrderHMM(MaxHMMTrainingParameterSet trainingParameterSet, String[] name, int[] emissionIdx, DifferentiableEmission[] emission, double ess, TransitionElement ... te) throws Exception {
        this(null, null, trainingParameterSet, name, null, emissionIdx, emission, ess, null, te);
    }

    public FastDifferentiableHigherOrderHMM(String type, int[][] statesGroups, MaxHMMTrainingParameterSet trainingParameterSet, String[] name, Filter[] filter, int[] emissionIdx, DifferentiableEmission[] emission, double ess, int[] transIndex, TransitionElement ... te) throws Exception {
        super(type, statesGroups, trainingParameterSet, name, filter, emissionIdx, null, emission, ess, transIndex, te);
    }

    public FastDifferentiableHigherOrderHMM(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    protected void createStates() {
        this.dStates = new RefSimpleDifferentiableState[this.emissionIdx.length];
        this.states = this.dStates;
        this.dEmission = new DifferentiableEmission[this.emission.length];
        int i = 0;
        while (i < this.emission.length) {
            this.dEmission[i] = (DifferentiableEmission)this.emission[i];
            ++i;
        }
        i = 0;
        while (i < this.emissionIdx.length) {
            this.dStates[i] = new RefSimpleDifferentiableState(this.emissionIdx[i], this.dEmission[this.emissionIdx[i]], this.name[i], this.forward[i]);
            ++i;
        }
    }

    @Override
    protected void fillLogEmission(int startPos, Sequence seq) throws OperationNotSupportedException, WrongLengthException {
        int e = 0;
        while (e < this.dEmission.length) {
            this.logEmission[e] = this.dEmission[e].getLogProbFor(true, startPos, startPos, seq);
            ++e;
        }
    }

    @Override
    protected void fillLogEmissionAndPartialDer(int endPos, Sequence seq, boolean grad) throws OperationNotSupportedException, WrongLengthException {
        boolean forward = true;
        if (grad) {
            int e = 0;
            while (e < this.dEmission.length) {
                this.indicesState[e].clear();
                this.partDerState[e].clear();
                this.logEmission[e] = this.dEmission[e].getLogProbAndPartialDerivationFor(forward, endPos, endPos, this.indicesState[e], this.partDerState[e], seq);
                ++e;
            }
        } else {
            int e = 0;
            while (e < this.dEmission.length) {
                this.logEmission[e] = this.dEmission[e].getLogProbFor(forward, endPos, endPos, seq);
                ++e;
            }
        }
    }

    @Override
    protected int getIndex(int i) {
        return this.dStates[i].idx;
    }

    private static class RefSimpleDifferentiableState
    extends SimpleDifferentiableState {
        int idx;

        public RefSimpleDifferentiableState(int idx, DifferentiableEmission e, String name, boolean forward) {
            super(e, name, forward);
            this.idx = idx;
        }
    }
}

