/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.utils;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.sequences.IntSequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.WrongSequenceTypeException;
import de.jstacs.sequenceScores.SequenceScore;
import de.jstacs.sequenceScores.statisticalModels.StatisticalModel;
import de.jstacs.utils.Normalisation;
import java.io.IOException;
import java.util.Arrays;

public class StatisticalModelTester {
    public static double getKLDivergence(StatisticalModel m1, StatisticalModel m2, int length) throws Exception {
        SeqIterator s = new SeqIterator(m1.getAlphabetContainer(), length);
        double kl = 0.0;
        do {
            Sequence seq = s.getSequence();
            double v = m1.getLogProbFor(seq);
            kl += Math.exp(v) * (v - m2.getLogProbFor(seq));
        } while (s.next());
        return kl;
    }

    public static double getSymKLDivergence(StatisticalModel m1, StatisticalModel m2, int length) throws Exception {
        SeqIterator s = new SeqIterator(m1.getAlphabetContainer(), length);
        double kl = 0.0;
        do {
            Sequence seq = s.getSequence();
            double logP1 = m1.getLogProbFor(seq);
            double logP2 = m2.getLogProbFor(seq);
            kl += (Math.exp(logP1) - Math.exp(logP2)) * (logP1 - logP2);
        } while (s.next());
        return kl;
    }

    public static double getLogLikelihood(StatisticalModel m, DataSet data) throws Exception {
        return StatisticalModelTester.getLogLikelihood(m, data, null);
    }

    public static double getLogLikelihood(StatisticalModel m, DataSet data, double[] weights) throws Exception {
        int d = data.getNumberOfElements();
        double erg = 0.0;
        DataSet.ElementEnumerator ei = new DataSet.ElementEnumerator(data);
        if (weights == null) {
            int counter = 0;
            while (counter < d) {
                erg += m.getLogProbFor(ei.nextElement());
                ++counter;
            }
        } else {
            if (d != weights.length) {
                throw new IllegalArgumentException("The weights and the data set does not match.");
            }
            int counter = 0;
            while (counter < d) {
                erg += weights[counter] * m.getLogProbFor(ei.nextElement());
                ++counter;
            }
        }
        return erg;
    }

    public static double[] getMarginalDistribution(StatisticalModel m, int[] ... constraint) throws Exception {
        int i;
        int l = constraint[0].length;
        int len = m.getLength();
        int i2 = 0;
        while (i2 < constraint.length) {
            if (l != constraint[i2].length || len != 0 && len != constraint[i2].length) {
                throw new IOException("This model can only classify sequences of length " + m.getLength() + ".");
            }
            ++i2;
        }
        double[] erg = new double[constraint.length];
        Arrays.fill(erg, Double.NEGATIVE_INFINITY);
        SeqIterator s = new SeqIterator(m.getAlphabetContainer(), l);
        do {
            i = 0;
            while (i < constraint.length) {
                if (s.isSatisfied(constraint[i])) {
                    erg[i] = Normalisation.getLogSum(erg[i], m.getLogProbFor(s.getSequence()));
                }
                ++i;
            }
        } while (s.next());
        i = 0;
        while (i < constraint.length) {
            erg[i] = Math.exp(erg[i]);
            ++i;
        }
        return erg;
    }

    public static double getMaxOfDeviation(StatisticalModel m1, StatisticalModel m2, int length) throws Exception {
        if (m1.getLength() != 0 && m1.getLength() != length) {
            throw new IOException("The model m1 can only classify sequences of length " + m1.getLength() + ".");
        }
        if (m2.getLength() != 0 && m2.getLength() != length) {
            throw new IOException("This model m2 can only classify sequences of length " + m2.getLength() + ".");
        }
        if (!m1.getAlphabetContainer().checkConsistency(m2.getAlphabetContainer())) {
            throw new IOException("The models are training on different alphabets.");
        }
        double max = 0.0;
        SeqIterator s = new SeqIterator(m1.getAlphabetContainer(), length);
        do {
            Sequence seq;
            double p;
            if (!((p = Math.abs(Math.exp(m1.getLogProbFor(seq = s.getSequence(), 0, s.last)) - Math.exp(m2.getLogProbFor(seq, 0, s.last)))) > max)) continue;
            max = p;
        } while (s.next());
        return max;
    }

    public static Sequence getMostProbableSequence(SequenceScore m, int length) throws Exception {
        SeqIterator s = new SeqIterator(m.getAlphabetContainer(), length);
        Sequence seq = s.getSequence();
        double pmax = m.getLogScoreFor(seq);
        while (s.next()) {
            Sequence current = s.getSequence();
            double p = m.getLogScoreFor(current);
            if (!(p > pmax)) continue;
            pmax = p;
            seq = current;
        }
        return seq;
    }

