/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.clustering.distances;

import de.jstacs.clustering.hierachical.ClusterTree;
import de.jstacs.data.DeBruijnGraphSequenceGenerator;
import de.jstacs.data.alphabets.DNAAlphabet;
import de.jstacs.data.alphabets.DNAAlphabetContainer;
import de.jstacs.data.alphabets.DiscreteAlphabet;
import de.jstacs.data.sequences.CyclicSequenceAdaptor;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.sequenceScores.statisticalModels.StatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.trainable.PFMWrapperTrainSM;
import de.jstacs.utils.PFMComparator;
import de.jstacs.utils.Pair;

public class DeBruijnMotifComparison {
    private static double[] getStatistics(double[] profile) {
        double sum = 0.0;
        double sq = 0.0;
        for (int i = 0; i < profile.length; ++i) {
            sum += profile[i];
            sq += profile[i] * profile[i];
        }
        return new double[]{sum, sq};
    }

    private static double getCross(double[] profile1, double[] profile2, int shift) {
        double cr = 0.0;
        for (int i = 0; i < profile1.length; ++i) {
            if (i + shift >= profile2.length) {
                cr += profile1[i] * profile2[i + shift - profile2.length];
                continue;
            }
            cr += profile1[i] * profile2[i + shift];
        }
        return cr;
    }

    public static Pair<Integer, Double> compare(double[] profile1, double[] profile2, int maxShift) {
        double[] fullStat1 = DeBruijnMotifComparison.getStatistics(profile1);
        double[] fullStat2 = DeBruijnMotifComparison.getStatistics(profile2);
        double fac = Math.sqrt((fullStat1[1] - 1.0 / (double)profile1.length * fullStat1[0] * fullStat1[0]) * (fullStat2[1] - 1.0 / (double)profile1.length * fullStat2[0] * fullStat2[0]));
        double max = Double.NEGATIVE_INFINITY;
        int maxOff = 0;
        for (int i = 0; i <= maxShift; ++i) {
            double cr = DeBruijnMotifComparison.getCross(profile1, profile2, i);
            double corPlus = cr - 1.0 / (double)profile1.length * fullStat1[0] * fullStat2[0];
            corPlus /= fac;
            cr = DeBruijnMotifComparison.getCross(profile2, profile1, i);
            double corMinus = cr - 1.0 / (double)profile1.length * fullStat1[0] * fullStat2[0];
            corMinus /= fac;
            if (corPlus > max) {
                max = corPlus;
                maxOff = i;
            }
            if (!(corMinus > max)) continue;
            max = corMinus;
            maxOff = -i;
        }
        return new Pair<Integer, Double>(maxOff, max);
    }

    public static double[][] getProfilesForMotif(StatisticalModel model, int n, boolean revcom, boolean exp) throws Exception {
        CyclicSequenceAdaptor[] ad = DeBruijnGraphSequenceGenerator.generate((DiscreteAlphabet)model.getAlphabetContainer().getAlphabetAt(0), n);
        return DeBruijnMotifComparison.getProfilesForMotif(ad, model, revcom, exp);
    }

    public static double[][] getProfilesForMotif(CyclicSequenceAdaptor[] ad, StatisticalModel model, boolean revcom, boolean exp) throws Exception {
        double[][] profiles = new double[ad.length][];
        for (int i = 0; i < ad.length; ++i) {
            Sequence seq = ad[i];
            if (revcom) {
                seq = ((CyclicSequenceAdaptor)seq).reverseComplement();
            }
            int origLength = ((CyclicSequenceAdaptor)seq).getLength();
            seq = ((CyclicSequenceAdaptor)seq).getSuperSequence(((CyclicSequenceAdaptor)seq).getLength() + model.getLength() - 1);
            profiles[i] = new double[origLength];
            for (int j = 0; j < origLength; ++j) {
                profiles[i][j] = revcom ? (j + model.getLength() < origLength + 1 ? model.getLogProbFor(seq, origLength - j - model.getLength()) : model.getLogProbFor(seq, ((CyclicSequenceAdaptor)seq).getLength() - j - 1)) : model.getLogProbFor(seq, j);
                if (!exp) continue;
                profiles[i][j] = Math.exp(profiles[i][j]);
            }
        }
        return profiles;
    }

