package projects.encodedream;

import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.data.DataSet;
import de.jstacs.data.EmptyDataSetException;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.GaussianNetwork;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ToolBox;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Random;
import projects.dimont.Interpolation;

/* loaded from: input_file:projects/encodedream/UnsupervisedTraining.class */
public class UnsupervisedTraining {
    private FeatureReader reader;
    private int threads;
    private HashMap<String, Integer> sizes;
    private Init init;
    private Select select;

    /* loaded from: input_file:projects/encodedream/UnsupervisedTraining$Init.class */
    public enum Init {
        MOTIF,
        DNASE,
        BOTH;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static Init[] valuesCustom() {
            Init[] valuesCustom = values();
            int length = valuesCustom.length;
            Init[] initArr = new Init[length];
            System.arraycopy(valuesCustom, 0, initArr, 0, length);
            return initArr;
        }
    }

    /* loaded from: input_file:projects/encodedream/UnsupervisedTraining$Select.class */
    public enum Select {
        ALTERNATE,
        RANDOM,
        FULL;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static Select[] valuesCustom() {
            Select[] valuesCustom = values();
            int length = valuesCustom.length;
            Select[] selectArr = new Select[length];
            System.arraycopy(valuesCustom, 0, selectArr, 0, length);
            return selectArr;
        }
    }

    public UnsupervisedTraining(FeatureReader featureReader, int i, HashMap<String, Integer> hashMap, Init init, Select select) {
        this.reader = featureReader;
        this.threads = i;
        this.sizes = hashMap;
        this.init = init;
        this.select = select;
    }

    public GenDisMixClassifier train(DataSet[] dataSetArr, double[][] dArr) throws Exception {
        GaussianNetwork gaussianNetwork = new GaussianNetwork(new int[dataSetArr[0].getElementLength()][0]);
        GenDisMixClassifier genDisMixClassifier = new GenDisMixClassifier(new GenDisMixClassifierParameterSet(dataSetArr[0].getAlphabetContainer(), gaussianNetwork.getLength(), (byte) 20, 1.0E-6d, 1.0E-6d, 1.0E-4d, false, OptimizableFunction.KindOfParameter.PLUGIN, true, this.threads), DoesNothingLogPrior.defaultInstance, LearningPrinciple.MCL, gaussianNetwork, gaussianNetwork);
        genDisMixClassifier.train(dataSetArr, dArr);
        return genDisMixClassifier;
    }

    public GenDisMixClassifier[] iterativeTraining(int i, LinkedList<String> linkedList, double d, double d2) throws Exception {
        LinkedList<Sequence> linkedList2 = new LinkedList<>();
        LinkedList<Sequence> linkedList3 = new LinkedList<>();
        LinkedList<Sequence> linkedList4 = new LinkedList<>();
        Pair<double[], double[][]> initialWeights = getInitialWeights(linkedList, d, linkedList2, linkedList4, linkedList3);
        double[] firstElement = initialWeights.getFirstElement();
        double[][] secondElement = initialWeights.getSecondElement();
        DataSet replaceNaN = FeatureReader.replaceNaN(new DataSet("", linkedList2));
        DataSet replaceNaN2 = FeatureReader.replaceNaN(new DataSet("", linkedList4));
        DataSet replaceNaN3 = FeatureReader.replaceNaN(new DataSet("", linkedList3));
        LinkedList linkedList5 = new LinkedList();
        GenDisMixClassifier train = train(new DataSet[]{replaceNaN, replaceNaN}, secondElement);
        linkedList5.add(train);
        DataSet curr = getCurr(replaceNaN, replaceNaN2, replaceNaN3, null);
        GenDisMixClassifier genDisMixClassifier = train;
        if (this.select != Select.FULL) {
            genDisMixClassifier = train(new DataSet[]{curr, curr}, secondElement);
        }
        for (int i2 = 0; i2 < i; i2++) {
            firstElement = updateVals(firstElement, genDisMixClassifier, curr, d2);
            double[][] updateWeights = updateWeights(firstElement, d);
            GenDisMixClassifier train2 = train(new DataSet[]{replaceNaN, replaceNaN}, updateWeights);
            linkedList5.add(train2);
            curr = getCurr(replaceNaN, replaceNaN2, replaceNaN3, curr);
            genDisMixClassifier = train2;
            if (this.select != Select.FULL) {
                genDisMixClassifier = train(new DataSet[]{curr, curr}, updateWeights);
            }
        }
        return (GenDisMixClassifier[]) linkedList5.toArray(new GenDisMixClassifier[0]);
    }

