/*
 * Decompiled with CFR 0.152.
 */
package projects.encodedream;

import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.GaussianNetwork;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ToolBox;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import projects.encodedream.FeatureReader;
import projects.encodedream.Predictor;

public class IterativeTraining {
    private FeatureReader reader;
    private int threads;
    private HashMap<String, Integer> sizes;

    public IterativeTraining(FeatureReader reader, int threads, HashMap<String, Integer> sizes) {
        this.reader = reader;
        this.threads = threads;
        this.sizes = sizes;
    }

    public GenDisMixClassifier train(DataSet[] data, double[][] weights) throws Exception {
        GaussianNetwork gn = new GaussianNetwork(new int[data[0].getElementLength()][0]);
        GenDisMixClassifierParameterSet params = new GenDisMixClassifierParameterSet(data[0].getAlphabetContainer(), gn.getLength(), 20, 1.0E-6, 1.0E-6, 1.0E-4, false, OptimizableFunction.KindOfParameter.PLUGIN, true, this.threads);
        GenDisMixClassifier cl = new GenDisMixClassifier(params, (LogPrior)DoesNothingLogPrior.defaultInstance, LearningPrinciple.MCL, gn, gn);
        cl.train(data, weights);
        return cl;
    }

    public GenDisMixClassifier[] iterativeTraining(int iterations, HashSet<String> trainChroms, LinkedList<String> itChroms, double perc, int binsBefore, int binsAfter) throws Exception {
        if (trainChroms != null && !trainChroms.containsAll(itChroms)) {
            throw new RuntimeException();
        }
        Pair<DataSet[], double[][]> pair = this.reader.getInitialData(trainChroms);
        LinkedList<GenDisMixClassifier> clList = new LinkedList<GenDisMixClassifier>();
        GenDisMixClassifier cl = this.train(pair.getFirstElement(), pair.getSecondElement());
        clList.add(cl);
        DataSet fg = pair.getFirstElement()[0];
        DataSet bg = pair.getFirstElement()[1];
        double[][] weights = pair.getSecondElement();
        int i = 1;
        while (i < iterations) {
            String chr;
            Predictor pred = new Predictor(clList.toArray(new GenDisMixClassifier[0]), this.reader, binsBefore, binsAfter);
            DoubleList scPos = new DoubleList();
            DoubleList scNeg = new DoubleList();
            double[][] preds = new double[itChroms.size()][];
            int l = 0;
            while (l < itChroms.size()) {
                chr = itChroms.get(l);
                preds[l] = pred.predict(chr, this.sizes.get(chr));
                ++l;
            }
            this.reader.reset();
            l = 0;
            while (l < itChroms.size()) {
                chr = itChroms.get(l);
                this.reader.findChr(chr);
                int size = this.sizes.get(chr);
                for (int j = 0; j < size; ++j) {
                    char lab = this.reader.getCurrentLabel();
                    if (lab == 'S' || lab == 'B') {
                        scPos.add(preds[l][j]);
                        continue;
                    }
                    if (lab != 'U') continue;
                    scNeg.add(preds[l][j]);
                    if (this.reader.readNextFeatureVector()) continue;
                }
                ++l;
            }
            double prev = ToolBox.sum(weights[1]) * 0.15;
            double perc2 = 1.0 - prev / (double)scNeg.length();
            double th = ToolBox.percentile(scPos.toArray(), perc);
            double th2 = ToolBox.percentile(scNeg.toArray(), perc2);
            th = Math.max(th, th2);
            LinkedList<Sequence> seqs = new LinkedList<Sequence>();
            this.reader.reset();
            int l2 = 0;
            while (l2 < itChroms.size()) {
                String chr2 = itChroms.get(l2);
                this.reader.findChr(chr2);
                int size = this.sizes.get(chr2);
                for (int j = 0; j < size; ++j) {
                    if (!(preds[l2][j] >= th) || this.reader.getCurrentLabel() != 'U') continue;
                    Sequence seq = this.reader.getCurrentSequence();
                    seqs.add(seq);
                    if (this.reader.readNextFeatureVector()) continue;
                }
                ++l2;
            }
            bg = FeatureReader.replaceNaN(DataSet.union(bg, new DataSet("", seqs)));
            double[] temp = new double[weights[1].length + seqs.size()];
            Arrays.fill(temp, 1.0);
            System.arraycopy(weights[1], 0, temp, 0, weights[1].length);
            weights[1] = temp;
            cl = this.train(new DataSet[]{fg, bg}, weights);
            clList.add(cl);
            ++i;
        }
        return clList.toArray(new GenDisMixClassifier[0]);
    }
}

