/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements;

import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.ReferenceSequenceAnnotation;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.trainable.hmm.transitions.elements.ReferenceBasedTransitionElement;
import java.text.NumberFormat;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Hashtable;

public class DistanceBasedScaledTransitionElement
extends ReferenceBasedTransitionElement {
    protected String arrowOptions;
    protected double scalingFactor;
    protected double[] statisticsTransitionProb;
    protected Hashtable<Sequence, double[]> diagonalWeights;
    private static final String XML_TAG = "DISTANCE_BASED_SCALED_TRANSITION_ELEMENT";

    public DistanceBasedScaledTransitionElement(int[] context, int[] states, double[] probabilities, double ess, double scalingFactor, String annotationID) {
        this(context, states, probabilities, ess, scalingFactor, annotationID, null);
    }

    public DistanceBasedScaledTransitionElement(int[] context, int[] states, double[] probabilities, double ess, double scalingFactor, String annotationID, double[] weight) {
        super(context, states, ess, probabilities, annotationID, weight);
        this.scalingFactor = scalingFactor;
        this.init();
    }

    @Override
    protected void init() {
        super.init();
        this.statisticsTransitionProb = new double[this.states.length];
        this.diagonalWeights = new Hashtable();
        this.arrowOptions = "style=\"setlinewidth(1)\"";
    }

    public DistanceBasedScaledTransitionElement(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public DistanceBasedScaledTransitionElement clone() throws CloneNotSupportedException {
        DistanceBasedScaledTransitionElement clone = (DistanceBasedScaledTransitionElement)super.clone();
        clone.scalingFactor = this.scalingFactor;
        return clone;
    }

    protected double getIndex(int pos, Sequence seq) {
        return ((ReferenceSequenceAnnotation)seq.getSequenceAnnotationByTypeAndIdentifier("reference", this.annotationID)).getReferenceSequence().discreteVal(pos);
    }

    @Override
    public void addToStatistic(int childIdx, double weight, Sequence sequence, int sequencePosition) {
        int n = childIdx;
        this.statisticsTransitionProb[n] = this.statisticsTransitionProb[n] + weight;
        if (sequencePosition == 0) {
            System.out.println("Hier seqPos = 0");
        }
        if (childIdx == this.diagElement) {
            boolean contained = this.diagonalWeights.containsKey(sequence);
            if (!contained) {
                this.diagonalWeights.put(sequence, new double[sequence.getLength()]);
            }
            double[] epsilons = this.diagonalWeights.get(sequence);
            epsilons[sequencePosition] = weight;
        }
    }

    @Override
    public void resetStatistic() {
        for (int j = 0; j < this.states.length; ++j) {
            this.statisticsTransitionProb[j] = 0.0;
        }
        if (!this.diagonalWeights.isEmpty()) {
            Enumeration<Sequence> keyIterator = this.diagonalWeights.keys();
            do {
                Sequence seq = keyIterator.nextElement();
                double[] diagEpsilons = this.diagonalWeights.get(seq);
                Arrays.fill(diagEpsilons, 0.0);
            } while (keyIterator.hasMoreElements());
        }
        super.resetStatistic();
    }

    @Override
    public void estimateFromStatistic() {
        double sumOfGammas = 0.0;
        for (int i = 0; i < this.states.length; ++i) {
            int n = i;
            this.statistic[n] = this.statistic[n] + this.statisticsTransitionProb[i];
            sumOfGammas += this.statistic[i];
        }
        this.logNorm = 0.0;
        this.parameters[this.diagElement] = this.determineDiagonalElement(sumOfGammas, 0.001, true, false);
        for (int j = 0; j < this.states.length; ++j) {
            if (j == this.diagElement) continue;
            this.parameters[j] = (1.0 - this.parameters[this.diagElement]) * this.statistic[j] / (sumOfGammas - this.statistic[this.diagElement]);
        }
    }

    private double determineDiagonalElement(double sumOfGammas, double currentDiagonalElement, boolean firstStep, boolean output) {
        double distanceBasedScalingFactor;
        double tc;
        int t;
        double[] diagEpsilons;
        int T;
        Sequence seq;
        String stateSymbol = "" + this.diagElement;
        if (firstStep && output) {
            System.out.println("<------ Start '" + stateSymbol + "' ------>");
            System.out.println("\tdiag = " + currentDiagonalElement);
        }
        Enumeration<Sequence> keyIterator = this.diagonalWeights.keys();
        double variablePart = 0.0;
        double variablePartDerivative = 0.0;
        do {
            seq = keyIterator.nextElement();
            T = seq.getLength();
            diagEpsilons = this.diagonalWeights.get(seq);
            for (t = this.context.length; t < T; ++t) {
                tc = this.getIndex(t, seq);
                distanceBasedScalingFactor = 1.0 + (this.scalingFactor - 1.0) * (1.0 - tc);
                variablePart += distanceBasedScalingFactor / (currentDiagonalElement - 1.0 + distanceBasedScalingFactor) * diagEpsilons[t];
                variablePartDerivative -= distanceBasedScalingFactor / Math.pow(currentDiagonalElement - 1.0 + distanceBasedScalingFactor, 2.0) * diagEpsilons[t];
            }
        } while (keyIterator.hasMoreElements());
        double nextDiagonalElement = currentDiagonalElement - ((variablePart += this.hyperParameters[this.diagElement] / currentDiagonalElement) - sumOfGammas) / (variablePartDerivative -= this.hyperParameters[this.diagElement] / Math.pow(currentDiagonalElement, 2.0));
        keyIterator = this.diagonalWeights.keys();
        variablePart = 0.0;
        do {
            seq = keyIterator.nextElement();
            T = seq.getLength();
            diagEpsilons = this.diagonalWeights.get(seq);
            for (t = this.context.length; t < T; ++t) {
                tc = this.getIndex(t, seq);
                distanceBasedScalingFactor = 1.0 + (this.scalingFactor - 1.0) * (1.0 - tc);
                variablePart += distanceBasedScalingFactor / (nextDiagonalElement - 1.0 + distanceBasedScalingFactor) * diagEpsilons[t];
            }
        } while (keyIterator.hasMoreElements());
        variablePart += this.hyperParameters[this.diagElement] / nextDiagonalElement;
        if (output) {
            System.out.println("\tdiag = " + nextDiagonalElement + "\t\tf = " + (variablePart - sumOfGammas));
        }
        if (Math.abs(nextDiagonalElement - currentDiagonalElement) > 1.0E-10) {
            nextDiagonalElement = this.determineDiagonalElement(sumOfGammas, nextDiagonalElement, false, output);
        }
        if (firstStep && output) {
            System.out.println("<------ End   '" + stateSymbol + "' ------>");
        }
        return nextDiagonalElement;
    }

    @Override
    public double getLogScoreFor(int state, Sequence sequence, int sequencePosition) {
        return Math.log(this.getTransitionProb(state, this.getIndex(sequencePosition, sequence)));
    }

    private double getTransitionProb(int j, double tc) {
        double distanceBasedScalingFactor = 1.0 + (this.scalingFactor - 1.0) * (1.0 - tc);
        double transitionProb = j == this.diagElement ? (this.parameters[j] - 1.0 + distanceBasedScalingFactor) / distanceBasedScalingFactor : this.parameters[j] / distanceBasedScalingFactor;
        return transitionProb;
    }

    @Override
    protected void appendTransitions(StringBuffer res, String contextNodeRepresentation, NumberFormat nf, String arrowOption, boolean graphical) {
        for (int s = 0; s < this.states.length; ++s) {
            res.append("\t" + contextNodeRepresentation + "->" + this.states[s] + DistanceBasedScaledTransitionElement.getArrowOption(nf, this.parameters[s], this.getGraphvizEdgeWeight(s), "\n", graphical));
        }
    }

    @Override
    protected void appendFurtherInformation(StringBuffer xml) {
        super.appendFurtherInformation(xml);
        XMLParser.appendObjectWithTags(xml, this.scalingFactor, "scalingFactor");
    }

    @Override
    protected void extractFurtherInformation(StringBuffer xml) throws NonParsableException {
        super.extractFurtherInformation(xml);
        this.scalingFactor = (Double)XMLParser.extractObjectForTags(xml, "scalingFactor");
    }

    @Override
    protected String getXMLTag() {
        return XML_TAG;
    }

    @Override
    public String toString(String[] stateNames) {
        if (this.parameters.length > 0) {
            StringBuffer sb = new StringBuffer();
            String context = this.getContext(stateNames);
            String c = context;
            if (c.length() == 0) {
                c = "|";
            }
            for (int i = 0; i < this.parameters.length; ++i) {
                sb.append("P(" + this.getLabel(stateNames, this.states[i]) + c + ") \t= " + this.getTransitionProb(i, 1.0));
                sb.append("\t");
            }
            sb.append("\n");
            return sb.toString();
        }
        return "";
    }
}

