/*
 * Decompiled with CFR 0.152.
 */
package projects.dream2016.mix;

import de.jstacs.NotTrainedException;
import de.jstacs.algorithms.optimization.ConstantStartDistance;
import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.algorithms.optimization.NegativeDifferentiableFunction;
import de.jstacs.algorithms.optimization.Optimizer;
import de.jstacs.algorithms.optimization.termination.SmallDifferenceOfFunctionEvaluationsCondition;
import de.jstacs.classifiers.AbstractScoreBasedClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifier;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.GenDisMixClassifierParameterSet;
import de.jstacs.classifiers.differentiableSequenceScoreBased.gendismix.LearningPrinciple;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifiers.differentiableSequenceScoreBased.logPrior.LogPrior;
import de.jstacs.data.DataSet;
import de.jstacs.data.EmptyDataSetException;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.results.CategoricalResult;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.MixtureDiffSM;
import de.jstacs.sequenceScores.statisticalModels.trainable.DifferentiableStatisticalModelWrapperTrainSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.Pair;
import java.util.Arrays;
import java.util.LinkedList;
import javax.naming.OperationNotSupportedException;
import projects.dream2016.mix.MSPClassifierObjective;
import projects.dream2016.mix.OptimizableClassifier;

public class NewMixtureClassifier
extends AbstractScoreBasedClassifier
implements OptimizableClassifier {
    private AbstractScoreBasedClassifier[] component;
    private MyMixtureScoringFunction model;
    private LogPrior mixPrior;
    private Vote vote;
    private Training training;
    private int threads;
    private int starts;
    private OptimizableClassifier[] optComponent;
    private double[] helpArray;
    private double[] logClassifierProbs;
    private double[] logMixScores;
    private double[] params;
    private double[] mixParams;
    private double[] mixGrad;
    private IntList[] indices;
    private DoubleList[] partialDer;
    private byte algo = (byte)10;
    private double eps = 1.0E-7;
    private double linEps = 1.0E-7;
    private static final String XML_TAG = "MixtureClassifier";

    public NewMixtureClassifier(int threads, Training training, int starts, DifferentiableStatisticalModel[] componentSF, AbstractScoreBasedClassifier[] componentClassifiers, Vote vote, LogPrior prior) throws Exception {
        super(componentClassifiers[0].getAlphabetContainer(), componentClassifiers[0].getLength(), componentClassifiers[0].getNumberOfClasses());
        if (componentSF != null && componentSF.length != componentClassifiers.length) {
            throw new IllegalArgumentException();
        }
        this.helpArray = new double[this.getNumberOfClasses()];
        this.model = new MyMixtureScoringFunction(training == Training.COMBINED ? 1 : starts, true, (DifferentiableStatisticalModel[])ArrayHandler.clone((Cloneable[])componentSF));
        if (starts <= 0) {
            throw new IllegalArgumentException("The number of starts has to be positive.");
        }
        this.starts = starts;
        this.component = (AbstractScoreBasedClassifier[])ArrayHandler.clone((Cloneable[])componentClassifiers);
        this.logClassifierProbs = new double[componentClassifiers.length];
        this.logMixScores = new double[componentSF.length];
        this.training = training;
        this.optComponent = training == Training.COMBINED ? ArrayHandler.cast(OptimizableClassifier.class, this.component) : null;
        this.threads = threads;
        this.mixPrior = prior;
        this.vote = vote;
    }

    public NewMixtureClassifier(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    protected void extractFurtherClassifierInfosFromXML(StringBuffer xml) throws NonParsableException {
        super.extractFurtherClassifierInfosFromXML(xml);
        this.model = XMLParser.extractObjectForTags(xml, "model", MyMixtureScoringFunction.class);
        this.component = XMLParser.extractObjectForTags(xml, "componentClassifier", AbstractScoreBasedClassifier[].class);
        StringBuffer pr = XMLParser.extractForTag(xml, "mixPrior");
        if (pr != null) {
            String className = XMLParser.extractObjectForTags(pr, "className", String.class);
            try {
                this.mixPrior = (LogPrior)Class.forName(className).getConstructor(StringBuffer.class).newInstance(pr);
            }
            catch (NoSuchMethodException e) {
                NonParsableException n = new NonParsableException("You must provide a constructor " + className + "(StringBuffer).");
                n.setStackTrace(e.getStackTrace());
                throw n;
            }
            catch (Exception e) {
                NonParsableException n = new NonParsableException("problem at " + className + ": " + e.getMessage());
                n.setStackTrace(e.getStackTrace());
                throw n;
            }
        } else {
            this.mixPrior = DoesNothingLogPrior.defaultInstance;
        }
        this.vote = XMLParser.extractObjectForTags(xml, "vote", Vote.class);
        this.training = XMLParser.extractObjectForTags(xml, "training", Training.class);
        this.threads = XMLParser.extractObjectForTags(xml, "threads", Integer.TYPE);
        this.starts = XMLParser.extractObjectForTags(xml, "starts", Integer.TYPE);
        if (this.training == Training.COMBINED) {
            this.optComponent = ArrayHandler.cast(OptimizableClassifier.class, this.component);
            try {
                this.reset();
            }
            catch (Exception e) {
                throw new NonParsableException(e.getMessage());
            }
        } else {
            this.optComponent = null;
        }
        this.helpArray = new double[this.getNumberOfClasses()];
        this.logClassifierProbs = new double[this.component.length];
        this.logMixScores = new double[this.logClassifierProbs.length];
    }

    @Override
    protected StringBuffer getFurtherClassifierInfos() {
        StringBuffer xml = super.getFurtherClassifierInfos();
        XMLParser.appendObjectWithTags(xml, this.model, "model");
        XMLParser.appendObjectWithTags(xml, this.component, "componentClassifier");
        if (!(this.mixPrior instanceof DoesNothingLogPrior)) {
            StringBuffer pr = new StringBuffer(1000);
            pr.append("<mixPrior>\n");
            XMLParser.appendObjectWithTags(pr, this.mixPrior.getClass().getName(), "className");
            pr.append(this.mixPrior.toXML());
            pr.append("\t</mixPrior>\n");
            xml.append(pr);
        }
        XMLParser.appendObjectWithTags(xml, (Object)this.vote, "vote");
        XMLParser.appendObjectWithTags(xml, (Object)this.training, "training");
        XMLParser.appendObjectWithTags(xml, this.threads, "threads");
        XMLParser.appendObjectWithTags(xml, this.starts, "starts");
        return xml;
    }

    @Override
    public NewMixtureClassifier clone() throws CloneNotSupportedException {
        NewMixtureClassifier clone = (NewMixtureClassifier)super.clone();
        clone.component = (AbstractScoreBasedClassifier[])ArrayHandler.clone((Cloneable[])this.component);
        if (this.optComponent != null) {
            clone.optComponent = ArrayHandler.cast(OptimizableClassifier.class, clone.component);
        }
        clone.model = (MyMixtureScoringFunction)this.model.clone();
        clone.mixPrior = this.mixPrior.getNewInstance();
        clone.logClassifierProbs = (double[])this.logClassifierProbs.clone();
        clone.logMixScores = (double[])this.logMixScores.clone();
        if (this.params != null) {
            clone.params = (double[])this.params.clone();
            clone.mixParams = (double[])this.mixParams.clone();
            clone.mixGrad = (double[])this.mixGrad.clone();
            clone.indices = new IntList[this.indices.length];
            clone.partialDer = new DoubleList[this.partialDer.length];
            int i = 0;
            while (i < this.optComponent.length) {
                clone.indices[i] = new IntList();
                clone.partialDer[i] = new DoubleList();
                ++i;
            }
        }
        clone.helpArray = (double[])this.helpArray.clone();
        return clone;
    }

    public void setVote(Vote v) {
        this.vote = v;
    }

    @Override
    protected double getScore(Sequence seq, int i, boolean check) throws IllegalArgumentException, NotTrainedException, Exception {
        return this.getLogProb(seq, i, check, this.vote);
    }

    private double getLogProb(Sequence seq, int i, boolean check, Vote vote) throws EvaluationException {
        try {
            if (check) {
                this.check(seq);
            }
            switch (vote) {
                case DOC: {
                    int k = this.model.getIndexOfMaximalComponentFor(seq, 0);
                    int j = 0;
                    while (j < this.helpArray.length) {
                        this.helpArray[j] = this.component[k].getScore(seq, j);
                        ++j;
                    }
                    return this.helpArray[i] - Normalisation.getLogSum(this.helpArray);
                }
                case VOC: {
                    this.model.getLogScores(seq, this.logMixScores);
                    double logSum = Normalisation.getLogSum(this.logMixScores);
                    int k = 0;
                    while (k < this.logMixScores.length) {
                        int j = 0;
                        while (j < this.helpArray.length) {
                            this.helpArray[j] = this.component[k].getScore(seq, j);
                            ++j;
                        }
                        int n = k++;
                        this.logMixScores[n] = this.logMixScores[n] + (this.helpArray[i] - Normalisation.getLogSum(this.helpArray));
                    }
                    return Normalisation.getLogSum(this.logMixScores) - logSum;
                }
            }
            throw new RuntimeException();
        }
        catch (Exception e) {
            EvaluationException ee = new EvaluationException(e.getClass() + ": " + e.getMessage());
            ee.setStackTrace(e.getStackTrace());
            throw ee;
        }
    }

    @Override
    public CategoricalResult[] getClassifierAnnotation() {
        CategoricalResult[] res = new CategoricalResult[this.getNumberOfClasses() + 1];
        res[0] = new CategoricalResult("classifier", "a <b>short</b> description of the classifier", this.getInstanceName());
        int i = 1;
        while (i < res.length) {
            res[i] = new CategoricalResult("class info " + (i - 1), "some information about the class", "");
            ++i;
        }
        return res;
    }

    @Override
    public String getInstanceName() {
        StringBuffer sb = new StringBuffer();
        sb.append(String.valueOf(this.getClass().getSimpleName()) + "(model: " + this.model.getInstanceName() + "; classifier: ");
        sb.append(this.component[0].getInstanceName());
        int i = 1;
        while (i < this.component.length) {
            sb.append(", " + this.component[i].getInstanceName());
            ++i;
        }
        sb.append(")");
        return sb.toString();
    }

    @Override
    public NumericalResultSet getNumericalCharacteristics() throws Exception {
        LinkedList list = new LinkedList();
        return new NumericalResultSet(list);
    }

    @Override
    protected String getXMLTag() {
        return XML_TAG;
    }

    @Override
    public boolean isInitialized() {
        int i = 0;
        while (i < this.component.length && this.component[i].isInitialized()) {
            ++i;
        }
        return i == this.component.length && this.model.isInitialized();
    }

    private DataHandler init(DataSet[] dataSetArray, double[][] dArray, boolean bl) throws Exception {
        throw new Error("Unresolved compilation problem: \n\tThe constructor DifferentiableStatisticalModelWrapperTrainSM(NewMixtureClassifier.MyMixtureScoringFunction, int, byte, SmallDifferenceOfFunctionEvaluationsCondition, double, int, LogPrior) is undefined\n");
    }

    @Override
    public void train(DataSet[] s, double[][] weights) throws Exception {
        if (this.training == Training.COMBINED) {
            if (weights == null) {
                weights = new double[s.length][];
            }
            int i = 0;
            while (i < s.length) {
                if (weights[i] == null) {
                    weights[i] = new double[s[i].getNumberOfElements()];
                    Arrays.fill(weights[i], 1.0);
                }
                ++i;
            }
            NewMixtureClassifier bestClone = this;
            double best = Double.NEGATIVE_INFINITY;
            DataSet[] sampled = new DataSet[s.length];
            double[][] sampledWeights = new double[s.length][];
            SmallDifferenceOfFunctionEvaluationsCondition term = new SmallDifferenceOfFunctionEvaluationsCondition(this.eps);
            ConstantStartDistance sd = new ConstantStartDistance(1.0);
            boolean trained = false;
            Exception last = null;
            int i2 = 0;
            while (i2 < this.starts) {
                System.out.println(String.valueOf(i2) + " ========================");
                try {
                    this.initializeRandomly();
                    int j = 0;
                    while (j < sampled.length) {
                        Pair<DataSet[], double[][]> p = s[j].partition(weights[j], DataSet.PartitionMethod.PARTITION_BY_NUMBER_OF_ELEMENTS, 0.95, 0.05);
                        sampled[j] = p.getFirstElement()[1];
                        sampledWeights[j] = p.getSecondElement()[1];
                        ++j;
                    }
                    DataHandler dh = this.init(sampled, sampledWeights, true);
                    dh = NewMixtureClassifier.splitData(this.model, s, weights, false);
                    int j2 = 0;
                    while (j2 < this.component.length) {
                        this.component[j2].train(dh.getSplits(j2), dh.getWeightsOfSplits(j2));
                        ++j2;
                    }
                    MSPClassifierObjective obj = new MSPClassifierObjective(this.threads, this, s, (double[][])weights, true);
                    NegativeDifferentiableFunction neg = new NegativeDifferentiableFunction(obj);
                    obj.reset();
                    double[] params = obj.getParameters(OptimizableFunction.KindOfParameter.PLUGIN);
                    Optimizer.optimize(this.algo, neg, params, term, this.linEps, sd, System.out);
                    double current = obj.evaluateFunction(params);
                    System.out.println("start " + i2 + ":\t" + current);
                    if (current > best) {
                        bestClone = this.clone();
                        best = current;
                    }
                    trained = true;
                    obj.stopThreads();
                }
                catch (Exception e) {
                    last = e;
                    System.out.println("An exception was thrown. " + e.getMessage());
                }
                ++i2;
            }
            if (!trained) {
                throw last;
            }
            this.component = bestClone.component;
            bestClone.component = null;
            this.optComponent = ArrayHandler.cast(OptimizableClassifier.class, this.component);
            this.model = bestClone.model;
            bestClone.model = null;
            this.params = null;
            this.reset();
        } else {
            DataHandler dh = this.init(s, (double[][])weights, true);
            int i = 0;
            while (i < this.component.length) {
                DataSet[] sp = dh.getSplits(i);
                System.out.println(String.valueOf(i) + ")\t" + sp[0].getNumberOfElements() + " vs. " + sp[1].getNumberOfElements());
                this.component[i].train(sp, dh.getWeightsOfSplits(i));
                ++i;
            }
        }
    }

    public String toString() {
        StringBuffer sb = new StringBuffer();
        sb.append(this.model);
        int i = 0;
        while (i < this.component.length) {
            sb.append("\nclassifier " + i + ":\n");
            sb.append(this.component[i]);
            ++i;
        }
        return sb.toString();
    }

    public static DataHandler splitData(MyMixtureScoringFunction mix, DataSet[] data, double[][] weights, boolean doc) throws EmptyDataSetException, WrongAlphabetException {
        int k = mix.getNumberOfComponents();
        LinkedList[][] seqList = new LinkedList[k][data.length];
        DoubleList[][] weightsList = new DoubleList[k][data.length];
        int i = 0;
        while (i < k) {
            int j = 0;
            while (j < data.length) {
                seqList[i][j] = new LinkedList();
                weightsList[i][j] = new DoubleList();
                ++j;
            }
            ++i;
        }
        double[] scores = new double[mix.getNumberOfComponents()];
        int d = 0;
        while (d < data.length) {
            int anz = data[d].getNumberOfElements();
            double w = 1.0;
            int n = 0;
            while (n < anz) {
                Sequence seq = data[d].getElementAt(n);
                if (weights != null && weights[d] != null) {
                    w = weights[d][n];
                }
                if (doc) {
                    k = mix.getIndexOfMaximalComponentFor(seq, 0);
                    seqList[k][d].add(seq);
                    weightsList[k][d].add(w);
                } else {
                    seqList[0][d].add(seq);
                    mix.getLogScores(seq, scores);
                    Normalisation.logSumNormalisation(scores);
                    k = 0;
                    while (k < scores.length) {
                        weightsList[k][d].add(w * scores[k]);
                        ++k;
                    }
                }
                ++n;
            }
            ++d;
        }
        if (doc) {
            return new DataHandler(weightsList, seqList);
        }
        return new DataHandler(weightsList, new LinkedList[][]{seqList[0]});
    }

    @Override
    public void addGradient(double[] grad, int start) throws EvaluationException {
        Arrays.fill(this.mixGrad, 0.0);
        this.mixPrior.addGradientFor(this.mixParams, this.mixGrad);
        int i = 0;
        while (i < this.mixGrad.length) {
            int n = start++;
            grad[n] = grad[n] + this.mixGrad[i];
            ++i;
        }
        i = 0;
        while (i < this.optComponent.length) {
            this.optComponent[i].addGradient(grad, start);
            start += this.optComponent[i].getNumberOfParameters();
            ++i;
        }
    }

    @Override
    public double[] getCurrentParameterValues(OptimizableFunction.KindOfParameter kind) throws Exception {
        return this.getCurrentParameterValues(kind, this.getNumberOfParameters(), null);
    }

    private double[] getCurrentParameterValues(OptimizableFunction.KindOfParameter kind, int numOfParams, double[] arrayToFill) throws Exception {
        if (arrayToFill == null) {
            arrayToFill = new double[this.params.length];
        }
        this.mixParams = this.model.getCurrentParameterValues();
        numOfParams = this.mixParams.length;
        System.arraycopy(this.mixParams, 0, arrayToFill, 0, numOfParams);
        int i = 0;
        while (i < this.optComponent.length) {
            double[] part = this.optComponent[i].getCurrentParameterValues(kind);
            System.arraycopy(part, 0, arrayToFill, numOfParams, part.length);
            numOfParams += part.length;
            ++i;
        }
        return arrayToFill;
    }

    @Override
    public double getLogPriorTerm() throws DimensionException, EvaluationException {
        double logPriorTerm = this.mixPrior.evaluateFunction(this.mixParams);
        int i = 0;
        while (i < this.optComponent.length) {
            logPriorTerm += this.optComponent[i].getLogPriorTerm();
            ++i;
        }
        return logPriorTerm;
    }

    @Override
    public double getLogProb(int classIndex, Sequence seq) throws EvaluationException {
        return this.getLogProb(classIndex, seq, false);
    }

    private double getLogProb(int classIndex, Sequence seq, boolean out) throws EvaluationException {
        this.model.getLogScores(seq, this.logMixScores);
        double logSum = Normalisation.getLogSum(this.logMixScores);
        int i = 0;
        while (i < this.optComponent.length) {
            this.logClassifierProbs[i] = this.optComponent[i].getLogProb(classIndex, seq) + this.logMixScores[i] - logSum;
            ++i;
        }
        return Normalisation.getLogSum(this.logClassifierProbs);
    }

    @Override
    public double getLogProbAndPartialDerivations(int classIndex, Sequence seq, IntList indices, DoubleList partialDer) {
        this.model.getLogScores(seq, this.logMixScores);
        double nlms = Normalisation.getLogSum(this.logMixScores);
        int i = 0;
        while (i < this.optComponent.length) {
            this.indices[i].clear();
            this.partialDer[i].clear();
            this.logClassifierProbs[i] = this.optComponent[i].getLogProbAndPartialDerivations(classIndex, seq, this.indices[i], this.partialDer[i]) + this.logMixScores[i] - nlms;
            ++i;
        }
        double result = Normalisation.logSumNormalisation(this.logClassifierProbs);
        int offset = this.mixParams.length;
        int i2 = 0;
        while (i2 < this.indices.length) {
            int j = 0;
            while (j < this.indices[i2].length()) {
                indices.add(offset + this.indices[i2].get(j));
                partialDer.add(this.logClassifierProbs[i2] * this.partialDer[i2].get(j));
                ++j;
            }
            offset += this.optComponent[i2].getNumberOfParameters();
            ++i2;
        }
        this.model.getPartialDerivations(seq, this.logClassifierProbs, indices, partialDer);
        int i3 = 0;
        while (i3 < indices.length()) {
            if (Double.isNaN(partialDer.get(i3))) {
                System.out.println("partial derivation became: NaN");
                System.out.println(String.valueOf(classIndex) + "\t" + seq);
                System.out.println(Arrays.toString(this.params));
                System.out.println(String.valueOf(i3) + "\t" + partialDer.get(i3));
                System.exit(1);
            }
            ++i3;
        }
        return result;
    }

    @Override
    public int getNumberOfParameters() {
        return this.params == null ? -1 : this.params.length;
    }

    @Override
    public void initialize(DataSet[] s, double[][] weights) throws Exception {
        this.init(s, weights, true);
    }

    @Override
    public void initializeRandomly() throws Exception {
        this.model.initializeFunctionRandomly(false);
        int j = 0;
        while (j < this.optComponent.length) {
            this.optComponent[j].initializeRandomly();
            ++j;
        }
    }

    @Override
    public void reset() throws Exception {
        int i;
        this.mixPrior.set(true, this.model);
        int i2 = 0;
        while (i2 < this.optComponent.length) {
            this.optComponent[i2].reset();
            ++i2;
        }
        int numOfParams = this.model.getNumberOfParameters();
        if (numOfParams != -1) {
            i = 0;
            while (i < this.optComponent.length) {
                int n = this.optComponent[i].getNumberOfParameters();
                if (n == -1) {
                    numOfParams = n;
                    break;
                }
                numOfParams += n;
                ++i;
            }
        }
        if (numOfParams != -1) {
            if (this.params == null || this.params.length != numOfParams) {
                this.params = new double[numOfParams];
                this.mixParams = new double[this.model.getNumberOfParameters()];
                this.mixGrad = new double[this.mixParams.length];
            }
            this.getCurrentParameterValues(OptimizableFunction.KindOfParameter.PLUGIN, numOfParams, this.params);
        } else {
            this.params = null;
        }
        if (this.indices == null) {
            this.indices = new IntList[this.optComponent.length];
            this.partialDer = new DoubleList[this.optComponent.length];
            i = 0;
            while (i < this.optComponent.length) {
                this.indices[i] = new IntList();
                this.partialDer[i] = new DoubleList();
                ++i;
            }
        }
    }

    @Override
    public void setParameters(double[] params, int start) throws Exception {
        System.arraycopy(params, start, this.mixParams, 0, this.mixParams.length);
        System.arraycopy(params, start, this.params, 0, this.params.length);
        this.model.setParameters(params, start);
        start += this.model.getNumberOfParameters();
        int i = 0;
        while (i < this.optComponent.length) {
            this.optComponent[i].setParameters(params, start);
            start += this.optComponent[i].getNumberOfParameters();
            ++i;
        }
    }

    public static class DataHandler {
        private DataSet[][] splits;
        private double[][][] weights;

        public DataHandler(DoubleList[][] weightsList, LinkedList<Sequence>[] ... seqList) throws EmptyDataSetException, WrongAlphabetException {
            int c;
            int noOfSplits = weightsList.length;
            int noOfClasses = weightsList[0].length;
            this.splits = new DataSet[noOfSplits][];
            this.weights = new double[noOfSplits][][];
            Sequence[] empty = new Sequence[]{};
            DataSet[] help = null;
            if (seqList.length == 1) {
                help = new DataSet[noOfClasses];
                c = 0;
                while (c < noOfClasses) {
                    help[c] = new DataSet("", seqList[0][c].toArray(empty));
                    ++c;
                }
            }
            int s = 0;
            while (s < noOfSplits) {
                this.splits[s] = new DataSet[noOfClasses];
                this.weights[s] = new double[noOfClasses][];
                c = 0;
                while (c < noOfClasses) {
                    this.splits[s][c] = seqList.length == 1 ? help[c] : new DataSet("", seqList[s][c].toArray(new Sequence[0]));
                    if (weightsList[s][c].length() == 0) {
                        weightsList[s][c] = null;
                    } else {
                        this.weights[s][c] = weightsList[s][c].toArray();
                    }
                    ++c;
                }
                ++s;
            }
        }

        public int getNumberOfSplits() {
            return this.splits.length;
        }

        public DataSet[] getSplits(int i) {
            return this.splits[i];
        }

        public double[][] getWeightsOfSplits(int i) {
            return this.weights[i];
        }
    }

    public static class MyMixtureScoringFunction
    extends MixtureDiffSM {
        public MyMixtureScoringFunction(int starts, boolean plugIn, DifferentiableStatisticalModel[] component) throws CloneNotSupportedException {
            super(starts, plugIn, component);
        }

        public MyMixtureScoringFunction(StringBuffer xml) throws NonParsableException {
            super(xml);
        }

        public void getLogScores(Sequence seq, double[] logScores) {
            this.fillComponentScores(seq, 0);
            System.arraycopy(this.componentScore, 0, logScores, 0, this.componentScore.length);
        }

        public void getPartialDerivations(Sequence seq, double[] gamma, IntList indices, DoubleList partialDer) {
            int i = 0;
            while (i < this.function.length) {
                this.iList[i].clear();
                this.dList[i].clear();
                this.componentScore[i] = this.logHiddenPotential[i] + this.function[i].getLogScoreAndPartialDerivation(seq, 0, this.iList[i], this.dList[i]);
                ++i;
            }
            Normalisation.logSumNormalisation(this.componentScore);
            int k = this.paramRef.length - 2;
            int l = this.paramRef[k + 1] - this.paramRef[k];
            int i2 = 0;
            while (i2 < this.function.length) {
                double delta_i = gamma[i2] - this.componentScore[i2];
                int j = 0;
                while (j < this.iList[i2].length()) {
                    indices.add(this.paramRef[i2] + this.iList[i2].get(j));
                    partialDer.add(this.dList[i2].get(j) * delta_i);
                    ++j;
                }
                if (i2 < l) {
                    indices.add(this.paramRef[k] + i2);
                    partialDer.add(delta_i);
                }
                ++i2;
            }
        }

        @Override
        protected boolean determineIsNormalized() {
            return false;
        }

        public void init(DataSet[] s, double[][] weights) throws Exception {
            double[] stat = new double[this.function.length];
            SmallDifferenceOfFunctionEvaluationsCondition eps = new SmallDifferenceOfFunctionEvaluationsCondition(1.0E-11);
            int i = 0;
            while (i < stat.length) {
                DifferentiableStatisticalModelWrapperTrainSM myModel = new DifferentiableStatisticalModelWrapperTrainSM(this.function[i], 2, 10, eps, 1.0E-9, 1.0);
                myModel.train(s[i], weights == null ? null : weights[i]);
                this.function[i] = myModel.getFunction();
                if (weights != null && weights[i] != null) {
                    throw new OperationNotSupportedException("yet");
                }
                stat[i] = s[i].getNumberOfElements();
                ++i;
            }
            this.computeHiddenParameter(stat, true);
        }
    }

    public static class OptimizableMSPClassifier
    extends GenDisMixClassifier
    implements OptimizableClassifier {
        private double[] parameter;
        private double[] grad;
        private double[] helpArray;
        private IntList[] indi;
        private DoubleList[] partDer;

        public OptimizableMSPClassifier(GenDisMixClassifierParameterSet params, LogPrior logPrior, DifferentiableStatisticalModel ... score) throws CloneNotSupportedException {
            super(params, logPrior, 0.0, LearningPrinciple.getBeta(LearningPrinciple.MSP), score);
            this.init();
        }

        public OptimizableMSPClassifier(StringBuffer xml) throws NonParsableException {
            super(xml);
            this.init();
        }

        @Override
        public OptimizableMSPClassifier clone() throws CloneNotSupportedException {
            OptimizableMSPClassifier clone = (OptimizableMSPClassifier)super.clone();
            if (this.parameter != null) {
                clone.parameter = (double[])this.parameter.clone();
                clone.grad = (double[])this.grad.clone();
            }
            clone.init();
            return clone;
        }

        private void init() {
            this.helpArray = new double[this.getNumberOfClasses()];
            this.indi = new IntList[this.helpArray.length];
            this.partDer = new DoubleList[this.helpArray.length];
            int i = 0;
            while (i < this.helpArray.length) {
                this.indi[i] = new IntList();
                this.partDer[i] = new DoubleList();
                ++i;
            }
        }

        @Override
        public void addGradient(double[] grad, int start) throws EvaluationException {
            Arrays.fill(this.grad, 0.0);
            this.prior.addGradientFor(this.parameter, this.grad);
            int i = 0;
            while (i < this.grad.length) {
                int n = start + i;
                grad[n] = grad[n] + this.grad[i];
                ++i;
            }
        }

        @Override
        public double getLogPriorTerm() throws DimensionException, EvaluationException {
            return this.prior.evaluateFunction(this.parameter);
        }

        @Override
        public double getLogProb(int classIndex, Sequence seq) {
            int i = 0;
            while (i < this.score.length) {
                this.helpArray[i] = this.getClassWeight(i) + this.score[i].getLogScoreFor(seq);
                ++i;
            }
            return this.helpArray[classIndex] - Normalisation.getLogSum(this.helpArray);
        }

        @Override
        public double getLogProbAndPartialDerivations(int classIndex, Sequence seq, IntList indices, DoubleList partialDer) {
            int i = 0;
            while (i < this.score.length) {
                this.indi[i].clear();
                this.partDer[i].clear();
                this.helpArray[i] = this.getClassWeight(i) + this.score[i].getLogScoreAndPartialDerivation(seq, this.indi[i], this.partDer[i]);
                ++i;
            }
            double logProb = this.helpArray[classIndex] - Normalisation.logSumNormalisation(this.helpArray);
            int i2 = 0;
            int offset = this.score.length;
            while (i2 < this.score.length) {
                indices.add(i2);
                partialDer.add((double)(i2 == classIndex ? 1 : 0) - this.helpArray[i2]);
                int j = 0;
                while (j < this.indi[i2].length()) {
                    indices.add(offset + this.indi[i2].get(j));
                    partialDer.add(this.partDer[i2].get(j) * ((double)(i2 == classIndex ? 1 : 0) - this.helpArray[i2]));
                    ++j;
                }
                offset += this.score[i2].getNumberOfParameters();
                ++i2;
            }
            return logProb;
        }

        @Override
        public int getNumberOfParameters() {
            int i = 0;
            int num = 0;
            while (i < this.score.length) {
                int a;
                if ((a = this.score[i++].getNumberOfParameters()) == -1) {
                    return -1;
                }
                num += a;
            }
            return this.getNumberOfClasses() + num;
        }

        @Override
        public void initialize(DataSet[] data, double[][] weights) throws Exception {
            int i = 0;
            while (i < this.score.length) {
                this.score[i].initializeFunction(i, false, data, weights);
                ++i;
            }
            this.setClassWeights(false, new double[this.getNumberOfClasses()]);
            this.fillParameters();
        }

        @Override
        public void initializeRandomly() throws Exception {
            int i = 0;
            while (i < this.score.length) {
                this.score[i].initializeFunctionRandomly(false);
                ++i;
            }
            this.setClassWeights(false, new double[this.getNumberOfClasses()]);
            this.fillParameters();
        }

        @Override
        public void train(DataSet[] d, double[][] weight) throws Exception {
            super.train(d, weight);
            this.fillParameters();
        }

        @Override
        public void setParameters(double[] params, int start) throws Exception {
            this.setClassWeights(false, params, start);
            start += this.getNumberOfClasses();
            int i = 0;
            while (i < this.score.length) {
                this.score[i].setParameters(params, start);
                start += this.score[i].getNumberOfParameters();
                ++i;
            }
            this.fillParameters();
        }

        private void fillParameters() throws Exception {
            int start = this.getNumberOfParameters();
            if (this.parameter == null || this.parameter.length != start) {
                this.parameter = new double[start];
                this.grad = new double[start];
            }
            start = this.getNumberOfClasses();
            int i = 0;
            while (i < start) {
                this.parameter[i] = this.getClassWeight(i);
                ++i;
            }
            i = 0;
            while (i < this.score.length) {
                double[] help = this.score[i].getCurrentParameterValues();
                System.arraycopy(help, 0, this.parameter, start, help.length);
                start += help.length;
                ++i;
            }
        }

        @Override
        public double[] getCurrentParameterValues(OptimizableFunction.KindOfParameter kind) throws Exception {
            if (this.parameter == null) {
                this.fillParameters();
            }
            double[] res = (double[])this.parameter.clone();
            switch (kind) {
                case PLUGIN: {
                    break;
                }
                case LAST: {
                    break;
                }
                case ZEROS: {
                    Arrays.fill(res, 0, this.score.length, 0.0);
                    break;
                }
                default: {
                    throw new IllegalArgumentException("Unknown kind of parameter");
                }
            }
            return res;
        }

        @Override
        public void reset() throws Exception {
            this.prior.set(false, this.score);
        }

        @Override
        protected OptimizableFunction.KindOfParameter preoptimize(OptimizableFunction f) throws Exception {
            return OptimizableFunction.KindOfParameter.ZEROS;
        }
    }

    public static enum Training {
        SEPARATELY_DOC,
        SEPARATELY_VOC,
        COMBINED;

    }

    public static enum Vote {
        DOC,
        VOC;

    }
}

