package projects.dream2016.mix;

import de.jstacs.NotTrainedException;
import de.jstacs.algorithms.optimization.ConstantStartDistance;
import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.algorithms.optimization.NegativeDifferentiableFunction;
import de.jstacs.algorithms.optimization.Optimizer;
import de.jstacs.algorithms.optimization.termination.SmallDifferenceOfFunctionEvaluationsCondition;
import de.jstacs.classifiers.AbstractScoreBasedClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.data.DataSet;
import de.jstacs.data.EmptyDataSetException;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.results.CategoricalResult;
import de.jstacs.results.NumericalResult;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.MixtureDiffSM;
import de.jstacs.sequenceScores.statisticalModels.trainable.DifferentiableStatisticalModelWrapperTrainSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.Pair;
import java.util.Arrays;
import java.util.LinkedList;
import javax.naming.OperationNotSupportedException;
import org.apache.xmlgraphics.image.loader.spi.ImagePreloader;

/* loaded from: input_file:projects/dream2016/mix/NewMixtureClassifier.class */
public class NewMixtureClassifier extends AbstractScoreBasedClassifier implements OptimizableClassifier {
    private AbstractScoreBasedClassifier[] component;
    private MyMixtureScoringFunction model;
    private LogPrior mixPrior;
    private Vote vote;
    private Training training;
    private int threads;
    private int starts;
    private OptimizableClassifier[] optComponent;
    private double[] helpArray;
    private double[] logClassifierProbs;
    private double[] logMixScores;
    private double[] params;
    private double[] mixParams;
    private double[] mixGrad;
    private IntList[] indices;
    private DoubleList[] partialDer;
    private byte algo;
    private double eps;
    private double linEps;
    private static final String XML_TAG = "MixtureClassifier";
    private static /* synthetic */ int[] $SWITCH_TABLE$projects$dream2016$mix$NewMixtureClassifier$Vote;

    /* loaded from: input_file:projects/dream2016/mix/NewMixtureClassifier$DataHandler.class */
    public static class DataHandler {
        private DataSet[][] splits;
        private double[][][] weights;

        /* JADX WARN: Multi-variable type inference failed */
        /* JADX WARN: Type inference failed for: r1v2, types: [de.jstacs.data.DataSet[], de.jstacs.data.DataSet[][]] */
        /* JADX WARN: Type inference failed for: r1v4, types: [double[][], double[][][]] */
        public DataHandler(DoubleList[][] doubleListArr, LinkedList<Sequence>[]... linkedListArr) throws EmptyDataSetException, WrongAlphabetException {
            int length = doubleListArr.length;
            int length2 = doubleListArr[0].length;
            this.splits = new DataSet[length];
            this.weights = new double[length];
            Sequence[] sequenceArr = new Sequence[0];
            DataSet[] dataSetArr = null;
            if (linkedListArr.length == 1) {
                dataSetArr = new DataSet[length2];
                for (int i = 0; i < length2; i++) {
                    dataSetArr[i] = new DataSet("", (Sequence[]) linkedListArr[0][i].toArray(sequenceArr));
                }
            }
            for (int i2 = 0; i2 < length; i2++) {
                this.splits[i2] = new DataSet[length2];
                this.weights[i2] = new double[length2];
                for (int i3 = 0; i3 < length2; i3++) {
                    if (linkedListArr.length == 1) {
                        this.splits[i2][i3] = dataSetArr[i3];
                    } else {
                        this.splits[i2][i3] = new DataSet("", (Sequence[]) linkedListArr[i2][i3].toArray(new Sequence[0]));
                    }
                    if (doubleListArr[i2][i3].length() == 0) {
                        doubleListArr[i2][i3] = null;
                    } else {
                        this.weights[i2][i3] = doubleListArr[i2][i3].toArray();
                    }
                }
            }
        }

        public int getNumberOfSplits() {
            return this.splits.length;
        }

        public DataSet[] getSplits(int i) {
            return this.splits[i];
        }

        public double[][] getWeightsOfSplits(int i) {
            return this.weights[i];
        }
    }

    /* loaded from: input_file:projects/dream2016/mix/NewMixtureClassifier$MyMixtureScoringFunction.class */
    public static class MyMixtureScoringFunction extends MixtureDiffSM {
        public MyMixtureScoringFunction(int i, boolean z, DifferentiableStatisticalModel[] differentiableStatisticalModelArr) throws CloneNotSupportedException {
            super(i, z, differentiableStatisticalModelArr);
        }

        public MyMixtureScoringFunction(StringBuffer stringBuffer) throws NonParsableException {
            super(stringBuffer);
        }

        public void getLogScores(Sequence sequence, double[] dArr) {
            fillComponentScores(sequence, 0);
            System.arraycopy(this.componentScore, 0, dArr, 0, this.componentScore.length);
        }

        public void getPartialDerivations(Sequence sequence, double[] dArr, IntList intList, DoubleList doubleList) {
            for (int i = 0; i < this.function.length; i++) {
                this.iList[i].clear();
                this.dList[i].clear();
                this.componentScore[i] = this.logHiddenPotential[i] + this.function[i].getLogScoreAndPartialDerivation(sequence, 0, this.iList[i], this.dList[i]);
            }
            Normalisation.logSumNormalisation(this.componentScore);
            int length = this.paramRef.length - 2;
            int i2 = this.paramRef[length + 1] - this.paramRef[length];
            for (int i3 = 0; i3 < this.function.length; i3++) {
                double d = dArr[i3] - this.componentScore[i3];
                for (int i4 = 0; i4 < this.iList[i3].length(); i4++) {
                    intList.add(this.paramRef[i3] + this.iList[i3].get(i4));
                    doubleList.add(this.dList[i3].get(i4) * d);
                }
                if (i3 < i2) {
                    intList.add(this.paramRef[length] + i3);
                    doubleList.add(d);
                }
            }
        }

