/*
 * Decompiled with CFR 0.152.
 */
import de.jstacs.data.DNADataSet;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.SplitSequenceAnnotationParser;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.PFMComparator;
import de.jstacs.utils.REnvironment;
import de.jstacs.utils.ToolBox;
import java.util.Arrays;
import java.util.LinkedList;

public class PlotLogoAndMI {
    public static void main(String[] args) throws Exception {
        double[] wt = new double[]{0.5, 0.7, 0.9, 0.95};
        double[] pt = new double[]{0.001};
        REnvironment re = new REnvironment();
        re.voidEval("library(seqLogo);library(gplots);");
        StringBuffer pcLogo = new StringBuffer();
        StringBuffer pcDep = new StringBuffer();
        StringBuffer pcNn = new StringBuffer();
        StringBuffer pcNLogo = new StringBuffer();
        int s = 0;
        while (s < pt.length) {
            int t = 0;
            while (t < wt.length) {
                DataSet ds = new DNADataSet(args[0], '>', new SplitSequenceAnnotationParser(":", ";"));
                LinkedList<Sequence> seqs = new LinkedList<Sequence>();
                DoubleList ws = new DoubleList();
                int i = 0;
                while (i < ds.getNumberOfElements()) {
                    double w = Double.parseDouble(ds.getElementAt(i).getSequenceAnnotationByType("weight", 0).getIdentifier());
                    double p = Double.parseDouble(ds.getElementAt(i).getSequenceAnnotationByType("pval", 0).getIdentifier());
                    if (w >= wt[t] && p <= pt[s]) {
                        seqs.add(ds.getElementAt(i));
                        ws.add(w);
                    }
                    ++i;
                }
                if (seqs.size() != 0) {
                    ds = new DataSet("", seqs.toArray(new Sequence[0]));
                    if (args[1].equals("rc")) {
                        Sequence[] seqs2 = ds.getAllElements();
                        int i2 = 0;
                        while (i2 < seqs2.length) {
                            seqs2[i2] = seqs2[i2].reverseComplement();
                            ++i2;
                        }
                        ds = new DataSet("", seqs2);
                    }
                    double[][] pwm = PFMComparator.getPFM(ds, ws.toArray());
                    int i3 = 0;
                    while (i3 < pwm.length) {
                        Normalisation.sumNormalisation(pwm[i3]);
                        ++i3;
                    }
                    re.createMatrix("pwm" + (s + t), pwm);
                    double[][] mis = PlotLogoAndMI.computeMIs(ds, ws.toArray());
                    double[] nn = new double[mis.length - 1];
                    int i4 = 0;
                    while (i4 < mis.length - 1) {
                        nn[i4] = mis[i4][i4 + 1];
                        ++i4;
                    }
                    re.createMatrix("mis" + (s + t), mis);
                    if (s + t == 0) {
                        re.voidEval("mins<-matrix(nrow=" + wt.length * pt.length + ",ncol=" + nn.length + ")");
                    }
                    re.createVector("min" + (s + t), nn);
                    re.voidEval("mins[" + (s + t + 1) + ",]<-min" + (s + t) + ";");
                    pcLogo.append("seqLogo(t(pwm" + (s + t) + "));");
                    double[][][] pwms = PlotLogoAndMI.getConditionalPWMs(ds, ws.toArray());
                    double[][] compressed = new double[1 + 5 * (pwm.length - 1)][4];
                    IntList colsep = new IntList();
                    compressed[0] = pwm[0];
                    int i5 = 1;
                    int k = 1;
                    while (i5 < pwms[0].length) {
                        colsep.add(k);
                        compressed[k] = pwm[i5 - 1];
                        ++k;
                        int j = 0;
                        while (j < pwms.length) {
                            int l = 0;
                            while (l < 4) {
                                compressed[k][l] = pwms[l][i5][j] * pwm[i5 - 1][l];
                                ++l;
                            }
                            ++j;
                            ++k;
                        }
                        ++i5;
                    }
                    re.createVector("colsep", colsep.toArray());
                    re.createMatrix("compr" + (s + t), compressed);
                    re.voidEval("compr" + (s + t) + "[compr" + (s + t) + "==1.2]<-NA;");
                    pcNLogo.append("heatmap.2(t(compr" + (s + t) + "),zlim=c(0,1),colsep=colsep,sepcolor=1,key=T,trace=\"none\",Rowv=NULL,Colv=NULL,main=\"w>=" + wt[t] + ",p<=" + pt[s] + "\");");
                    pcDep.append("heatmap.2(mis" + (s + t) + ",Rowv=NULL,Colv=NULL,key=T,trace=\"none\",main=\"w>=" + wt[t] + ",p<=" + pt[s] + "\");");
                }
                ++t;
            }
            ++s;
        }
        pcNn.append("plot(mins[1,],t=\"l\",xlab=\"i: (i,i+1)\",ylab=\"MI\",ylim=c(0,max(mins,na.rm=T)));");
        pcNn.append("for(i in 2:nrow(mins)){lines(mins[i,],col=i);}");
        try {
            re.plotToPDF(pcLogo, 8.0, 5.0, String.valueOf(args[0]) + "_logos.pdf", true);
            re.plotToPDF(pcDep, 8.0, 5.0, String.valueOf(args[0]) + "_deps.pdf", true);
            re.plotToPDF(pcNLogo, 8.0, 5.0, String.valueOf(args[0]) + "_nlogos.pdf", true);
            re.plotToPDF(pcNn, 8.0, 5.0, String.valueOf(args[0]) + "_nndep.pdf", true);
        }
        catch (Exception exception) {
            // empty catch block
        }
        re.close();
    }