    public static Pair<double[][], int[][]> getClusterRepresentative(ClusterTree<StatisticalModel> tree, int n) throws Exception {
        if (tree.getNumberOfElements() == 1) {
            if (tree.getClusterElements()[0] instanceof PFMWrapperTrainSM) {
                return new Pair<double[][], int[][]>(((PFMWrapperTrainSM)tree.getClusterElements()[0]).getPWM(), new int[][]{{0, 1}});
            }
            return null;
        }
        int[][] shifts = new int[tree.getNumberOfElements()][2];
        ClusterTree<StatisticalModel>[] subs = tree.getSubTrees();
        double[][][] reps = new double[subs.length][][];
        int[][][] prevShifts = new int[subs.length][][];
        for (int i = 0; i < subs.length; ++i) {
            Pair<double[][], int[][]> pair = DeBruijnMotifComparison.getClusterRepresentative(subs[i], n);
            reps[i] = pair.getFirstElement();
            prevShifts[i] = pair.getSecondElement();
        }
        double[][] rep = reps[0];
        int g = 0;
        int minPrevShift = Integer.MAX_VALUE;
        int i = 0;
        while (i < prevShifts[0].length) {
            shifts[g] = (int[])prevShifts[0][i].clone();
            if (shifts[g][0] < minPrevShift) {
                minPrevShift = shifts[g][0];
            }
            ++i;
            ++g;
        }
        g = 0;
        i = 0;
        while (i < prevShifts[0].length) {
            int[] nArray = shifts[g];
            nArray[0] = nArray[0] - minPrevShift;
            ++i;
            ++g;
        }
        double n1 = subs[0].getNumberOfElements();
        for (int i2 = 1; i2 < reps.length; ++i2) {
            int j;
            PFMWrapperTrainSM model = new PFMWrapperTrainSM(DNAAlphabetContainer.SINGLETON, null, rep, 0.0);
            double[][] prof1 = DeBruijnMotifComparison.getProfilesForMotif(model, n, false, false);
            PFMWrapperTrainSM model2 = new PFMWrapperTrainSM(DNAAlphabetContainer.SINGLETON, null, reps[i2], 0.0);
            double[][] prof2 = DeBruijnMotifComparison.getProfilesForMotif(model2, n, false, false);
            double[][] prof2Rc = DeBruijnMotifComparison.getProfilesForMotif(model2, n, true, false);
            Pair<Integer, Double> fwd = DeBruijnMotifComparison.compare(prof1[0], prof2[0], Math.max(rep.length - (int)Math.floor(reps[i2].length), reps[i2].length - (int)Math.floor(reps.length)));
            Pair<Integer, Double> rev = DeBruijnMotifComparison.compare(prof1[0], prof2Rc[0], Math.max(rep.length - (int)Math.floor(reps[i2].length), reps[i2].length - (int)Math.floor(reps.length)));
            int shift = fwd.getFirstElement();
            int rc = 1;
            double[][] mat = (double[][])ArrayHandler.clone((Cloneable[])reps[i2]);
            if (fwd.getSecondElement() < rev.getSecondElement()) {
                shift = rev.getFirstElement();
                rc = -1;
                mat = PFMComparator.getReverseComplement(DNAAlphabet.SINGLETON, mat);
            }
            int totL = shift >= 0 ? Math.max(rep.length, mat.length + shift) : Math.max(rep.length - shift, mat.length);
            double[][] com = new double[totL][rep[0].length];
            double n2 = subs[i2].getNumberOfElements();
            for (j = 0; j < com.length; ++j) {
                for (int k = 0; k < com[j].length; ++k) {
                    com[j][k] = shift >= 0 ? ((j < rep.length ? rep[j][k] : 0.25) * n1 + (j >= shift && j - shift < mat.length ? mat[j - shift][k] : 0.25) * n2) / (n1 + n2) : ((j >= -shift && j + shift < rep.length ? rep[j + shift][k] : 0.25) * n1 + (j < mat.length ? mat[j][k] : 0.25) * n2) / (n1 + n2);
                }
            }
            n1 += n2;
            rep = com;
            minPrevShift = Integer.MAX_VALUE;
            for (j = 0; j < prevShifts[i2].length; ++j) {
                if (rc < 0) {
                    prevShifts[i2][j][0] = -(subs[i2].getClusterElements()[j].getLength() - mat.length) - prevShifts[i2][j][0];
                }
                if (prevShifts[i2][j][0] >= minPrevShift) continue;
                minPrevShift = prevShifts[i2][j][0];
            }
            j = 0;
            while (j < prevShifts[i2].length) {
                shifts[g][0] = prevShifts[i2][j][0] - minPrevShift + shift;
                shifts[g][1] = prevShifts[i2][j][1] * rc;
                ++j;
                ++g;
            }
        }
        return new Pair<double[][], int[][]>(rep, shifts);
    }
}

