package de.jstacs.sequenceScores.statisticalModels.differentiable.mixture;

import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.motifDiscovery.MotifDiscoverer;
import de.jstacs.motifDiscovery.MutableMotifDiscoverer;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.VariableLengthDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import java.text.NumberFormat;
import java.util.Arrays;
import javax.naming.OperationNotSupportedException;
import org.biojavax.bio.seq.Position;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/differentiable/mixture/MixtureDiffSM.class */
public class MixtureDiffSM extends AbstractMixtureDiffSM implements MutableMotifDiscoverer {
    private int[] motifsRef;

    public MixtureDiffSM(int i, boolean z, DifferentiableStatisticalModel... differentiableStatisticalModelArr) throws CloneNotSupportedException {
        super(differentiableStatisticalModelArr[0].getLength(), i, differentiableStatisticalModelArr.length, true, z, differentiableStatisticalModelArr);
        for (int i2 = 0; i2 < differentiableStatisticalModelArr.length; i2++) {
            int length = differentiableStatisticalModelArr[i2].getLength();
            if (length != 0 && this.length != length) {
                throw new IllegalArgumentException("The length of component " + i2 + " is " + length + " but should be " + this.length + Position.IN_RANGE);
            }
            if (!this.alphabets.checkConsistency(differentiableStatisticalModelArr[i2].getAlphabetContainer())) {
                throw new IllegalArgumentException("The AlphabetContainer of component " + i2 + " is not suitable.");
            }
        }
        computeLogGammaSum();
        init();
    }

    private void init() {
        this.motifsRef = new int[this.function.length + 1];
        for (int i = 0; i < this.function.length; i++) {
            this.motifsRef[i + 1] = this.motifsRef[i];
            if (this.function[i] instanceof MotifDiscoverer) {
                int[] iArr = this.motifsRef;
                int i2 = i + 1;
                iArr[i2] = iArr[i2] + ((MotifDiscoverer) this.function[i]).getNumberOfMotifs();
            }
        }
    }

