/*
 * Decompiled with CFR 0.152.
 */
package projects.tals.rnaseq;

import de.jstacs.DataType;
import de.jstacs.io.FileManager;
import de.jstacs.parameters.EnumParameter;
import de.jstacs.parameters.ExpandableParameterSet;
import de.jstacs.parameters.FileParameter;
import de.jstacs.parameters.Parameter;
import de.jstacs.parameters.ParameterException;
import de.jstacs.parameters.ParameterSet;
import de.jstacs.parameters.ParameterSetContainer;
import de.jstacs.parameters.SimpleParameter;
import de.jstacs.parameters.SimpleParameterSet;
import de.jstacs.results.Result;
import de.jstacs.results.ResultSet;
import de.jstacs.results.TextResult;
import de.jstacs.tools.JstacsTool;
import de.jstacs.tools.ProgressUpdater;
import de.jstacs.tools.Protocol;
import de.jstacs.tools.ToolParameterSet;
import de.jstacs.tools.ToolResult;
import de.jstacs.tools.ui.cli.CLI;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ToolBox;
import htsjdk.samtools.AlignmentBlock;
import htsjdk.samtools.BAMIndex;
import htsjdk.samtools.BAMIndexMetaData;
import htsjdk.samtools.SAMRecord;
import htsjdk.samtools.SAMRecordIterator;
import htsjdk.samtools.SamInputResource;
import htsjdk.samtools.SamReader;
import htsjdk.samtools.SamReaderFactory;
import htsjdk.samtools.ValidationStringency;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.LinkedList;
import java.util.List;

