/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.differentiable;

import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.StrandedLocatedSequenceAnnotationWithLength;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.motifDiscovery.Mutable;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.VariableLengthDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.StrandDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;

public final class NormalizedDiffSM
extends AbstractDifferentiableStatisticalModel
implements Mutable {
    private DifferentiableStatisticalModel nsf;
    private int starts;
    private double logNorm;
    private double[] proportion;

    public static final DifferentiableStatisticalModel getNormalizedVersion(DifferentiableStatisticalModel nsf, int starts) throws Exception {
        if (nsf.isNormalized()) {
            return (DifferentiableStatisticalModel)nsf.clone();
        }
        return new NormalizedDiffSM(nsf, starts);
    }

    public NormalizedDiffSM(DifferentiableStatisticalModel nsf, int starts) throws Exception {
        super(nsf.getAlphabetContainer(), nsf.getLength());
        if (nsf instanceof VariableLengthDiffSM) {
            throw new IllegalArgumentException();
        }
        if (starts <= 0) {
            throw new IllegalArgumentException("The number of starts has to be positive.");
        }
        this.starts = Math.max(starts, nsf.getNumberOfRecommendedStarts());
        this.nsf = (DifferentiableStatisticalModel)nsf.clone();
        this.precomputeIfPossible();
    }

    public NormalizedDiffSM(StringBuffer xml) throws NonParsableException {
        super(xml);
        try {
            this.precomputeIfPossible();
        }
        catch (Exception e) {
            NonParsableException n = new NonParsableException(e.getMessage());
            n.setStackTrace(e.getStackTrace());
            throw n;
        }
    }

    @Override
    public NormalizedDiffSM clone() throws CloneNotSupportedException {
        NormalizedDiffSM clone = (NormalizedDiffSM)super.clone();
        clone.nsf = (DifferentiableStatisticalModel)this.nsf.clone();
        return clone;
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        return this.nsf.getSizeOfEventSpaceForRandomVariablesOfParameter(index);
    }

    @Override
    public double getLogNormalizationConstant() {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double getESS() {
        return this.nsf.getESS();
    }

    @Override
    public double getLogPriorTerm() {
        return this.nsf.getLogPriorTerm() - this.nsf.getESS() * this.logNorm;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        this.nsf.addGradientOfLogPriorTerm(grad, start);
        double e = this.nsf.getESS();
        for (int i = 0; i < this.proportion.length; ++i) {
            int n = start + i;
            grad[n] = grad[n] - e * this.proportion[i];
        }
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        this.nsf.initializeFunction(index, freeParams, data, weights);
        this.precompute();
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        this.nsf.initializeFunctionRandomly(freeParams);
        this.precompute();
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        StringBuffer b = XMLParser.extractForTag(xml, this.getClass().getSimpleName());
        this.nsf = XMLParser.extractObjectForTags(b, "function", DifferentiableStatisticalModel.class);
        this.alphabets = this.nsf.getAlphabetContainer();
        this.length = this.nsf.getLength();
        this.starts = XMLParser.extractObjectForTags(b, "starts", Integer.TYPE);
    }

    @Override
    public String getInstanceName() {
        return "normalized " + this.nsf.getInstanceName();
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        return this.nsf.getLogScoreFor(seq, start) - this.logNorm;
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        double score = this.nsf.getLogScoreAndPartialDerivation(seq, start, indices, partialDer) - this.logNorm;
        for (int i = 0; i < this.proportion.length; ++i) {
            indices.add(i);
            partialDer.add(-this.proportion[i]);
        }
        return score;
    }

    @Override
    public int getNumberOfParameters() {
        return this.nsf.getNumberOfParameters();
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        return this.nsf.getCurrentParameterValues();
    }

    @Override
    public void setParameters(double[] params, int start) {
        this.nsf.setParameters(params, start);
        try {
            this.precompute();
        }
        catch (Exception e) {
            RuntimeException r = new RuntimeException(e.getMessage());
            r.setStackTrace(e.getStackTrace());
            throw r;
        }
    }

    private void precomputeIfPossible() throws Exception {
        if (this.nsf.isInitialized()) {
            this.precompute();
        } else {
            this.logNorm = Double.NEGATIVE_INFINITY;
            this.proportion = null;
        }
    }

    private void precompute() throws Exception {
        if (this.proportion == null) {
            this.proportion = new double[this.nsf.getNumberOfParameters()];
        }
        this.logNorm = this.nsf.getLogNormalizationConstant();
        for (int i = 0; i < this.proportion.length; ++i) {
            this.proportion[i] = Math.exp(this.nsf.getLogPartialNormalizationConstant(i) - this.logNorm);
        }
    }

    @Override
    public boolean isInitialized() {
        return this.nsf.isInitialized();
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer(100000);
        XMLParser.appendObjectWithTags(xml, this.nsf, "function");
        XMLParser.appendObjectWithTags(xml, this.starts, "starts");
        XMLParser.addTags(xml, this.getClass().getSimpleName());
        return xml;
    }

    @Override
    public int getNumberOfRecommendedStarts() {
        return this.starts;
    }

    @Override
    public boolean isNormalized() {
        return true;
    }

    public String toString() {
        return "normalized variante of\n" + this.nsf.toString();
    }

    public DifferentiableStatisticalModel getFunction() throws CloneNotSupportedException {
        return (DifferentiableStatisticalModel)this.nsf.clone();
    }

    @Override
    public boolean modify(int offsetLeft, int offsetRight) {
        if (this.nsf instanceof Mutable) {
            boolean res = ((Mutable)((Object)this.nsf)).modify(offsetLeft, offsetRight);
            if (res) {
                this.proportion = null;
                this.length = this.nsf.getLength();
                try {
                    this.precompute();
                }
                catch (Exception e) {
                    RuntimeException r = new RuntimeException(e.getMessage());
                    r.setStackTrace(e.getStackTrace());
                    throw r;
                }
            }
            return res;
        }
        return false;
    }

    public boolean isStrandModel() {
        if (this.nsf instanceof NormalizedDiffSM) {
            return ((NormalizedDiffSM)this.nsf).isStrandModel();
        }
        return this.nsf instanceof StrandDiffSM;
    }

    public StrandedLocatedSequenceAnnotationWithLength.Strand getStrand(Sequence seq, int startPos) {
        if (this.nsf instanceof NormalizedDiffSM) {
            return ((NormalizedDiffSM)this.nsf).getStrand(seq, startPos);
        }
        if (this.nsf instanceof StrandDiffSM) {
            return ((StrandDiffSM)this.nsf).getStrand(seq, startPos);
        }
        return StrandedLocatedSequenceAnnotationWithLength.Strand.FORWARD;
    }

    public void initializeHiddenUniformly() {
        if (this.nsf instanceof NormalizedDiffSM) {
            ((NormalizedDiffSM)this.nsf).initializeHiddenUniformly();
        } else if (this.nsf instanceof AbstractMixtureDiffSM) {
            ((AbstractMixtureDiffSM)this.nsf).initializeHiddenUniformly();
        }
    }
}