    public MixtureDiffSM(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
        init();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM, de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel, de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore, de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore, de.jstacs.sequenceScores.SequenceScore
    /* renamed from: clone */
    public MixtureDiffSM mo60clone() throws CloneNotSupportedException {
        MixtureDiffSM mixtureDiffSM = (MixtureDiffSM) super.mo60clone();
        mixtureDiffSM.init();
        return mixtureDiffSM;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM
    protected double getLogNormalizationConstantForComponent(int i) {
        return this.function[i].getLogNormalizationConstant();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogPartialNormalizationConstant(int i) throws Exception {
        if (isNormalized()) {
            return Double.NEGATIVE_INFINITY;
        }
        if (Double.isNaN(this.norm)) {
            precomputeNorm();
        }
        int[] indices = getIndices(i);
        return indices[0] == this.function.length ? this.partNorm[indices[1]] : this.logHiddenPotential[indices[0]] + this.function[indices[0]].getLogPartialNormalizationConstant(indices[1]);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM
    public double getHyperparameterForHiddenParameter(int i) {
        return this.function[i].getESS();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getESS() {
        double d = 0.0d;
        for (int i = 0; i < this.function.length; i++) {
            d += this.function[i].getESS();
        }
        return d;
    }

    private double[][] getRandomWeights(double[] dArr, int i) {
        Arrays.fill(this.hiddenParameter, 0.0d);
        double[][] dArr2 = new double[this.function.length][i];
        double[] dArr3 = new double[getNumberOfComponents()];
        if (getESS() == 0.0d) {
            Arrays.fill(dArr3, 1.0d);
        } else {
            for (int i2 = 0; i2 < dArr3.length; i2++) {
                dArr3[i2] = getHyperparameterForHiddenParameter(i2);
            }
        }
        DirichletMRGParams dirichletMRGParams = new DirichletMRGParams(dArr3);
        double[] dArr4 = new double[dArr3.length];
        double d = 1.0d;
        for (int i3 = 0; i3 < dArr2[0].length; i3++) {
            DirichletMRG.DEFAULT_INSTANCE.generate(dArr4, 0, dArr4.length, dirichletMRGParams);
            if (dArr != null) {
                d = dArr[i3];
            }
            for (int i4 = 0; i4 < dArr4.length; i4++) {
                dArr2[i4][i3] = d * dArr4[i4];
                double[] dArr5 = this.hiddenParameter;
                int i5 = i4;
                dArr5[i5] = dArr5[i5] + dArr2[i4][i3];
            }
        }
        computeHiddenParameter(this.hiddenParameter, true);
        return dArr2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v29, types: [double[]] */
    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM
    protected void initializeUsingPlugIn(int i, boolean z, DataSet[] dataSetArr, double[][] dArr) throws Exception {
        if (dArr == null) {
            dArr = new double[dataSetArr.length];
        }
        double[] dArr2 = dArr[i];
        double[][] dArr3 = (double[][]) null;
        for (int i2 = 0; i2 < 3; i2++) {
            if (i2 == 0) {
                dArr3 = getRandomWeights(dArr2, dataSetArr[i].getNumberOfElements());
            } else {
                for (int i3 = 0; i3 < dataSetArr[i].getNumberOfElements(); i3++) {
                    fillComponentScores(dataSetArr[i].getElementAt(i3), 0);
                    Normalisation.logSumNormalisation(this.componentScore);
                    for (int i4 = 0; i4 < this.function.length; i4++) {
                        dArr3[i4][i3] = this.componentScore[i4] * dArr2[i3];
                    }
                }
            }
            for (int i5 = 0; i5 < this.function.length; i5++) {
                dArr[i] = dArr3[i5];
                this.function[i5].initializeFunction(i, z, dataSetArr, dArr);
            }
        }
        dArr[i] = dArr2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v19, types: [double[]] */
    @Override // de.jstacs.motifDiscovery.MutableMotifDiscoverer
    public void adjustHiddenParameters(int i, DataSet[] dataSetArr, double[][] dArr) throws Exception {
        if (dArr == null) {
            dArr = new double[dataSetArr.length];
        }
        double[][] randomWeights = getRandomWeights(dArr[i], dataSetArr[i].getNumberOfElements());
        double[] dArr2 = dArr[i];
        for (int i2 = 0; i2 < this.function.length; i2++) {
            dArr[i] = randomWeights[i2];
            if (this.function[i2] instanceof MutableMotifDiscoverer) {
                ((MutableMotifDiscoverer) this.function[i2]).adjustHiddenParameters(i, dataSetArr, dArr);
            }
        }
        dArr[i] = dArr2;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        String str = "mixture(" + this.function[0].getInstanceName();
        for (int i = 1; i < this.function.length; i++) {
            str = String.valueOf(str) + ", " + this.function[i].getInstanceName();
        }
        return String.valueOf(str) + ")";
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM
    protected void fillComponentScores(Sequence sequence, int i) {
        for (int i2 = 0; i2 < this.function.length; i2++) {
            if (!(this.function[i2] instanceof VariableLengthDiffSM)) {
                this.componentScore[i2] = this.logHiddenPotential[i2] + this.function[i2].getLogScoreFor(sequence, i);
            } else if (this.length != 0) {
                this.componentScore[i2] = this.logHiddenPotential[i2] + ((VariableLengthDiffSM) this.function[i2]).getLogScoreFor(sequence, i, (i + this.length) - 1);
            } else {
                this.componentScore[i2] = this.logHiddenPotential[i2] + ((VariableLengthDiffSM) this.function[i2]).getLogScoreFor(sequence, i, sequence.getLength() - 1);
            }
        }
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double getLogScoreAndPartialDerivation(Sequence sequence, int i, IntList intList, DoubleList doubleList) {
        int length = this.paramRef.length - 1;
        int i2 = this.paramRef[length] - this.paramRef[length - 1];
        for (int i3 = 0; i3 < this.function.length; i3++) {
            this.iList[i3].clear();
            this.dList[i3].clear();
            if (!(this.function[i3] instanceof VariableLengthDiffSM)) {
                this.componentScore[i3] = this.logHiddenPotential[i3] + this.function[i3].getLogScoreAndPartialDerivation(sequence, i, this.iList[i3], this.dList[i3]);
            } else if (this.length != 0) {
                this.componentScore[i3] = this.logHiddenPotential[i3] + ((VariableLengthDiffSM) this.function[i3]).getLogScoreAndPartialDerivation(sequence, i, (i + this.length) - 1, this.iList[i3], this.dList[i3]);
            } else {
                this.componentScore[i3] = this.logHiddenPotential[i3] + ((VariableLengthDiffSM) this.function[i3]).getLogScoreAndPartialDerivation(sequence, i, sequence.getLength() - 1, this.iList[i3], this.dList[i3]);
            }
        }
        double logSumNormalisation = Normalisation.logSumNormalisation(this.componentScore, 0, this.function.length, this.componentScore, 0);
        int i4 = 0;
        while (i4 < this.function.length) {
            for (int i5 = 0; i5 < this.iList[i4].length(); i5++) {
                intList.add(this.paramRef[i4] + this.iList[i4].get(i5));
                doubleList.add(this.componentScore[i4] * this.dList[i4].get(i5));
            }
            i4++;
        }
        for (int i6 = 0; i6 < i2; i6++) {
            intList.add(this.paramRef[i4] + i6);
            doubleList.add(this.componentScore[i6] - (isNormalized() ? this.hiddenPotential[i6] : 0.0d));
        }
        return logSumNormalisation;
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String toString(NumberFormat numberFormat) {
        if (Double.isNaN(this.norm)) {
            precomputeNorm();
        }
        StringBuffer stringBuffer = new StringBuffer(this.function.length * 1000);
        for (int i = 0; i < this.function.length; i++) {
            stringBuffer.append("p(" + i + ") = " + numberFormat.format(isNormalized() ? this.hiddenPotential[i] : Math.exp(this.partNorm[i] - this.norm)) + "\n" + this.function[i].toString(numberFormat) + "\n");
        }
        return stringBuffer.toString();
    }

    private int getComponentFor(int i) {
        int i2 = 0;
        while (i >= this.motifsRef[i2]) {
            i2++;
        }
        return i2 - 1;
    }

    @Override // de.jstacs.motifDiscovery.MutableMotifDiscoverer
    public void initializeMotif(int i, DataSet dataSet, double[] dArr) throws Exception {
        int componentFor = getComponentFor(i);
        if (this.function[componentFor] instanceof MutableMotifDiscoverer) {
            ((MutableMotifDiscoverer) this.function[componentFor]).initializeMotif(i - this.motifsRef[componentFor], dataSet, dArr);
        } else {
            System.out.println("WARNING: Not possible!");
        }
    }

    @Override // de.jstacs.motifDiscovery.MutableMotifDiscoverer
    public void initializeMotifRandomly(int i) throws Exception {
        int componentFor = getComponentFor(i);
        if (this.function[componentFor] instanceof MutableMotifDiscoverer) {
            ((MutableMotifDiscoverer) this.function[componentFor]).initializeMotifRandomly(i - this.motifsRef[componentFor]);
        } else {
            System.out.println("WARNING: Not possible!");
        }
    }

    @Override // de.jstacs.motifDiscovery.MutableMotifDiscoverer
    public boolean modifyMotif(int i, int i2, int i3) throws Exception {
        int componentFor = getComponentFor(i);
        if (!(this.function[componentFor] instanceof MutableMotifDiscoverer)) {
            return false;
        }
        boolean modifyMotif = ((MutableMotifDiscoverer) this.function[componentFor]).modifyMotif(i - this.motifsRef[componentFor], i2, i3);
        if (modifyMotif) {
            init(this.freeParams);
        }
        return modifyMotif;
    }

    @Override // de.jstacs.motifDiscovery.MotifDiscoverer
    public int getGlobalIndexOfMotifInComponent(int i, int i2) {
        int i3 = this.motifsRef[i] + i2;
        if (i3 >= this.motifsRef[i + 1]) {
            throw new IndexOutOfBoundsException("Component " + i + " has only " + (this.motifsRef[i + 1] - this.motifsRef[i]) + " motifs.");
        }
        return i3;
    }

    @Override // de.jstacs.motifDiscovery.MotifDiscoverer
    public int getIndexOfMaximalComponentFor(Sequence sequence) throws Exception {
        return getIndexOfMaximalComponentFor(sequence, 0);
    }

    @Override // de.jstacs.motifDiscovery.MotifDiscoverer
    public int getMotifLength(int i) {
        int componentFor = getComponentFor(i);
        return ((MotifDiscoverer) this.function[componentFor]).getMotifLength(i - this.motifsRef[componentFor]);
    }

    @Override // de.jstacs.motifDiscovery.MotifDiscoverer
    public int getNumberOfMotifs() {
        return this.motifsRef[this.function.length];
    }

    @Override // de.jstacs.motifDiscovery.MotifDiscoverer
    public int getNumberOfMotifsInComponent(int i) {
        return this.motifsRef[i + 1] - this.motifsRef[i];
    }

    @Override // de.jstacs.motifDiscovery.MotifDiscoverer
    public double[] getProfileOfScoresFor(int i, int i2, Sequence sequence, int i3, MotifDiscoverer.KindOfProfile kindOfProfile) throws Exception {
        if (kindOfProfile != MotifDiscoverer.KindOfProfile.UNNORMALIZED_JOINT) {
            throw new OperationNotSupportedException("Currently it is only allowed to used KindOfProfile.UNNORMALIZED_JOINT");
        }
        if (!(this.function[i] instanceof MotifDiscoverer)) {
            throw new IllegalArgumentException();
        }
        MotifDiscoverer motifDiscoverer = (MotifDiscoverer) this.function[i];
        double[] dArr = (double[]) null;
        int numberOfComponents = motifDiscoverer.getNumberOfComponents();
        for (int i4 = 0; i4 < numberOfComponents; i4++) {
            int numberOfMotifsInComponent = motifDiscoverer.getNumberOfMotifsInComponent(i4);
            int i5 = 0;
            while (i5 < numberOfMotifsInComponent && motifDiscoverer.getGlobalIndexOfMotifInComponent(i4, i5) != i2) {
                i5++;
            }
            if (i5 < numberOfMotifsInComponent) {
                double[] profileOfScoresFor = motifDiscoverer.getProfileOfScoresFor(i4, i2, sequence, i3, kindOfProfile);
                if (dArr == null) {
                    dArr = profileOfScoresFor;
                } else {
                    for (int i6 = 0; i6 < dArr.length; i6++) {
                        dArr[i6] = Normalisation.getLogSum(dArr[i6], profileOfScoresFor[i6]);
                    }
                }
            }
        }
        if (dArr == null) {
            throw new IllegalArgumentException();
        }
        return dArr;
    }

    @Override // de.jstacs.motifDiscovery.MotifDiscoverer
    public double[] getStrandProbabilitiesFor(int i, int i2, Sequence sequence, int i3) throws Exception {
        return new double[]{0.5d, 0.5d};
    }
}
