/*
 * Decompiled with CFR 0.152.
 */
package projects.encodedream;

import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import java.io.File;
import java.io.FileOutputStream;
import java.io.PrintWriter;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.zip.GZIPOutputStream;
import projects.encodedream.FeatureReader;

public class Predictor {
    private GenDisMixClassifier[] cls;
    private FeatureReader reader;
    private int binsBefore;
    private int binsAfter;

    public Predictor(GenDisMixClassifier[] cls, FeatureReader reader, int binsBefore, int binsAfter) {
        this.cls = cls;
        this.reader = reader;
        this.binsBefore = binsBefore;
        this.binsAfter = binsAfter;
    }

    public File predict(HashMap<String, Integer> sizes, LinkedList<String> chroms) throws Exception {
        if (chroms == null) {
            chroms = new LinkedList<String>(sizes.keySet());
            Collections.sort(chroms);
        }
        File temp = File.createTempFile("preds", ".tsv.gz");
        temp.deleteOnExit();
        PrintWriter wr = new PrintWriter(new GZIPOutputStream(new FileOutputStream(temp)));
        block0: for (String chr : chroms) {
            double[] preds = this.predict(chr, sizes.get(chr));
            this.reader.reset();
            this.reader.findChr(chr);
            int i = 0;
            do {
                String chr2 = this.reader.getCurrentChromosome();
                int start = this.reader.getCurrentStart();
                if (!chr.equals(chr2)) continue block0;
                wr.println(String.valueOf(chr) + "\t" + start + "\t" + preds[i]);
                ++i;
            } while (this.reader.readNextFeatureVector());
        }
        wr.close();
        return temp;
    }

    public double[] predict(String chr, int size) throws Exception {
        this.reader.reset();
        boolean found = this.reader.findChr(chr);
        if (!found) {
            throw new RuntimeException("Did not find chromosome " + chr + " in feature files.");
        }
        double[][] scores = new double[this.cls.length][size];
        boolean lastStart = false;
        for (int i = 0; i < size; ++i) {
            Sequence seq = this.reader.getCurrentSequence();
            int j = 0;
            while (j < this.cls.length) {
                scores[j][i] = this.cls[j].getScore(seq, 0) - this.cls[j].getScore(seq, 1);
                ++j;
            }
            if (this.reader.readNextFeatureVector()) continue;
        }
        double[] pred = this.aggregate(scores);
        return pred;
    }

    public double[] aggregate(double[][] scs) {
        double[] preds = new double[scs[0].length];
        int j = 0;
        while (j < scs[0].length) {
            double all = 0.0;
            int i = 0;
            while (i < this.cls.length) {
                int start = Math.max(0, j - this.binsBefore);
                int end = Math.min(scs[i].length, j + this.binsAfter + 1);
                double sum = 0.0;
                int k = start;
                while (k < end) {
                    sum += Math.log1p(-1.0 / (1.0 + Math.exp(-scs[i][k])));
                    ++k;
                }
                if (Double.isNaN(sum)) {
                    sum = 0.0;
                }
                all += 1.0 - Math.exp(sum);
                ++i;
            }
            preds[j] = all /= (double)this.cls.length;
            ++j;
        }
        return preds;
    }

    public double[] predict(DataSet data) throws Exception {
        double[][] scs = new double[this.cls.length][];
        int i = 0;
        while (i < this.cls.length) {
            scs[i] = this.cls[i].getScores(data);
            ++i;
        }
        return this.aggregate(scs);
    }
}

