package supplementary.cookbook.recipes;

import cern.colt.matrix.impl.AbstractFormatter;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import java.text.NumberFormat;
import java.util.Arrays;
import org.biojava.bio.program.tagvalue.TagValueParser;
import projects.dimont.DimontParameterSet;
import projects.dispom.DispomParameterSet;

/* loaded from: input_file:supplementary/cookbook/recipes/PositionWeightMatrixDiffSM.class */
public class PositionWeightMatrixDiffSM extends AbstractDifferentiableStatisticalModel {
    protected double[][] parameters;
    private double ess;
    private boolean isInitialized;
    protected Double norm;

    public PositionWeightMatrixDiffSM(AlphabetContainer alphabetContainer, int i, double d) throws IllegalArgumentException {
        super(alphabetContainer, i);
        if (!alphabetContainer.isSimple() || !alphabetContainer.isDiscrete()) {
            throw new IllegalArgumentException("This PWM can handle only discrete alphabets with the same alphabet at each position.");
        }
        this.parameters = new double[i][(int) alphabetContainer.getAlphabetLengthAt(0)];
        this.ess = d;
        this.isInitialized = false;
        this.norm = null;
    }

    public PositionWeightMatrixDiffSM(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int i) {
        return this.parameters[0].length;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogNormalizationConstant() {
        if (this.norm == null) {
            this.norm = Double.valueOf(0.0d);
            for (int i = 0; i < this.parameters.length; i++) {
                this.norm = Double.valueOf(this.norm.doubleValue() + Normalisation.getLogSum(this.parameters[i]));
            }
        }
        return this.norm.doubleValue();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogPartialNormalizationConstant(int i) throws Exception {
        if (this.norm == null) {
            getLogNormalizationConstant();
        }
        int alphabetLengthAt = i % ((int) this.alphabets.getAlphabetLengthAt(0));
        int alphabetLengthAt2 = i / ((int) this.alphabets.getAlphabetLengthAt(0));
        return (this.norm.doubleValue() - Normalisation.getLogSum(this.parameters[alphabetLengthAt2])) + this.parameters[alphabetLengthAt2][alphabetLengthAt];
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel, de.jstacs.sequenceScores.statisticalModels.StatisticalModel
    public double getLogPriorTerm() {
        double d = 0.0d;
        for (int i = 0; i < this.parameters.length; i++) {
            for (int i2 = 0; i2 < this.parameters[i].length; i2++) {
                d += (this.ess / this.alphabets.getAlphabetLengthAt(0)) * this.parameters[i][i2];
            }
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public void addGradientOfLogPriorTerm(double[] dArr, int i) throws Exception {
        for (int i2 = 0; i2 < this.parameters.length; i2++) {
            int i3 = 0;
            while (i3 < this.parameters[i2].length) {
                dArr[i] = this.ess / this.alphabets.getAlphabetLengthAt(0);
                i3++;
                i++;
            }
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getESS() {
        return this.ess;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunction(int i, boolean z, DataSet[] dataSetArr, double[][] dArr) throws Exception {
        if (!dataSetArr[i].getAlphabetContainer().checkConsistency(this.alphabets) || dataSetArr[i].getElementLength() != this.length) {
            throw new IllegalArgumentException("Alphabet or length to not match.");
        }
        for (int i2 = 0; i2 < this.parameters.length; i2++) {
            Arrays.fill(this.parameters[i2], this.ess / this.alphabets.getAlphabetLengthAt(0));
        }
        for (int i3 = 0; i3 < dataSetArr[i].getNumberOfElements(); i3++) {
            Sequence elementAt = dataSetArr[i].getElementAt(i3);
            for (int i4 = 0; i4 < elementAt.getLength(); i4++) {
                double[] dArr2 = this.parameters[i4];
                int discreteVal = elementAt.discreteVal(i4);
                dArr2[discreteVal] = dArr2[discreteVal] + dArr[i][i3];
            }
        }
        for (int i5 = 0; i5 < this.parameters.length; i5++) {
            Normalisation.sumNormalisation(this.parameters[i5]);
            for (int i6 = 0; i6 < this.parameters[i5].length; i6++) {
                this.parameters[i5][i6] = Math.log(this.parameters[i5][i6]);
            }
        }
        this.norm = null;
        this.isInitialized = true;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void initializeFunctionRandomly(boolean z) throws Exception {
        int alphabetLengthAt = (int) this.alphabets.getAlphabetLengthAt(0);
        DirichletMRGParams dirichletMRGParams = new DirichletMRGParams(this.ess / alphabetLengthAt, alphabetLengthAt);
        for (int i = 0; i < this.parameters.length; i++) {
            this.parameters[i] = DirichletMRG.DEFAULT_INSTANCE.generate(alphabetLengthAt, dirichletMRGParams);
            for (int i2 = 0; i2 < this.parameters[i].length; i2++) {
                this.parameters[i][i2] = Math.log(this.parameters[i][i2]);
            }
        }
        this.norm = null;
        this.isInitialized = true;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public double getLogScoreFor(Sequence sequence, int i) {
        double d = 0.0d;
        for (int i2 = 0; i2 < this.parameters.length; i2++) {
            d += this.parameters[i2][sequence.discreteVal(i2 + i)];
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double getLogScoreAndPartialDerivation(Sequence sequence, int i, IntList intList, DoubleList doubleList) {
        double d = 0.0d;
        int i2 = 0;
        for (int i3 = 0; i3 < this.parameters.length; i3++) {
            int discreteVal = sequence.discreteVal(i3 + i);
            d += this.parameters[i3][discreteVal];
            intList.add(i2 + discreteVal);
            doubleList.add(1.0d);
            i2 += this.parameters[i3].length;
        }
        return d;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public int getNumberOfParameters() {
        int i = 0;
        for (int i2 = 0; i2 < this.parameters.length; i2++) {
            i += this.parameters[i2].length;
        }
        return i;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double[] getCurrentParameterValues() throws Exception {
        double[] dArr = new double[getNumberOfParameters()];
        int i = 0;
        for (int i2 = 0; i2 < this.parameters.length; i2++) {
            int i3 = 0;
            while (i3 < this.parameters[i2].length) {
                dArr[i] = this.parameters[i2][i3];
                i3++;
                i++;
            }
        }
        return dArr;
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public void setParameters(double[] dArr, int i) {
        for (int i2 = 0; i2 < this.parameters.length; i2++) {
            int i3 = 0;
            while (i3 < this.parameters[i2].length) {
                this.parameters[i2][i3] = dArr[i];
                i3++;
                i++;
            }
        }
        this.norm = null;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        return "Position weight matrix";
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public boolean isInitialized() {
        return this.isInitialized;
    }

    @Override // de.jstacs.Storable
    public StringBuffer toXML() {
        StringBuffer stringBuffer = new StringBuffer();
        XMLParser.appendObjectWithTags(stringBuffer, this.alphabets, "alphabets");
        XMLParser.appendObjectWithTags(stringBuffer, Integer.valueOf(this.length), DispomParameterSet.LENGTH);
        XMLParser.appendObjectWithTags(stringBuffer, this.parameters, "parameters");
        XMLParser.appendObjectWithTags(stringBuffer, Boolean.valueOf(this.isInitialized), "isInitialized");
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.ess), DimontParameterSet.ESS);
        XMLParser.addTags(stringBuffer, "PWM");
        return stringBuffer;
    }

    @Override // de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore
    protected void fromXML(StringBuffer stringBuffer) throws NonParsableException {
        StringBuffer extractForTag = XMLParser.extractForTag(stringBuffer, "PWM");
        this.alphabets = (AlphabetContainer) XMLParser.extractObjectForTags(extractForTag, "alphabets");
        this.length = ((Integer) XMLParser.extractObjectForTags(extractForTag, DispomParameterSet.LENGTH, Integer.TYPE)).intValue();
        this.parameters = (double[][]) XMLParser.extractObjectForTags(extractForTag, "parameters");
        this.isInitialized = ((Boolean) XMLParser.extractObjectForTags(extractForTag, "isInitialized", Boolean.TYPE)).booleanValue();
        this.ess = ((Double) XMLParser.extractObjectForTags(extractForTag, DimontParameterSet.ESS, Double.TYPE)).doubleValue();
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String toString(NumberFormat numberFormat) {
        String str = TagValueParser.EMPTY_LINE_EOR;
        for (int i = 0; i < this.parameters.length; i++) {
            double logSum = Normalisation.getLogSum(this.parameters[i]);
            String str2 = String.valueOf(str) + i;
            for (int i2 = 0; i2 < this.parameters[i].length; i2++) {
                str2 = String.valueOf(str2) + "\t" + numberFormat.format(Math.exp(this.parameters[i][i2] - logSum));
            }
            str = String.valueOf(str2) + AbstractFormatter.DEFAULT_ROW_SEPARATOR;
        }
        return str;
    }
}
