package de.jstacs.sequenceScores.statisticalModels.differentiable.mixture;

import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.alphabets.ComplementableDiscreteAlphabet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.StrandedLocatedSequenceAnnotationWithLength;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.motifDiscovery.Mutable;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.NormalizedDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import java.text.NumberFormat;
import java.util.Arrays;

/* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/differentiable/mixture/StrandDiffSM.class */
public class StrandDiffSM extends AbstractMixtureDiffSM implements Mutable {
    private InitMethod initMethod;
    private double forwardPartOfESS;
    private static /* synthetic */ int[] $SWITCH_TABLE$de$jstacs$sequenceScores$statisticalModels$differentiable$mixture$StrandDiffSM$InitMethod;

    /* loaded from: input_file:de/jstacs/sequenceScores/statisticalModels/differentiable/mixture/StrandDiffSM$InitMethod.class */
    public enum InitMethod {
        INIT_FORWARD_STRAND,
        INIT_BACKWARD_STRAND,
        INIT_BOTH_STRANDS;

        /* renamed from: values, reason: to resolve conflict with enum method */
        public static InitMethod[] valuesCustom() {
            InitMethod[] valuesCustom = values();
            int length = valuesCustom.length;
            InitMethod[] initMethodArr = new InitMethod[length];
            System.arraycopy(valuesCustom, 0, initMethodArr, 0, length);
            return initMethodArr;
        }
    }

    private StrandDiffSM(DifferentiableStatisticalModel differentiableStatisticalModel, int i, boolean z, boolean z2, double d, InitMethod initMethod) throws CloneNotSupportedException, WrongAlphabetException {
        super(differentiableStatisticalModel.getLength(), i, 2, z, z2, differentiableStatisticalModel);
        if (!differentiableStatisticalModel.getAlphabetContainer().isReverseComplementable()) {
            throw new WrongAlphabetException("The given AlphabetContainer can not be used for building a reverse complement.");
        }
        if (d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("The part of the ESS for the forward strand has to be in [0,1].");
        }
        this.forwardPartOfESS = d;
        this.initMethod = initMethod;
        computeLogGammaSum();
    }

    public StrandDiffSM(DifferentiableStatisticalModel differentiableStatisticalModel, double d, int i, boolean z, InitMethod initMethod) throws CloneNotSupportedException, WrongAlphabetException {
        this(differentiableStatisticalModel, i, true, z, d, initMethod);
    }

    public StrandDiffSM(DifferentiableStatisticalModel differentiableStatisticalModel, int i, boolean z, InitMethod initMethod, double d) throws CloneNotSupportedException, WrongAlphabetException {
        this(differentiableStatisticalModel, i, false, z, d, initMethod);
        if (d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("The value for forward is no probability.");
        }
        setForwardProb(d);
    }

    protected void setForwardProb(double d) {
        this.hiddenParameter[0] = Math.log(d);
        this.hiddenParameter[1] = Math.log(1.0d - d);
        setHiddenParameters(this.hiddenParameter, 0);
    }

