/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.scoringFunctions;

import de.jstacs.NonParsableException;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.data.sequences.annotation.StrandedLocatedSequenceAnnotationWithLength;
import de.jstacs.io.XMLParser;
import de.jstacs.motifDiscovery.Mutable;
import de.jstacs.scoringFunctions.AbstractNormalizableScoringFunction;
import de.jstacs.scoringFunctions.NormalizableScoringFunction;
import de.jstacs.scoringFunctions.mix.AbstractMixtureScoringFunction;
import de.jstacs.scoringFunctions.mix.StrandScoringFunction;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;

public final class NormalizedScoringFunction
extends AbstractNormalizableScoringFunction
implements Mutable {
    private NormalizableScoringFunction nsf;
    private int starts;
    private double logNorm;
    private double[] proportion;

    public static final NormalizableScoringFunction getNormalizedVersion(NormalizableScoringFunction nsf, int starts) throws Exception {
        if (nsf.isNormalized()) {
            return (NormalizableScoringFunction)nsf.clone();
        }
        return new NormalizedScoringFunction(nsf, starts);
    }

    public NormalizedScoringFunction(NormalizableScoringFunction nsf, int starts) throws Exception {
        super(nsf.getAlphabetContainer(), nsf.getLength());
        if (starts <= 0) {
            throw new IllegalArgumentException("The number of starts has to be positive.");
        }
        this.starts = Math.max(starts, nsf.getNumberOfRecommendedStarts());
        this.nsf = (NormalizableScoringFunction)nsf.clone();
        this.precomputeIfPossible();
    }

    public NormalizedScoringFunction(StringBuffer xml) throws NonParsableException {
        super(xml);
        try {
            this.precomputeIfPossible();
        }
        catch (Exception e) {
            NonParsableException n = new NonParsableException(e.getMessage());
            n.setStackTrace(e.getStackTrace());
            throw n;
        }
    }

    @Override
    public NormalizedScoringFunction clone() throws CloneNotSupportedException {
        NormalizedScoringFunction clone = (NormalizedScoringFunction)super.clone();
        clone.nsf = (NormalizableScoringFunction)this.nsf.clone();
        return clone;
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        return this.nsf.getSizeOfEventSpaceForRandomVariablesOfParameter(index);
    }

    @Override
    public double getLogNormalizationConstant() {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double getEss() {
        return this.nsf.getEss();
    }

    @Override
    public double getLogPriorTerm() {
        return this.nsf.getLogPriorTerm() - this.nsf.getEss() * this.logNorm;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        this.nsf.addGradientOfLogPriorTerm(grad, start);
        double e = this.nsf.getEss();
        for (int i = 0; i < this.proportion.length; ++i) {
            int n = start + i;
            grad[n] = grad[n] - e * this.proportion[i];
        }
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, Sample[] data, double[][] weights) throws Exception {
        this.nsf.initializeFunction(index, freeParams, data, weights);
        this.precompute();
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        this.nsf.initializeFunctionRandomly(freeParams);
        this.precompute();
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        StringBuffer b = XMLParser.extractForTag(xml, this.getClass().getSimpleName());
        this.nsf = (NormalizableScoringFunction)XMLParser.extractStorableForTag(b, "function");
        this.alphabets = this.nsf.getAlphabetContainer();
        this.length = this.nsf.getLength();
        this.starts = XMLParser.extractIntForTag(b, "starts");
    }

    @Override
    public String getInstanceName() {
        return "normalized " + this.nsf.getInstanceName();
    }

    @Override
    public double getLogScore(Sequence seq, int start) {
        return this.nsf.getLogScore(seq, start) - this.logNorm;
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        double score = this.nsf.getLogScoreAndPartialDerivation(seq, start, indices, partialDer) - this.logNorm;
        try {
            for (int i = 0; i < this.proportion.length; ++i) {
                indices.add(i);
                partialDer.add(-this.proportion[i]);
            }
        }
        catch (Exception e) {
            RuntimeException r = new RuntimeException(e.getMessage());
            r.setStackTrace(e.getStackTrace());
            throw r;
        }
        return score;
    }

    @Override
    public int getNumberOfParameters() {
        return this.nsf.getNumberOfParameters();
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        return this.nsf.getCurrentParameterValues();
    }

    @Override
    public void setParameters(double[] params, int start) {
        this.nsf.setParameters(params, start);
        try {
            this.precompute();
        }
        catch (Exception e) {
            RuntimeException r = new RuntimeException(e.getMessage());
            r.setStackTrace(e.getStackTrace());
            throw r;
        }
    }

    private void precomputeIfPossible() throws Exception {
        if (this.nsf.isInitialized()) {
            this.precompute();
        } else {
            this.logNorm = Double.NEGATIVE_INFINITY;
            this.proportion = null;
        }
    }

    private void precompute() throws Exception {
        if (this.proportion == null) {
            this.proportion = new double[this.nsf.getNumberOfParameters()];
        }
        this.logNorm = this.nsf.getLogNormalizationConstant();
        for (int i = 0; i < this.proportion.length; ++i) {
            this.proportion[i] = Math.exp(this.nsf.getLogPartialNormalizationConstant(i) - this.logNorm);
        }
    }

    @Override
    public boolean isInitialized() {
        return this.nsf.isInitialized();
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer(100000);
        XMLParser.appendStorableWithTags(xml, this.nsf, "function");
        XMLParser.appendIntWithTags(xml, this.starts, "starts");
        XMLParser.addTags(xml, this.getClass().getSimpleName());
        return xml;
    }

    @Override
    public int getNumberOfRecommendedStarts() {
        return this.starts;
    }

    @Override
    public boolean isNormalized() {
        return true;
    }

    public String toString() {
        return "normalized variante of\n" + this.nsf.toString();
    }

    public NormalizableScoringFunction getFunction() throws CloneNotSupportedException {
        return (NormalizableScoringFunction)this.nsf.clone();
    }

    @Override
    public boolean modify(int offsetLeft, int offsetRight) {
        if (this.nsf instanceof Mutable) {
            boolean res = ((Mutable)((Object)this.nsf)).modify(offsetLeft, offsetRight);
            if (res) {
                this.proportion = null;
                this.length = this.nsf.getLength();
                try {
                    this.precompute();
                }
                catch (Exception e) {
                    RuntimeException r = new RuntimeException(e.getMessage());
                    r.setStackTrace(e.getStackTrace());
                    throw r;
                }
            }
            return res;
        }
        return false;
    }

    public boolean isStrandScoringFunction() {
        if (this.nsf instanceof NormalizedScoringFunction) {
            return ((NormalizedScoringFunction)this.nsf).isStrandScoringFunction();
        }
        return this.nsf instanceof StrandScoringFunction;
    }

    public StrandedLocatedSequenceAnnotationWithLength.Strand getStrand(Sequence seq, int startPos) {
        if (this.nsf instanceof NormalizedScoringFunction) {
            return ((NormalizedScoringFunction)this.nsf).getStrand(seq, startPos);
        }
        if (this.nsf instanceof StrandScoringFunction) {
            return ((StrandScoringFunction)this.nsf).getStrand(seq, startPos);
        }
        return StrandedLocatedSequenceAnnotationWithLength.Strand.FORWARD;
    }

    public void initializeHiddenUniformly() {
        if (this.nsf instanceof NormalizedScoringFunction) {
            ((NormalizedScoringFunction)this.nsf).initializeHiddenUniformly();
        } else if (this.nsf instanceof AbstractMixtureScoringFunction) {
            ((AbstractMixtureScoringFunction)this.nsf).initializeHiddenUniformly();
        }
    }
}

