/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements;

import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.BasicTransitionElement;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.ToolBox;
import java.util.Arrays;

public class TransitionElement
extends BasicTransitionElement {
    protected double[] probs;
    protected int offset;

    public TransitionElement(int[] context, int[] states, double[] hyperParameters) {
        this(context, states, hyperParameters, null);
    }

    public TransitionElement(int[] context, int[] states, double[] hyperParameters, double[] weight) {
        this(context, states, hyperParameters, weight, true);
    }

    public TransitionElement(int[] context, int[] states, double[] hyperParameters, double[] weight, boolean norm) {
        super(context, states, hyperParameters, weight, norm);
    }

    public TransitionElement(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public TransitionElement clone() throws CloneNotSupportedException {
        TransitionElement clone = (TransitionElement)super.clone();
        clone.probs = (double[])this.probs.clone();
        return clone;
    }

    @Override
    protected void init() {
        this.probs = new double[this.parameters.length];
        super.init();
    }

    @Override
    protected void precompute() {
        if (this.norm) {
            this.logNorm = Normalisation.logSumNormalisation(this.parameters, 0, this.parameters.length, this.probs, 0);
        } else {
            this.logNorm = 0.0;
            Arrays.fill(this.probs, 0.0);
        }
    }

    public double getLogScoreAndPartialDerivation(int childIdx, IntList indices, DoubleList partialDer, Sequence sequence, int sequencePosition) {
        if (this.parameters.length > 1) {
            if (this.norm) {
                int i = 0;
                while (i < this.parameters.length) {
                    indices.add(this.offset + i);
                    partialDer.add(-this.probs[i]);
                    ++i;
                }
            }
            indices.add(this.offset + childIdx);
            partialDer.add(1.0);
        }
        return this.parameters[childIdx] - this.logNorm;
    }

    public int setParameterOffset(int o) {
        this.offset = o;
        if (this.parameters.length > 1) {
            return this.offset + this.parameters.length;
        }
        return this.offset;
    }

    public int fillParameters(double[] params, int offset) {
        if (this.parameters.length > 1) {
            int i = 0;
            while (i < this.parameters.length) {
                params[offset + i] = this.parameters[i];
                ++i;
            }
            return offset + this.parameters.length;
        }
        return offset;
    }

    public int setParameters(double[] params, int start) {
        if (this.parameters.length > 1) {
            int i = 0;
            while (i < this.parameters.length) {
                this.parameters[i] = params[start + i];
                ++i;
            }
            this.precompute();
            return start + this.parameters.length;
        }
        return start;
    }

    public void addGradientForLogPriorTerm(double[] gradient, int start) {
        if (this.hyperParameters.length > 1) {
            double sum = 0.0;
            int i = 0;
            while (i < this.hyperParameters.length) {
                sum += this.hyperParameters[i];
                ++i;
            }
            i = 0;
            while (i < this.hyperParameters.length) {
                int n = start + this.offset + i;
                gradient[n] = gradient[n] + (this.hyperParameters[i] - sum * this.probs[i]);
                ++i;
            }
        }
    }

    public double getMinimalHyperparameter() {
        return ToolBox.min(this.hyperParameters);
    }

    public double getLogPosteriorFromStatistic() {
        double logPost = 0.0;
        int i = 0;
        while (i < this.parameters.length) {
            logPost += this.statistic[i] * (this.parameters[i] - this.logNorm);
            ++i;
        }
        return logPost;
    }
}

