/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements;

import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.ReferenceSequenceAnnotation;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.ReferenceBasedTransitionElement;
import java.text.NumberFormat;

public class ScaledTransitionElement
extends ReferenceBasedTransitionElement {
    protected String[] arrowOptions;
    protected double[] scalingFactor;
    private static final String XML_TAG = "SCALED_TRANSITION_ELEMENT";

    public ScaledTransitionElement(int[] context, int[] states, double[] probabilities, double ess, double[] scalingFactor, String annotationID) {
        super(context, states, ess, probabilities, annotationID);
        this.scalingFactor = (double[])scalingFactor.clone();
        this.init();
    }

    @Override
    protected void init() {
        super.init();
        if (this.scalingFactor != null) {
            int numberOfTransitionClasses = this.scalingFactor.length;
            this.statisticsTransitionProb = new double[numberOfTransitionClasses][this.states.length];
            this.arrowOptions = new String[numberOfTransitionClasses];
            int i = 0;
            while (i < numberOfTransitionClasses) {
                this.arrowOptions[i] = "style=\"setlinewidth(" + (i + 1) + ")\"";
                ++i;
            }
        }
    }

    public ScaledTransitionElement(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public ScaledTransitionElement clone() throws CloneNotSupportedException {
        ScaledTransitionElement clone = (ScaledTransitionElement)super.clone();
        clone.arrowOptions = (String[])this.arrowOptions.clone();
        clone.scalingFactor = (double[])this.scalingFactor.clone();
        return clone;
    }

    protected int getIndex(int pos, Sequence seq) {
        return ((ReferenceSequenceAnnotation)seq.getSequenceAnnotationByTypeAndIdentifier("reference", this.annotationID)).getReferenceSequence().discreteVal(pos);
    }

    @Override
    public void addToStatistic(int childIdx, double weight, Sequence sequence, int sequencePosition) {
        double[] dArray = this.statisticsTransitionProb[this.getIndex(sequencePosition, sequence)];
        int n = childIdx;
        dArray[n] = dArray[n] + weight;
    }

    @Override
    public void resetStatistic() {
        int c = 0;
        while (c < this.statisticsTransitionProb.length) {
            int j = 0;
            while (j < this.states.length) {
                this.statisticsTransitionProb[c][j] = 0.0;
                ++j;
            }
            ++c;
        }
        super.resetStatistic();
    }

    @Override
    public void estimateFromStatistic() {
        double sumOfGammas = 0.0;
        int i = 0;
        while (i < this.states.length) {
            int c = 0;
            while (c < this.statisticsTransitionProb.length) {
                int n = i;
                this.statistic[n] = this.statistic[n] + this.statisticsTransitionProb[c][i];
                ++c;
            }
            sumOfGammas += this.statistic[i];
            ++i;
        }
        this.logNorm = 0.0;
        this.parameters[this.diagElement] = this.determineDiagonalElement(sumOfGammas, 0.001, true, false);
        int j = 0;
        while (j < this.states.length) {
            if (j != this.diagElement) {
                this.parameters[j] = (1.0 - this.parameters[this.diagElement]) * this.statistic[j] / (sumOfGammas - this.statistic[this.diagElement]);
            }
            ++j;
        }
    }

    private double determineDiagonalElement(double sumOfGammas, double currentDiagonalElement, boolean firstStep, boolean output) {
        String stateSymbol = "" + this.diagElement;
        if (firstStep && output) {
            System.out.println("<------ Start '" + stateSymbol + "' ------>");
            System.out.println("\tdiag = " + currentDiagonalElement);
        }
        double variablePart = 0.0;
        double variablePartDerivative = 0.0;
        int c = 0;
        while (c < this.scalingFactor.length) {
            variablePart += this.scalingFactor[c] / (currentDiagonalElement - 1.0 + this.scalingFactor[c]) * this.statisticsTransitionProb[c][this.diagElement];
            variablePartDerivative -= this.scalingFactor[c] / Math.pow(currentDiagonalElement - 1.0 + this.scalingFactor[c], 2.0) * this.statisticsTransitionProb[c][this.diagElement];
            ++c;
        }
        double nextDiagonalElement = currentDiagonalElement - ((variablePart += this.hyperParameters[this.diagElement] / currentDiagonalElement) - sumOfGammas) / (variablePartDerivative -= this.hyperParameters[this.diagElement] / Math.pow(currentDiagonalElement, 2.0));
        variablePart = 0.0;
        c = 0;
        while (c < this.scalingFactor.length) {
            variablePart += this.scalingFactor[c] / (nextDiagonalElement - 1.0 + this.scalingFactor[c]) * this.statisticsTransitionProb[c][this.diagElement];
            ++c;
        }
        variablePart += this.hyperParameters[this.diagElement] / nextDiagonalElement;
        if (output) {
            System.out.println("\tdiag = " + nextDiagonalElement + "\t\tf = " + (variablePart - sumOfGammas));
        }
        if (Math.abs(nextDiagonalElement - currentDiagonalElement) > 1.0E-10) {
            nextDiagonalElement = this.determineDiagonalElement(sumOfGammas, nextDiagonalElement, false, output);
        }
        if (firstStep && output) {
            System.out.println("<------ End   '" + stateSymbol + "' ------>");
        }
        return nextDiagonalElement;
    }

    @Override
    public double getLogScoreFor(int state, Sequence sequence, int sequencePosition) {
        return Math.log(this.getTransitionProb(state, this.getIndex(sequencePosition, sequence)));
    }

    private double getTransitionProb(int j, int tc) {
        double transitionProb = j == this.diagElement ? (this.parameters[j] - 1.0 + this.scalingFactor[tc]) / this.scalingFactor[tc] : this.parameters[j] / this.scalingFactor[tc];
        return transitionProb;
    }

    @Override
    protected void appendTransitions(StringBuffer representation, String contextNodeRepresentation, NumberFormat nf, String arrowOption, boolean graphical) {
        int s = 0;
        while (s < this.states.length) {
            int tc = 0;
            while (tc < this.scalingFactor.length) {
                representation.append("\t" + contextNodeRepresentation + "->" + this.states[s] + ScaledTransitionElement.getArrowOption(nf, this.getTransitionProb(s, tc), this.getGraphvizEdgeWeight(s), this.arrowOptions[tc], graphical) + "\n");
                ++tc;
            }
            ++s;
        }
    }

    @Override
    protected void appendFurtherInformation(StringBuffer xml) {
        super.appendFurtherInformation(xml);
        XMLParser.appendObjectWithTags(xml, this.scalingFactor, "scalingFactor");
    }

    @Override
    protected void extractFurtherInformation(StringBuffer xml) throws NonParsableException {
        super.extractFurtherInformation(xml);
        this.scalingFactor = (double[])XMLParser.extractObjectForTags(xml, "scalingFactor");
    }

    @Override
    protected String getXMLTag() {
        return XML_TAG;
    }

    public String toString(String[] stateNames) {
        if (this.parameters.length > 0) {
            StringBuffer sb = new StringBuffer();
            String context = this.getContext(stateNames);
            int tc = 0;
            while (tc < this.scalingFactor.length) {
                String c = context;
                c = c.length() == 0 ? "|" : String.valueOf(c) + ",";
                c = String.valueOf(c) + "tc=" + tc;
                int i = 0;
                while (i < this.parameters.length) {
                    sb.append("P(" + this.getLabel(stateNames, this.states[i]) + c + ") \t= " + this.getTransitionProb(i, tc));
                    sb.append("\t");
                    ++i;
                }
                sb.append("\n");
                ++tc;
            }
            return sb.toString();
        }
        return "";
    }
}

