/*
 * Decompiled with CFR 0.152.
 */
package projects.taleningner;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.alphabets.DNAAlphabetContainer;
import de.jstacs.data.alphabets.DiscreteAlphabet;
import de.jstacs.data.sequences.IntSequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.WrongSequenceTypeException;
import de.jstacs.io.FileManager;
import de.jstacs.io.XMLParser;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.PFMComparator;
import java.util.HashSet;
import projects.tals.TALgetterDiffSM;

public class MonomerScoring {
    private HashSet<String> strong;
    private HashSet<String> weak;
    private TALgetterDiffSM model;
    private DiscreteAlphabet rvds;

    public MonomerScoring(AlphabetContainer rvds) throws Exception {
        this.rvds = (DiscreteAlphabet)rvds.getAlphabetAt(0);
        this.strong = new HashSet();
        this.strong.add("HD");
        this.strong.add("NN");
        this.weak = new HashSet();
        this.weak.add("NI");
        this.weak.add("NK");
        this.weak.add("NG");
        this.weak.add("N*");
        this.weak.add("SN");
        this.weak.add("SH");
        this.model = XMLParser.extractObjectForTags(FileManager.readFile("/Users/dev/Desktop/TAL-Chips/Designer/talgetter_training/retrained.xml"), "model", TALgetterDiffSM.class);
        this.check(this.model.getRVDAlphabet());
        this.model.fix();
    }

    public TALgetterDiffSM getModel() {
        return this.model;
    }

    public DiscreteAlphabet getRvds() {
        return this.rvds;
    }

    public double getMonomerEntropy(Sequence seq, boolean nucl) throws WrongAlphabetException, WrongSequenceTypeException {
        if (nucl) {
            this.check(seq.getAlphabetContainer());
            double[] scs = new double[seq.getLength() + 1];
            int[] bs = new int[seq.getLength() + 1];
            double total = this.model.getBestPossibleScore(seq, scs, bs);
            IntSequence temp = new IntSequence((AlphabetContainer)DNAAlphabetContainer.SINGLETON, bs);
            seq = temp.getSubSequence(1);
        }
        double en = 0.0;
        double[] counts = new double[(int)seq.getAlphabetContainer().getAlphabetLengthAt(0)];
        int i = 0;
        while (i < seq.getLength()) {
            int n = seq.discreteVal(i);
            counts[n] = counts[n] + 1.0;
            ++i;
        }
        i = 0;
        while (i < counts.length) {
            if (counts[i] > 0.0) {
                en -= counts[i] / (double)seq.getLength() * Math.log(counts[i] / (double)seq.getLength());
            }
            ++i;
        }
        return en;
    }

    public double getMutualInformation(Sequence seq, boolean nucl) throws WrongAlphabetException, WrongSequenceTypeException {
        if (nucl) {
            this.check(seq.getAlphabetContainer());
            double[] scs = new double[seq.getLength() + 1];
            int[] bs = new int[seq.getLength() + 1];
            double total = this.model.getBestPossibleScore(seq, scs, bs);
            IntSequence temp = new IntSequence((AlphabetContainer)DNAAlphabetContainer.SINGLETON, bs);
            seq = temp.getSubSequence(1);
        }
        double en = 0.0;
        double[][] counts = new double[(int)seq.getAlphabetContainer().getAlphabetLengthAt(0)][(int)seq.getAlphabetContainer().getAlphabetLengthAt(0)];
        double[] bord = new double[(int)seq.getAlphabetContainer().getAlphabetLengthAt(0)];
        int i = 0;
        while (i < seq.getLength()) {
            if (i < seq.getLength() - 1) {
                double[] dArray = counts[seq.discreteVal(i)];
                int n = seq.discreteVal(i + 1);
                dArray[n] = dArray[n] + 1.0;
            }
            int n = seq.discreteVal(i);
            bord[n] = bord[n] + 1.0;
            ++i;
        }
        i = 0;
        while (i < counts.length) {
            int j = 0;
            while (j < counts[i].length) {
                if (counts[i][j] > 0.0) {
                    en += counts[i][j] / (double)(seq.getLength() - 1) * Math.log(counts[i][j] / (double)(seq.getLength() - 1) / (bord[i] / (double)seq.getLength()) / (bord[j] / (double)seq.getLength()));
                }
                ++j;
            }
            ++i;
        }
        return en;
    }

    public double getDimerEntropy(Sequence seq, boolean nucl) throws WrongAlphabetException, WrongSequenceTypeException {
        if (nucl) {
            this.check(seq.getAlphabetContainer());
            double[] scs = new double[seq.getLength() + 1];
            int[] bs = new int[seq.getLength() + 1];
            double total = this.model.getBestPossibleScore(seq, scs, bs);
            IntSequence temp = new IntSequence((AlphabetContainer)DNAAlphabetContainer.SINGLETON, bs);
            seq = temp.getSubSequence(1);
        }
        double en = 0.0;
        double[][] counts = new double[(int)seq.getAlphabetContainer().getAlphabetLengthAt(0)][(int)seq.getAlphabetContainer().getAlphabetLengthAt(0)];
        int i = 0;
        while (i < seq.getLength() - 1) {
            double[] dArray = counts[seq.discreteVal(i)];
            int n = seq.discreteVal(i + 1);
            dArray[n] = dArray[n] + 1.0;
            ++i;
        }
        i = 0;
        while (i < counts.length) {
            int j = 0;
            while (j < counts[i].length) {
                if (counts[i][j] > 0.0) {
                    en -= counts[i][j] / (double)(seq.getLength() - 1) * Math.log(counts[i][j] / (double)(seq.getLength() - 1));
                }
                ++j;
            }
            ++i;
        }
        return en;
    }

