/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.classifiers.differentiableSequenceScoreBased;

import de.jstacs.NotTrainedException;
import de.jstacs.algorithms.optimization.ConstantStartDistance;
import de.jstacs.algorithms.optimization.MultiThreadedFunction;
import de.jstacs.algorithms.optimization.termination.AbstractTerminationCondition;
import de.jstacs.classifiers.AbstractScoreBasedClassifier;
import de.jstacs.classifiers.ClassDimensionException;
import de.jstacs.classifiers.differentiableSequenceScoreBased.DiffSSBasedOptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.OptimizableFunction;
import de.jstacs.classifiers.differentiableSequenceScoreBased.ScoreClassifierParameterSet;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.motifDiscovery.MutableMotifDiscovererToolbox;
import de.jstacs.motifDiscovery.history.History;
import de.jstacs.results.CategoricalResult;
import de.jstacs.results.NumericalResult;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.utils.SafeOutputStream;
import java.io.OutputStream;
import java.util.Arrays;

public abstract class ScoreClassifier
extends AbstractScoreBasedClassifier {
    protected DifferentiableSequenceScore[] score;
    protected ScoreClassifierParameterSet params;
    protected boolean hasBeenOptimized;
    private double lastScore;
    protected SafeOutputStream sostream;
    public static final double NOT_TRAINED_VALUE = Double.NaN;
    protected History template = null;

    public ScoreClassifier(ScoreClassifierParameterSet params, double lastScore, DifferentiableSequenceScore ... score) throws CloneNotSupportedException {
        super(params.getAlphabetContainer(), params.getLength(), score.length);
        int len = this.getLength();
        AlphabetContainer con = this.getAlphabetContainer();
        for (int i = 0; i < score.length; ++i) {
            int l = score[i].getLength();
            if ((l == 0 || l == len) && con.checkConsistency(score[i].getAlphabetContainer())) {
                continue;
            }
            throw new IllegalArgumentException("Please check the length (" + l + " vs. " + len + ")" + " and the AlphabetContainer of the DifferentiableSequenceScore with index " + i + ".");
        }
        this.score = (DifferentiableSequenceScore[])ArrayHandler.clone((Cloneable[])score);
        this.hasBeenOptimized = false;
        this.lastScore = this.isInitialized() ? lastScore : Double.NaN;
        this.set((ScoreClassifierParameterSet)params.clone());
    }

    public ScoreClassifier(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public ScoreClassifier clone() throws CloneNotSupportedException {
        ScoreClassifier clone = (ScoreClassifier)super.clone();
        clone.params = (ScoreClassifierParameterSet)this.params.clone();
        clone.score = (DifferentiableSequenceScore[])ArrayHandler.clone((Cloneable[])this.score);
        clone.setOutputStream(this.sostream.doesNothing() ? null : SafeOutputStream.DEFAULT_STREAM);
        return clone;
    }

    @Override
    public String getInstanceName() {
        return this.getClass().getSimpleName();
    }

    @Override
    public CategoricalResult[] getClassifierAnnotation() {
        CategoricalResult[] res = new CategoricalResult[this.score.length + 1];
        res[0] = new CategoricalResult("classifier", "a <b>short</b> description of the classifier", this.getInstanceName());
        int i = 0;
        while (i < this.score.length) {
            res[i + 1] = new CategoricalResult("class info " + i, "some information about the class", this.score[i++].getInstanceName());
        }
        return res;
    }

    @Override
    public NumericalResultSet getNumericalCharacteristics() throws Exception {
        NumericalResult[] pars = new NumericalResult[this.score.length + (this.hasBeenOptimized ? 1 : 0)];
        if (this.hasBeenOptimized) {
            pars[0] = new NumericalResult("Last score", "The final score after the optimization", this.lastScore);
        }
        int i = 0;
        while (i < this.score.length) {
            pars[i + (this.hasBeenOptimized ? 1 : 0)] = new NumericalResult("Number of parameters " + (i + 1), "The number of parameters for scoring function " + (i + 1) + ", -1 indicates unknown number of parameters.", this.score[i].getNumberOfParameters());
            ++i;
        }
        return new NumericalResultSet(new NumericalResult[][]{pars});
    }

    @Override
    public boolean isInitialized() {
        int i = 0;
        while (i < this.score.length && this.score[i].isInitialized()) {
            ++i;
        }
        return i == this.score.length;
    }

    public boolean hasBeenOptimized() {
        return this.hasBeenOptimized;
    }

    public void setOutputStream(OutputStream o) {
        this.sostream = SafeOutputStream.getSafeOutputStream(o);
    }

    @Override
    public void train(DataSet[] data, double[][] weights) throws Exception {
        this.hasBeenOptimized = false;
        if (weights == null) {
            weights = new double[data.length][];
        }
        if (data.length > 1 && data.length != ((double[][])weights).length) {
            throw new IllegalArgumentException("data and weights do not match");
        }
        if (this.score.length != ((double[][])weights).length) {
            throw new ClassDimensionException();
        }
        DataSet[] reduced = new DataSet[data.length];
        double[][] newWeights = new double[((double[][])weights).length][];
        AlphabetContainer abc = this.getAlphabetContainer();
        int j = 0;
        int l = this.getLength();
        int i = 0;
        while (i < this.score.length) {
            DataSet.WeightedDataSetFactory wsf;
            if (weights[i] != null && data[j].getNumberOfElements() != weights[i].length) {
                throw new IllegalArgumentException("At least for one data set: The dimension of the data set and the weight do not match.");
            }
            if (i == 0 || data.length > 1) {
                if (!abc.checkConsistency(data[j].getAlphabetContainer())) {
                    throw new IllegalArgumentException("At least one data set is not defined over the correct alphabets.");
                }
                wsf = data[i].getElementLength() != l ? new DataSet.WeightedDataSetFactory(DataSet.WeightedDataSetFactory.SortOperation.NO_SORT, data[i], weights[i], l) : new DataSet.WeightedDataSetFactory(DataSet.WeightedDataSetFactory.SortOperation.NO_SORT, data[i], weights[i]);
                reduced[i] = wsf.getDataSet();
                newWeights[i] = wsf.getWeights();
            } else {
                wsf = data[j].getElementLength() != l ? new DataSet.WeightedDataSetFactory(DataSet.WeightedDataSetFactory.SortOperation.NO_SORT, data[j], weights[i], l) : new DataSet.WeightedDataSetFactory(DataSet.WeightedDataSetFactory.SortOperation.NO_SORT, data[j], weights[i]);
                newWeights[i] = wsf.getWeights();
            }
            if (data.length > 1) {
                ++j;
            }
            ++i;
        }
        this.lastScore = this.doOptimization(reduced, newWeights);
    }

    protected int getIterations() {
        return AbstractDifferentiableStatisticalModel.getNumberOfStarts(this.score);
    }

    protected double doOptimization(DataSet[] reduced, double[][] newWeights) throws Exception {
        byte algo = (Byte)this.params.getParameterForName("algorithm").getValue();
        AbstractTerminationCondition tc = this.params.getTerminantionCondition();
        double linEps = (Double)this.params.getParameterForName("line epsilon").getValue();
        double startDist = (Double)this.params.getParameterForName("start distance").getValue();
        double[] best = null;
        Object res = new double[2][];
        double max = Double.NEGATIVE_INFINITY;
        int iterations = this.getIterations();
        this.sostream.writeln(this.getInstanceName());
        DifferentiableSequenceScore[] bestSF = new DifferentiableSequenceScore[this.score.length];
        Cloneable[] secure = iterations > 1 ? (DifferentiableSequenceScore[])ArrayHandler.clone((Cloneable[])this.score) : null;
        History[][] hist = MutableMotifDiscovererToolbox.createHistoryArray(this.score, this.template);
        int[][] minimalNewLength = MutableMotifDiscovererToolbox.createMinimalNewLengthArray(this.score);
        ConstantStartDistance sd = new ConstantStartDistance(startDist);
        DiffSSBasedOptimizableFunction f = this.getFunction(reduced, newWeights);
        int i = 0;
        while (i < iterations) {
            this.createStructure(reduced, newWeights);
            f.reset(this.score);
            if (i == 0) {
                this.sostream.writeln("optimizing " + f.getDimensionOfScope() + " parameters");
            }
            this.sostream.writeln("start " + ++i + ":");
            OptimizableFunction.KindOfParameter plugIn = this.preoptimize(f);
            MutableMotifDiscovererToolbox.clearHistoryArray(hist);
            sd.reset();
            res = MutableMotifDiscovererToolbox.optimize(this.score, f, algo, tc, linEps, sd, this.sostream, false, hist, minimalNewLength, plugIn, true);
            double current = res[0][0];
            if (current > max) {
                System.arraycopy(this.score, 0, bestSF, 0, this.score.length);
                best = res[1];
                max = current;
                System.gc();
            }
            if (iterations <= 1) continue;
            this.score = (DifferentiableSequenceScore[])ArrayHandler.clone((Cloneable[])secure);
            this.sostream.doesNothing();
        }
        this.sostream.writeln("best = " + max);
        this.score = bestSF;
        this.setClassWeights(false, best);
        this.hasBeenOptimized = true;
        if (f instanceof MultiThreadedFunction) {
            f.stopThreads();
        }
        return max;
    }

    protected OptimizableFunction.KindOfParameter preoptimize(OptimizableFunction f) throws Exception {
        return (OptimizableFunction.KindOfParameter)((Object)this.params.getParameterForName(OptimizableFunction.KindOfParameter.class.getSimpleName()).getValue());
    }

    protected void createStructure(DataSet[] data, double[][] weights, boolean initRandomly) throws Exception {
        Object[] d;
        boolean freeParams = this.params.useOnlyFreeParameter();
        if (!initRandomly && data.length == 1 && weights != null && weights.length > 1) {
            d = new DataSet[weights.length];
            Arrays.fill(d, data[0]);
        } else {
            d = data;
        }
        int i = 0;
        while (i < this.score.length) {
            if (initRandomly) {
                this.score[i].initializeFunctionRandomly(freeParams);
            } else {
                this.score[i].initializeFunction(i, freeParams, (DataSet[])d, weights);
            }
            ++i;
        }
    }

    public void initUsingParameters(double[] parameters) throws Exception {
        this.createStructure(null, null, true);
        double[] cw = new double[this.score.length];
        int i = 0;
        while (i < cw.length - (this.params.useOnlyFreeParameter() ? 1 : 0)) {
            cw[i] = parameters[i];
            ++i;
        }
        this.setClassWeights(false, cw, 0);
        int off = this.score.length - (this.params.useOnlyFreeParameter() ? 1 : 0);
        int i2 = 0;
        while (i2 < this.score.length) {
            this.score[i2].setParameters(parameters, off);
            off += this.score[i2].getNumberOfParameters();
            ++i2;
        }
    }

    protected void createStructure(DataSet[] data, double[][] weights) throws Exception {
        this.createStructure(data, weights, false);
    }

    @Override
    protected void extractFurtherClassifierInfosFromXML(StringBuffer xml) throws NonParsableException {
        super.extractFurtherClassifierInfosFromXML(xml);
        this.set((ScoreClassifierParameterSet)XMLParser.extractObjectForTags(xml, "params"));
        this.hasBeenOptimized = XMLParser.extractObjectForTags(xml, "hasBeenOptimized", Boolean.TYPE);
        this.lastScore = XMLParser.extractObjectForTags(xml, "lastScore", Double.TYPE);
        this.score = (DifferentiableSequenceScore[])XMLParser.extractObjectForTags(xml, "score");
    }

    protected abstract DiffSSBasedOptimizableFunction getFunction(DataSet[] var1, double[][] var2) throws Exception;

    @Override
    protected StringBuffer getFurtherClassifierInfos() {
        StringBuffer xml = super.getFurtherClassifierInfos();
        XMLParser.appendObjectWithTags(xml, this.params, "params");
        XMLParser.appendObjectWithTags(xml, this.hasBeenOptimized, "hasBeenOptimized");
        XMLParser.appendObjectWithTags(xml, this.lastScore, "lastScore");
        XMLParser.appendObjectWithTags(xml, this.score, "score");
        return xml;
    }

    @Override
    protected double getScore(Sequence seq, int i, boolean check) throws IllegalArgumentException, NotTrainedException, Exception {
        if (check) {
            this.check(seq);
        }
        return this.getClassWeight(i) + this.score[i].getLogScoreFor(seq, 0);
    }

    public double getLastScore() {
        return this.lastScore;
    }

    public DifferentiableSequenceScore getDifferentiableSequenceScore(int i) throws CloneNotSupportedException {
        return this.score[i].clone();
    }

    public DifferentiableSequenceScore[] getDifferentiableSequenceScores() throws CloneNotSupportedException {
        return (DifferentiableSequenceScore[])ArrayHandler.clone((Cloneable[])this.score);
    }

    @Override
    protected abstract String getXMLTag();

    private void set(ScoreClassifierParameterSet params) {
        this.params = params;
        this.setOutputStream(SafeOutputStream.DEFAULT_STREAM);
    }

    public ScoreClassifierParameterSet getCurrentParameterSet() throws CloneNotSupportedException {
        return (ScoreClassifierParameterSet)this.params.clone();
    }
}