        @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM
        protected boolean determineIsNormalized() {
            return false;
        }

        public void init(DataSet[] dataSetArr, double[][] dArr) throws Exception {
            double[] dArr2 = new double[this.function.length];
            SmallDifferenceOfFunctionEvaluationsCondition smallDifferenceOfFunctionEvaluationsCondition = new SmallDifferenceOfFunctionEvaluationsCondition(1.0E-11d);
            for (int i = 0; i < dArr2.length; i++) {
                DifferentiableStatisticalModelWrapperTrainSM differentiableStatisticalModelWrapperTrainSM = new DifferentiableStatisticalModelWrapperTrainSM(this.function[i], 2, (byte) 10, smallDifferenceOfFunctionEvaluationsCondition, 1.0E-9d, 1.0d);
                differentiableStatisticalModelWrapperTrainSM.train(dataSetArr[i], dArr == null ? null : dArr[i]);
                this.function[i] = differentiableStatisticalModelWrapperTrainSM.getFunction();
                if (dArr != null && dArr[i] != null) {
                    throw new OperationNotSupportedException("yet");
                }
                dArr2[i] = dataSetArr[i].getNumberOfElements();
            }
            computeHiddenParameter(dArr2, true);
        }
    }

    /* loaded from: input_file:projects/dream2016/mix/NewMixtureClassifier$OptimizableMSPClassifier.class */
    public static class OptimizableMSPClassifier extends GenDisMixClassifier implements OptimizableClassifier {
        private double[] parameter;
        private double[] grad;
        private double[] helpArray;
        private IntList[] indi;
        private DoubleList[] partDer;
        private static /* synthetic */ int[] $SWITCH_TABLE$de$jstacs$classifiers$differentiableSequenceScoreBased$OptimizableFunction$KindOfParameter;

        public OptimizableMSPClassifier(GenDisMixClassifierParameterSet genDisMixClassifierParameterSet, LogPrior logPrior, DifferentiableStatisticalModel... differentiableStatisticalModelArr) throws CloneNotSupportedException {
            super(genDisMixClassifierParameterSet, logPrior, 0.0d, LearningPrinciple.getBeta(LearningPrinciple.MSP), differentiableStatisticalModelArr);
            init();
        }

        public OptimizableMSPClassifier(StringBuffer stringBuffer) throws NonParsableException {
            super(stringBuffer);
            init();
        }

        @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier, de.jstacs.classifiers.differentiableSequenceScoreBased.ScoreClassifier, de.jstacs.classifiers.AbstractScoreBasedClassifier, de.jstacs.classifiers.AbstractClassifier
        /* renamed from: clone */
        public OptimizableMSPClassifier mo52clone() throws CloneNotSupportedException {
            OptimizableMSPClassifier optimizableMSPClassifier = (OptimizableMSPClassifier) super.mo52clone();
            if (this.parameter != null) {
                optimizableMSPClassifier.parameter = (double[]) this.parameter.clone();
                optimizableMSPClassifier.grad = (double[]) this.grad.clone();
            }
            optimizableMSPClassifier.init();
            return optimizableMSPClassifier;
        }

        private void init() {
            this.helpArray = new double[getNumberOfClasses()];
            this.indi = new IntList[this.helpArray.length];
            this.partDer = new DoubleList[this.helpArray.length];
            for (int i = 0; i < this.helpArray.length; i++) {
                this.indi[i] = new IntList();
                this.partDer[i] = new DoubleList();
            }
        }

        @Override // projects.dream2016.mix.OptimizableClassifier
        public void addGradient(double[] dArr, int i) throws EvaluationException {
            Arrays.fill(this.grad, 0.0d);
            this.prior.addGradientFor(this.parameter, this.grad);
            for (int i2 = 0; i2 < this.grad.length; i2++) {
                int i3 = i + i2;
                dArr[i3] = dArr[i3] + this.grad[i2];
            }
        }

        @Override // projects.dream2016.mix.OptimizableClassifier
        public double getLogPriorTerm() throws DimensionException, EvaluationException {
            return this.prior.evaluateFunction(this.parameter);
        }

        @Override // projects.dream2016.mix.OptimizableClassifier
        public double getLogProb(int i, Sequence sequence) {
            for (int i2 = 0; i2 < this.score.length; i2++) {
                this.helpArray[i2] = getClassWeight(i2) + this.score[i2].getLogScoreFor(sequence);
            }
            return this.helpArray[i] - Normalisation.getLogSum(this.helpArray);
        }

        @Override // projects.dream2016.mix.OptimizableClassifier
        public double getLogProbAndPartialDerivations(int i, Sequence sequence, IntList intList, DoubleList doubleList) {
            for (int i2 = 0; i2 < this.score.length; i2++) {
                this.indi[i2].clear();
                this.partDer[i2].clear();
                this.helpArray[i2] = getClassWeight(i2) + this.score[i2].getLogScoreAndPartialDerivation(sequence, this.indi[i2], this.partDer[i2]);
            }
            double logSumNormalisation = this.helpArray[i] - Normalisation.logSumNormalisation(this.helpArray);
            int i3 = 0;
            int length = this.score.length;
            while (i3 < this.score.length) {
                intList.add(i3);
                doubleList.add((i3 == i ? 1 : 0) - this.helpArray[i3]);
                for (int i4 = 0; i4 < this.indi[i3].length(); i4++) {
                    intList.add(length + this.indi[i3].get(i4));
                    doubleList.add(this.partDer[i3].get(i4) * ((i3 == i ? 1 : 0) - this.helpArray[i3]));
                }
                length += this.score[i3].getNumberOfParameters();
                i3++;
            }
            return logSumNormalisation;
        }