    private DataSet getCurr(DataSet dataSet, DataSet dataSet2, DataSet dataSet3, DataSet dataSet4) throws EmptyDataSetException, WrongAlphabetException {
        Random random = new Random(127L);
        if (this.select == Select.ALTERNATE) {
            return dataSet4 == null ? (this.init == Init.MOTIF || this.init == Init.BOTH) ? dataSet2 : dataSet3 : dataSet4 == dataSet2 ? dataSet3 : dataSet2;
        }
        if (this.select == Select.FULL) {
            return dataSet;
        }
        int[] iArr = new int[dataSet.getElementLength()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = i;
        }
        int[] iArr2 = new int[iArr.length / 2];
        for (int i2 = 0; i2 < iArr.length / 2; i2++) {
            int nextInt = random.nextInt(iArr.length);
            int i3 = iArr[i2];
            iArr[i2] = iArr[nextInt];
            iArr2[i2] = iArr[nextInt];
            iArr[nextInt] = iArr[i2];
        }
        int[] iArr3 = new int[iArr2.length];
        Arrays.fill(iArr3, 1);
        Sequence[] sequenceArr = new Sequence[dataSet.getNumberOfElements()];
        for (int i4 = 0; i4 < sequenceArr.length; i4++) {
            sequenceArr[i4] = dataSet.getElementAt(i4).getCompositeSequence(iArr2, iArr3);
        }
        return new DataSet("", sequenceArr);
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    private double[][] updateWeights(double[] dArr, double d) throws Exception {
        double[] weight = Interpolation.getWeight(null, dArr, d, Interpolation.RANK_LOG);
        return new double[]{weight, Interpolation.getBgWeight(weight)};
    }

    private static double[] updateVals(double[] dArr, GenDisMixClassifier genDisMixClassifier, DataSet dataSet, double d) throws Exception {
        double[] zscore = ToolBox.zscore(dArr);
        double[] zscore2 = ToolBox.zscore(genDisMixClassifier.getScores(dataSet));
        for (int i = 0; i < zscore2.length; i++) {
            zscore2[i] = zscore2[i] + (zscore[i] * d);
        }
        return zscore2;
    }

    private Pair<double[], double[][]> getInitialWeights(LinkedList<String> linkedList, double d, LinkedList<Sequence> linkedList2, LinkedList<Sequence> linkedList3, LinkedList<Sequence> linkedList4) throws Exception {
        double[] dArr;
        this.reader.reset();
        DoubleList doubleList = new DoubleList();
        DoubleList doubleList2 = new DoubleList();
        for (int i = 0; i < linkedList.size(); i++) {
            String str = linkedList.get(i);
            this.reader.findChr(str);
            int i2 = 0;
            int intValue = this.sizes.get(str).intValue();
            while (i2 < intValue) {
                if (this.init == Init.BOTH || this.init == Init.MOTIF) {
                    doubleList.add(this.reader.getCurrentMotifMax(0));
                }
                if (this.init == Init.BOTH || this.init == Init.DNASE) {
                    doubleList2.add(this.reader.getCurrentDNaseMin());
                }
                linkedList2.add(this.reader.getCurrentSequence());
                linkedList3.add(this.reader.getCurrentDNaseSequence());
                linkedList4.add(this.reader.getCurrentMotifsSequence());
                i2++;
                if (!this.reader.readNextFeatureVector()) {
                    break;
                }
            }
        }
        double[] array = doubleList.toArray();
        double[] array2 = doubleList2.toArray();
        if (this.init == Init.MOTIF) {
            dArr = array;
        } else if (this.init == Init.DNASE) {
            dArr = array2;
        } else {
            ToolBox.zscore(array);
            ToolBox.zscore(array2);
            double min = ToolBox.min(array);
            for (int i3 = 0; i3 < array.length; i3++) {
                int i4 = i3;
                array[i4] = array[i4] - min;
            }
            double min2 = ToolBox.min(array2);
            for (int i5 = 0; i5 < array2.length; i5++) {
                int i6 = i5;
                array2[i6] = array2[i6] - min2;
            }
            dArr = array;
            for (int i7 = 0; i7 < dArr.length; i7++) {
                int i8 = i7;
                dArr[i8] = dArr[i8] + array2[i7];
            }
        }
        return new Pair<>(dArr, updateWeights(dArr, d));
    }
}