    private static double[][] computeMIs(DataSet ds, double[] w) {
        double[][] mis = new double[ds.getElementLength()][ds.getElementLength()];
        int i = 0;
        while (i < mis.length) {
            int j = 0;
            while (j < i) {
                mis[i][j] = PlotLogoAndMI.computeMI(ds, w, i, j);
                mis[j][i] = mis[i][j];
                ++j;
            }
            ++i;
        }
        return mis;
    }

    private static double[][][] getConditionalPWMs(DataSet ds, double[] w) {
        double[][][] pwms = new double[4][ds.getElementLength()][4];
        int i = 0;
        while (i < ds.getNumberOfElements()) {
            Sequence seq = ds.getElementAt(i);
            int j = 0;
            while (j < seq.getLength() - 1) {
                double[] dArray = pwms[seq.discreteVal(j)][j + 1];
                int n = seq.discreteVal(j + 1);
                dArray[n] = dArray[n] + w[i];
                ++j;
            }
            ++i;
        }
        i = 0;
        while (i < pwms.length) {
            int j = 0;
            while (j < pwms[i].length) {
                Normalisation.sumNormalisation(pwms[i][j]);
                if (Double.isNaN(pwms[i][j][0])) {
                    Arrays.fill(pwms[i][j], 0.25);
                }
                ++j;
            }
            ++i;
        }
        return pwms;
    }

    private static double computeMI(DataSet ds, double[] w, int p1, int p2) {
        double[][] count = new double[4][4];
        int i = 0;
        while (i < ds.getNumberOfElements()) {
            Sequence seq = ds.getElementAt(i);
            double[] dArray = count[seq.discreteVal(p1)];
            int n = seq.discreteVal(p2);
            dArray[n] = dArray[n] + w[i];
            ++i;
        }
        double[] count1 = new double[4];
        double[] count2 = new double[4];
        int i2 = 0;
        while (i2 < count.length) {
            int j = 0;
            while (j < count[i2].length) {
                int n = i2;
                count1[n] = count1[n] + count[i2][j];
                int n2 = j;
                count2[n2] = count2[n2] + count[i2][j];
                ++j;
            }
            ++i2;
        }
        double sum = ToolBox.sum(count1);
        Normalisation.sumNormalisation(count1);
        Normalisation.sumNormalisation(count2);
        double mi = 0.0;
        int i3 = 0;
        while (i3 < count.length) {
            int j = 0;
            while (j < count[i3].length) {
                double[] dArray = count[i3];
                int n = j;
                dArray[n] = dArray[n] / sum;
                if (count[i3][j] > 0.0) {
                    mi += count[i3][j] * Math.log(count[i3][j] / (count1[i3] * count2[j]));
                }
                ++j;
            }
            ++i3;
        }
        return mi;
    }
}