        @Override // projects.dream2016.mix.OptimizableClassifier
        public int getNumberOfParameters() {
            int i = 0;
            int i2 = 0;
            while (true) {
                int i3 = i2;
                if (i >= this.score.length) {
                    return getNumberOfClasses() + i3;
                }
                int i4 = i;
                i++;
                int numberOfParameters = this.score[i4].getNumberOfParameters();
                if (numberOfParameters == -1) {
                    return -1;
                }
                i2 = i3 + numberOfParameters;
            }
        }

        @Override // projects.dream2016.mix.OptimizableClassifier
        public void initialize(DataSet[] dataSetArr, double[][] dArr) throws Exception {
            for (int i = 0; i < this.score.length; i++) {
                this.score[i].initializeFunction(i, false, dataSetArr, dArr);
            }
            setClassWeights(false, new double[getNumberOfClasses()]);
            fillParameters();
        }

        @Override // projects.dream2016.mix.OptimizableClassifier
        public void initializeRandomly() throws Exception {
            for (int i = 0; i < this.score.length; i++) {
                this.score[i].initializeFunctionRandomly(false);
            }
            setClassWeights(false, new double[getNumberOfClasses()]);
            fillParameters();
        }

        @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.ScoreClassifier, de.jstacs.classifiers.AbstractClassifier
        public void train(DataSet[] dataSetArr, double[][] dArr) throws Exception {
            super.train(dataSetArr, dArr);
            fillParameters();
        }

        @Override // projects.dream2016.mix.OptimizableClassifier
        public void setParameters(double[] dArr, int i) throws Exception {
            setClassWeights(false, dArr, i);
            int numberOfClasses = i + getNumberOfClasses();
            for (int i2 = 0; i2 < this.score.length; i2++) {
                this.score[i2].setParameters(dArr, numberOfClasses);
                numberOfClasses += this.score[i2].getNumberOfParameters();
            }
            fillParameters();
        }

        private void fillParameters() throws Exception {
            int numberOfParameters = getNumberOfParameters();
            if (this.parameter == null || this.parameter.length != numberOfParameters) {
                this.parameter = new double[numberOfParameters];
                this.grad = new double[numberOfParameters];
            }
            int numberOfClasses = getNumberOfClasses();
            for (int i = 0; i < numberOfClasses; i++) {
                this.parameter[i] = getClassWeight(i);
            }
            for (int i2 = 0; i2 < this.score.length; i2++) {
                double[] currentParameterValues = this.score[i2].getCurrentParameterValues();
                System.arraycopy(currentParameterValues, 0, this.parameter, numberOfClasses, currentParameterValues.length);
                numberOfClasses += currentParameterValues.length;
            }
        }

        @Override // projects.dream2016.mix.OptimizableClassifier
        public double[] getCurrentParameterValues(OptimizableFunction.KindOfParameter kindOfParameter) throws Exception {
            if (this.parameter == null) {
                fillParameters();
            }
            double[] dArr = (double[]) this.parameter.clone();
            switch ($SWITCH_TABLE$de$jstacs$classifiers$differentiableSequenceScoreBased$OptimizableFunction$KindOfParameter()[kindOfParameter.ordinal()]) {
                case 1:
                    Arrays.fill(dArr, 0, this.score.length, 0.0d);
                    break;
                case 2:
                case 3:
                    break;
                default:
                    throw new IllegalArgumentException("Unknown kind of parameter");
            }
            return dArr;
        }

        @Override // projects.dream2016.mix.OptimizableClassifier
        public void reset() throws Exception {
            this.prior.set(false, this.score);
        }

        @Override // de.jstacs.classifiers.differentiableSequenceScoreBased.ScoreClassifier
        protected OptimizableFunction.KindOfParameter preoptimize(OptimizableFunction optimizableFunction) throws Exception {
            return OptimizableFunction.KindOfParameter.ZEROS;
        }

        static /* synthetic */ int[] $SWITCH_TABLE$de$jstacs$classifiers$differentiableSequenceScoreBased$OptimizableFunction$KindOfParameter() {
            int[] iArr = $SWITCH_TABLE$de$jstacs$classifiers$differentiableSequenceScoreBased$OptimizableFunction$KindOfParameter;
            if (iArr != null) {
                return iArr;
            }
            int[] iArr2 = new int[OptimizableFunction.KindOfParameter.valuesCustom().length];
            try {
                iArr2[OptimizableFunction.KindOfParameter.LAST.ordinal()] = 2;
            } catch (NoSuchFieldError unused) {
            }
            try {
                iArr2[OptimizableFunction.KindOfParameter.PLUGIN.ordinal()] = 3;
            } catch (NoSuchFieldError unused2) {
            }
            try {
                iArr2[OptimizableFunction.KindOfParameter.ZEROS.ordinal()] = 1;
            } catch (NoSuchFieldError unused3) {
            }
            $SWITCH_TABLE$de$jstacs$classifiers$differentiableSequenceScoreBased$OptimizableFunction$KindOfParameter = iArr2;
            return iArr2;
        }
    }

    /* loaded from: input_file:projects/dream2016/mix/NewMixtureClassifier$Training.class */
    public enum Training {
        SEPARATELY_DOC,
        SEPARATELY_VOC,
        COMBINED;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static Training[] valuesCustom() {
            Training[] valuesCustom = values();
            int length = valuesCustom.length;
            Training[] trainingArr = new Training[length];
            System.arraycopy(valuesCustom, 0, trainingArr, 0, length);
            return trainingArr;
        }
    }