    public double[] getValues(Sequence seq) throws Exception {
        double l = seq.getLength();
        double strongstart = this.getDistanceOfStrongFromStart(seq);
        double strongend = this.getDistanceOfStrongFromEnd(seq);
        double[] scs = this.getBestScores(seq);
        double[] nucs = this.getNucleotideContent(seq);
        double eps = 1.0E-6;
        double[] all = new double[]{((double)this.getNumberOfStrongRVDs(seq) + eps) / (l + 3.0 * eps), ((double)this.getNumberOfWeakRVDs(seq) + eps) / (l + 3.0 * eps), strongstart + eps, strongend + eps};
        eps = 1.0E-6;
        double maxdist = this.getMaximumDistanceBetweenStrong(seq);
        double maxstretch = this.getLongestStretch(seq);
        double[] cont = this.getNucleotideContent(seq);
        double[] cont2 = this.getNucleotideContent(seq.getSubSequence(0, 5));
        double[] cont3 = this.getNucleotideContent(seq.getSubSequence(seq.getLength() - 5, 5));
        return new double[]{l + eps, all[0], all[1], all[2], all[3], this.getMaximumWeakStretchLength(seq) + eps, cont[0], cont[1], cont[2], this.getMixedness(seq) + eps, scs[0] / l, scs[1], scs[2]};
    }

    private void check(AlphabetContainer totest) {
        if (totest.getNumberOfAlphabets() != 1 || !totest.getAlphabetAt(0).checkConsistency(this.rvds)) {
            throw new RuntimeException("Wrong alphabet");
        }
    }

    public int getNumberOfStrongRVDs(Sequence seq) {
        this.check(seq.getAlphabetContainer());
        int num = 0;
        int i = 0;
        while (i < seq.getLength()) {
            if (this.strong.contains(this.rvds.getSymbolAt(seq.discreteVal(i)))) {
                ++num;
            }
            ++i;
        }
        return num;
    }

    public int getNumberOfIntermediateRVDs(Sequence seq) {
        this.check(seq.getAlphabetContainer());
        int num = 0;
        int i = 0;
        while (i < seq.getLength()) {
            if (!this.strong.contains(this.rvds.getSymbolAt(seq.discreteVal(i))) && !this.weak.contains(this.rvds.getSymbolAt(seq.discreteVal(i)))) {
                ++num;
            }
            ++i;
        }
        return num;
    }

    public int getNumberOfWeakRVDs(Sequence seq) {
        this.check(seq.getAlphabetContainer());
        int num = 0;
        int i = 0;
        while (i < seq.getLength()) {
            if (this.weak.contains(this.rvds.getSymbolAt(seq.discreteVal(i)))) {
                ++num;
            }
            ++i;
        }
        return num;
    }

    public double getAverageDistanceBetweenStrong(Sequence seq) {
        this.check(seq.getAlphabetContainer());
        double avg = 0.0;
        double n = 0.0;
        int d = 0;
        int i = 0;
        while (i < seq.getLength()) {
            if (this.strong.contains(this.rvds.getSymbolAt(seq.discreteVal(i)))) {
                avg += (double)d;
                d = 0;
                n += 1.0;
            } else {
                ++d;
            }
            ++i;
        }
        if (d != 0) {
            avg += (double)d;
            n += 1.0;
        }
        return avg / n;
    }

    public double getAverageDistanceBetweenWeak(Sequence seq) {
        this.check(seq.getAlphabetContainer());
        double avg = 0.0;
        double n = 0.0;
        int d = 0;
        int i = 0;
        while (i < seq.getLength()) {
            if (this.weak.contains(this.rvds.getSymbolAt(seq.discreteVal(i)))) {
                avg += (double)d;
                d = 0;
                n += 1.0;
            } else {
                ++d;
            }
            ++i;
        }
        if (d != 0) {
            avg += (double)d;
            n += 1.0;
        }
        return avg / n;
    }

    public double getDistanceOfStrongFromStart(Sequence seq) {
        int i = 0;
        while (i < seq.getLength()) {
            if (this.strong.contains(this.rvds.getSymbolAt(seq.discreteVal(i)))) {
                return i;
            }
            ++i;
        }
        return seq.getLength();
    }

    public double getDistanceOfStrongFromEnd(Sequence seq) {
        int i = 0;
        while (i < seq.getLength()) {
            if (this.strong.contains(this.rvds.getSymbolAt(seq.discreteVal(seq.getLength() - i - 1)))) {
                return i;
            }
            ++i;
        }
        return seq.getLength();
    }

