package projects.encodedream;

import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.data.DataSet;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.GaussianNetwork;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ToolBox;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;

/* loaded from: input_file:projects/encodedream/IterativeTraining.class */
public class IterativeTraining {
    private FeatureReader reader;
    private int threads;
    private HashMap<String, Integer> sizes;

    public IterativeTraining(FeatureReader featureReader, int i, HashMap<String, Integer> hashMap) {
        this.reader = featureReader;
        this.threads = i;
        this.sizes = hashMap;
    }

    public GenDisMixClassifier train(DataSet[] dataSetArr, double[][] dArr) throws Exception {
        GaussianNetwork gaussianNetwork = new GaussianNetwork(new int[dataSetArr[0].getElementLength()][0]);
        GenDisMixClassifier genDisMixClassifier = new GenDisMixClassifier(new GenDisMixClassifierParameterSet(dataSetArr[0].getAlphabetContainer(), gaussianNetwork.getLength(), (byte) 20, 1.0E-6d, 1.0E-6d, 1.0E-4d, false, OptimizableFunction.KindOfParameter.PLUGIN, true, this.threads), DoesNothingLogPrior.defaultInstance, LearningPrinciple.MCL, gaussianNetwork, gaussianNetwork);
        genDisMixClassifier.train(dataSetArr, dArr);
        return genDisMixClassifier;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public GenDisMixClassifier[] iterativeTraining(int i, HashSet<String> hashSet, LinkedList<String> linkedList, double d, int i2, int i3) throws Exception {
        if (hashSet != null && !hashSet.containsAll(linkedList)) {
            throw new RuntimeException();
        }
        Pair<DataSet[], double[][]> initialData = this.reader.getInitialData(hashSet);
        LinkedList linkedList2 = new LinkedList();
        linkedList2.add(train(initialData.getFirstElement(), initialData.getSecondElement()));
        DataSet dataSet = initialData.getFirstElement()[0];
        DataSet dataSet2 = initialData.getFirstElement()[1];
        double[][] secondElement = initialData.getSecondElement();
        for (int i4 = 1; i4 < i; i4++) {
            Predictor predictor = new Predictor((GenDisMixClassifier[]) linkedList2.toArray(new GenDisMixClassifier[0]), this.reader, i2, i3);
            DoubleList doubleList = new DoubleList();
            DoubleList doubleList2 = new DoubleList();
            double[] dArr = new double[linkedList.size()];
            for (int i5 = 0; i5 < linkedList.size(); i5++) {
                String str = linkedList.get(i5);
                dArr[i5] = predictor.predict(str, this.sizes.get(str).intValue());
            }
            this.reader.reset();
            for (int i6 = 0; i6 < linkedList.size(); i6++) {
                String str2 = linkedList.get(i6);
                this.reader.findChr(str2);
                int i7 = 0;
                int intValue = this.sizes.get(str2).intValue();
                while (i7 < intValue) {
                    char currentLabel = this.reader.getCurrentLabel();
                    if (currentLabel == 'S' || currentLabel == 'B') {
                        doubleList.add(dArr[i6][i7]);
                    } else if (currentLabel == 'U') {
                        doubleList2.add(dArr[i6][i7]);
                    }
                    i7++;
                    if (!this.reader.readNextFeatureVector()) {
                        break;
                    }
                }
            }
            double max = Math.max(ToolBox.percentile(doubleList.toArray(), d), ToolBox.percentile(doubleList2.toArray(), 1.0d - ((ToolBox.sum(secondElement[1]) * 0.15d) / doubleList2.length())));
            LinkedList linkedList3 = new LinkedList();
            this.reader.reset();
            for (int i8 = 0; i8 < linkedList.size(); i8++) {
                String str3 = linkedList.get(i8);
                this.reader.findChr(str3);
                int i9 = 0;
                int intValue2 = this.sizes.get(str3).intValue();
                while (i9 < intValue2) {
                    if (dArr[i8][i9] >= max && this.reader.getCurrentLabel() == 'U') {
                        linkedList3.add(this.reader.getCurrentSequence());
                    }
                    i9++;
                    if (!this.reader.readNextFeatureVector()) {
                        break;
                    }
                }
            }
            dataSet2 = FeatureReader.replaceNaN(DataSet.union(dataSet2, new DataSet("", linkedList3)));
            double[] dArr2 = new double[secondElement[1].length + linkedList3.size()];
            Arrays.fill(dArr2, 1.0d);
            System.arraycopy(secondElement[1], 0, dArr2, 0, secondElement[1].length);
            secondElement[1] = dArr2;
            linkedList2.add(train(new DataSet[]{dataSet, dataSet2}, secondElement));
        }
        return (GenDisMixClassifier[]) linkedList2.toArray(new GenDisMixClassifier[0]);
    }
}