    /* loaded from: input_file:projects/dream2016/mix/NewMixtureClassifier$Vote.class */
    public enum Vote {
        DOC,
        VOC;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static Vote[] valuesCustom() {
            Vote[] valuesCustom = values();
            int length = valuesCustom.length;
            Vote[] voteArr = new Vote[length];
            System.arraycopy(valuesCustom, 0, voteArr, 0, length);
            return voteArr;
        }
    }

    public NewMixtureClassifier(int i, Training training, int i2, DifferentiableStatisticalModel[] differentiableStatisticalModelArr, AbstractScoreBasedClassifier[] abstractScoreBasedClassifierArr, Vote vote, LogPrior logPrior) throws Exception {
        super(abstractScoreBasedClassifierArr[0].getAlphabetContainer(), abstractScoreBasedClassifierArr[0].getLength(), abstractScoreBasedClassifierArr[0].getNumberOfClasses());
        this.algo = (byte) 10;
        this.eps = 1.0E-7d;
        this.linEps = 1.0E-7d;
        if (differentiableStatisticalModelArr != null && differentiableStatisticalModelArr.length != abstractScoreBasedClassifierArr.length) {
            throw new IllegalArgumentException();
        }
        this.helpArray = new double[getNumberOfClasses()];
        this.model = new MyMixtureScoringFunction(training == Training.COMBINED ? 1 : i2, true, (DifferentiableStatisticalModel[]) ArrayHandler.clone(differentiableStatisticalModelArr));
        if (i2 <= 0) {
            throw new IllegalArgumentException("The number of starts has to be positive.");
        }
        this.starts = i2;
        this.component = (AbstractScoreBasedClassifier[]) ArrayHandler.clone(abstractScoreBasedClassifierArr);
        this.logClassifierProbs = new double[abstractScoreBasedClassifierArr.length];
        this.logMixScores = new double[differentiableStatisticalModelArr.length];
        this.training = training;
        if (training == Training.COMBINED) {
            this.optComponent = (OptimizableClassifier[]) ArrayHandler.cast(OptimizableClassifier.class, this.component);
        } else {
            this.optComponent = null;
        }
        this.threads = i;
        this.mixPrior = logPrior;
        this.vote = vote;
    }