    public double getMaximumDistanceBetweenStrong(Sequence seq) {
        this.check(seq.getAlphabetContainer());
        double max = 0.0;
        int d = 0;
        int i = 0;
        while (i < seq.getLength()) {
            if (this.strong.contains(this.rvds.getSymbolAt(seq.discreteVal(i)))) {
                if ((double)d > max) {
                    max = d;
                }
                d = 0;
            } else {
                ++d;
            }
            ++i;
        }
        if ((double)d > max) {
            max = d;
        }
        return max;
    }

    public double getAverageWeakStretchLength(Sequence seq) {
        this.check(seq.getAlphabetContainer());
        double avg = 0.0;
        double n = 0.0;
        int d = 0;
        int i = 0;
        while (i < seq.getLength()) {
            if (this.weak.contains(this.rvds.getSymbolAt(seq.discreteVal(i)))) {
                ++d;
            } else {
                avg += (double)d;
                n += 1.0;
                d = 0;
            }
            ++i;
        }
        if (d != 0) {
            avg += (double)d;
            n += 1.0;
        }
        return avg / n;
    }

    public double getMaximumWeakStretchLength(Sequence seq) {
        this.check(seq.getAlphabetContainer());
        double max = 0.0;
        int d = 0;
        int i = 0;
        while (i < seq.getLength()) {
            if (this.weak.contains(this.rvds.getSymbolAt(seq.discreteVal(i)))) {
                ++d;
            } else {
                if ((double)d > max) {
                    max = d;
                }
                d = 0;
            }
            ++i;
        }
        if ((double)d > max) {
            max = d;
        }
        return max;
    }

    public double getMaximumStrongStretchLength(Sequence seq) {
        this.check(seq.getAlphabetContainer());
        double max = 0.0;
        int d = 0;
        int i = 0;
        while (i < seq.getLength()) {
            if (this.strong.contains(this.rvds.getSymbolAt(seq.discreteVal(i)))) {
                ++d;
            } else {
                if ((double)d > max) {
                    max = d;
                }
                d = 0;
            }
            ++i;
        }
        if ((double)d > max) {
            max = d;
        }
        return max;
    }

    public double getMixedness(Sequence seq) {
        double stretch = 0.0;
        double n = 1.0;
        double mean = 0.0;
        int i = 1;
        while (i < seq.getLength()) {
            if (seq.discreteVal(i) == seq.discreteVal(i - 1)) {
                n += 1.0;
            } else {
                mean += n;
                stretch += 1.0;
                n = 1.0;
            }
            ++i;
        }
        return (mean += n) / (stretch += 1.0);
    }

    public double getMixedness2(Sequence seq) {
        double m = 0.0;
        double n = 0.0;
        int i = 0;
        while (i < seq.getLength()) {
            block3: {
                int j = i + 1;
                while (j < seq.getLength()) {
                    if (seq.discreteVal(i) == seq.discreteVal(j)) {
                        m += (double)(j - i);
                        n += 1.0;
                        break block3;
                    }
                    ++j;
                }
                m += (double)seq.getLength();
                n += 1.0;
            }
            ++i;
        }
        System.out.println(m / n);
        return m / n / (double)seq.getLength();
    }

    public double getLongestStretch(Sequence seq) {
        double stretch = 0.0;
        double n = 1.0;
        int i = 1;
        while (i < seq.getLength()) {
            if (seq.discreteVal(i) == seq.discreteVal(i - 1)) {
                n += 1.0;
            } else {
                if (n > stretch) {
                    stretch = n;
                }
                n = 1.0;
            }
            ++i;
        }
        return stretch;
    }

    public double[] getBestScores(Sequence seq) {
        this.check(seq.getAlphabetContainer());
        double[] scs = new double[seq.getLength() + 1];
        int[] bs = new int[seq.getLength() + 1];
        double total = this.model.getBestPossibleScore(seq, scs, bs);
        total = this.model.getPartialLogScoreFor(seq, bs, 0, 0, bs.length);
        double first = this.model.getPartialLogScoreFor(seq, bs, 0, 1, 5);
        double first10 = this.model.getPartialLogScoreFor(seq, bs, 0, 1, 10);
        double last = this.model.getPartialLogScoreFor(seq, bs, 0, bs.length - 5, 5);
        return new double[]{total, first, last, first10};
    }

    public double[] getNucleotideContent(Sequence seq) throws Exception {
        this.check(seq.getAlphabetContainer());
        double[] scs = new double[seq.getLength() + 1];
        int[] bs = new int[seq.getLength() + 1];
        double total = this.model.getBestPossibleScore(seq, scs, bs);
        IntSequence temp = new IntSequence((AlphabetContainer)DNAAlphabetContainer.SINGLETON, bs);
        double[] c = PFMComparator.getCounts(new DataSet("", temp.getSubSequence(1)));
        int i = 0;
        while (i < c.length) {
            c[i] = c[i] + 1.0E-6;
            ++i;
        }
        Normalisation.sumNormalisation(c);
        return c;
    }
}

