/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.models;

import de.jstacs.NonParsableException;
import de.jstacs.NotTrainedException;
import de.jstacs.WrongAlphabetException;
import de.jstacs.algorithms.optimization.LimitedMedianStartDistance;
import de.jstacs.algorithms.optimization.NegativeDifferentiableFunction;
import de.jstacs.algorithms.optimization.Optimizer;
import de.jstacs.algorithms.optimization.termination.AbstractTerminationCondition;
import de.jstacs.algorithms.optimization.termination.SmallDifferenceOfFunctionEvaluationsCondition;
import de.jstacs.classifier.scoringFunctionBased.OptimizableFunction;
import de.jstacs.classifier.scoringFunctionBased.gendismix.LearningPrinciple;
import de.jstacs.classifier.scoringFunctionBased.gendismix.LogGenDisMixFunction;
import de.jstacs.classifier.scoringFunctionBased.logPrior.CompositeLogPrior;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.data.WrongLengthException;
import de.jstacs.io.XMLParser;
import de.jstacs.models.AbstractModel;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.scoringFunctions.IndependentProductScoringFunction;
import de.jstacs.scoringFunctions.NormalizableScoringFunction;
import de.jstacs.scoringFunctions.ScoringFunction;
import de.jstacs.scoringFunctions.UniformScoringFunction;
import de.jstacs.scoringFunctions.homogeneous.UniformHomogeneousScoringFunction;
import de.jstacs.utils.SafeOutputStream;
import java.io.OutputStream;

