/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.models.utils;

import de.jstacs.WrongAlphabetException;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.data.sequences.IntSequence;
import de.jstacs.data.sequences.WrongSequenceTypeException;
import de.jstacs.models.Model;
import java.io.IOException;

public class ModelTester {
    public static double getKLDivergence(Model m1, Model m2, int length) throws Exception {
        SeqIterator s = new SeqIterator(m1.getAlphabetContainer(), length);
        double kl = 0.0;
        do {
            Sequence seq = s.getSequence();
            double v = m1.getLogProbFor(seq);
            kl += Math.exp(v) * (v - m2.getLogProbFor(seq));
        } while (s.next());
        return kl;
    }

    public static double getSymKLDivergence(Model m1, Model m2, int length) throws Exception {
        SeqIterator s = new SeqIterator(m1.getAlphabetContainer(), length);
        double kl = 0.0;
        do {
            Sequence seq = s.getSequence();
            double p1 = m1.getProbFor(seq);
            double p2 = m2.getProbFor(seq);
            kl += (p1 - p2) * Math.log(p1 / p2);
        } while (s.next());
        return kl;
    }

    public static double getLogLikelihood(Model m, Sample data) throws Exception {
        return ModelTester.getLogLikelihood(m, data, null);
    }

    public static double getLogLikelihood(Model m, Sample data, double[] weights) throws Exception {
        int d = data.getNumberOfElements();
        double erg = 0.0;
        Sample.ElementEnumerator ei = new Sample.ElementEnumerator(data);
        if (weights == null) {
            for (int counter = 0; counter < d; ++counter) {
                erg += m.getLogProbFor(ei.nextElement());
            }
        } else {
            if (d != weights.length) {
                throw new IllegalArgumentException("The weights and the sample does not match.");
            }
            for (int counter = 0; counter < d; ++counter) {
                erg += weights[counter] * m.getLogProbFor(ei.nextElement());
            }
        }
        return erg;
    }

    public static double getMarginalDistribution(Model m, int[] constraint) throws Exception {
        if (m.getLength() != 0 && m.getLength() != constraint.length) {
            throw new IOException("This model can only classify sequences of length " + m.getLength() + ".");
        }
        double erg = 0.0;
        SeqIterator s = new SeqIterator(m.getAlphabetContainer(), constraint.length);
        do {
            if (!s.isSatisfied(constraint)) continue;
            erg += m.getProbFor(s.getSequence());
        } while (s.next());
        return erg;
    }

    public static double getMaxOfDeviation(Model m1, Model m2, int length) throws Exception {
        if (m1.getLength() != 0 && m1.getLength() != length) {
            throw new IOException("The model m1 can only classify sequences of length " + m1.getLength() + ".");
        }
        if (m2.getLength() != 0 && m2.getLength() != length) {
            throw new IOException("This model m2 can only classify sequences of length " + m2.getLength() + ".");
        }
        if (!m1.getAlphabetContainer().checkConsistency(m2.getAlphabetContainer())) {
            throw new IOException("The models are training on different alphabets.");
        }
        double max = 0.0;
        SeqIterator s = new SeqIterator(m1.getAlphabetContainer(), length);
        do {
            Sequence seq;
            double p;
            if (!((p = Math.abs(m1.getProbFor(seq = s.getSequence()) - m2.getProbFor(seq, 0, s.last))) > max)) continue;
            max = p;
        } while (s.next());
        return max;
    }

    public static Sequence getMostProbableSequence(Model m, int length) throws Exception {
        SeqIterator s = new SeqIterator(m.getAlphabetContainer(), length);
        Sequence seq = s.getSequence();
        double pmax = m.getLogProbFor(seq);
        while (s.next()) {
            Sequence current = s.getSequence();
            double p = m.getLogProbFor(current);
            if (!(p > pmax)) continue;
            pmax = p;
            seq = current;
        }
        return seq;
    }