    public NewMixtureClassifier(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
        this.algo = (byte) 10;
        this.eps = 1.0E-7d;
        this.linEps = 1.0E-7d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.classifiers.AbstractScoreBasedClassifier, de.jstacs.classifiers.AbstractClassifier
    public void extractFurtherClassifierInfosFromXML(StringBuffer stringBuffer) throws NonParsableException {
        super.extractFurtherClassifierInfosFromXML(stringBuffer);
        this.model = (MyMixtureScoringFunction) XMLParser.extractObjectForTags(stringBuffer, "model", MyMixtureScoringFunction.class);
        this.component = (AbstractScoreBasedClassifier[]) XMLParser.extractObjectForTags(stringBuffer, "componentClassifier", AbstractScoreBasedClassifier[].class);
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, "mixPrior");
        if (extractForTag != null) {
            String str = (String) XMLParser.extractObjectForTags(extractForTag, "className", String.class);
            try {
                this.mixPrior = (LogPrior) Class.forName(str).getConstructor(StringBuffer.class).newInstance(extractForTag);
            } catch (NoSuchMethodException e) {
                NonParsableException nonParsableException = new NonParsableException("You must provide a constructor " + str + "(StringBuffer).");
                nonParsableException.setStackTrace(e.getStackTrace());
                throw nonParsableException;
            } catch (Exception e2) {
                NonParsableException nonParsableException2 = new NonParsableException("problem at " + str + ": " + e2.getMessage());
                nonParsableException2.setStackTrace(e2.getStackTrace());
                throw nonParsableException2;
            }
        } else {
            this.mixPrior = DoesNothingLogPrior.defaultInstance;
        }
        this.vote = (Vote) XMLParser.extractObjectForTags(stringBuffer, "vote", Vote.class);
        this.training = (Training) XMLParser.extractObjectForTags(stringBuffer, "training", Training.class);
        this.threads = ((Integer) XMLParser.extractObjectForTags(stringBuffer, "threads", Integer.TYPE)).intValue();
        this.starts = ((Integer) XMLParser.extractObjectForTags(stringBuffer, "starts", Integer.TYPE)).intValue();
        if (this.training == Training.COMBINED) {
            this.optComponent = (OptimizableClassifier[]) ArrayHandler.cast(OptimizableClassifier.class, this.component);
            try {
                reset();
            } catch (Exception e3) {
                throw new NonParsableException(e3.getMessage());
            }
        } else {
            this.optComponent = null;
        }
        this.helpArray = new double[getNumberOfClasses()];
        this.logClassifierProbs = new double[this.component.length];
        this.logMixScores = new double[this.logClassifierProbs.length];
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.classifiers.AbstractScoreBasedClassifier, de.jstacs.classifiers.AbstractClassifier
    public StringBuffer getFurtherClassifierInfos() {
        StringBuffer furtherClassifierInfos = super.getFurtherClassifierInfos();
        XMLParser.appendObjectWithTags(furtherClassifierInfos, this.model, "model");
        XMLParser.appendObjectWithTags(furtherClassifierInfos, this.component, "componentClassifier");
        if (!(this.mixPrior instanceof DoesNothingLogPrior)) {
            StringBuffer stringBuffer = new StringBuffer(ImagePreloader.DEFAULT_PRIORITY);
            stringBuffer.append("<mixPrior>\n");
            XMLParser.appendObjectWithTags(stringBuffer, this.mixPrior.getClass().getName(), "className");
            stringBuffer.append(this.mixPrior.toXML());
            stringBuffer.append("\t</mixPrior>\n");
            furtherClassifierInfos.append(stringBuffer);
        }
        XMLParser.appendObjectWithTags(furtherClassifierInfos, this.vote, "vote");
        XMLParser.appendObjectWithTags(furtherClassifierInfos, this.training, "training");
        XMLParser.appendObjectWithTags(furtherClassifierInfos, Integer.valueOf(this.threads), "threads");
        XMLParser.appendObjectWithTags(furtherClassifierInfos, Integer.valueOf(this.starts), "starts");
        return furtherClassifierInfos;
    }

    @Override // de.jstacs.classifiers.AbstractScoreBasedClassifier, de.jstacs.classifiers.AbstractClassifier
    /* renamed from: clone */
    public NewMixtureClassifier mo52clone() throws CloneNotSupportedException {
        NewMixtureClassifier newMixtureClassifier = (NewMixtureClassifier) super.mo52clone();
        newMixtureClassifier.component = (AbstractScoreBasedClassifier[]) ArrayHandler.clone(this.component);
        if (this.optComponent != null) {
            newMixtureClassifier.optComponent = (OptimizableClassifier[]) ArrayHandler.cast(OptimizableClassifier.class, newMixtureClassifier.component);
        }
        newMixtureClassifier.model = (MyMixtureScoringFunction) this.model.mo114clone();
        newMixtureClassifier.mixPrior = this.mixPrior.getNewInstance();
        newMixtureClassifier.logClassifierProbs = (double[]) this.logClassifierProbs.clone();
        newMixtureClassifier.logMixScores = (double[]) this.logMixScores.clone();
        if (this.params != null) {
            newMixtureClassifier.params = (double[]) this.params.clone();
            newMixtureClassifier.mixParams = (double[]) this.mixParams.clone();
            newMixtureClassifier.mixGrad = (double[]) this.mixGrad.clone();
            newMixtureClassifier.indices = new IntList[this.indices.length];
            newMixtureClassifier.partialDer = new DoubleList[this.partialDer.length];
            for (int i = 0; i < this.optComponent.length; i++) {
                newMixtureClassifier.indices[i] = new IntList();
                newMixtureClassifier.partialDer[i] = new DoubleList();
            }
        }
        newMixtureClassifier.helpArray = (double[]) this.helpArray.clone();
        return newMixtureClassifier;
    }

    public void setVote(Vote vote) {
        this.vote = vote;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.classifiers.AbstractScoreBasedClassifier
    public double getScore(Sequence sequence, int i, boolean z) throws IllegalArgumentException, NotTrainedException, Exception {
        return getLogProb(sequence, i, z, this.vote);
    }

    private double getLogProb(Sequence sequence, int i, boolean z, Vote vote) throws EvaluationException {
        if (z) {
            try {
                check(sequence);
            } catch (Exception e) {
                EvaluationException evaluationException = new EvaluationException(e.getClass() + ": " + e.getMessage());
                evaluationException.setStackTrace(e.getStackTrace());
                throw evaluationException;
            }
        }
        switch ($SWITCH_TABLE$projects$dream2016$mix$NewMixtureClassifier$Vote()[vote.ordinal()]) {
            case 1:
                int indexOfMaximalComponentFor = this.model.getIndexOfMaximalComponentFor(sequence, 0);
                for (int i2 = 0; i2 < this.helpArray.length; i2++) {
                    this.helpArray[i2] = this.component[indexOfMaximalComponentFor].getScore(sequence, i2);
                }
                return this.helpArray[i] - Normalisation.getLogSum(this.helpArray);
            case 2:
                this.model.getLogScores(sequence, this.logMixScores);
                double logSum = Normalisation.getLogSum(this.logMixScores);
                for (int i3 = 0; i3 < this.logMixScores.length; i3++) {
                    for (int i4 = 0; i4 < this.helpArray.length; i4++) {
                        this.helpArray[i4] = this.component[i3].getScore(sequence, i4);
                    }
                    double[] dArr = this.logMixScores;
                    int i5 = i3;
                    dArr[i5] = dArr[i5] + (this.helpArray[i] - Normalisation.getLogSum(this.helpArray));
                }
                return Normalisation.getLogSum(this.logMixScores) - logSum;
            default:
                throw new RuntimeException();
        }
    }

    @Override // de.jstacs.classifiers.AbstractClassifier
    public CategoricalResult[] getClassifierAnnotation() {
        CategoricalResult[] categoricalResultArr = new CategoricalResult[getNumberOfClasses() + 1];
        categoricalResultArr[0] = new CategoricalResult("classifier", "a <b>short</b> description of the classifier", getInstanceName());
        for (int i = 1; i < categoricalResultArr.length; i++) {
            categoricalResultArr[i] = new CategoricalResult("class info " + (i - 1), "some information about the class", "");
        }
        return categoricalResultArr;
    }

    @Override // de.jstacs.classifiers.AbstractClassifier
    public String getInstanceName() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append(String.valueOf(getClass().getSimpleName()) + "(model: " + this.model.getInstanceName() + "; classifier: ");
        stringBuffer.append(this.component[0].getInstanceName());
        for (int i = 1; i < this.component.length; i++) {
            stringBuffer.append(", " + this.component[i].getInstanceName());
        }
        stringBuffer.append(")");
        return stringBuffer.toString();
    }

    @Override // de.jstacs.classifiers.AbstractClassifier
    public NumericalResultSet getNumericalCharacteristics() throws Exception {
        return new NumericalResultSet((LinkedList<? extends NumericalResult>) new LinkedList());
    }

    @Override // de.jstacs.classifiers.AbstractClassifier
    protected String getXMLTag() {
        return XML_TAG;
    }

    @Override // de.jstacs.classifiers.AbstractClassifier
    public boolean isInitialized() {
        int i = 0;
        while (i < this.component.length && this.component[i].isInitialized()) {
            i++;
        }
        return i == this.component.length && this.model.isInitialized();
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v7, types: [double[], double[][]] */
    private DataHandler init(DataSet[] dataSetArr, double[][] dArr, boolean z) throws Exception {
        LinkedList linkedList = new LinkedList();
        DoubleList doubleList = new DoubleList();
        for (int i = 0; i < dataSetArr.length; i++) {
            for (int i2 = 0; i2 < dataSetArr[i].getNumberOfElements(); i2++) {
                linkedList.add(dataSetArr[i].getElementAt(i2));
                if (dArr != null && dArr[i] != null) {
                    doubleList.add(dArr[i][i2]);
                }
            }
        }
        DataSet[] dataSetArr2 = {new DataSet("all", (Sequence[]) linkedList.toArray(new Sequence[0]))};
        ?? r0 = new double[1];
        r0[0] = doubleList.length() == 0 ? null : doubleList.toArray();
        if (this.model.getNumberOfParameters() <= this.model.getNumberOfComponents() || !z) {
            this.model.initializeFunction(0, false, dataSetArr2, r0);
        } else {
            DifferentiableStatisticalModelWrapperTrainSM differentiableStatisticalModelWrapperTrainSM = new DifferentiableStatisticalModelWrapperTrainSM(this.model, this.threads, this.algo, new SmallDifferenceOfFunctionEvaluationsCondition(0.001d), this.linEps, this.starts, this.mixPrior.getNewInstance());
            differentiableStatisticalModelWrapperTrainSM.train(dataSetArr2[0], r0[0]);
            this.model = (MyMixtureScoringFunction) differentiableStatisticalModelWrapperTrainSM.getFunction();
        }
        System.out.println(this.model);
        DataHandler splitData = splitData(this.model, dataSetArr, dArr, this.training == Training.SEPARATELY_DOC);
        if (this.optComponent != null) {
            for (int i3 = 0; i3 < this.optComponent.length; i3++) {
                this.optComponent[i3].initialize(splitData.getSplits(i3), splitData.getWeightsOfSplits(i3));
            }
        }
        return splitData;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v22, types: [double[], double[][]] */
    /* JADX WARN: Type inference failed for: r0v84, types: [double[]] */
    /* JADX WARN: Type inference failed for: r9v0, types: [projects.dream2016.mix.OptimizableClassifier, projects.dream2016.mix.NewMixtureClassifier] */
    @Override // de.jstacs.classifiers.AbstractClassifier
    public void train(DataSet[] dataSetArr, double[][] dArr) throws Exception {
        if (this.training != Training.COMBINED) {
            DataHandler init = init(dataSetArr, dArr, true);
            for (int i = 0; i < this.component.length; i++) {
                DataSet[] splits = init.getSplits(i);
                System.out.println(String.valueOf(i) + ")\t" + splits[0].getNumberOfElements() + " vs. " + splits[1].getNumberOfElements());
                this.component[i].train(splits, init.getWeightsOfSplits(i));
            }
            return;
        }
        if (dArr == null) {
            dArr = new double[dataSetArr.length];
        }
        for (int i2 = 0; i2 < dataSetArr.length; i2++) {
            if (dArr[i2] == null) {
                dArr[i2] = new double[dataSetArr[i2].getNumberOfElements()];
                Arrays.fill(dArr[i2], 1.0d);
            }
        }
        NewMixtureClassifier newMixtureClassifier = this;
        double d = Double.NEGATIVE_INFINITY;
        DataSet[] dataSetArr2 = new DataSet[dataSetArr.length];
        ?? r0 = new double[dataSetArr.length];
        SmallDifferenceOfFunctionEvaluationsCondition smallDifferenceOfFunctionEvaluationsCondition = new SmallDifferenceOfFunctionEvaluationsCondition(this.eps);
        ConstantStartDistance constantStartDistance = new ConstantStartDistance(1.0d);
        boolean z = false;
        Exception exc = null;
        for (int i3 = 0; i3 < this.starts; i3++) {
            System.out.println(String.valueOf(i3) + " ========================");
            try {
                initializeRandomly();
                for (int i4 = 0; i4 < dataSetArr2.length; i4++) {
                    Pair<DataSet[], double[][]> partition = dataSetArr[i4].partition(dArr[i4], DataSet.PartitionMethod.PARTITION_BY_NUMBER_OF_ELEMENTS, 0.95d, 0.05d);
                    dataSetArr2[i4] = partition.getFirstElement()[1];
                    r0[i4] = partition.getSecondElement()[1];
                }
                init(dataSetArr2, r0, true);
                DataHandler splitData = splitData(this.model, dataSetArr, dArr, false);
                for (int i5 = 0; i5 < this.component.length; i5++) {
                    this.component[i5].train(splitData.getSplits(i5), splitData.getWeightsOfSplits(i5));
                }
                MSPClassifierObjective mSPClassifierObjective = new MSPClassifierObjective(this.threads, this, dataSetArr, dArr, true);
                NegativeDifferentiableFunction negativeDifferentiableFunction = new NegativeDifferentiableFunction(mSPClassifierObjective);
                mSPClassifierObjective.reset();
                double[] parameters = mSPClassifierObjective.getParameters(OptimizableFunction.KindOfParameter.PLUGIN);
                Optimizer.optimize(this.algo, negativeDifferentiableFunction, parameters, smallDifferenceOfFunctionEvaluationsCondition, this.linEps, constantStartDistance, System.out);
                double evaluateFunction = mSPClassifierObjective.evaluateFunction(parameters);
                System.out.println("start " + i3 + ":\t" + evaluateFunction);
                if (evaluateFunction > d) {
                    newMixtureClassifier = mo52clone();
                    d = evaluateFunction;
                }
                z = true;
                mSPClassifierObjective.stopThreads();
            } catch (Exception e) {
                exc = e;
                System.out.println("An exception was thrown. " + e.getMessage());
            }
        }
        if (!z) {
            throw exc;
        }
        this.component = newMixtureClassifier.component;
        newMixtureClassifier.component = null;
        this.optComponent = (OptimizableClassifier[]) ArrayHandler.cast(OptimizableClassifier.class, this.component);
        this.model = newMixtureClassifier.model;
        newMixtureClassifier.model = null;
        this.params = null;
        reset();
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append(this.model);
        for (int i = 0; i < this.component.length; i++) {
            stringBuffer.append("\nclassifier " + i + ":\n");
            stringBuffer.append(this.component[i]);
        }
        return stringBuffer.toString();
    }

    /* JADX WARN: Type inference failed for: r3v1, types: [java.util.LinkedList[], java.util.LinkedList[][]] */
    public static DataHandler splitData(MyMixtureScoringFunction myMixtureScoringFunction, DataSet[] dataSetArr, double[][] dArr, boolean z) throws EmptyDataSetException, WrongAlphabetException {
        int numberOfComponents = myMixtureScoringFunction.getNumberOfComponents();
        LinkedList[][] linkedListArr = new LinkedList[numberOfComponents][dataSetArr.length];
        DoubleList[][] doubleListArr = new DoubleList[numberOfComponents][dataSetArr.length];
        for (int i = 0; i < numberOfComponents; i++) {
            for (int i2 = 0; i2 < dataSetArr.length; i2++) {
                linkedListArr[i][i2] = new LinkedList();
                doubleListArr[i][i2] = new DoubleList();
            }
        }
        double[] dArr2 = new double[myMixtureScoringFunction.getNumberOfComponents()];
        for (int i3 = 0; i3 < dataSetArr.length; i3++) {
            int numberOfElements = dataSetArr[i3].getNumberOfElements();
            double d = 1.0d;
            for (int i4 = 0; i4 < numberOfElements; i4++) {
                Sequence elementAt = dataSetArr[i3].getElementAt(i4);
                if (dArr != null && dArr[i3] != null) {
                    d = dArr[i3][i4];
                }
                if (z) {
                    int indexOfMaximalComponentFor = myMixtureScoringFunction.getIndexOfMaximalComponentFor(elementAt, 0);
                    linkedListArr[indexOfMaximalComponentFor][i3].add(elementAt);
                    doubleListArr[indexOfMaximalComponentFor][i3].add(d);
                } else {
                    linkedListArr[0][i3].add(elementAt);
                    myMixtureScoringFunction.getLogScores(elementAt, dArr2);
                    Normalisation.logSumNormalisation(dArr2);
                    for (int i5 = 0; i5 < dArr2.length; i5++) {
                        doubleListArr[i5][i3].add(d * dArr2[i5]);
                    }
                }
            }
        }
        return z ? new DataHandler(doubleListArr, linkedListArr) : new DataHandler(doubleListArr, new LinkedList[]{linkedListArr[0]});
    }

    @Override // projects.dream2016.mix.OptimizableClassifier
    public void addGradient(double[] dArr, int i) throws EvaluationException {
        Arrays.fill(this.mixGrad, 0.0d);
        this.mixPrior.addGradientFor(this.mixParams, this.mixGrad);
        int i2 = 0;
        while (i2 < this.mixGrad.length) {
            int i3 = i;
            dArr[i3] = dArr[i3] + this.mixGrad[i2];
            i2++;
            i++;
        }
        for (int i4 = 0; i4 < this.optComponent.length; i4++) {
            this.optComponent[i4].addGradient(dArr, i);
            i += this.optComponent[i4].getNumberOfParameters();
        }
    }

    @Override // projects.dream2016.mix.OptimizableClassifier
    public double[] getCurrentParameterValues(OptimizableFunction.KindOfParameter kindOfParameter) throws Exception {
        return getCurrentParameterValues(kindOfParameter, getNumberOfParameters(), null);
    }

    private double[] getCurrentParameterValues(OptimizableFunction.KindOfParameter kindOfParameter, int i, double[] dArr) throws Exception {
        if (dArr == null) {
            dArr = new double[this.params.length];
        }
        this.mixParams = this.model.getCurrentParameterValues();
        int length = this.mixParams.length;
        System.arraycopy(this.mixParams, 0, dArr, 0, length);
        for (int i2 = 0; i2 < this.optComponent.length; i2++) {
            double[] currentParameterValues = this.optComponent[i2].getCurrentParameterValues(kindOfParameter);
            System.arraycopy(currentParameterValues, 0, dArr, length, currentParameterValues.length);
            length += currentParameterValues.length;
        }
        return dArr;
    }

    @Override // projects.dream2016.mix.OptimizableClassifier
    public double getLogPriorTerm() throws DimensionException, EvaluationException {
        double evaluateFunction = this.mixPrior.evaluateFunction(this.mixParams);
        for (int i = 0; i < this.optComponent.length; i++) {
            evaluateFunction += this.optComponent[i].getLogPriorTerm();
        }
        return evaluateFunction;
    }

    @Override // projects.dream2016.mix.OptimizableClassifier
    public double getLogProb(int i, Sequence sequence) throws EvaluationException {
        return getLogProb(i, sequence, false);
    }

    private double getLogProb(int i, Sequence sequence, boolean z) throws EvaluationException {
        this.model.getLogScores(sequence, this.logMixScores);
        double logSum = Normalisation.getLogSum(this.logMixScores);
        for (int i2 = 0; i2 < this.optComponent.length; i2++) {
            this.logClassifierProbs[i2] = (this.optComponent[i2].getLogProb(i, sequence) + this.logMixScores[i2]) - logSum;
        }
        return Normalisation.getLogSum(this.logClassifierProbs);
    }

    @Override // projects.dream2016.mix.OptimizableClassifier
    public double getLogProbAndPartialDerivations(int i, Sequence sequence, IntList intList, DoubleList doubleList) {
        this.model.getLogScores(sequence, this.logMixScores);
        double logSum = Normalisation.getLogSum(this.logMixScores);
        for (int i2 = 0; i2 < this.optComponent.length; i2++) {
            this.indices[i2].clear();
            this.partialDer[i2].clear();
            this.logClassifierProbs[i2] = (this.optComponent[i2].getLogProbAndPartialDerivations(i, sequence, this.indices[i2], this.partialDer[i2]) + this.logMixScores[i2]) - logSum;
        }
        double logSumNormalisation = Normalisation.logSumNormalisation(this.logClassifierProbs);
        int length = this.mixParams.length;
        for (int i3 = 0; i3 < this.indices.length; i3++) {
            for (int i4 = 0; i4 < this.indices[i3].length(); i4++) {
                intList.add(length + this.indices[i3].get(i4));
                doubleList.add(this.logClassifierProbs[i3] * this.partialDer[i3].get(i4));
            }
            length += this.optComponent[i3].getNumberOfParameters();
        }
        this.model.getPartialDerivations(sequence, this.logClassifierProbs, intList, doubleList);
        for (int i5 = 0; i5 < intList.length(); i5++) {
            if (Double.isNaN(doubleList.get(i5))) {
                System.out.println("partial derivation became: NaN");
                System.out.println(String.valueOf(i) + "\t" + sequence);
                System.out.println(Arrays.toString(this.params));
                System.out.println(String.valueOf(i5) + "\t" + doubleList.get(i5));
                System.exit(1);
            }
        }
        return logSumNormalisation;
    }

    @Override // projects.dream2016.mix.OptimizableClassifier
    public int getNumberOfParameters() {
        if (this.params == null) {
            return -1;
        }
        return this.params.length;
    }

    @Override // projects.dream2016.mix.OptimizableClassifier
    public void initialize(DataSet[] dataSetArr, double[][] dArr) throws Exception {
        init(dataSetArr, dArr, true);
    }

    @Override // projects.dream2016.mix.OptimizableClassifier
    public void initializeRandomly() throws Exception {
        this.model.initializeFunctionRandomly(false);
        for (int i = 0; i < this.optComponent.length; i++) {
            this.optComponent[i].initializeRandomly();
        }
    }

    @Override // projects.dream2016.mix.OptimizableClassifier
    public void reset() throws Exception {
        this.mixPrior.set(true, this.model);
        for (int i = 0; i < this.optComponent.length; i++) {
            this.optComponent[i].reset();
        }
        int numberOfParameters = this.model.getNumberOfParameters();
        if (numberOfParameters != -1) {
            int i2 = 0;
            while (true) {
                if (i2 >= this.optComponent.length) {
                    break;
                }
                int numberOfParameters2 = this.optComponent[i2].getNumberOfParameters();
                if (numberOfParameters2 == -1) {
                    numberOfParameters = numberOfParameters2;
                    break;
                } else {
                    numberOfParameters += numberOfParameters2;
                    i2++;
                }
            }
        }
        if (numberOfParameters != -1) {
            if (this.params == null || this.params.length != numberOfParameters) {
                this.params = new double[numberOfParameters];
                this.mixParams = new double[this.model.getNumberOfParameters()];
                this.mixGrad = new double[this.mixParams.length];
            }
            getCurrentParameterValues(OptimizableFunction.KindOfParameter.PLUGIN, numberOfParameters, this.params);
        } else {
            this.params = null;
        }
        if (this.indices == null) {
            this.indices = new IntList[this.optComponent.length];
            this.partialDer = new DoubleList[this.optComponent.length];
            for (int i3 = 0; i3 < this.optComponent.length; i3++) {
                this.indices[i3] = new IntList();
                this.partialDer[i3] = new DoubleList();
            }
        }
    }

    @Override // projects.dream2016.mix.OptimizableClassifier
    public void setParameters(double[] dArr, int i) throws Exception {
        System.arraycopy(dArr, i, this.mixParams, 0, this.mixParams.length);
        System.arraycopy(dArr, i, this.params, 0, this.params.length);
        this.model.setParameters(dArr, i);
        int numberOfParameters = i + this.model.getNumberOfParameters();
        for (int i2 = 0; i2 < this.optComponent.length; i2++) {
            this.optComponent[i2].setParameters(dArr, numberOfParameters);
            numberOfParameters += this.optComponent[i2].getNumberOfParameters();
        }
    }

    static /* synthetic */ int[] $SWITCH_TABLE$projects$dream2016$mix$NewMixtureClassifier$Vote() {
        int[] iArr = $SWITCH_TABLE$projects$dream2016$mix$NewMixtureClassifier$Vote;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[Vote.valuesCustom().length];
        try {
            iArr2[Vote.DOC.ordinal()] = 1;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[Vote.VOC.ordinal()] = 2;
        } catch (NoSuchFieldError unused2) {
        }
        $SWITCH_TABLE$projects$dream2016$mix$NewMixtureClassifier$Vote = iArr2;
        return iArr2;
    }
}
