/*
 * Decompiled with CFR 0.152.
 */
package projects.encodedream;

import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.data.DataSet;
import de.jstacs.data.EmptyDataSetException;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.GaussianNetwork;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ToolBox;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Random;
import projects.dimont.Interpolation;
import projects.encodedream.FeatureReader;

public class UnsupervisedTraining {
    private FeatureReader reader;
    private int threads;
    private HashMap<String, Integer> sizes;
    private Init init;
    private Select select;

    public UnsupervisedTraining(FeatureReader reader, int threads, HashMap<String, Integer> sizes, Init init, Select select) {
        this.reader = reader;
        this.threads = threads;
        this.sizes = sizes;
        this.init = init;
        this.select = select;
    }

    public GenDisMixClassifier train(DataSet[] data, double[][] weights) throws Exception {
        GaussianNetwork gn = new GaussianNetwork(new int[data[0].getElementLength()][0]);
        GenDisMixClassifierParameterSet params = new GenDisMixClassifierParameterSet(data[0].getAlphabetContainer(), gn.getLength(), 20, 1.0E-6, 1.0E-6, 1.0E-4, false, OptimizableFunction.KindOfParameter.PLUGIN, true, this.threads);
        GenDisMixClassifier cl = new GenDisMixClassifier(params, (LogPrior)DoesNothingLogPrior.defaultInstance, LearningPrinciple.MCL, gn, gn);
        cl.train(data, weights);
        return cl;
    }

    public GenDisMixClassifier[] iterativeTraining(int iterations, LinkedList<String> trainChroms, double frac, double cons) throws Exception {
        LinkedList<Sequence> full = new LinkedList<Sequence>();
        LinkedList<Sequence> motifs = new LinkedList<Sequence>();
        LinkedList<Sequence> dnase = new LinkedList<Sequence>();
        Pair<double[], double[][]> pair = this.getInitialWeights(trainChroms, frac, full, dnase, motifs);
        double[] vals = pair.getFirstElement();
        double[][] weights = pair.getSecondElement();
        DataSet fullDS = FeatureReader.replaceNaN(new DataSet("", full));
        DataSet dnaseDS = FeatureReader.replaceNaN(new DataSet("", dnase));
        DataSet motifsDS = FeatureReader.replaceNaN(new DataSet("", motifs));
        LinkedList<GenDisMixClassifier> clList = new LinkedList<GenDisMixClassifier>();
        GenDisMixClassifier cl = this.train(new DataSet[]{fullDS, fullDS}, weights);
        clList.add(cl);
        DataSet curr = null;
        curr = this.getCurr(fullDS, dnaseDS, motifsDS, curr);
        GenDisMixClassifier currCl = cl;
        if (this.select != Select.FULL) {
            currCl = this.train(new DataSet[]{curr, curr}, weights);
        }
        int i = 0;
        while (i < iterations) {
            vals = UnsupervisedTraining.updateVals(vals, currCl, curr, cons);
            weights = this.updateWeights(vals, frac);
            cl = this.train(new DataSet[]{fullDS, fullDS}, weights);
            clList.add(cl);
            curr = this.getCurr(fullDS, dnaseDS, motifsDS, curr);
            currCl = cl;
            if (this.select != Select.FULL) {
                currCl = this.train(new DataSet[]{curr, curr}, weights);
            }
            ++i;
        }
        return clList.toArray(new GenDisMixClassifier[0]);
    }

