/*
 * Decompiled with CFR 0.152.
 */
package projects.tals.linear;

import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.classifiers.differentiableSequenceScoreBased.AbstractMultiThreadedOptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.EmptyDataSetException;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.ReferenceSequenceAnnotation;
import de.jstacs.data.sequences.annotation.SequenceAnnotation;
import de.jstacs.results.Result;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.utils.ComparableElement;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.ToolBox;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedList;

public class MSDFunction
extends AbstractMultiThreadedOptimizableFunction {
    private DifferentiableSequenceScore[] scores;
    private double[][][] yi;
    private double[][] grads;
    private double[] vals;
    private IntList[] indices;
    private DoubleList[] partDers;
    private double con;
    private boolean laplace;
    private int penaltyOff;
    private String sortTag;

    public MSDFunction(double con, boolean laplace, DifferentiableSequenceScore score, int threads, DataSet[] data, double[][] weights, int penaltyOff, String sortTag) throws IllegalArgumentException {
        super(threads, data, weights, false, false);
        this.scores = new DifferentiableSequenceScore[threads];
        this.scores[0] = score;
        this.precomputeIndexes(sortTag);
        this.con = con;
        this.laplace = laplace;
        this.penaltyOff = penaltyOff;
        this.sortTag = sortTag;
    }

    @Override
    public int getDimensionOfScope() {
        return this.scores[0].getNumberOfParameters();
    }

    @Override
    protected void evaluateGradientOfFunction(int index, int startClass, int startSeq, int endClass, int endSeq) {
        if (startSeq != 0 || endSeq != this.data[endClass].getNumberOfElements()) {
            throw new RuntimeException();
        }
        Arrays.fill(this.grads[index], 0.0);
        int cl = startClass;
        while (cl <= endClass) {
            int start = 0;
            int end = this.data[cl].getNumberOfElements();
            int i = start;
            while (i < end) {
                this.indices[index].clear();
                this.partDers[index].clear();
                double s = this.scores[index].getLogScoreAndPartialDerivation(this.data[cl].getElementAt(i), this.indices[index], this.partDers[index]);
                double v = 2.0 * (s - this.yi[index][cl - startClass][i]) * this.weights[cl][i];
                int j = 0;
                while (j < this.indices[index].length()) {
                    double[] dArray = this.grads[index];
                    int n = this.indices[index].get(j);
                    dArray[n] = dArray[n] + v * this.partDers[index].get(j);
                    ++j;
                }
                ++i;
            }
            ++cl;
        }
    }

    @Override
    protected double[] joinGradients() throws EvaluationException {
        int i = 1;
        while (i < this.grads.length) {
            int j = 0;
            while (j < this.grads[0].length) {
                double[] dArray = this.grads[0];
                int n = j;
                dArray[n] = dArray[n] + this.grads[i][j];
                ++j;
            }
            ++i;
        }
        i = 0;
        while (i < this.params.length) {
            if (i >= this.penaltyOff) {
                if (this.laplace) {
                    double[] dArray = this.grads[0];
                    int n = i;
                    dArray[n] = dArray[n] + (double)(this.params[i] < 0.0 ? -1 : 1) * this.con;
                } else {
                    double[] dArray = this.grads[0];
                    int n = i;
                    dArray[n] = dArray[n] + 2.0 * this.params[i] * this.con;
                }
            }
            ++i;
        }
        i = 0;
        while (i < this.params.length) {
            double[] dArray = this.grads[0];
            int n = i++;
            dArray[n] = dArray[n] / this.sum[this.cl];
        }
        return (double[])this.grads[0].clone();
    }

    @Override
    protected void evaluateFunction(int index, int startClass, int startSeq, int endClass, int endSeq) throws EvaluationException {
        if (startSeq != 0 || endSeq != this.data[endClass].getNumberOfElements()) {
            throw new RuntimeException();
        }
        double val = 0.0;
        int cl = startClass;
        while (cl <= endClass) {
            int start = 0;
            int end = this.data[cl].getNumberOfElements();
            int i = start;
            while (i < end) {
                double s = this.scores[index].getLogScoreFor(this.data[cl].getElementAt(i));
                double v = s - this.yi[index][cl - startClass][i];
                val += v * v * this.weights[cl][i];
                ++i;
            }
            ++cl;
        }
        this.vals[index] = val;
    }

    @Override
    protected double joinFunction() throws EvaluationException, DimensionException {
        double val = ToolBox.sum(this.vals);
        int i = 0;
        while (i < this.params.length) {
            if (i >= this.penaltyOff) {
                val = this.laplace ? (val += Math.abs(this.params[i]) * this.con) : (val += this.params[i] * this.params[i] * this.con);
            }
            ++i;
        }
        return val /= this.sum[this.cl];
    }

    @Override
    protected void setThreadIndependentParameters() throws DimensionException {
    }

    @Override
    public void setDataAndWeights(DataSet[] data, double[][] weights) throws IllegalArgumentException {
        super.setDataAndWeights(data, weights);
        this.precomputeIndexes(this.sortTag);
    }

    public void precomputeIndexes(String sortTag) {
        if (this.worker != null) {
            this.vals = new double[this.worker.length];
            this.indices = new IntList[this.worker.length];
            this.partDers = new DoubleList[this.worker.length];
            this.grads = new double[this.worker.length][this.getDimensionOfScope()];
            this.yi = new double[this.worker.length][][];
            int i = 0;
            while (i < this.worker.length) {
                this.indices[i] = new IntList();
                this.partDers[i] = new DoubleList();
                int[] temp = this.worker[i].getIndices();
                int startClass = temp[0];
                int startSeq = temp[1];
                int endClass = temp[2];
                int endSeq = temp[3];
                if (startSeq != 0 || endSeq != this.data[endClass].getNumberOfElements()) {
                    throw new RuntimeException();
                }
                this.yi[i] = new double[endClass - startClass + 1][];
                int j = startClass;
                while (j <= endClass) {
                    this.yi[i][j - startClass] = new double[this.data[j].getNumberOfElements()];
                    int k = 0;
                    while (k < this.data[j].getNumberOfElements()) {
                        double y;
                        Sequence seq = this.data[j].getElementAt(k);
                        this.yi[i][j - startClass][k] = y = Double.parseDouble(seq.getSequenceAnnotationByType(sortTag, 0).getIdentifier());
                        ++k;
                    }
                    ++j;
                }
                ++i;
            }
        }
    }

    @Override
    protected void setParams(int index) throws DimensionException {
        this.scores[index].setParameters(this.params, 0);
    }

    @Override
    public void getParameters(OptimizableFunction.KindOfParameter kind, double[] erg) throws Exception {
        double[] temp = this.scores[0].getCurrentParameterValues();
        System.arraycopy(temp, 0, erg, 0, temp.length);
    }

    @Override
    public void reset() throws Exception {
        int i = 1;
        while (i < this.scores.length) {
            this.scores[i] = this.scores[0].clone();
            ++i;
        }
    }

    /*
     * Enabled aggressive block sorting
     */
    @Override
    protected void prepareThreads() {
        double[] sizes = new double[this.data.length];
        int i = 0;
        while (i < this.data.length) {
            sizes[i] = this.data[i].getNumberOfElements();
            sizes[i] = sizes[i] * sizes[i] * Math.sqrt(this.data[i].getAverageElementLength());
            ++i;
        }
        double sum = ToolBox.sum(sizes);
        double part = sum / (double)this.worker.length;
        int startClass = 0;
        int i2 = 0;
        while (i2 < this.worker.length) {
            double curr = 0.0;
            int endClass = startClass;
            curr = sizes[endClass];
            while (endClass < sizes.length - 1 && curr + sizes[endClass + 1] <= part) {
                curr += sizes[endClass + 1];
                ++endClass;
            }
            sum -= curr;
            if (i2 == this.worker.length - 1) {
                endClass = this.data.length - 1;
            }
            if (this.worker[i2] != null) {
                if (!this.worker[i2].isWaiting()) {
                    this.stopThreads();
                    throw new RuntimeException();
                }
                this.worker[i2].setIndices(startClass, 0, endClass, this.data[endClass].getNumberOfElements());
            } else {
                this.worker[i2] = new AbstractMultiThreadedOptimizableFunction.Worker(i2, startClass, 0, endClass, this.data[endClass].getNumberOfElements());
                this.worker[i2].start();
            }
            startClass = endClass + 1;
            part = sum / (double)(this.worker.length - i2 - 1);
            ++i2;
        }
    }

    private static double[] getMAD(double[] ws, double[] gw, boolean print) {
        ws = (double[])ws.clone();
        gw = (double[])gw.clone();
        ToolBox.sortAlongWith(ws, new double[][]{gw});
        if (print) {
            System.out.println(String.valueOf(Arrays.toString(ws)) + " " + Arrays.toString(gw));
        }
        double fullSum = ToolBox.sum(gw);
        double median = 0.0;
        double sum = fullSum;
        double percentile = sum * 0.5;
        int j = 0;
        while (j < ws.length) {
            sum -= gw[j];
            if (print) {
                System.out.println(String.valueOf(sum) + "<" + percentile + "?");
            }
            if (sum * 1.000001 < percentile) {
                double w1 = gw[j];
                double w2 = gw[j - 1];
                median = (ws[j] * w1 + ws[j - 1] * w2) / (w1 + w2);
                break;
            }
            ++j;
        }
        if (Double.isNaN(median) || Double.isInfinite(median)) {
            median = 0.0;
        }
        double[] mads = new double[ws.length];
        int j2 = 0;
        while (j2 < ws.length) {
            mads[j2] = Math.abs(median - ws[j2]);
            ++j2;
        }
        ToolBox.sortAlongWith(mads, new double[][]{gw});
        double mad = 1.0;
        sum = fullSum;
        int j3 = 0;
        while (j3 < mads.length) {
            if ((sum -= gw[j3]) * 1.000001 < percentile) {
                double w1 = gw[j3];
                double w2 = gw[j3 - 1];
                mad = (mads[j3] * w1 + mads[j3 - 1] * w2) / (w1 + w2);
                break;
            }
            ++j3;
        }
        if (mad <= 1.0E-6 || Double.isNaN(mad) || Double.isInfinite(mad)) {
            mad = 1.0;
        }
        return new double[]{median, mad};
    }

    public static DataSet[] splitByTagAndSort(int numThreads, DataSet data, String splitTag, String sortTag, String globalWeightTag, boolean filter, boolean normalize) throws EmptyDataSetException, WrongAlphabetException {
        HashMap sets = new HashMap();
        int i = 0;
        while (i < data.getNumberOfElements()) {
            Sequence seq = data.getElementAt(i);
            SequenceAnnotation sa = seq.getSequenceAnnotationByType(splitTag, 0);
            String key = "null";
            if (sa != null) {
                key = seq.getSequenceAnnotationByType(splitTag, 0).getIdentifier();
            }
            if (!sets.containsKey(key)) {
                sets.put(key, new LinkedList());
            }
            ((LinkedList)sets.get(key)).add(seq);
            ++i;
        }
        DataSet[] ds = new DataSet[sets.keySet().size()];
        int i2 = 0;
        for (String key : sets.keySet()) {
            Sequence[] seqs = ((LinkedList)sets.get(key)).toArray(new Sequence[0]);
            Object[] ws = new ComparableElement[seqs.length];
            double sum = 0.0;
            double sumsq = 0.0;
            double n = 0.0;
            double[] gws = new double[seqs.length];
            double[] lws = new double[seqs.length];
            int j = 0;
            while (j < seqs.length) {
                double w = Double.parseDouble(seqs[j].getSequenceAnnotationByType(sortTag, 0).getIdentifier());
                SequenceAnnotation an = seqs[j].getSequenceAnnotationByType(globalWeightTag, 0);
                double gw = Double.parseDouble(an.getIdentifier());
                sum += w * gw;
                sumsq += w * w * gw;
                n += gw;
                ws[j] = new ComparableElement<Sequence, Double>(seqs[j], -w);
                gws[j] = gw;
                lws[j] = w;
                ++j;
            }
            double mean = sum / n;
            double sd = Math.sqrt((sumsq /= n) - mean * mean);
            if (n == 0.0) {
                mean = 0.0;
                sd = 1.0;
            }
            if (sd <= 0.0) {
                System.err.println(String.valueOf(((Sequence)((ComparableElement)ws[0]).getElement()).getSequenceAnnotationByType(splitTag, 0).getIdentifier()) + " " + mean + " " + sd);
                sd = 1.0;
            }
            Arrays.sort(ws);
            double min = -((Double)((ComparableElement)ws[ws.length - 1]).getWeight()).doubleValue();
            double max = -((Double)((ComparableElement)ws[0]).getWeight()).doubleValue();
            int j2 = 0;
            while (j2 < ws.length) {
                seqs[j2] = (Sequence)((ComparableElement)ws[j2]).getElement();
                SequenceAnnotation mask = seqs[j2].getSequenceAnnotationByType("mask", 0);
                int numPos = mask == null ? seqs[j2].getLength() : mask.getIdentifier().length() - mask.getIdentifier().replaceAll("X", "").length();
                double lw = Double.parseDouble(seqs[j2].getSequenceAnnotationByType(sortTag, 0).getIdentifier());
                double myMax = max;
                if (normalize) {
                    lw = (lw - mean) / sd;
                }
                SequenceAnnotation mms = seqs[j2].getSequenceAnnotationByType("mms", 0);
                seqs[j2] = mask != null ? seqs[j2].annotate(false, new SequenceAnnotation("intgroup", String.valueOf(i2), (Result[][])new Result[0][]), (ReferenceSequenceAnnotation)seqs[j2].getSequenceAnnotationByType("reference", 0), mask, seqs[j2].getSequenceAnnotationByType(globalWeightTag, 0), new SequenceAnnotation(sortTag, String.valueOf(lw), (Result[][])new Result[0][])) : seqs[j2].annotate(false, new SequenceAnnotation("intgroup", String.valueOf(i2), (Result[][])new Result[0][]), (ReferenceSequenceAnnotation)seqs[j2].getSequenceAnnotationByType("reference", 0), seqs[j2].getSequenceAnnotationByType(globalWeightTag, 0), new SequenceAnnotation(sortTag, String.valueOf(lw), (Result[][])new Result[0][]));
                if (mms != null) {
                    seqs[j2] = seqs[j2].annotate(true, mms);
                }
                ++j2;
            }
            if (filter) {
                ArrayList<Sequence> list = new ArrayList<Sequence>();
                int j3 = 0;
                while (j3 < seqs.length) {
                    SequenceAnnotation mask = seqs[j3].getSequenceAnnotationByType("mask", 0);
                    if (mask == null || mask.getIdentifier().indexOf("X") < 0) {
                        ReferenceSequenceAnnotation an = (ReferenceSequenceAnnotation)seqs[j3].getSequenceAnnotationByType("reference", 0);
                        Sequence ref = an.getReferenceSequence();
                        AlphabetContainer rvds = ref.getAlphabetContainer();
                        int nmm = 0;
                        int k = 0;
                        while (k < ref.getLength()) {
                            if ((double)ref.discreteVal(k) == rvds.getCode(k, "HD")) {
                                if (seqs[j3].discreteVal(k + 1) != 1) {
                                    ++nmm;
                                }
                            } else if ((double)ref.discreteVal(k) == rvds.getCode(k, "NI")) {
                                if (seqs[j3].discreteVal(k + 1) != 0) {
                                    ++nmm;
                                }
                            } else if ((double)ref.discreteVal(k) == rvds.getCode(k, "NG")) {
                                if (seqs[j3].discreteVal(k + 1) != 3) {
                                    ++nmm;
                                }
                            } else if ((double)ref.discreteVal(k) == rvds.getCode(k, "NN")) {
                                if (seqs[j3].discreteVal(k + 1) != 0 && seqs[j3].discreteVal(k + 1) != 2) {
                                    ++nmm;
                                }
                            } else {
                                nmm = 0;
                                break;
                            }
                            ++k;
                        }
                        if (nmm <= 3) {
                            list.add(seqs[j3]);
                        }
                    } else {
                        list.add(seqs[j3]);
                    }
                    ++j3;
                }
                seqs = list.toArray(new Sequence[0]);
            }
            ds[i2] = new DataSet("", seqs);
            ++i2;
        }
        if (numThreads > 1) {
            double[] sizes = new double[ds.length];
            i2 = 0;
            while (i2 < ds.length) {
                sizes[i2] = ds[i2].getNumberOfElements();
                sizes[i2] = sizes[i2] * sizes[i2] * Math.sqrt(ds[i2].getAverageElementLength());
                ++i2;
            }
            int[] order = ToolBox.order(sizes, true);
            IntList[] lists = new IntList[numThreads];
            int j = 0;
            while (j < lists.length) {
                lists[j] = new IntList();
                ++j;
            }
            double[] curr = new double[numThreads];
            int j4 = 0;
            while (j4 < order.length) {
                double size = sizes[order[j4]];
                int idx = ToolBox.getMinIndex(curr);
                lists[idx].add(order[j4]);
                int n = idx;
                curr[n] = curr[n] + size;
                ++j4;
            }
            DataSet[] ds2 = new DataSet[ds.length];
            int j5 = 0;
            int k = 0;
            while (j5 < lists.length) {
                int l = 0;
                while (l < lists[j5].length()) {
                    ds2[k] = ds[lists[j5].get(l)];
                    ++l;
                    ++k;
                }
                ++j5;
            }
            ds = ds2;
        }
        return ds;
    }
}