    public static double getShannonEntropy(StatisticalModel m, int length) throws Exception {
        if (m.getLength() != 0 && m.getLength() != length) {
            throw new IOException("This model can only classify sequences of length " + m.getLength() + ".");
        }
        double erg = 0.0;
        SeqIterator s = new SeqIterator(m.getAlphabetContainer(), length);
        do {
            double logP;
            if (!Double.isInfinite(logP = m.getLogProbFor(s.getSequence()))) {
                erg -= Math.exp(logP) * logP;
            }
            if (!(logP > 0.0)) continue;
            throw new IOException("The probability of sequence " + s.getSequence() + " is not correct (" + Math.exp(logP) + ").");
        } while (s.next());
        return erg;
    }

    public static double getShannonEntropyInBits(StatisticalModel m, int length) throws Exception {
        return StatisticalModelTester.getShannonEntropy(m, length) / Math.log(2.0);
    }

    public static double getSumOfDeviation(StatisticalModel m1, StatisticalModel m2, int length) throws Exception {
        if (m1.getLength() != 0 && m1.getLength() != length) {
            throw new IOException("The model m1 can only classify sequences of length " + m1.getLength() + ".");
        }
        if (m2.getLength() != 0 && m2.getLength() != length) {
            throw new IOException("This model m2 can only classify sequences of length " + m2.getLength() + ".");
        }
        if (!m1.getAlphabetContainer().checkConsistency(m2.getAlphabetContainer())) {
            throw new IOException("The models are training on different alphabets.");
        }
        double sum = 0.0;
        SeqIterator s = new SeqIterator(m1.getAlphabetContainer(), length);
        do {
            Sequence seq = s.getSequence();
            sum += Math.abs(Math.exp(m1.getLogProbFor(seq, 0, s.last)) - Math.exp(m2.getLogProbFor(seq, 0, s.last)));
        } while (s.next());
        return sum;
    }

    public static double getSumOfDistribution(StatisticalModel m, int length) throws Exception {
        if (m.getLength() != 0 && m.getLength() != length) {
            throw new IOException("This model can only classify sequences of length " + m.getLength() + ".");
        }
        double erg = Double.NEGATIVE_INFINITY;
        SeqIterator s = new SeqIterator(m.getAlphabetContainer(), length);
        do {
            double p;
            if ((p = m.getLogProbFor(s.getSequence())) > 0.0) {
                throw new IOException("The probability (" + Math.exp(p) + ") for sequence \"" + s.getSequence() + "\" is not in [0,1].");
            }
            erg = Normalisation.getLogSum(erg, p);
        } while (s.next());
        return Math.exp(erg);
    }

    public static double getValueOfAIC(StatisticalModel m, DataSet s, int k) throws Exception {
        return 2.0 * StatisticalModelTester.getLogLikelihood(m, s) - (double)(2 * k);
    }

    public static double getValueOfBIC(StatisticalModel m, DataSet s, int k) throws Exception {
        return 2.0 * StatisticalModelTester.getLogLikelihood(m, s) - (double)k * StrictMath.log(s.getNumberOfElements());
    }

    private static class SeqIterator {
        private int[] seq;
        private boolean simple;
        private int[] a;
        private int l;
        private int last;
        private AlphabetContainer abc;

        private SeqIterator(AlphabetContainer abc, int length) throws IllegalArgumentException {
            if (!abc.isDiscrete()) {
                throw new IllegalArgumentException("The model is not discrete.");
            }
            this.abc = abc;
            this.simple = abc.isSimple();
            this.a = new int[(this.simple ? 1 : length) + 1];
            int i = 0;
            while (i < this.a.length - 1) {
                this.a[i] = (int)abc.getAlphabetLengthAt(i) - 1;
                ++i;
            }
            this.a[i] = 1;
            this.l = length;
            this.last = this.l - 1;
            this.seq = new int[length + 1];
        }

        private boolean next() {
            int s_index = 0;
            while (this.seq[s_index] == this.a[this.simple ? 0 : s_index]) {
                this.seq[s_index++] = 0;
            }
            int n = s_index;
            this.seq[n] = this.seq[n] + 1;
            return this.seq[this.l] == 0;
        }

        private boolean isSatisfied(int[] constr) {
            int i = 0;
            while (i < constr.length && (constr[i] == -1 || constr[i] == this.seq[i])) {
                ++i;
            }
            return i == constr.length;
        }

        private Sequence getSequence() throws WrongAlphabetException, WrongSequenceTypeException {
            return new IntSequence(this.abc, this.seq, 0, this.l);
        }
    }
}