    private DataSet getCurr(DataSet fullDS, DataSet dnaseDS, DataSet motifsDS, DataSet curr) throws EmptyDataSetException, WrongAlphabetException {
        Random r = new Random(127L);
        if (this.select == Select.ALTERNATE) {
            if (curr == null) {
                if (this.init == Init.MOTIF || this.init == Init.BOTH) {
                    return dnaseDS;
                }
                return motifsDS;
            }
            if (curr == dnaseDS) {
                return motifsDS;
            }
            return dnaseDS;
        }
        if (this.select == Select.FULL) {
            return fullDS;
        }
        int[] starts = new int[fullDS.getElementLength()];
        int i = 0;
        while (i < starts.length) {
            starts[i] = i;
            ++i;
        }
        int[] starts2 = new int[starts.length / 2];
        int i2 = 0;
        while (i2 < starts.length / 2) {
            int idx = r.nextInt(starts.length);
            int temp = starts[i2];
            starts[i2] = starts[idx];
            starts2[i2] = starts[idx];
            starts[idx] = starts[i2];
            ++i2;
        }
        int[] lengths = new int[starts2.length];
        Arrays.fill(lengths, 1);
        Sequence[] seqs = new Sequence[fullDS.getNumberOfElements()];
        int i3 = 0;
        while (i3 < seqs.length) {
            Sequence seq = fullDS.getElementAt(i3);
            seqs[i3] = seq.getCompositeSequence(starts2, lengths);
            ++i3;
        }
        return new DataSet("", seqs);
    }

    private double[][] updateWeights(double[] vals, double frac) throws Exception {
        double[] weights = Interpolation.getWeight(null, vals, frac, Interpolation.RANK_LOG);
        return new double[][]{weights, Interpolation.getBgWeight(weights)};
    }

    private static double[] updateVals(double[] vals, GenDisMixClassifier cl, DataSet curr, double cons) throws Exception {
        double[] temp = ToolBox.zscore(vals);
        double[] sc = cl.getScores(curr);
        sc = ToolBox.zscore(sc);
        int i = 0;
        while (i < sc.length) {
            sc[i] = sc[i] + temp[i] * cons;
            ++i;
        }
        return sc;
    }

    private Pair<double[], double[][]> getInitialWeights(LinkedList<String> trainChroms, double frac, LinkedList<Sequence> full, LinkedList<Sequence> dnase, LinkedList<Sequence> motifs) throws Exception {
        this.reader.reset();
        DoubleList motifScores = new DoubleList();
        DoubleList dnaseScores = new DoubleList();
        int l = 0;
        while (l < trainChroms.size()) {
            String chr = trainChroms.get(l);
            this.reader.findChr(chr);
            int size = this.sizes.get(chr);
            for (int j = 0; j < size; ++j) {
                if (this.init == Init.BOTH || this.init == Init.MOTIF) {
                    motifScores.add(this.reader.getCurrentMotifMax(0));
                }
                if (this.init == Init.BOTH || this.init == Init.DNASE) {
                    dnaseScores.add(this.reader.getCurrentDNaseMin());
                }
                full.add(this.reader.getCurrentSequence());
                dnase.add(this.reader.getCurrentDNaseSequence());
                motifs.add(this.reader.getCurrentMotifsSequence());
                if (this.reader.readNextFeatureVector()) continue;
            }
            ++l;
        }
        double[] ms = motifScores.toArray();
        double[] ds = dnaseScores.toArray();
        double[] vals = null;
        if (this.init == Init.MOTIF) {
            vals = ms;
        } else if (this.init == Init.DNASE) {
            vals = ds;
        } else {
            ToolBox.zscore(ms);
            ToolBox.zscore(ds);
            double mi = ToolBox.min(ms);
            int i = 0;
            while (i < ms.length) {
                int n = i++;
                ms[n] = ms[n] - mi;
            }
            mi = ToolBox.min(ds);
            i = 0;
            while (i < ds.length) {
                int n = i++;
                ds[n] = ds[n] - mi;
            }
            vals = ms;
            i = 0;
            while (i < vals.length) {
                int n = i;
                vals[n] = vals[n] + ds[i];
                ++i;
            }
        }
        return new Pair<double[], double[][]>(vals, this.updateWeights(vals, frac));
    }

    public static enum Init {
        MOTIF,
        DNASE,
        BOTH;

    }

    public static enum Select {
        ALTERNATE,
        RANDOM,
        FULL;

    }
}

