/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.differentiable.mixture;

import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.alphabets.ComplementableDiscreteAlphabet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.StrandedLocatedSequenceAnnotationWithLength;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.motifDiscovery.Mutable;
import de.jstacs.sequenceScores.statisticalModels.differentiable.DifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.NormalizedDiffSM;
import de.jstacs.sequenceScores.statisticalModels.differentiable.mixture.AbstractMixtureDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import java.text.NumberFormat;
import java.util.Arrays;

public class StrandDiffSM
extends AbstractMixtureDiffSM
implements Mutable {
    private InitMethod initMethod;
    private double forwardPartOfESS;

    private StrandDiffSM(DifferentiableStatisticalModel function, int starts, boolean optimizeHidden, boolean plugIn, double forwardPartOfESS, InitMethod initMethod) throws CloneNotSupportedException, WrongAlphabetException {
        super(function.getLength(), starts, 2, optimizeHidden, plugIn, function);
        if (!function.getAlphabetContainer().isReverseComplementable()) {
            throw new WrongAlphabetException("The given AlphabetContainer can not be used for building a reverse complement.");
        }
        if (forwardPartOfESS < 0.0 || forwardPartOfESS > 1.0) {
            throw new IllegalArgumentException("The part of the ESS for the forward strand has to be in [0,1].");
        }
        this.forwardPartOfESS = forwardPartOfESS;
        this.initMethod = initMethod;
        this.computeLogGammaSum();
    }

    public StrandDiffSM(DifferentiableStatisticalModel function, double forwardPartOfESS, int starts, boolean plugIn, InitMethod initMethod) throws CloneNotSupportedException, WrongAlphabetException {
        this(function, starts, true, plugIn, forwardPartOfESS, initMethod);
    }

    public StrandDiffSM(DifferentiableStatisticalModel function, int starts, boolean plugIn, InitMethod initMethod, double forward) throws CloneNotSupportedException, WrongAlphabetException {
        this(function, starts, false, plugIn, forward, initMethod);
        if (forward < 0.0 || forward > 1.0) {
            throw new IllegalArgumentException("The value for forward is no probability.");
        }
        this.setForwardProb(forward);
    }

    protected void setForwardProb(double forward) {
        this.hiddenParameter[0] = Math.log(forward);
        this.hiddenParameter[1] = Math.log(1.0 - forward);
        this.setHiddenParameters(this.hiddenParameter, 0);
    }

    public StrandDiffSM(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    protected double getLogNormalizationConstantForComponent(int i) {
        return this.function[0].getLogNormalizationConstant();
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        int[] ind;
        if (this.isNormalized()) {
            return Double.NEGATIVE_INFINITY;
        }
        if (Double.isNaN(this.norm)) {
            this.precomputeNorm();
        }
        if ((ind = this.getIndices(parameterIndex))[0] == 1) {
            return this.logHiddenPotential[ind[1]] + this.function[0].getLogNormalizationConstant();
        }
        return Normalisation.getLogSum(this.logHiddenPotential) + this.function[ind[0]].getLogPartialNormalizationConstant(ind[1]);
    }

    @Override
    public double getHyperparameterForHiddenParameter(int index) {
        switch (index) {
            case 0: {
                return this.forwardPartOfESS * this.function[0].getESS();
            }
            case 1: {
                return (1.0 - this.forwardPartOfESS) * this.function[0].getESS();
            }
        }
        throw new IndexOutOfBoundsException();
    }

    public double getForwardProbability() {
        double d = this.hiddenPotential[0] + this.hiddenPotential[1];
        return this.hiddenPotential[0] / d;
    }

    @Override
    public double getESS() {
        return this.function[0].getESS();
    }

    @Override
    protected void initializeUsingPlugIn(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        DataSet myData = data[index];
        double[] stat = new double[2];
        switch (this.initMethod) {
            case INIT_BOTH_STRANDS: {
                double p;
                if (this.optimizeHidden) {
                    double[] h = new double[2];
                    if (this.getESS() == 0.0) {
                        h[1] = 1.0;
                        h[0] = 1.0;
                    } else {
                        h[0] = this.getHyperparameterForHiddenParameter(0);
                        h[1] = this.getHyperparameterForHiddenParameter(1);
                    }
                    p = DirichletMRG.DEFAULT_INSTANCE.generate(2, new DirichletMRGParams(h))[0];
                } else {
                    p = this.hiddenPotential[0] / (this.hiddenPotential[0] + this.hiddenPotential[1]);
                }
                if (myData != null) {
                    Sequence[] seqs = new Sequence[myData.getNumberOfElements()];
                    double w = 1.0;
                    int i = 0;
                    while (i < seqs.length) {
                        if (weights != null && weights[index] != null) {
                            w = weights[index][i];
                        }
                        if (r.nextDouble() < p) {
                            stat[0] = stat[0] + w;
                            seqs[i] = myData.getElementAt(i);
                        } else {
                            stat[1] = stat[1] + w;
                            seqs[i] = myData.getElementAt(i).reverseComplement();
                        }
                        ++i;
                    }
                    data[index] = new DataSet("randomly strand scrambled", seqs);
                    this.function[0].initializeFunction(index, freeParams, data, weights);
                    if (this.optimizeHidden) {
                        this.computeHiddenParameter(stat, true);
                    }
                    Arrays.fill(stat, 0.0);
                    i = 0;
                    while (i < seqs.length) {
                        if (weights != null && weights[index] != null) {
                            w = weights[index][i];
                        }
                        if (this.getIndexOfMaximalComponentFor(myData.getElementAt(i), 0) == 0) {
                            stat[0] = stat[0] + w;
                            seqs[i] = myData.getElementAt(i);
                        } else {
                            stat[1] = stat[1] + w;
                            seqs[i] = myData.getElementAt(i).reverseComplement();
                        }
                        ++i;
                    }
                    data[index] = new DataSet("strand scrambled", seqs);
                    break;
                }
                data[index] = null;
                break;
            }
            case INIT_BACKWARD_STRAND: {
                Sequence[] rcs = new Sequence[myData.getNumberOfElements()];
                int i = 0;
                while (i < rcs.length) {
                    rcs[i] = myData.getElementAt(i).reverseComplement();
                    ++i;
                }
                data[index] = new DataSet("backward strand", rcs);
            }
            default: {
                stat[0] = this.forwardPartOfESS;
                stat[1] = 1.0 - this.forwardPartOfESS;
            }
        }
        this.function[0].initializeFunction(index, freeParams, data, weights);
        data[index] = myData;
        if (this.optimizeHidden) {
            this.computeHiddenParameter(stat, true);
        }
    }

    @Override
    public String getInstanceName() {
        String erg = "strand-mixture(" + this.function[0].getInstanceName();
        if (!this.optimizeHidden) {
            erg = String.valueOf(erg) + ", " + Arrays.toString(this.hiddenPotential);
        }
        return String.valueOf(erg) + ")";
    }

    @Override
    protected void fillComponentScores(Sequence seq, int start) {
        block4: {
            this.componentScore[0] = this.logHiddenPotential[0] + this.function[0].getLogScoreFor(seq, start);
            try {
                if (this.length != 0) {
                    this.componentScore[1] = this.logHiddenPotential[1] + this.function[0].getLogScoreFor(seq.reverseComplement(), seq.getLength() - start - this.length);
                    break block4;
                }
                if (start == 0) {
                    this.componentScore[1] = this.logHiddenPotential[1] + this.function[0].getLogScoreFor(seq.reverseComplement(), 0);
                    break block4;
                }
                throw new Exception("strand scoring for variable length function");
            }
            catch (Exception doesNotHappen) {
                RuntimeException r = new RuntimeException(String.valueOf(doesNotHappen.getClass().getName()) + ": " + doesNotHappen.getMessage());
                r.setStackTrace(doesNotHappen.getStackTrace());
                throw r;
            }
        }
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        int j;
        block7: {
            this.iList[0].clear();
            this.dList[0].clear();
            this.componentScore[0] = this.logHiddenPotential[0] + this.function[0].getLogScoreAndPartialDerivation(seq, start, this.iList[0], this.dList[0]);
            this.iList[1].clear();
            this.dList[1].clear();
            try {
                if (this.length != 0) {
                    this.componentScore[1] = this.logHiddenPotential[1] + this.function[0].getLogScoreAndPartialDerivation(seq.reverseComplement(), seq.getLength() - start - this.length, this.iList[1], this.dList[1]);
                    break block7;
                }
                if (start == 0) {
                    this.componentScore[1] = this.logHiddenPotential[1] + this.function[0].getLogScoreAndPartialDerivation(seq.reverseComplement(), 0, this.iList[1], this.dList[1]);
                    break block7;
                }
                throw new Exception("strand scoring for variable length function");
            }
            catch (Exception doesNotHappen) {
                RuntimeException r = new RuntimeException(String.valueOf(doesNotHappen.getClass().getName()) + ": " + doesNotHappen.getMessage());
                r.setStackTrace(doesNotHappen.getStackTrace());
                throw r;
            }
        }
        double logScore = Normalisation.logSumNormalisation(this.componentScore, 0, 2, this.componentScore, 0);
        int i = 0;
        while (i < this.logHiddenPotential.length) {
            j = 0;
            while (j < this.iList[i].length()) {
                indices.add(this.iList[i].get(j));
                partialDer.add(this.componentScore[i] * this.dList[i].get(j));
                ++j;
            }
            ++i;
        }
        i = this.paramRef[2] - this.paramRef[1];
        j = 0;
        while (j < i) {
            indices.add(this.paramRef[1] + j);
            partialDer.add(this.componentScore[j] - (this.isNormalized() ? this.hiddenPotential[j] : 0.0));
            ++j;
        }
        return logScore;
    }

    @Override
    protected StringBuffer getFurtherInformation() {
        StringBuffer erg = new StringBuffer(100);
        XMLParser.appendObjectWithTags(erg, this.forwardPartOfESS, "forwardPartOfESS");
        XMLParser.appendObjectWithTags(erg, (Object)this.initMethod, "initMethod");
        return erg;
    }

    @Override
    protected void extractFurtherInformation(StringBuffer xml) throws NonParsableException {
        this.forwardPartOfESS = XMLParser.extractObjectForTags(xml, "forwardPartOfESS", Double.TYPE);
        this.initMethod = XMLParser.extractObjectForTags(xml, "initMethod", InitMethod.class);
    }

    @Override
    protected void init(boolean freeParams) {
        super.init(freeParams);
    }

    @Override
    public String toString(NumberFormat nf) {
        StringBuffer erg = new StringBuffer(1500);
        double d = this.hiddenPotential[0] + this.hiddenPotential[1];
        erg.append("forward: " + nf.format(this.hiddenPotential[0] / d) + "\n");
        erg.append("reverse: " + nf.format(this.hiddenPotential[1] / d) + "\n\n");
        erg.append(this.function[0].toString(nf));
        return erg.toString();
    }

    @Override
    public boolean modify(int offsetLeft, int offsetRight) {
        boolean res = false;
        if (this.function[0] instanceof Mutable && (res = ((Mutable)((Object)this.function[0])).modify(offsetLeft, offsetRight))) {
            this.length = this.function[0].getLength();
            this.init(this.freeParams);
            this.norm = Double.NaN;
        }
        return res;
    }

    public static double[][][] getReverseComplementDistributions(ComplementableDiscreteAlphabet abc, double[][][] condDistr) {
        double joint;
        int h;
        int i;
        int l = (int)abc.length();
        int idx = 0;
        int ord = condDistr.length;
        int anz = condDistr[ord - 1].length * l;
        double[][][] result = new double[ord][][];
        int o = 0;
        while (o < ord) {
            result[o] = new double[condDistr[o].length][l];
            ++o;
        }
        int[] assign = new int[ord];
        while (idx < anz) {
            i = idx;
            o = ord - 1;
            while (o >= 0) {
                assign[o] = i % l;
                i /= l;
                --o;
            }
            h = 0;
            joint = 1.0;
            o = 0;
            while (o < ord) {
                joint *= condDistr[o][h][assign[o]];
                h = h * l + assign[o];
                ++o;
            }
            o = 0;
            while (o < ord) {
                assign[o] = abc.getComplementaryCode(assign[o]);
                ++o;
            }
            h = 0;
            o = 0;
            while (o < ord) {
                double[] dArray = result[o][h];
                int n = assign[ord - 1 - o];
                dArray[n] = dArray[n] + joint;
                h = h * l + assign[ord - 1 - o];
                ++o;
            }
            ++idx;
        }
        o = 1;
        while (o < ord) {
            h = 0;
            while (h < result[o].length) {
                joint = 0.0;
                i = 0;
                while (i < l) {
                    joint += result[o][h][i];
                    ++i;
                }
                i = 0;
                while (i < l) {
                    double[] dArray = result[o][h];
                    int n = i++;
                    dArray[n] = dArray[n] / joint;
                }
                ++h;
            }
            ++o;
        }
        return result;
    }

    public StrandedLocatedSequenceAnnotationWithLength.Strand getStrand(Sequence seq, int startPos) {
        return this.getIndexOfMaximalComponentFor(seq, startPos) == 0 ? StrandedLocatedSequenceAnnotationWithLength.Strand.FORWARD : StrandedLocatedSequenceAnnotationWithLength.Strand.REVERSE;
    }

    public static boolean isStrandModel(DifferentiableStatisticalModel nsf) {
        if (nsf instanceof NormalizedDiffSM) {
            return ((NormalizedDiffSM)nsf).isStrandModel();
        }
        return nsf instanceof StrandDiffSM;
    }

    public static enum InitMethod {
        INIT_FORWARD_STRAND,
        INIT_BACKWARD_STRAND,
        INIT_BOTH_STRANDS;

    }
}