    public StrandDiffSM(StringBuffer stringBuffer) throws NonParsableException {
        super(stringBuffer);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM
    protected double getLogNormalizationConstantForComponent(int i) {
        return this.function[0].getLogNormalizationConstant();
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getLogPartialNormalizationConstant(int i) throws Exception {
        if (isNormalized()) {
            return Double.NEGATIVE_INFINITY;
        }
        if (Double.isNaN(this.norm)) {
            precomputeNorm();
        }
        int[] indices = getIndices(i);
        return indices[0] == 1 ? this.logHiddenPotential[indices[1]] + this.function[0].getLogNormalizationConstant() : Normalisation.getLogSum(this.logHiddenPotential) + this.function[indices[0]].getLogPartialNormalizationConstant(indices[1]);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM
    public double getHyperparameterForHiddenParameter(int i) {
        switch (i) {
            case 0:
                return this.forwardPartOfESS * this.function[0].getESS();
            case 1:
                return (1.0d - this.forwardPartOfESS) * this.function[0].getESS();
            default:
                throw new IndexOutOfBoundsException();
        }
    }

    public double getForwardProbability() {
        return this.hiddenPotential[0] / (this.hiddenPotential[0] + this.hiddenPotential[1]);
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel
    public double getESS() {
        return this.function[0].getESS();
    }

    /* JADX WARN: Can't fix incorrect switch cases order, some code will duplicate */
    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM
    protected void initializeUsingPlugIn(int i, boolean z, DataSet[] dataSetArr, double[][] dArr) throws Exception {
        double d;
        DataSet dataSet = dataSetArr[i];
        double[] dArr2 = new double[2];
        switch ($SWITCH_TABLE$de$jstacs$sequenceScores$statisticalModels$differentiable$mixture$StrandDiffSM$InitMethod()[this.initMethod.ordinal()]) {
            case 2:
                Sequence[] sequenceArr = new Sequence[dataSet.getNumberOfElements()];
                for (int i2 = 0; i2 < sequenceArr.length; i2++) {
                    sequenceArr[i2] = dataSet.getElementAt(i2).reverseComplement();
                }
                dataSetArr[i] = new DataSet("backward strand", sequenceArr);
                dArr2[0] = this.forwardPartOfESS;
                dArr2[1] = 1.0d - this.forwardPartOfESS;
                break;
            case 3:
                if (this.optimizeHidden) {
                    double[] dArr3 = new double[2];
                    if (getESS() == 0.0d) {
                        dArr3[1] = 1.0d;
                        dArr3[0] = 1.0d;
                    } else {
                        dArr3[0] = getHyperparameterForHiddenParameter(0);
                        dArr3[1] = getHyperparameterForHiddenParameter(1);
                    }
                    d = DirichletMRG.DEFAULT_INSTANCE.generate(2, new DirichletMRGParams(dArr3))[0];
                } else {
                    d = this.hiddenPotential[0] / (this.hiddenPotential[0] + this.hiddenPotential[1]);
                }
                if (dataSet != null) {
                    Sequence[] sequenceArr2 = new Sequence[dataSet.getNumberOfElements()];
                    double d2 = 1.0d;
                    for (int i3 = 0; i3 < sequenceArr2.length; i3++) {
                        if (dArr != null && dArr[i] != null) {
                            d2 = dArr[i][i3];
                        }
                        if (r.nextDouble() < d) {
                            dArr2[0] = dArr2[0] + d2;
                            sequenceArr2[i3] = dataSet.getElementAt(i3);
                        } else {
                            dArr2[1] = dArr2[1] + d2;
                            sequenceArr2[i3] = dataSet.getElementAt(i3).reverseComplement();
                        }
                    }
                    dataSetArr[i] = new DataSet("randomly strand scrambled", sequenceArr2);
                    this.function[0].initializeFunction(i, z, dataSetArr, dArr);
                    if (this.optimizeHidden) {
                        computeHiddenParameter(dArr2, true);
                    }
                    Arrays.fill(dArr2, 0.0d);
                    for (int i4 = 0; i4 < sequenceArr2.length; i4++) {
                        if (dArr != null && dArr[i] != null) {
                            d2 = dArr[i][i4];
                        }
                        if (getIndexOfMaximalComponentFor(dataSet.getElementAt(i4), 0) == 0) {
                            dArr2[0] = dArr2[0] + d2;
                            sequenceArr2[i4] = dataSet.getElementAt(i4);
                        } else {
                            dArr2[1] = dArr2[1] + d2;
                            sequenceArr2[i4] = dataSet.getElementAt(i4).reverseComplement();
                        }
                    }
                    dataSetArr[i] = new DataSet("strand scrambled", sequenceArr2);
                    break;
                } else {
                    dataSetArr[i] = null;
                    break;
                }
                break;
            default:
                dArr2[0] = this.forwardPartOfESS;
                dArr2[1] = 1.0d - this.forwardPartOfESS;
                break;
        }
        this.function[0].initializeFunction(i, z, dataSetArr, dArr);
        dataSetArr[i] = dataSet;
        if (this.optimizeHidden) {
            computeHiddenParameter(dArr2, true);
        }
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String getInstanceName() {
        String str = "strand-mixture(" + this.function[0].getInstanceName();
        if (!this.optimizeHidden) {
            str = String.valueOf(str) + ", " + Arrays.toString(this.hiddenPotential);
        }
        return String.valueOf(str) + ")";
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM
    protected void fillComponentScores(Sequence sequence, int i) {
        this.componentScore[0] = this.logHiddenPotential[0] + this.function[0].getLogScoreFor(sequence, i);
        try {
            if (this.length != 0) {
                this.componentScore[1] = this.logHiddenPotential[1] + this.function[0].getLogScoreFor(sequence.reverseComplement(), (sequence.getLength() - i) - this.length);
            } else {
                if (i != 0) {
                    throw new Exception("strand scoring for variable length function");
                }
                this.componentScore[1] = this.logHiddenPotential[1] + this.function[0].getLogScoreFor(sequence.reverseComplement(), 0);
            }
        } catch (Exception e) {
            RuntimeException runtimeException = new RuntimeException(String.valueOf(e.getClass().getName()) + ": " + e.getMessage());
            runtimeException.setStackTrace(e.getStackTrace());
            throw runtimeException;
        }
    }

    @Override // de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore
    public double getLogScoreAndPartialDerivation(Sequence sequence, int i, IntList intList, DoubleList doubleList) {
        this.iList[0].clear();
        this.dList[0].clear();
        this.componentScore[0] = this.logHiddenPotential[0] + this.function[0].getLogScoreAndPartialDerivation(sequence, i, this.iList[0], this.dList[0]);
        this.iList[1].clear();
        this.dList[1].clear();
        try {
            if (this.length != 0) {
                this.componentScore[1] = this.logHiddenPotential[1] + this.function[0].getLogScoreAndPartialDerivation(sequence.reverseComplement(), (sequence.getLength() - i) - this.length, this.iList[1], this.dList[1]);
            } else {
                if (i != 0) {
                    throw new Exception("strand scoring for variable length function");
                }
                this.componentScore[1] = this.logHiddenPotential[1] + this.function[0].getLogScoreAndPartialDerivation(sequence.reverseComplement(), 0, this.iList[1], this.dList[1]);
            }
            double logSumNormalisation = Normalisation.logSumNormalisation(this.componentScore, 0, 2, this.componentScore, 0);
            for (int i2 = 0; i2 < this.logHiddenPotential.length; i2++) {
                for (int i3 = 0; i3 < this.iList[i2].length(); i3++) {
                    intList.add(this.iList[i2].get(i3));
                    doubleList.add(this.componentScore[i2] * this.dList[i2].get(i3));
                }
            }
            int i4 = this.paramRef[2] - this.paramRef[1];
            for (int i5 = 0; i5 < i4; i5++) {
                intList.add(this.paramRef[1] + i5);
                doubleList.add(this.componentScore[i5] - (isNormalized() ? this.hiddenPotential[i5] : 0.0d));
            }
            return logSumNormalisation;
        } catch (Exception e) {
            RuntimeException runtimeException = new RuntimeException(String.valueOf(e.getClass().getName()) + ": " + e.getMessage());
            runtimeException.setStackTrace(e.getStackTrace());
            throw runtimeException;
        }
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM
    protected StringBuffer getFurtherInformation() {
        StringBuffer stringBuffer = new StringBuffer(100);
        XMLParser.appendObjectWithTags(stringBuffer, Double.valueOf(this.forwardPartOfESS), "forwardPartOfESS");
        XMLParser.appendObjectWithTags(stringBuffer, this.initMethod, "initMethod");
        return stringBuffer;
    }

    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM
    protected void extractFurtherInformation(StringBuffer stringBuffer) throws NonParsableException {
        this.forwardPartOfESS = ((Double) XMLParser.extractObjectForTags(stringBuffer, "forwardPartOfESS", Double.TYPE)).doubleValue();
        this.initMethod = (InitMethod) XMLParser.extractObjectForTags(stringBuffer, "initMethod", InitMethod.class);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM
    public void init(boolean z) {
        super.init(z);
    }

    @Override // de.jstacs.sequenceScores.SequenceScore
    public String toString(NumberFormat numberFormat) {
        StringBuffer stringBuffer = new StringBuffer(1500);
        double d = this.hiddenPotential[0] + this.hiddenPotential[1];
        stringBuffer.append("forward: " + numberFormat.format(this.hiddenPotential[0] / d) + "\n");
        stringBuffer.append("reverse: " + numberFormat.format(this.hiddenPotential[1] / d) + "\n\n");
        stringBuffer.append(this.function[0].toString(numberFormat));
        return stringBuffer.toString();
    }

    @Override // de.jstacs.motifDiscovery.Mutable
    public boolean modify(int i, int i2) {
        boolean z = false;
        if (this.function[0] instanceof Mutable) {
            z = ((Mutable) this.function[0]).modify(i, i2);
            if (z) {
                this.length = this.function[0].getLength();
                init(this.freeParams);
                this.norm = Double.NaN;
            }
        }
        return z;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v11, types: [double[][], double[][][]] */
    public static double[][][] getReverseComplementDistributions(ComplementableDiscreteAlphabet complementableDiscreteAlphabet, double[][][] dArr) {
        int length = (int) complementableDiscreteAlphabet.length();
        int length2 = dArr.length;
        int length3 = dArr[length2 - 1].length * length;
        ?? r0 = new double[length2];
        for (int i = 0; i < length2; i++) {
            r0[i] = new double[dArr[i].length][length];
        }
        int[] iArr = new int[length2];
        for (int i2 = 0; i2 < length3; i2++) {
            int i3 = i2;
            for (int i4 = length2 - 1; i4 >= 0; i4--) {
                iArr[i4] = i3 % length;
                i3 /= length;
            }
            int i5 = 0;
            double d = 1.0d;
            for (int i6 = 0; i6 < length2; i6++) {
                d *= dArr[i6][i5][iArr[i6]];
                i5 = (i5 * length) + iArr[i6];
            }
            for (int i7 = 0; i7 < length2; i7++) {
                iArr[i7] = complementableDiscreteAlphabet.getComplementaryCode(iArr[i7]);
            }
            int i8 = 0;
            for (int i9 = 0; i9 < length2; i9++) {
                double[] dArr2 = r0[i9][i8];
                int i10 = iArr[(length2 - 1) - i9];
                dArr2[i10] = dArr2[i10] + d;
                i8 = (i8 * length) + iArr[(length2 - 1) - i9];
            }
        }
        for (int i11 = 1; i11 < length2; i11++) {
            for (int i12 = 0; i12 < r0[i11].length; i12++) {
                double d2 = 0.0d;
                for (int i13 = 0; i13 < length; i13++) {
                    d2 += r0[i11][i12][i13];
                }
                for (int i14 = 0; i14 < length; i14++) {
                    double[] dArr3 = r0[i11][i12];
                    int i15 = i14;
                    dArr3[i15] = dArr3[i15] / d2;
                }
            }
        }
        return r0;
    }

    public StrandedLocatedSequenceAnnotationWithLength.Strand getStrand(Sequence sequence, int i) {
        return getIndexOfMaximalComponentFor(sequence, i) == 0 ? StrandedLocatedSequenceAnnotationWithLength.Strand.FORWARD : StrandedLocatedSequenceAnnotationWithLength.Strand.REVERSE;
    }

    public static boolean isStrandModel(DifferentiableStatisticalModel differentiableStatisticalModel) {
        return differentiableStatisticalModel instanceof NormalizedDiffSM ? ((NormalizedDiffSM) differentiableStatisticalModel).isStrandModel() : differentiableStatisticalModel instanceof StrandDiffSM;
    }

    static /* synthetic */ int[] $SWITCH_TABLE$de$jstacs$sequenceScores$statisticalModels$differentiable$mixture$StrandDiffSM$InitMethod() {
        int[] iArr = $SWITCH_TABLE$de$jstacs$sequenceScores$statisticalModels$differentiable$mixture$StrandDiffSM$InitMethod;
        if (iArr != null) {
            return iArr;
        }
        int[] iArr2 = new int[InitMethod.valuesCustom().length];
        try {
            iArr2[InitMethod.INIT_BACKWARD_STRAND.ordinal()] = 2;
        } catch (NoSuchFieldError unused) {
        }
        try {
            iArr2[InitMethod.INIT_BOTH_STRANDS.ordinal()] = 3;
        } catch (NoSuchFieldError unused2) {
        }
        try {
            iArr2[InitMethod.INIT_FORWARD_STRAND.ordinal()] = 1;
        } catch (NoSuchFieldError unused3) {
        }
        $SWITCH_TABLE$de$jstacs$sequenceScores$statisticalModels$differentiable$mixture$StrandDiffSM$InitMethod = iArr2;
        return iArr2;
    }
}
