package projects.tals.linear;

import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.EmptyDataSetException;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.ReferenceSequenceAnnotation;
import de.jstacs.data.sequences.annotation.SequenceAnnotation;
import de.jstacs.results.Result;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.utils.ComparableElement;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.ToolBox;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import org.apache.batik.util.XMLConstants;

/* loaded from: input_file:projects/tals/linear/MSDFunction.class */
public class MSDFunction extends AbstractMultiThreadedOptimizableFunction {
    private DifferentiableSequenceScore[] scores;
    private double[][][] yi;
    private double[][] grads;
    private double[] vals;
    private IntList[] indices;
    private DoubleList[] partDers;
    private double con;
    private boolean laplace;
    private int penaltyOff;
    private String sortTag;

    public MSDFunction(double d, boolean z, DifferentiableSequenceScore differentiableSequenceScore, int i, DataSet[] dataSetArr, double[][] dArr, int i2, String str) throws IllegalArgumentException {
        super(i, dataSetArr, dArr, false, false);
        this.scores = new DifferentiableSequenceScore[i];
        this.scores[0] = differentiableSequenceScore;
        precomputeIndexes(str);
        this.con = d;
        this.laplace = z;
        this.penaltyOff = i2;
        this.sortTag = str;
    }

    @Override // de.jstacs.algorithms.optimization.Function
    public int getDimensionOfScope() {
        return this.scores[0].getNumberOfParameters();
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction
    protected void evaluateGradientOfFunction(int i, int i2, int i3, int i4, int i5) {
        if (i3 != 0 || i5 != this.data[i4].getNumberOfElements()) {
            throw new RuntimeException();
        }
        Arrays.fill(this.grads[i], 0.0d);
        for (int i6 = i2; i6 <= i4; i6++) {
            int numberOfElements = this.data[i6].getNumberOfElements();
            for (int i7 = 0; i7 < numberOfElements; i7++) {
                this.indices[i].clear();
                this.partDers[i].clear();
                double logScoreAndPartialDerivation = 2.0d * (this.scores[i].getLogScoreAndPartialDerivation(this.data[i6].getElementAt(i7), this.indices[i], this.partDers[i]) - this.yi[i][i6 - i2][i7]) * this.weights[i6][i7];
                for (int i8 = 0; i8 < this.indices[i].length(); i8++) {
                    double[] dArr = this.grads[i];
                    int i9 = this.indices[i].get(i8);
                    dArr[i9] = dArr[i9] + (logScoreAndPartialDerivation * this.partDers[i].get(i8));
                }
            }
        }
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction
    protected double[] joinGradients() throws EvaluationException {
        for (int i = 1; i < this.grads.length; i++) {
            for (int i2 = 0; i2 < this.grads[0].length; i2++) {
                double[] dArr = this.grads[0];
                int i3 = i2;
                dArr[i3] = dArr[i3] + this.grads[i][i2];
            }
        }
        for (int i4 = 0; i4 < this.params.length; i4++) {
            if (i4 >= this.penaltyOff) {
                if (this.laplace) {
                    double[] dArr2 = this.grads[0];
                    int i5 = i4;
                    dArr2[i5] = dArr2[i5] + ((this.params[i4] < 0.0d ? -1 : 1) * this.con);
                } else {
                    double[] dArr3 = this.grads[0];
                    int i6 = i4;
                    dArr3[i6] = dArr3[i6] + (2.0d * this.params[i4] * this.con);
                }
            }
        }
        for (int i7 = 0; i7 < this.params.length; i7++) {
            double[] dArr4 = this.grads[0];
            int i8 = i7;
            dArr4[i8] = dArr4[i8] / this.sum[this.cl];
        }
        return (double[]) this.grads[0].clone();
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction
    protected void evaluateFunction(int i, int i2, int i3, int i4, int i5) throws EvaluationException {
        if (i3 != 0 || i5 != this.data[i4].getNumberOfElements()) {
            throw new RuntimeException();
        }
        double d = 0.0d;
        for (int i6 = i2; i6 <= i4; i6++) {
            int numberOfElements = this.data[i6].getNumberOfElements();
            for (int i7 = 0; i7 < numberOfElements; i7++) {
                double logScoreFor = this.scores[i].getLogScoreFor(this.data[i6].getElementAt(i7)) - this.yi[i][i6 - i2][i7];
                d += logScoreFor * logScoreFor * this.weights[i6][i7];
            }
        }
        this.vals[i] = d;
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction
    protected double joinFunction() throws EvaluationException, DimensionException {
        double sum = ToolBox.sum(this.vals);
        for (int i = 0; i < this.params.length; i++) {
            if (i >= this.penaltyOff) {
                sum = this.laplace ? sum + (Math.abs(this.params[i]) * this.con) : sum + (this.params[i] * this.params[i] * this.con);
            }
        }
        return sum / this.sum[this.cl];
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction
    protected void setThreadIndependentParameters() throws DimensionException {
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction, de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractOptimizableFunction, de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction
    public void setDataAndWeights(DataSet[] dataSetArr, double[][] dArr) throws IllegalArgumentException {
        super.setDataAndWeights(dataSetArr, dArr);
        precomputeIndexes(this.sortTag);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v19, types: [double[][], double[][][]] */
    public void precomputeIndexes(String str) {
        if (this.worker != null) {
            this.vals = new double[this.worker.length];
            this.indices = new IntList[this.worker.length];
            this.partDers = new DoubleList[this.worker.length];
            this.grads = new double[this.worker.length][getDimensionOfScope()];
            this.yi = new double[this.worker.length];
            for (int i = 0; i < this.worker.length; i++) {
                this.indices[i] = new IntList();
                this.partDers[i] = new DoubleList();
                int[] indices = this.worker[i].getIndices();
                int i2 = indices[0];
                int i3 = indices[1];
                int i4 = indices[2];
                int i5 = indices[3];
                if (i3 != 0 || i5 != this.data[i4].getNumberOfElements()) {
                    throw new RuntimeException();
                }
                this.yi[i] = new double[(i4 - i2) + 1];
                for (int i6 = i2; i6 <= i4; i6++) {
                    this.yi[i][i6 - i2] = new double[this.data[i6].getNumberOfElements()];
                    for (int i7 = 0; i7 < this.data[i6].getNumberOfElements(); i7++) {
                        this.yi[i][i6 - i2][i7] = Double.parseDouble(this.data[i6].getElementAt(i7).getSequenceAnnotationByType(str, 0).getIdentifier());
                    }
                }
            }
        }
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction
    protected void setParams(int i) throws DimensionException {
        this.scores[i].setParameters(this.params, 0);
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractOptimizableFunction
    public void getParameters(OptimizableFunction.KindOfParameter kindOfParameter, double[] dArr) throws Exception {
        double[] currentParameterValues = this.scores[0].getCurrentParameterValues();
        System.arraycopy(currentParameterValues, 0, dArr, 0, currentParameterValues.length);
    }

    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction
    public void reset() throws Exception {
        for (int i = 1; i < this.scores.length; i++) {
            this.scores[i] = this.scores[0].mo0clone();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction
    public void prepareThreads() {
        double[] dArr = new double[this.data.length];
        for (int i = 0; i < this.data.length; i++) {
            dArr[i] = this.data[i].getNumberOfElements();
            dArr[i] = dArr[i] * dArr[i] * Math.sqrt(this.data[i].getAverageElementLength());
        }
        double sum = ToolBox.sum(dArr);
        double length = sum / this.worker.length;
        int i2 = 0;
        for (int i3 = 0; i3 < this.worker.length; i3++) {
            int i4 = i2;
            double d = dArr[i4];
            while (i4 < dArr.length - 1 && d + dArr[i4 + 1] <= length) {
                d += dArr[i4 + 1];
                i4++;
            }
            sum -= d;
            if (i3 == this.worker.length - 1) {
                i4 = this.data.length - 1;
            }
            if (this.worker[i3] == null) {
                this.worker[i3] = new AbstractMultiThreadedOptimizableFunction.Worker(i3, i2, 0, i4, this.data[i4].getNumberOfElements());
                this.worker[i3].start();
            } else {
                if (!this.worker[i3].isWaiting()) {
                    stopThreads();
                    throw new RuntimeException();
                }
                this.worker[i3].setIndices(i2, 0, i4, this.data[i4].getNumberOfElements());
            }
            i2 = i4 + 1;
            length = sum / ((this.worker.length - i3) - 1);
        }
    }

    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r1v8, types: [double[], double[][]] */
    private static double[] getMAD(double[] dArr, double[] dArr2, boolean z) {
        double[] dArr3 = (double[]) dArr.clone();
        double[] dArr4 = (double[]) dArr2.clone();
        ToolBox.sortAlongWith(dArr3, new double[]{dArr4});
        if (z) {
            System.out.println(String.valueOf(Arrays.toString(dArr3)) + " " + Arrays.toString(dArr4));
        }
        double sum = ToolBox.sum(dArr4);
        double d = 0.0d;
        double d2 = sum;
        double d3 = d2 * 0.5d;
        int i = 0;
        while (true) {
            if (i >= dArr3.length) {
                break;
            }
            d2 -= dArr4[i];
            if (z) {
                System.out.println(String.valueOf(d2) + XMLConstants.XML_OPEN_TAG_START + d3 + "?");
            }
            if (d2 * 1.000001d < d3) {
                double d4 = dArr4[i];
                double d5 = dArr4[i - 1];
                d = ((dArr3[i] * d4) + (dArr3[i - 1] * d5)) / (d4 + d5);
                break;
            }
            i++;
        }
        if (Double.isNaN(d) || Double.isInfinite(d)) {
            d = 0.0d;
        }
        double[] dArr5 = new double[dArr3.length];
        for (int i2 = 0; i2 < dArr3.length; i2++) {
            dArr5[i2] = Math.abs(d - dArr3[i2]);
        }
        ToolBox.sortAlongWith(dArr5, new double[]{dArr4});
        double d6 = 1.0d;
        double d7 = sum;
        int i3 = 0;
        while (true) {
            if (i3 >= dArr5.length) {
                break;
            }
            d7 -= dArr4[i3];
            if (d7 * 1.000001d < d3) {
                double d8 = dArr4[i3];
                double d9 = dArr4[i3 - 1];
                d6 = ((dArr5[i3] * d8) + (dArr5[i3 - 1] * d9)) / (d8 + d9);
                break;
            }
            i3++;
        }
        if (d6 <= 1.0E-6d || Double.isNaN(d6) || Double.isInfinite(d6)) {
            d6 = 1.0d;
        }
        return new double[]{d, d6};
    }

    /* JADX WARN: Type inference failed for: r11v11, types: [de.jstacs.results.Result[], de.jstacs.results.Result[][]] */
    /* JADX WARN: Type inference failed for: r11v2, types: [de.jstacs.results.Result[], de.jstacs.results.Result[][]] */
    /* JADX WARN: Type inference failed for: r11v5, types: [de.jstacs.results.Result[], de.jstacs.results.Result[][]] */
    /* JADX WARN: Type inference failed for: r11v8, types: [de.jstacs.results.Result[], de.jstacs.results.Result[][]] */
    public static DataSet[] splitByTagAndSort(int i, DataSet dataSet, String str, String str2, String str3, boolean z, boolean z2) throws EmptyDataSetException, WrongAlphabetException {
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < dataSet.getNumberOfElements(); i2++) {
            Sequence elementAt = dataSet.getElementAt(i2);
            String identifier = elementAt.getSequenceAnnotationByType(str, 0) != null ? elementAt.getSequenceAnnotationByType(str, 0).getIdentifier() : "null";
            if (!hashMap.containsKey(identifier)) {
                hashMap.put(identifier, new LinkedList());
            }
            ((LinkedList) hashMap.get(identifier)).add(elementAt);
        }
        DataSet[] dataSetArr = new DataSet[hashMap.keySet().size()];
        int i3 = 0;
        Iterator it = hashMap.keySet().iterator();
        while (it.hasNext()) {
            Sequence[] sequenceArr = (Sequence[]) ((LinkedList) hashMap.get((String) it.next())).toArray(new Sequence[0]);
            ComparableElement[] comparableElementArr = new ComparableElement[sequenceArr.length];
            double d = 0.0d;
            double d2 = 0.0d;
            double d3 = 0.0d;
            double[] dArr = new double[sequenceArr.length];
            double[] dArr2 = new double[sequenceArr.length];
            for (int i4 = 0; i4 < sequenceArr.length; i4++) {
                double parseDouble = Double.parseDouble(sequenceArr[i4].getSequenceAnnotationByType(str2, 0).getIdentifier());
                double parseDouble2 = Double.parseDouble(sequenceArr[i4].getSequenceAnnotationByType(str3, 0).getIdentifier());
                d += parseDouble * parseDouble2;
                d2 += parseDouble * parseDouble * parseDouble2;
                d3 += parseDouble2;
                comparableElementArr[i4] = new ComparableElement(sequenceArr[i4], Double.valueOf(-parseDouble));
                dArr[i4] = parseDouble2;
                dArr2[i4] = parseDouble;
            }
            double d4 = d / d3;
            double sqrt = Math.sqrt((d2 / d3) - (d4 * d4));
            if (d3 == 0.0d) {
                d4 = 0.0d;
                sqrt = 1.0d;
            }
            if (sqrt <= 0.0d) {
                System.err.println(String.valueOf(((Sequence) comparableElementArr[0].getElement()).getSequenceAnnotationByType(str, 0).getIdentifier()) + " " + d4 + " " + sqrt);
                sqrt = 1.0d;
            }
            Arrays.sort(comparableElementArr);
            double d5 = -((Double) comparableElementArr[comparableElementArr.length - 1].getWeight()).doubleValue();
            double d6 = -((Double) comparableElementArr[0].getWeight()).doubleValue();
            for (int i5 = 0; i5 < comparableElementArr.length; i5++) {
                sequenceArr[i5] = (Sequence) comparableElementArr[i5].getElement();
                SequenceAnnotation sequenceAnnotationByType = sequenceArr[i5].getSequenceAnnotationByType("mask", 0);
                int length = sequenceAnnotationByType == null ? sequenceArr[i5].getLength() : sequenceAnnotationByType.getIdentifier().length() - sequenceAnnotationByType.getIdentifier().replaceAll("X", "").length();
                double parseDouble3 = Double.parseDouble(sequenceArr[i5].getSequenceAnnotationByType(str2, 0).getIdentifier());
                if (z2) {
                    parseDouble3 = (parseDouble3 - d4) / sqrt;
                }
                SequenceAnnotation sequenceAnnotationByType2 = sequenceArr[i5].getSequenceAnnotationByType("mms", 0);
                if (sequenceAnnotationByType != null) {
                    sequenceArr[i5] = sequenceArr[i5].annotate(false, new SequenceAnnotation("intgroup", new StringBuilder(String.valueOf(i3)).toString(), (Result[][]) new Result[0]), (ReferenceSequenceAnnotation) sequenceArr[i5].getSequenceAnnotationByType(ReferenceSequenceAnnotation.TYPE, 0), sequenceAnnotationByType, sequenceArr[i5].getSequenceAnnotationByType(str3, 0), new SequenceAnnotation(str2, new StringBuilder(String.valueOf(parseDouble3)).toString(), (Result[][]) new Result[0]));
                } else {
                    sequenceArr[i5] = sequenceArr[i5].annotate(false, new SequenceAnnotation("intgroup", new StringBuilder(String.valueOf(i3)).toString(), (Result[][]) new Result[0]), (ReferenceSequenceAnnotation) sequenceArr[i5].getSequenceAnnotationByType(ReferenceSequenceAnnotation.TYPE, 0), sequenceArr[i5].getSequenceAnnotationByType(str3, 0), new SequenceAnnotation(str2, new StringBuilder(String.valueOf(parseDouble3)).toString(), (Result[][]) new Result[0]));
                }
                if (sequenceAnnotationByType2 != null) {
                    sequenceArr[i5] = sequenceArr[i5].annotate(true, sequenceAnnotationByType2);
                }
            }
            if (z) {
                ArrayList arrayList = new ArrayList();
                for (int i6 = 0; i6 < sequenceArr.length; i6++) {
                    SequenceAnnotation sequenceAnnotationByType3 = sequenceArr[i6].getSequenceAnnotationByType("mask", 0);
                    if (sequenceAnnotationByType3 == null || sequenceAnnotationByType3.getIdentifier().indexOf("X") < 0) {
                        Sequence referenceSequence = ((ReferenceSequenceAnnotation) sequenceArr[i6].getSequenceAnnotationByType(ReferenceSequenceAnnotation.TYPE, 0)).getReferenceSequence();
                        AlphabetContainer alphabetContainer = referenceSequence.getAlphabetContainer();
                        int i7 = 0;
                        int i8 = 0;
                        while (true) {
                            if (i8 >= referenceSequence.getLength()) {
                                break;
                            }
                            if (referenceSequence.discreteVal(i8) != alphabetContainer.getCode(i8, "HD")) {
                                if (referenceSequence.discreteVal(i8) != alphabetContainer.getCode(i8, "NI")) {
                                    if (referenceSequence.discreteVal(i8) != alphabetContainer.getCode(i8, "NG")) {
                                        if (referenceSequence.discreteVal(i8) != alphabetContainer.getCode(i8, "NN")) {
                                            i7 = 0;
                                            break;
                                        }
                                        if (sequenceArr[i6].discreteVal(i8 + 1) != 0 && sequenceArr[i6].discreteVal(i8 + 1) != 2) {
                                            i7++;
                                        }
                                    } else if (sequenceArr[i6].discreteVal(i8 + 1) != 3) {
                                        i7++;
                                    }
                                } else if (sequenceArr[i6].discreteVal(i8 + 1) != 0) {
                                    i7++;
                                }
                            } else if (sequenceArr[i6].discreteVal(i8 + 1) != 1) {
                                i7++;
                            }
                            i8++;
                        }
                        if (i7 <= 3) {
                            arrayList.add(sequenceArr[i6]);
                        }
                    } else {
                        arrayList.add(sequenceArr[i6]);
                    }
                }
                sequenceArr = (Sequence[]) arrayList.toArray(new Sequence[0]);
            }
            dataSetArr[i3] = new DataSet("", sequenceArr);
            i3++;
        }
        if (i > 1) {
            double[] dArr3 = new double[dataSetArr.length];
            for (int i9 = 0; i9 < dataSetArr.length; i9++) {
                dArr3[i9] = dataSetArr[i9].getNumberOfElements();
                dArr3[i9] = dArr3[i9] * dArr3[i9] * Math.sqrt(dataSetArr[i9].getAverageElementLength());
            }
            int[] order = ToolBox.order(dArr3, true);
            IntList[] intListArr = new IntList[i];
            for (int i10 = 0; i10 < intListArr.length; i10++) {
                intListArr[i10] = new IntList();
            }
            double[] dArr4 = new double[i];
            for (int i11 = 0; i11 < order.length; i11++) {
                double d7 = dArr3[order[i11]];
                int minIndex = ToolBox.getMinIndex(dArr4);
                intListArr[minIndex].add(order[i11]);
                dArr4[minIndex] = dArr4[minIndex] + d7;
            }
            DataSet[] dataSetArr2 = new DataSet[dataSetArr.length];
            int i12 = 0;
            for (int i13 = 0; i13 < intListArr.length; i13++) {
                int i14 = 0;
                while (i14 < intListArr[i13].length()) {
                    dataSetArr2[i12] = dataSetArr[intListArr[i13].get(i14)];
                    i14++;
                    i12++;
                }
            }
            dataSetArr = dataSetArr2;
        }
        return dataSetArr;
    }
}
