package de.jstacs.sequenceScores.statisticalModels.differentiable;

import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.StrandedLocatedSequenceAnnotationWithLength;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.motifDiscovery.Mutable;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.StrandDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import java.text.NumberFormat;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/differentiable/NormalizedDiffSM.class */
public final class NormalizedDiffSM extends AbstractDifferentiableStatisticalModel implements Mutable {
    private DifferentiableStatisticalModel nsf;
    private int starts;
    private double logNorm;
    private double[] proportion;

    public static final DifferentiableStatisticalModel getNormalizedVersion(DifferentiableStatisticalModel differentiableStatisticalModel, int i) throws Exception {
        return differentiableStatisticalModel.isNormalized() ? (DifferentiableStatisticalModel) differentiableStatisticalModel.mo114clone() : new NormalizedDiffSM(differentiableStatisticalModel, i);
    }

    public NormalizedDiffSM(DifferentiableStatisticalModel differentiableStatisticalModel, int i) throws Exception {
        super(differentiableStatisticalModel.getAlphabetContainer(), differentiableStatisticalModel.getLength());
        if (differentiableStatisticalModel instanceof VariableLengthDiffSM) {
            throw new IllegalArgumentException();
        }
        if (i <= 0) {
            throw new IllegalArgumentException("The number of starts has to be positive.");
        }
        this.starts = Math.max(i, differentiableStatisticalModel.getNumberOfRecommendedStarts());
        this.nsf = (DifferentiableStatisticalModel) differentiableStatisticalModel.mo114clone();
        precomputeIfPossible();
    }

