package projects.plantdream;

import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.classifiers.performanceMeasures.AucPR;
import de.jstacs.classifiers.performanceMeasures.AucROC;
import de.jstacs.classifiers.performanceMeasures.NumericalPerformanceMeasureParameterSet;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.alphabets.ContinuousAlphabet;
import de.jstacs.data.sequences.annotation.SequenceAnnotationParser;
import de.jstacs.io.FileManager;
import de.jstacs.io.SparseStringExtractor;
import de.jstacs.sequenceScores.statisticalModels.differentiable.continuous.GaussianNetwork;
import de.jstacs.utils.Pair;
import de.jstacs.utils.ToolBox;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.util.zip.GZIPInputStream;
import projects.dimont.Interpolation;

/* loaded from: input_file:projects/plantdream/TrainUnsupervisedAlternateShift.class */
public class TrainUnsupervisedAlternateShift {

    /* loaded from: input_file:projects/plantdream/TrainUnsupervisedAlternateShift$Init.class */
    public enum Init {
        MOTIF,
        DNASE,
        BOTH;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static Init[] valuesCustom() {
            Init[] valuesCustom = values();
            int length = valuesCustom.length;
            Init[] initArr = new Init[length];
            System.arraycopy(valuesCustom, 0, initArr, 0, length);
            return initArr;
        }
    }

    public static void main(String[] strArr) throws Exception {
        DataSet dataSet;
        int parseInt = Integer.parseInt(strArr[4]);
        double parseDouble = Double.parseDouble(strArr[5]);
        int parseInt2 = Integer.parseInt(strArr[6]);
        double parseDouble2 = Double.parseDouble(strArr[7]);
        boolean parseBoolean = Boolean.parseBoolean(strArr[8]);
        Init valueOf = Init.valueOf(strArr[9]);
        AlphabetContainer alphabetContainer = new AlphabetContainer(new ContinuousAlphabet());
        DataSet union = DataSet.union(new DataSet(alphabetContainer, new SparseStringExtractor((Reader) new InputStreamReader(new GZIPInputStream(new FileInputStream(strArr[0]))), '>', "", (SequenceAnnotationParser) null), "\t"), new DataSet(alphabetContainer, new SparseStringExtractor((Reader) new InputStreamReader(new GZIPInputStream(new FileInputStream(strArr[1]))), '>', "", (SequenceAnnotationParser) null), "\t"));
        DataSet compositeDataSet = union.getCompositeDataSet(new int[]{9, 20, 31, 42, 53}, new int[]{2, 2, 2, 2, 2});
        DataSet compositeDataSet2 = union.getCompositeDataSet(new int[]{0, 11, 22, 33, 44}, new int[]{9, 9, 9, 9, 9});
        Pair<double[], double[][]> weights = getWeights(compositeDataSet, compositeDataSet2, ((compositeDataSet.getElementLength() / 5) * (5 - 1)) / 2, ((compositeDataSet2.getElementLength() / 5) * (5 - 1)) / 2, valueOf, parseDouble, parseBoolean);
        double[] firstElement = weights.getFirstElement();
        double[][] secondElement = weights.getSecondElement();
        DataSet dataSet2 = compositeDataSet;
        DataSet dataSet3 = compositeDataSet2;
        NumericalPerformanceMeasureParameterSet numericalPerformanceMeasureParameterSet = new NumericalPerformanceMeasureParameterSet(new AucROC(), new AucPR());
        for (int i = 0; i < parseInt2; i++) {
            int elementLength = dataSet3.getElementLength() / 5;
            ShiftedMixtureDiffSM shiftedMixtureDiffSM = new ShiftedMixtureDiffSM(dataSet3.getElementLength(), new GaussianNetwork(new int[elementLength * 3][0]), elementLength);
            GenDisMixClassifier genDisMixClassifier = new GenDisMixClassifier(new GenDisMixClassifierParameterSet(alphabetContainer, dataSet3.getElementLength(), (byte) 20, 1.0E-6d, 1.0E-6d, 1.0E-4d, false, OptimizableFunction.KindOfParameter.ZEROS, true, parseInt), (LogPrior) null, LearningPrinciple.MCL, shiftedMixtureDiffSM, shiftedMixtureDiffSM);
            genDisMixClassifier.train(new DataSet[]{dataSet3, dataSet3}, secondElement);
            System.out.println(genDisMixClassifier.evaluate(numericalPerformanceMeasureParameterSet, true, new DataSet[]{dataSet3, dataSet3}, secondElement));
            firstElement = updateVals(firstElement, genDisMixClassifier, dataSet3, parseDouble2, parseBoolean);
            secondElement = getWeights(firstElement, parseDouble);
            if (dataSet2 == compositeDataSet) {
                dataSet2 = compositeDataSet2;
                dataSet = compositeDataSet;
            } else {
                dataSet2 = compositeDataSet;
                dataSet = compositeDataSet2;
            }
            dataSet3 = dataSet;
        }
        double[][] weights2 = getWeights(firstElement, parseDouble);
        GaussianNetwork gaussianNetwork = new GaussianNetwork(new int[union.getElementLength()][0]);
        GenDisMixClassifier genDisMixClassifier2 = new GenDisMixClassifier(new GenDisMixClassifierParameterSet(alphabetContainer, union.getElementLength(), (byte) 20, 1.0E-6d, 1.0E-6d, 1.0E-4d, false, OptimizableFunction.KindOfParameter.ZEROS, true, parseInt), (LogPrior) null, LearningPrinciple.MCL, gaussianNetwork, gaussianNetwork);
        genDisMixClassifier2.train(new DataSet[]{union, union}, weights2);
        FileManager.writeFile(String.valueOf(strArr[0]) + "_UAW.xml", genDisMixClassifier2.toXML());
        System.out.println(genDisMixClassifier2.evaluate(numericalPerformanceMeasureParameterSet, true, new DataSet[]{union, union}, weights2));
        System.out.println(genDisMixClassifier2.evaluate(numericalPerformanceMeasureParameterSet, true, new DataSet(alphabetContainer, new SparseStringExtractor((Reader) new InputStreamReader(new GZIPInputStream(new FileInputStream(strArr[2]))), '>', "", (SequenceAnnotationParser) null), "\t"), new DataSet(alphabetContainer, new SparseStringExtractor((Reader) new InputStreamReader(new GZIPInputStream(new FileInputStream(strArr[3]))), '>', "", (SequenceAnnotationParser) null), "\t")));
    }