public class DerTALE
implements JstacsTool {
    public static void main(String[] args) throws Exception {
        CLI cli = new CLI(new DerTALE());
        cli.run(args);
    }

    @Override
    public ToolParameterSet getToolParameters() {
        LinkedList<Parameter> pars = new LinkedList<Parameter>();
        try {
            pars.add(new FileParameter("Predictions", "Predictions output file", "tsv,tabular", true));
            pars.add(new ParameterSetContainer("Treatment", "", new ExpandableParameterSet(new SimpleParameterSet(new FileParameter("Treatment BAM", "BAM file of mapped reads from treatment experiment. BAM file must have an index with additional extension .bai.", "bam", true)), "Treatment data", "")));
            pars.add(new ParameterSetContainer("Control", "", new ExpandableParameterSet(new SimpleParameterSet(new FileParameter("Control BAM", "BAM file of mapped reads from control experiment. BAM file must have an index with additional extension .bai.", "bam", true)), "Control data", "")));
        }
        catch (CloneNotSupportedException e1) {
            e1.printStackTrace();
        }
        try {
            pars.add(new SimpleParameter(DataType.INT, "Number of predictions", "Number of (top) predictions considered", true, 100));
            pars.add(new SimpleParameter(DataType.INT, "Region width", "Number of bases around the predicted site", true, 3000));
            pars.add(new SimpleParameter(DataType.INT, "Window width", "Width of the window considered for differential abundance", true, 300));
            pars.add(new SimpleParameter(DataType.DOUBLE, "Pseudo count", "Pseudo count on the count profile", true, 1.0));
            pars.add(new EnumParameter(Compare.class, "Measure for comparing replicates", true, "MEAN"));
            pars.add(new SimpleParameter(DataType.DOUBLE, "Threshold", "Threshold on the log differential abundance", true, 1.0));
        }
        catch (ParameterException e) {
            e.printStackTrace();
        }
        return new ToolParameterSet(this.getShortName(), pars.toArray(new Parameter[0]));
    }

    @Override
    public ToolResult run(ToolParameterSet parameters, Protocol protocol, ProgressUpdater progress, int threads) throws Exception {
        FileParameter.FileRepresentation predsFile = ((FileParameter)parameters.getParameterAt(0)).getFileContents();
        String[] posFiles = this.getFiles((ExpandableParameterSet)parameters.getParameterAt(1).getValue());
        String[] negFiles = this.getFiles((ExpandableParameterSet)parameters.getParameterAt(2).getValue());
        String[] files = new String[posFiles.length + negFiles.length];
        System.arraycopy(posFiles, 0, files, 0, posFiles.length);
        System.arraycopy(negFiles, 0, files, posFiles.length, negFiles.length);
        int numTop = (Integer)parameters.getParameterAt(3).getValue();
        int regionWidth = (Integer)parameters.getParameterAt(4).getValue();
        int windowWidth = (Integer)parameters.getParameterAt(5).getValue();
        double pseudoCount = (Double)parameters.getParameterAt(6).getValue();
        Compare compare = (Compare)((EnumParameter)parameters.getParameterAt(7)).getValue();
        double t = (Double)parameters.getParameterAt(8).getValue();
        LinkedList<Entry> predsList = new LinkedList<Entry>();
        BufferedReader br = new BufferedReader(new StringReader(predsFile.getContent()));
        String curr = null;
        curr = br.readLine();
        if (curr != null) {
            String strand;
            String[] parts;
            if (curr.startsWith("options_used:")) {
                while ((curr = br.readLine()) != null) {
                    if (!curr.startsWith("Best") && !curr.startsWith("Sequence")) {
                        parts = (curr = curr.replaceAll("\\s+", " ")).split(" ");
                        strand = parts[1];
                        if (strand.equals("Plus")) {
                            strand = "+";
                        } else if (strand.equals("Minus")) {
                            strand = "-";
                        }
                        predsList.add(new Entry(parts[0], Integer.parseInt(parts[3]), strand));
                    }
                    if (predsList.size() < numTop) {
                        continue;
                    }
                    break;
                }
            } else if (curr.startsWith(">")) {
                parts = (curr = curr.replaceAll("\\s+", " ")).split(" ");
                strand = parts[1];
                strand = strand.contains("_revcom") ? "-" : "+";
                predsList.add(new Entry(parts[1].replace("_revcom", ""), Integer.parseInt(parts[3]), strand));
                while ((curr = br.readLine()) != null) {
                    parts = (curr = curr.replaceAll("\\s+", " ")).split(" ");
                    strand = parts[1];
                    strand = strand.contains("_revcom") ? "-" : "+";
                    predsList.add(new Entry(parts[1].replace("_revcom", ""), Integer.parseInt(parts[3]), strand));
                    if (predsList.size() < numTop) {
                        continue;
                    }
                    break;
                }
            } else if (curr.startsWith("#")) {
                while ((curr = br.readLine()) != null) {
                    if (!curr.startsWith("#")) {
                        parts = curr.split("\t");
                        predsList.add(new Entry(parts[0], Integer.parseInt(parts[1]), parts[2]));
                    }
                    if (predsList.size() < numTop) {
                        continue;
                    }
                    break;
                }
            } else {
                throw new RuntimeException("File format of the prediction file does not look like the output of PrediTALE, Talvez or Target Finder!");
            }
        }
        br.close();
        SamReaderFactory srf = SamReaderFactory.makeDefault();
        srf.validationStringency(ValidationStringency.SILENT);
        int k = 0;
        while (k < files.length) {
            String fName = files[k];
            for (Entry en : predsList) {
                if (!new File(String.valueOf(fName) + ".bai").exists()) {
                    throw new RuntimeException("No index found for file " + fName + ". The index must be in the same directory as the specified BAM file with filename " + fName + ".bai.");
                }
                SamReader sr = srf.open(SamInputResource.of(new File(fName)).index(new File(String.valueOf(fName) + ".bai")));
                int start = en.getPos() - regionWidth;
                int end = en.getPos() + regionWidth;
                SAMRecordIterator sri = sr.query(en.chr, start, end, true);
                BAMIndex index = sr.indexing().getIndex();
                double count = 0.0;
                int i = 0;
                while (i < sr.getFileHeader().getSequenceDictionary().size()) {
                    BAMIndexMetaData meta = index.getMetaData(i);
                    count += (double)meta.getAlignedRecordCount();
                    ++i;
                }
                double[] counts = new double[2 * regionWidth + 1];
                Arrays.fill(counts, 1000000.0 / count * pseudoCount);
                while (sri.hasNext()) {
                    SAMRecord rec = (SAMRecord)sri.next();
                    int quali = rec.getMappingQuality();
                    if (quali <= 20) continue;
                    List<AlignmentBlock> lab = rec.getAlignmentBlocks();
                    for (AlignmentBlock ab : lab) {
                        int blockstart = ab.getReferenceStart();
                        int len = ab.getLength();
                        int i2 = Math.max(start, blockstart) - start;
                        while (i2 < Math.min(end, blockstart + len) - start) {
                            int n = i2++;
                            counts[n] = counts[n] + 1000000.0 / count;
                        }
                    }
                }
                if (k < posFiles.length) {
                    en.addPositives(counts);
                    continue;
                }
                en.addNegatives(counts);
            }
            ++k;
        }
        LinkedList<TextResult> ress = new LinkedList<TextResult>();
        StringBuffer sb = new StringBuffer();
        sb.append("#" + Entry.getHeader() + "\tlog fold-change\tcenter-max\n");
        for (Entry en : predsList) {
            Pair<double[], String> pair = en.getProfileResult(files, windowWidth, t, compare);
            double[] ri = pair.getFirstElement();
            if (pair.getSecondElement() == null) continue;
            sb.append(en + "\t" + ri[0] + "\t" + (int)ri[1] + "\n");
            ress.add(new TextResult("Profile for " + en.chr + ":" + en.pos + ":" + en.strand, "", new FileParameter.FileRepresentation("", pair.getSecondElement()), "tsv", this.getToolName(), null, true));
        }
        TextResult tr = new TextResult("Differentially abundant", "", new FileParameter.FileRepresentation("", sb.toString()), "tsv", this.getToolName(), null, true);
        ress.addFirst(tr);
        return new ToolResult("Result of " + this.getToolName(), this.getToolName(), null, new ResultSet(new Result[][]{ress.toArray(new Result[0])}), parameters, this.getToolName(), new Date(System.currentTimeMillis()));
    }

    private String[] getFiles(ExpandableParameterSet value) {
        String[] vals = new String[value.getNumberOfParameters()];
        int i = 0;
        while (i < value.getNumberOfParameters()) {
            vals[i] = ((ParameterSet)value.getParameterAt(i).getValue()).getParameterAt(0).getValue().toString();
            ++i;
        }
        return vals;
    }

    @Override
    public String getToolName() {
        return "DerTALE";
    }

    @Override
    public String getToolVersion() {
        return "0.1";
    }

    @Override
    public String getShortName() {
        return "dertale";
    }

    @Override
    public String getDescription() {
        return "filters genome-wide predictions for differential expression";
    }

    @Override
    public String getHelpText() {
        try {
            return FileManager.readInputStream(DerTALE.class.getClassLoader().getResourceAsStream("projects/tals/rnaseq/DerTALE.txt")).toString();
        }
        catch (IOException e) {
            e.printStackTrace();
            return "";
        }
    }

    @Override
    public JstacsTool.ResultEntry[] getDefaultResultInfos() {
        return null;
    }

    @Override
    public ToolResult[] getTestCases(String path) {
        return null;
    }

    @Override
    public void clear() {
    }

    @Override
    public String[] getReferences() {
        return new String[]{"@article{erkes19preditale,\n\ttitle = {{PrediTALE}: A novel model learned from quantitative data allows for new perspectives on {TALE} targeting},\n\tauthor = {Erkes, Annett AND M\\\"ucke, Stefanie AND Reschke, Maik AND Boch, Jens AND Grau, Jan},\n\tjournal = {PLOS Computational Biology},\n\tyear = {2019},\n\tvolume = {15},\n\tnumber = {7},\n\tpages = {1-31},\n\tdoi = {10.1371/journal.pcbi.1007206}\n\t}\n"};
    }

    private static enum Compare {
        EXTREMES,
        MEDIAN,
        MEAN;

    }

    private static class Entry {
        private String chr;
        private int pos;
        private String strand;
        private LinkedList<double[]> positives;
        private LinkedList<double[]> negatives;

        public static String getHeader() {
            return "Chr\tPosition\tStrand";
        }

        public String toString() {
            return String.valueOf(this.chr) + "\t" + this.pos + "\t" + this.strand;
        }

        public void addPositives(double[] positives) {
            if (this.positives == null) {
                this.positives = new LinkedList();
            } else if (this.positives.getFirst().length != positives.length) {
                throw new RuntimeException();
            }
            this.positives.add(positives);
        }

        public void addNegatives(double[] negatives) {
            if (this.negatives == null) {
                this.negatives = new LinkedList();
            } else if (this.negatives.getFirst().length != negatives.length) {
                throw new RuntimeException();
            }
            this.negatives.add(negatives);
        }

        public Entry(String chr, int pos, String strand) {
            this.chr = chr;
            this.pos = pos;
            this.strand = strand;
        }

        public String getChr() {
            return this.chr;
        }

        public int getPos() {
            return this.pos;
        }

        public String getStrand() {
            return this.strand;
        }

        private static double[] getSmoothedProfile(double[] profile, int windowWidth) {
            double currPos = ToolBox.sum(0, windowWidth, profile);
            double[] res = new double[profile.length - windowWidth];
            int i = windowWidth;
            while (i < profile.length) {
                res[i - windowWidth] = currPos;
                currPos += profile[i] - profile[i - windowWidth];
                ++i;
            }
            return res;
        }

        private static double[][] getSmoothedProfiles(LinkedList<double[]> profs, int windowWidth) {
            double[][] dProfs = new double[profs.size()][];
            int i = 0;
            while (i < dProfs.length) {
                dProfs[i] = Entry.getSmoothedProfile(profs.get(i), windowWidth);
                ++i;
            }
            return dProfs;
        }

        public Pair<double[], String> getProfileResult(String[] header, int windowWidth, double t, Compare comp) {
            int i;
            double[][] posProfs = Entry.getSmoothedProfiles(this.positives, windowWidth);
            double[][] negProfs = Entry.getSmoothedProfiles(this.negatives, windowWidth);
            int mid = this.positives.getFirst().length / 2;
            int relIdx = -1;
            double max = Double.NEGATIVE_INFINITY;
            int sumAboveThreshold = 0;
            int[] posAboveThreshold = new int[posProfs[0].length];
            Arrays.fill(posAboveThreshold, Integer.MIN_VALUE);
            StringBuffer sb = new StringBuffer();
            sb.append("Position");
            int i2 = 0;
            while (i2 < header.length) {
                sb.append("\t" + header[i2].replaceAll(".*/SRR", "SRR").replace("/", "_"));
                ++i2;
            }
            sb.append("\tabove threshold\n");
            i2 = 0;
            while (i2 < windowWidth / 2) {
                sb.append(this.pos - mid + i2);
                int j = 0;
                while (j < posProfs.length) {
                    sb.append("\t" + this.positives.get(j)[i2]);
                    ++j;
                }
                j = 0;
                while (j < negProfs.length) {
                    sb.append("\t" + this.negatives.get(j)[i2]);
                    ++j;
                }
                sb.append("\tNA\n");
                ++i2;
            }
            i2 = 0;
            while (i2 < posProfs[0].length) {
                sb.append(this.pos - mid + i2 + windowWidth / 2);
                int j = 0;
                while (j < posProfs.length) {
                    sb.append("\t" + this.positives.get(j)[i2 + windowWidth / 2]);
                    ++j;
                }
                j = 0;
                while (j < negProfs.length) {
                    sb.append("\t" + this.negatives.get(j)[i2 + windowWidth / 2]);
                    ++j;
                }
                double minPos = Double.POSITIVE_INFINITY;
                double maxNeg = Double.NEGATIVE_INFINITY;
                if (comp == Compare.EXTREMES) {
                    int j2 = 0;
                    while (j2 < posProfs.length) {
                        if (posProfs[j2][i2] < minPos) {
                            minPos = posProfs[j2][i2];
                        }
                        ++j2;
                    }
                    j2 = 0;
                    while (j2 < negProfs.length) {
                        if (negProfs[j2][i2] > maxNeg) {
                            maxNeg = negProfs[j2][i2];
                        }
                        ++j2;
                    }
                } else {
                    double[] temp = new double[posProfs.length];
                    int j3 = 0;
                    while (j3 < posProfs.length) {
                        temp[j3] = posProfs[j3][i2];
                        ++j3;
                    }
                    minPos = comp == Compare.MEAN ? ToolBox.mean(temp) : ToolBox.median(temp);
                    temp = new double[negProfs.length];
                    j3 = 0;
                    while (j3 < negProfs.length) {
                        temp[j3] = negProfs[j3][i2];
                        ++j3;
                    }
                    maxNeg = comp == Compare.MEAN ? ToolBox.mean(temp) : ToolBox.median(temp);
                }
                double rat = Math.log(minPos) - Math.log(maxNeg);
                if (rat > max) {
                    max = rat;
                    relIdx = i2 + windowWidth / 2 - mid;
                }
                if (rat > t) {
                    posAboveThreshold[i2] = i2 + windowWidth / 2 - mid;
                    ++sumAboveThreshold;
                    sb.append("\tTRUE\n");
                } else {
                    sb.append("\tFALSE\n");
                }
                ++i2;
            }
            boolean prevPovAbove = false;
            int maxStrechLength = 0;
            int aktStrechAboveThreshold = 0;
            ArrayList<Integer> strechLengths = new ArrayList<Integer>();
            int whichStrechSpans = -1;
            int countSpans = 0;
            int spanWidth = 50;
            int i3 = 0;
            while (i3 < posProfs[0].length) {
                if (posAboveThreshold[i3] <= spanWidth && posAboveThreshold[i3] >= -spanWidth) {
                    ++countSpans;
                }
                if (posAboveThreshold[i3] > Integer.MIN_VALUE) {
                    prevPovAbove = true;
                    ++aktStrechAboveThreshold;
                } else if (aktStrechAboveThreshold > maxStrechLength) {
                    maxStrechLength = aktStrechAboveThreshold;
                    strechLengths.add(aktStrechAboveThreshold);
                    if (whichStrechSpans == -1 && countSpans == spanWidth * 2 + 1) {
                        whichStrechSpans = strechLengths.size() - 1;
                    }
                    aktStrechAboveThreshold = 0;
                    prevPovAbove = false;
                }
                ++i3;
            }
            int minStrechLength = 400;
            boolean secStrech = false;
            if (whichStrechSpans != -1) {
                i = 0;
                while (i < strechLengths.size()) {
                    if (i != whichStrechSpans && (Integer)strechLengths.get(i) >= minStrechLength) {
                        secStrech = true;
                    }
                    ++i;
                }
            }
            i = posProfs[0].length + windowWidth / 2;
            while (i < posProfs[0].length + windowWidth) {
                sb.append(this.pos - mid + i);
                int j = 0;
                while (j < posProfs.length) {
                    sb.append("\t" + this.positives.get(j)[i]);
                    ++j;
                }
                j = 0;
                while (j < negProfs.length) {
                    sb.append("\t" + this.negatives.get(j)[i]);
                    ++j;
                }
                sb.append("\tNA\n");
                ++i;
            }
            if (max > t && maxStrechLength >= minStrechLength) {
                if (whichStrechSpans != -1 && secStrech) {
                    return new Pair<double[], String>(new double[]{max, relIdx}, sb.toString());
                }
                if (whichStrechSpans == -1) {
                    return new Pair<double[], String>(new double[]{max, relIdx}, sb.toString());
                }
                return new Pair<double[], Object>(new double[]{max, relIdx}, null);
            }
            return new Pair<double[], Object>(new double[]{max, relIdx}, null);
        }
    }
}