    public NormalizedDiffSM(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
        try {
            precomputeIfPossible();
        } catch (Exception e) {
            NonParsableException nonParsableException = new NonParsableException(e.getMessage());
            nonParsableException.setStackTrace(e.getStackTrace());
            throw nonParsableException;
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel, de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore, de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore, de.jstacs.sequenceScores.SequenceScore
    /* renamed from: clone */
    public NormalizedDiffSM mo114clone() throws CloneNotSupportedException {
        NormalizedDiffSM normalizedDiffSM = (NormalizedDiffSM) super.mo114clone();
        normalizedDiffSM.nsf = (DifferentiableStatisticalModel) this.nsf.mo114clone();
        return normalizedDiffSM;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int i) {
        return this.nsf.getSizeOfEventSpaceForRandomVariablesOfParameter(i);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogNormalizationConstant() {
        return 0.0d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogPartialNormalizationConstant(int i) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getESS() {
        return this.nsf.getESS();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogPriorTerm() {
        return this.nsf.getLogPriorTerm() - (this.nsf.getESS() * this.logNorm);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public void addGradientOfLogPriorTerm(double[] dArr, int i) throws Exception {
        this.nsf.addGradientOfLogPriorTerm(dArr, i);
        double ess = this.nsf.getESS();
        for (int i2 = 0; i2 < this.proportion.length; i2++) {
            int i3 = i + i2;
            dArr[i3] = dArr[i3] - (ess * this.proportion[i2]);
        }
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunction(int i, boolean z, DataSet[] dataSetArr, double[][] dArr) throws Exception {
        this.nsf.initializeFunction(i, z, dataSetArr, dArr);
        precompute();
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunctionRandomly(boolean z) throws Exception {
        this.nsf.initializeFunctionRandomly(z);
        precompute();
    }

    @Override // de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore
    protected void fromXML(StringBuffer stringBuffer) throws NonParsableException {
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, getClass().getSimpleName());
        this.nsf = (DifferentiableStatisticalModel) XMLParser.extractObjectForTags(extractForTag, "function", DifferentiableStatisticalModel.class);
        this.alphabets = this.nsf.getAlphabetContainer();
        this.length = this.nsf.getLength();
        this.starts = ((Integer) XMLParser.extractObjectForTags(extractForTag, "starts", Integer.TYPE)).intValue();
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        return "normalized " + this.nsf.getInstanceName();
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public double getLogScoreFor(Sequence sequence, int i) {
        return this.nsf.getLogScoreFor(sequence, i) - this.logNorm;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double getLogScoreAndPartialDerivation(Sequence sequence, int i, IntList intList, DoubleList doubleList) {
        double logScoreAndPartialDerivation = this.nsf.getLogScoreAndPartialDerivation(sequence, i, intList, doubleList) - this.logNorm;
        for (int i2 = 0; i2 < this.proportion.length; i2++) {
            intList.add(i2);
            doubleList.add(-this.proportion[i2]);
        }
        return logScoreAndPartialDerivation;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public int getNumberOfParameters() {
        return this.nsf.getNumberOfParameters();
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double[] getCurrentParameterValues() throws Exception {
        return this.nsf.getCurrentParameterValues();
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void setParameters(double[] dArr, int i) {
        this.nsf.setParameters(dArr, i);
        try {
            precompute();
        } catch (Exception e) {
            RuntimeException runtimeException = new RuntimeException(e.getMessage());
            runtimeException.setStackTrace(e.getStackTrace());
            throw runtimeException;
        }
    }

    private void precomputeIfPossible() throws Exception {
        if (this.nsf.isInitialized()) {
            precompute();
        } else {
            this.logNorm = Double.NEGATIVE_INFINITY;
            this.proportion = null;
        }
    }

    private void precompute() throws Exception {
        if (this.proportion == null) {
            this.proportion = new double[this.nsf.getNumberOfParameters()];
        }
        this.logNorm = this.nsf.getLogNormalizationConstant();
        for (int i = 0; i < this.proportion.length; i++) {
            this.proportion[i] = Math.exp(this.nsf.getLogPartialNormalizationConstant(i) - this.logNorm);
        }
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public boolean isInitialized() {
        return this.nsf.isInitialized();
    }

    @Override // de.jstacs.Storable
    public StringBuffer toXML() {
        StringBuffer stringBuffer = new StringBuffer(100000);
        XMLParser.appendObjectWithTags(stringBuffer, this.nsf, "function");
        XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.starts), "starts");
        XMLParser.addTags(stringBuffer, getClass().getSimpleName());
        return stringBuffer;
    }

    @Override // de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore, de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public int getNumberOfRecommendedStarts() {
        return this.starts;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public boolean isNormalized() {
        return true;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String toString(NumberFormat numberFormat) {
        return "normalized variante of\n" + this.nsf.toString(numberFormat);
    }

    public DifferentiableStatisticalModel getFunction() throws CloneNotSupportedException {
        return (DifferentiableStatisticalModel) this.nsf.mo114clone();
    }

    @Override // de.jstacs.motifDiscovery.Mutable
    public boolean modify(int i, int i2) {
        if (!(this.nsf instanceof Mutable)) {
            return false;
        }
        boolean modify = ((Mutable) this.nsf).modify(i, i2);
        if (modify) {
            this.proportion = null;
            this.length = this.nsf.getLength();
            try {
                precompute();
            } catch (Exception e) {
                RuntimeException runtimeException = new RuntimeException(e.getMessage());
                runtimeException.setStackTrace(e.getStackTrace());
                throw runtimeException;
            }
        }
        return modify;
    }

    public boolean isStrandModel() {
        return this.nsf instanceof NormalizedDiffSM ? ((NormalizedDiffSM) this.nsf).isStrandModel() : this.nsf instanceof StrandDiffSM;
    }

    public StrandedLocatedSequenceAnnotationWithLength.Strand getStrand(Sequence sequence, int i) {
        return this.nsf instanceof NormalizedDiffSM ? ((NormalizedDiffSM) this.nsf).getStrand(sequence, i) : this.nsf instanceof StrandDiffSM ? ((StrandDiffSM) this.nsf).getStrand(sequence, i) : StrandedLocatedSequenceAnnotationWithLength.Strand.FORWARD;
    }

    public void initializeHiddenUniformly() {
        if (this.nsf instanceof NormalizedDiffSM) {
            ((NormalizedDiffSM) this.nsf).initializeHiddenUniformly();
        } else if (this.nsf instanceof AbstractMixtureDiffSM) {
            ((AbstractMixtureDiffSM) this.nsf).initializeHiddenUniformly();
        }
    }
}