public class NormalizableScoringFunctionModel
extends AbstractModel {
    private SafeOutputStream out;
    protected NormalizableScoringFunction nsf;
    private double logNorm;
    private double lineps;
    private double startD;
    private AbstractTerminationCondition tc;
    private byte algo;
    private int threads;
    private static final String XML_TAG = "NormalizableScoringFunctionModel";

    public NormalizableScoringFunctionModel(NormalizableScoringFunction nsf, int threads, byte algo, AbstractTerminationCondition tc, double lineps, double startD) throws CloneNotSupportedException {
        super(nsf.getAlphabetContainer(), nsf.getLength());
        if (threads < 1) {
            throw new IllegalArgumentException("The number of threads has to be positive.");
        }
        this.threads = threads;
        this.tc = tc.clone();
        if (lineps < 0.0) {
            throw new IllegalArgumentException("The value of lineps has to be non-negative.");
        }
        this.lineps = lineps;
        if (startD <= 0.0) {
            throw new IllegalArgumentException("The value of startD has to be positive.");
        }
        this.startD = startD;
        this.algo = algo;
        this.nsf = (NormalizableScoringFunction)nsf.clone();
        this.logNorm = this.isTrained() ? nsf.getLogNormalizationConstant() : Double.NEGATIVE_INFINITY;
        this.setOutputStream(SafeOutputStream.DEFAULT_STREAM);
    }

    public NormalizableScoringFunctionModel(StringBuffer stringBuff) throws NonParsableException {
        super(stringBuff);
    }

    @Override
    public NormalizableScoringFunctionModel clone() throws CloneNotSupportedException {
        NormalizableScoringFunctionModel clone = (NormalizableScoringFunctionModel)super.clone();
        clone.nsf = (NormalizableScoringFunction)this.nsf.clone();
        clone.tc = this.tc.clone();
        clone.setOutputStream(this.out.doesNothing() ? null : SafeOutputStream.DEFAULT_STREAM);
        return clone;
    }

    @Override
    public void train(Sample data, double[] weights) throws Exception {
        if (!data.getAlphabetContainer().checkConsistency(this.alphabets)) {
            throw new WrongAlphabetException("The AlphabetConatainer of the sample and the model do not match.");
        }
        if (this.length != 0 && this.length != data.getElementLength()) {
            throw new WrongLengthException("The length of the elements of the sample is not suitable for the model.");
        }
        if (this.nsf instanceof IndependentProductScoringFunction) {
            IndependentProductScoringFunction ipsf = (IndependentProductScoringFunction)this.nsf;
            NormalizableScoringFunction[] nsfs = ipsf.getFunctions();
            Sample[] part = new Sample[1];
            Sample[] packedData = new Sample[]{data};
            double[][] packedWeights = new double[][]{weights};
            for (int i = 0; i < nsfs.length; ++i) {
                int a = ipsf.extractSequenceParts(i, packedData, part);
                double[][] partWeights = ipsf.extractWeights(a, packedWeights);
                nsfs[i] = this.train(part[0], partWeights[0], nsfs[i]);
            }
            this.nsf = new IndependentProductScoringFunction(ipsf.getEss(), true, nsfs, ipsf.getIndices(), ipsf.getPartialLengths(), ipsf.getReverseSwitches());
        } else {
            this.nsf = this.train(data, weights, this.nsf);
        }
    }

    private NormalizableScoringFunction train(Sample data, double[] weights, NormalizableScoringFunction nsf) throws Exception {
        if (!(nsf instanceof UniformScoringFunction) && !(nsf instanceof UniformHomogeneousScoringFunction)) {
            Sample.WeightedSampleFactory wsf = new Sample.WeightedSampleFactory(Sample.WeightedSampleFactory.SortOperation.NO_SORT, data, weights);
            Sample small = wsf.getSample();
            double[] smallWeights = wsf.getWeights();
            ScoringFunction best = null;
            double max = Double.NEGATIVE_INFINITY;
            double fac = data.getNumberOfElements();
            double ess = nsf.getEss();
            fac = fac / (ess + fac) * (ess == 0.0 ? 1.0 : 2.0);
            ScoringFunction[] score = new NormalizableScoringFunction[]{(NormalizableScoringFunction)nsf.clone()};
            CompositeLogPrior prior = new CompositeLogPrior();
            double[] beta = LearningPrinciple.getBeta(ess == 0.0 ? LearningPrinciple.ML : LearningPrinciple.MAP);
            LogGenDisMixFunction f = new LogGenDisMixFunction(this.threads, score, new Sample[]{small}, new double[][]{smallWeights}, prior, beta, true, false);
            NegativeDifferentiableFunction minusF = new NegativeDifferentiableFunction(f);
            LimitedMedianStartDistance sd = new LimitedMedianStartDistance(5, this.startD * fac);
            for (int i = 0; i < nsf.getNumberOfRecommendedStarts(); ++i) {
                this.out.writeln("start: " + i);
                score[0].initializeFunction(0, false, new Sample[]{small}, new double[][]{smallWeights});
                f.reset(score);
                double[] params = f.getParameters(OptimizableFunction.KindOfParameter.PLUGIN);
                sd.reset();
                Optimizer.optimize(this.algo, minusF, params, this.tc, this.lineps * fac, sd, this.out);
                double current = f.evaluateFunction(params);
                if (current > max) {
                    best = score[0];
                    max = current;
                }
                score[0] = (NormalizableScoringFunction)nsf.clone();
            }
            this.out.writeln("best: " + max);
            nsf = best;
            this.logNorm = nsf.getLogNormalizationConstant();
            f.stopThreads();
            System.gc();
        }
        return nsf;
    }

    @Override
    public double getProbFor(Sequence sequence, int startpos, int endpos) throws NotTrainedException, Exception {
        return Math.exp(this.getLogProbFor(sequence, startpos, endpos));
    }

    @Override
    public double getLogProbFor(Sequence sequence, int startpos, int endpos) throws NotTrainedException, Exception {
        if (!this.isTrained()) {
            throw new NotTrainedException();
        }
        if (!sequence.getAlphabetContainer().checkConsistency(this.alphabets)) {
            throw new WrongAlphabetException("The AlphabetContainer of the sequence and the model do not match.");
        }
        if (startpos < 0) {
            throw new IllegalArgumentException("Check start position.");
        }
        if (endpos + 1 < startpos || endpos >= sequence.getLength()) {
            throw new IllegalArgumentException("Check end position.");
        }
        if (this.length != 0 && this.length != endpos - startpos + 1) {
            throw new WrongLengthException("Check length of the sequence.");
        }
        return this.nsf.getLogScore(sequence, startpos) - this.logNorm;
    }

    @Override
    public double getLogPriorTerm() throws Exception {
        return this.nsf.getLogPriorTerm() - this.nsf.getEss() * this.logNorm;
    }

    @Override
    public String getInstanceName() {
        return "model using " + this.nsf.getInstanceName();
    }

    @Override
    public boolean isTrained() {
        return this.nsf.isInitialized();
    }

    @Override
    public NumericalResultSet getNumericalCharacteristics() throws Exception {
        return null;
    }

    @Override
    public String toString() {
        return this.nsf.toString();
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        StringBuffer rep = XMLParser.extractForTag(xml, XML_TAG);
        this.nsf = XMLParser.extractObjectForTags(rep, "NormalizableScoringFunction", NormalizableScoringFunction.class);
        this.threads = XMLParser.extractObjectForTags(rep, "threads", Integer.TYPE);
        this.algo = XMLParser.extractObjectForTags(rep, "algorithm", Byte.TYPE);
        if (XMLParser.hasTag(rep, "terminationCondition", null, null)) {
            this.tc = (AbstractTerminationCondition)XMLParser.extractObjectForTags(rep, "terminationCondition");
        } else {
            try {
                this.tc = new SmallDifferenceOfFunctionEvaluationsCondition(XMLParser.extractObjectForTags(rep, "eps", Double.TYPE));
            }
            catch (Exception e) {
                NonParsableException n = new NonParsableException(e.getMessage());
                throw n;
            }
        }
        this.lineps = XMLParser.extractObjectForTags(rep, "lineps", Double.TYPE);
        this.startD = XMLParser.extractObjectForTags(rep, "startDistance", Double.TYPE);
        this.logNorm = this.isTrained() ? this.nsf.getLogNormalizationConstant() : Double.NEGATIVE_INFINITY;
        this.alphabets = this.nsf.getAlphabetContainer();
        this.length = this.nsf.getLength();
        this.setOutputStream(SafeOutputStream.DEFAULT_STREAM);
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer(100000);
        XMLParser.appendObjectWithTags(xml, this.nsf, "NormalizableScoringFunction");
        XMLParser.appendObjectWithTags(xml, this.threads, "threads");
        XMLParser.appendObjectWithTags(xml, this.algo, "algorithm");
        XMLParser.appendObjectWithTags(xml, this.tc, "tc");
        XMLParser.appendObjectWithTags(xml, this.lineps, "lineps");
        XMLParser.appendObjectWithTags(xml, this.startD, "startDistance");
        XMLParser.addTags(xml, XML_TAG);
        return xml;
    }

    public final void setOutputStream(OutputStream o) {
        this.out = new SafeOutputStream(o);
    }

    public NormalizableScoringFunction getFunction() throws CloneNotSupportedException {
        return (NormalizableScoringFunction)this.nsf.clone();
    }
}