    private static double[] updateVals(double[] dArr, GenDisMixClassifier genDisMixClassifier, DataSet dataSet, double d, boolean z) throws Exception {
        double[] zscore = ToolBox.zscore(dArr);
        double[] zscore2 = ToolBox.zscore(genDisMixClassifier.getScores(dataSet));
        if (z) {
            double min = ToolBox.min(zscore);
            for (int i = 0; i < zscore.length; i++) {
                int i2 = i;
                zscore[i2] = zscore[i2] - min;
            }
            double min2 = ToolBox.min(zscore2);
            for (int i3 = 0; i3 < zscore2.length; i3++) {
                int i4 = i3;
                zscore2[i4] = zscore2[i4] - min2;
            }
            for (int i5 = 0; i5 < zscore2.length; i5++) {
                zscore2[i5] = (zscore2[i5] + d) * zscore[i5];
            }
        } else {
            for (int i6 = 0; i6 < zscore2.length; i6++) {
                zscore2[i6] = (zscore2[i6] / d) + zscore[i6];
            }
        }
        return zscore2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private static Pair<double[], double[][]> getWeights(DataSet dataSet, DataSet dataSet2, int i, int i2, Init init, double d, boolean z) throws Exception {
        double[] dArr = new double[dataSet.getNumberOfElements()];
        double[] dArr2 = new double[dataSet.getNumberOfElements()];
        double[] dArr3 = null;
        if (init == Init.MOTIF || init == Init.BOTH) {
            for (int i3 = 0; i3 < dataSet.getNumberOfElements(); i3++) {
                dArr[i3] = dataSet.getElementAt(i3).continuousVal(i);
            }
            dArr3 = dArr;
        }
        if (init == Init.DNASE || init == Init.BOTH) {
            for (int i4 = 0; i4 < dataSet2.getNumberOfElements(); i4++) {
                dArr2[i4] = dataSet2.getElementAt(i4).continuousVal(i2);
            }
            dArr3 = dArr2;
        }
        if (init == Init.BOTH) {
            ToolBox.zscore(dArr);
            double min = ToolBox.min(dArr);
            for (int i5 = 0; i5 < dArr.length; i5++) {
                int i6 = i5;
                dArr[i6] = dArr[i6] - min;
            }
            ToolBox.zscore(dArr2);
            double min2 = ToolBox.min(dArr2);
            for (int i7 = 0; i7 < dArr2.length; i7++) {
                int i8 = i7;
                dArr2[i8] = dArr2[i8] - min2;
            }
            for (int i9 = 0; i9 < dArr.length; i9++) {
                if (z) {
                    int i10 = i9;
                    dArr[i10] = dArr[i10] * dArr2[i9];
                } else {
                    int i11 = i9;
                    dArr[i11] = dArr[i11] + dArr2[i9];
                }
            }
            dArr3 = dArr;
        }
        double[] weight = Interpolation.getWeight(dataSet, dArr3, d, Interpolation.RANK_LOG);
        return new Pair<>(dArr3, new double[]{weight, Interpolation.getBgWeight(weight)});
    }

    /* JADX WARN: Type inference failed for: r0v3, types: [double[], double[][]] */
    private static double[][] getWeights(double[] dArr, double d) throws Exception {
        double[] weight = Interpolation.getWeight(null, dArr, d, Interpolation.RANK_LOG);
        return new double[]{weight, Interpolation.getBgWeight(weight)};
    }
}