    public static double getShannonEntropy(Model m, int length) throws Exception {
        if (m.getLength() != 0 && m.getLength() != length) {
            throw new IOException("This model can only classify sequences of length " + m.getLength() + ".");
        }
        double erg = 0.0;
        SeqIterator s = new SeqIterator(m.getAlphabetContainer(), length);
        do {
            double p;
            if ((p = m.getProbFor(s.getSequence())) != 0.0) {
                erg -= p * Math.log(p);
            }
            if (!(p < 0.0) && !(p > 1.0)) continue;
            throw new IOException("The probability of sequence " + s.getSequence() + " is not correct (" + p + ").");
        } while (s.next());
        return erg;
    }

    public static double getShannonEntropyInBits(Model m, int length) throws Exception {
        return ModelTester.getShannonEntropy(m, length) / Math.log(2.0);
    }

    public static double getSumOfDeviation(Model m1, Model m2, int length) throws Exception {
        if (m1.getLength() != 0 && m1.getLength() != length) {
            throw new IOException("The model m1 can only classify sequences of length " + m1.getLength() + ".");
        }
        if (m2.getLength() != 0 && m2.getLength() != length) {
            throw new IOException("This model m2 can only classify sequences of length " + m2.getLength() + ".");
        }
        if (!m1.getAlphabetContainer().checkConsistency(m2.getAlphabetContainer())) {
            throw new IOException("The models are training on different alphabets.");
        }
        double sum = 0.0;
        SeqIterator s = new SeqIterator(m1.getAlphabetContainer(), length);
        do {
            Sequence seq = s.getSequence();
            sum += Math.abs(m1.getProbFor(seq, 0, s.last) - m2.getProbFor(seq, 0, s.last));
        } while (s.next());
        return sum;
    }

    public static double getSumOfDistribution(Model m, int length) throws Exception {
        if (m.getLength() != 0 && m.getLength() != length) {
            throw new IOException("This model can only classify sequences of length " + m.getLength() + ".");
        }
        double erg = 0.0;
        SeqIterator s = new SeqIterator(m.getAlphabetContainer(), length);
        do {
            double p;
            if ((p = m.getProbFor(s.getSequence())) < 0.0 || p > 1.0) {
                throw new IOException("The probability (" + p + ") for sequence \"" + s.getSequence() + "\" is not in [0,1].");
            }
            erg += p;
        } while (s.next());
        return erg;
    }

    public static double getValueOfAIC(Model m, Sample s, int k) throws Exception {
        return 2.0 * ModelTester.getLogLikelihood(m, s) - (double)(2 * k);
    }

    public static double getValueOfBIC(Model m, Sample s, int k) throws Exception {
        return 2.0 * ModelTester.getLogLikelihood(m, s) - (double)k * StrictMath.log(s.getNumberOfElements());
    }

    private static class SeqIterator {
        private int[] seq;
        private boolean simple;
        private int[] a;
        private int l;
        private int last;
        private AlphabetContainer abc;

        public SeqIterator(AlphabetContainer abc, int length) throws IllegalArgumentException {
            if (!abc.isDiscrete()) {
                throw new IllegalArgumentException("The model is not discrete.");
            }
            this.abc = abc;
            this.simple = abc.isSimple();
            this.a = new int[(this.simple ? 1 : length) + 1];
            for (int i = 0; i < this.a.length - 1; ++i) {
                this.a[i] = (int)abc.getAlphabetLengthAt(i) - 1;
            }
            this.a[i] = 1;
            this.l = length;
            this.last = this.l - 1;
            this.seq = new int[length + 1];
        }

        public boolean next() {
            int s_index = 0;
            while (this.seq[s_index] == this.a[this.simple ? 0 : s_index]) {
                this.seq[s_index++] = 0;
            }
            int n = s_index;
            this.seq[n] = this.seq[n] + 1;
            return this.seq[this.l] == 0;
        }

        private boolean isSatisfied(int[] constr) {
            int i;
            for (i = 0; i < constr.length && (constr[i] == -1 || constr[i] == this.seq[i]); ++i) {
            }
            return i == constr.length;
        }

        public Sequence getSequence() throws WrongAlphabetException, WrongSequenceTypeException {
            return new IntSequence(this.abc, this.seq, 0, this.l);
        }
    }
}

