/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.scoringFunctions.mix;

import de.jstacs.NonParsableException;
import de.jstacs.WrongAlphabetException;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.data.alphabets.ComplementableDiscreteAlphabet;
import de.jstacs.data.sequences.annotation.StrandedLocatedSequenceAnnotationWithLength;
import de.jstacs.io.XMLParser;
import de.jstacs.motifDiscovery.Mutable;
import de.jstacs.scoringFunctions.NormalizableScoringFunction;
import de.jstacs.scoringFunctions.mix.AbstractMixtureScoringFunction;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;
import java.util.Arrays;

public class StrandScoringFunction
extends AbstractMixtureScoringFunction
implements Mutable {
    private InitMethod initMethod;
    private double forwardPartOfESS;

    private StrandScoringFunction(NormalizableScoringFunction function, int starts, boolean optimizeHidden, boolean plugIn, double forwardPartOfESS, InitMethod initMethod) throws CloneNotSupportedException, WrongAlphabetException {
        super(function.getLength(), starts, 2, optimizeHidden, plugIn, function);
        if (!function.getAlphabetContainer().isReverseComplementable()) {
            throw new WrongAlphabetException("The given AlphabetContainer can not be used for building a reverse complement.");
        }
        if (forwardPartOfESS < 0.0 || forwardPartOfESS > 1.0) {
            throw new IllegalArgumentException("The part of the ESS for the forward strand has to be in [0,1].");
        }
        this.forwardPartOfESS = forwardPartOfESS;
        this.initMethod = initMethod;
        this.computeLogGammaSum();
    }

    public StrandScoringFunction(NormalizableScoringFunction function, double forwardPartOfESS, int starts, boolean plugIn, InitMethod initMethod) throws CloneNotSupportedException, WrongAlphabetException {
        this(function, starts, true, plugIn, forwardPartOfESS, initMethod);
    }

    public StrandScoringFunction(NormalizableScoringFunction function, int starts, boolean plugIn, InitMethod initMethod, double forward) throws CloneNotSupportedException, WrongAlphabetException {
        this(function, starts, false, plugIn, forward, initMethod);
        if (forward < 0.0 || forward > 1.0) {
            throw new IllegalArgumentException("The value for forward is no probability.");
        }
        this.setForwardProb(forward);
    }

    protected void setForwardProb(double forward) {
        this.hiddenParameter[0] = Math.log(forward);
        this.hiddenParameter[1] = Math.log(1.0 - forward);
        this.setHiddenParameters(this.hiddenParameter, 0);
    }

    public StrandScoringFunction(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    protected double getNormalizationConstantForComponent(int i) {
        return this.function[0].getNormalizationConstant();
    }

    public double getPartialNormalizationConstant(int parameterIndex) throws Exception {
        int[] ind;
        if (this.isNormalized) {
            return 0.0;
        }
        if (this.norm < 0.0) {
            this.precomputeNorm();
        }
        if ((ind = this.getIndices(parameterIndex))[0] == 1) {
            return this.hiddenPotential[ind[1]] * this.function[0].getNormalizationConstant();
        }
        return (this.hiddenPotential[0] + this.hiddenPotential[1]) * this.function[ind[0]].getPartialNormalizationConstant(ind[1]);
    }

    public double getHyperparameterForHiddenParameter(int index) {
        switch (index) {
            case 0: {
                return this.forwardPartOfESS * this.function[0].getEss();
            }
            case 1: {
                return (1.0 - this.forwardPartOfESS) * this.function[0].getEss();
            }
        }
        throw new IndexOutOfBoundsException();
    }

    public double getEss() {
        return this.function[0].getEss();
    }

    protected void initializeUsingPlugIn(int index, boolean freeParams, Sample[] data, double[][] weights) throws Exception {
        Sample myData = data[index];
        double[] stat = new double[2];
        switch (this.initMethod) {
            case INIT_BOTH_STRANDS: {
                double p;
                if (this.optimizeHidden) {
                    double[] h = new double[2];
                    if (this.getEss() == 0.0) {
                        h[1] = 1.0;
                        h[0] = 1.0;
                    } else {
                        h[0] = this.getHyperparameterForHiddenParameter(0);
                        h[1] = this.getHyperparameterForHiddenParameter(1);
                    }
                    p = DirichletMRG.DEFAULT_INSTANCE.generate(2, new DirichletMRGParams(h))[0];
                } else {
                    p = this.hiddenPotential[0] / (this.hiddenPotential[0] + this.hiddenPotential[1]);
                }
                if (myData != null) {
                    Sequence[] seqs = new Sequence[myData.getNumberOfElements()];
                    double w = 0.0;
                    for (int i = 0; i < seqs.length; ++i) {
                        if (weights != null && weights[index] != null) {
                            w = weights[index][i];
                        }
                        if (r.nextDouble() < p) {
                            stat[0] = stat[0] + w;
                            seqs[i] = myData.getElementAt(i);
                            continue;
                        }
                        stat[1] = stat[1] + w;
                        seqs[i] = myData.getElementAt(i).reverseComplement();
                    }
                    data[index] = new Sample("strand scrambled", seqs);
                    break;
                }
                data[index] = null;
                break;
            }
            case INIT_BACKWARD_STRAND: {
                Sequence[] rcs = new Sequence[myData.getNumberOfElements()];
                for (int i = 0; i < rcs.length; ++i) {
                    rcs[i] = myData.getElementAt(i).reverseComplement();
                }
                data[index] = new Sample("backward strand", rcs);
            }
            default: {
                stat[0] = this.forwardPartOfESS;
                stat[1] = 1.0 - this.forwardPartOfESS;
            }
        }
        this.function[0].initializeFunction(index, freeParams, data, weights);
        data[index] = myData;
        if (this.optimizeHidden) {
            this.computeHiddenParameter(stat);
        }
    }

    public String getInstanceName() {
        String erg = "strand-mixture(" + this.function[0].getInstanceName();
        if (!this.optimizeHidden) {
            erg = erg + ", " + Arrays.toString(this.hiddenPotential);
        }
        return erg + ")";
    }

    protected void fillComponentScores(Sequence seq, int start) {
        block4: {
            this.componentScore[0] = this.logHiddenPotential[0] + this.function[0].getLogScore(seq, start);
            try {
                if (this.length != 0) {
                    this.componentScore[1] = this.logHiddenPotential[1] + this.function[0].getLogScore(seq.reverseComplement(), seq.getLength() - start - this.length);
                    break block4;
                }
                if (start == 0) {
                    this.componentScore[1] = this.logHiddenPotential[1] + this.function[0].getLogScore(seq.reverseComplement(), 0);
                    break block4;
                }
                throw new Exception("strand scoring for variable length function");
            }
            catch (Exception doesNotHappen) {
                RuntimeException r = new RuntimeException(doesNotHappen.getClass().getName() + ": " + doesNotHappen.getMessage());
                r.setStackTrace(doesNotHappen.getStackTrace());
                throw r;
            }
        }
    }

    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        int j;
        int i;
        block7: {
            this.iList[0].clear();
            this.dList[0].clear();
            this.componentScore[0] = this.logHiddenPotential[0] + this.function[0].getLogScoreAndPartialDerivation(seq, start, this.iList[0], this.dList[0]);
            this.iList[1].clear();
            this.dList[1].clear();
            try {
                if (this.length != 0) {
                    this.componentScore[1] = this.logHiddenPotential[1] + this.function[0].getLogScoreAndPartialDerivation(seq.reverseComplement(), seq.getLength() - start - this.length, this.iList[1], this.dList[1]);
                    break block7;
                }
                if (start == 0) {
                    this.componentScore[1] = this.logHiddenPotential[1] + this.function[0].getLogScoreAndPartialDerivation(seq.reverseComplement(), 0, this.iList[1], this.dList[1]);
                    break block7;
                }
                throw new Exception("strand scoring for variable length function");
            }
            catch (Exception doesNotHappen) {
                RuntimeException r = new RuntimeException(doesNotHappen.getClass().getName() + ": " + doesNotHappen.getMessage());
                r.setStackTrace(doesNotHappen.getStackTrace());
                throw r;
            }
        }
        double logScore = Normalisation.logSumNormalisation(this.componentScore, 0, 2, this.componentScore, 0);
        for (i = 0; i < this.logHiddenPotential.length; ++i) {
            for (j = 0; j < this.iList[i].length(); ++j) {
                indices.add(this.iList[i].get(j));
                partialDer.add(this.componentScore[i] * this.dList[i].get(j));
            }
        }
        i = this.paramRef[2] - this.paramRef[1];
        for (j = 0; j < i; ++j) {
            indices.add(this.paramRef[1] + j);
            partialDer.add(this.componentScore[j] - (this.isNormalized ? this.hiddenPotential[j] : 0.0));
        }
        return logScore;
    }

    protected StringBuffer getFurtherInformation() {
        StringBuffer erg = new StringBuffer(100);
        XMLParser.appendDoubleWithTags(erg, this.forwardPartOfESS, "forwardPartOfESS");
        XMLParser.appendEnumWithTags(erg, this.initMethod, "initMethod");
        return erg;
    }

    protected void extractFurtherInformation(StringBuffer xml) throws NonParsableException {
        this.forwardPartOfESS = XMLParser.extractDoubleForTag(xml, "forwardPartOfESS");
        this.initMethod = (InitMethod)((Object)XMLParser.extractEnumForTag(xml, "initMethod"));
    }

    protected void init(boolean freeParams) {
        super.init(freeParams);
    }

    public String toString() {
        StringBuffer erg = new StringBuffer(1500);
        double d = this.hiddenPotential[0] + this.hiddenPotential[1];
        erg.append("forward: " + this.hiddenPotential[0] / d + "\n");
        erg.append("reverse: " + this.hiddenPotential[1] / d + "\n\n");
        erg.append(this.function[0].toString());
        return erg.toString();
    }

    public boolean modify(double[] weightsLeft, double[] weightsRight, double[][][][] replacementLeft, double[][][][] replacementRight, int offsetLeft, int offsetRight) {
        if (this.function[0] instanceof Mutable) {
            Container[] cont = this.getDistribtions(weightsLeft, weightsRight, replacementLeft, replacementRight);
            boolean modified = ((Mutable)((Object)this.function[0])).modify(cont[0].weights, cont[1].weights, cont[0].distribution, cont[1].distribution, offsetLeft, offsetRight);
            if (modified) {
                this.length = this.function[0].getLength();
                this.init(this.freeParams);
            }
            return modified;
        }
        return false;
    }

    public int[] determineNotSignificantPositions(double samples, double[] weightsLeft, double[] weightsRight, double[][][][] contrastLeft, double[][][][] contrastRight, double sign) {
        if (this.function[0] instanceof Mutable) {
            Container[] cont = this.getDistribtions(weightsLeft, weightsRight, contrastLeft, contrastRight);
            return ((Mutable)((Object)this.function[0])).determineNotSignificantPositions(samples, cont[0].weights, cont[1].weights, cont[0].distribution, cont[1].distribution, sign);
        }
        return new int[2];
    }

    private Container[] getDistribtions(double[] weightsLeft, double[] weightsRight, double[][][][] leftContrast, double[][][][] rightContrast) {
        Container[] cont = new Container[2];
        for (int i = 0; i < 2; ++i) {
            cont[i] = new Container();
            cont[i].setLength(weightsLeft.length + weightsRight.length);
        }
        this.fill(0, cont[0], cont[1], weightsLeft, weightsRight, leftContrast, rightContrast);
        this.fill(weightsLeft.length, cont[1], cont[0], weightsRight, weightsLeft, rightContrast, leftContrast);
        return cont;
    }

    private void fill(int idx, Container contForw, Container contBack, double[] weightsForw, double[] weightsBack, double[][][][] contrastForw, double[][][][] contrastBack) {
        double norm = this.hiddenPotential[0] + this.hiddenPotential[1];
        double forward = this.hiddenPotential[0] / norm;
        double backward = this.hiddenPotential[1] / norm;
        ComplementableDiscreteAlphabet comp = (ComplementableDiscreteAlphabet)this.alphabets.getAlphabetAt(0);
        int i = 0;
        while (i < weightsForw.length) {
            ((Container)contForw).weights[idx] = forward * weightsForw[i];
            ((Container)contForw).distribution[idx] = contrastForw[i];
            ((Container)contBack).weights[idx] = backward * weightsForw[i];
            ((Container)contBack).distribution[idx] = StrandScoringFunction.getReverseComplementDistributions(comp, contrastForw[i]);
            ++i;
            ++idx;
        }
    }

    public static double[][][] getReverseComplementDistributions(ComplementableDiscreteAlphabet abc, double[][][] condDistr) {
        double joint;
        int h;
        int i;
        int o;
        int l = (int)abc.length();
        int idx = 0;
        int ord = condDistr.length;
        int anz = condDistr[ord - 1].length * l;
        double[][][] result = new double[ord][][];
        for (o = 0; o < ord; ++o) {
            result[o] = new double[condDistr[o].length][l];
        }
        int[] assign = new int[ord];
        while (idx < anz) {
            i = idx;
            for (o = ord - 1; o >= 0; --o) {
                assign[o] = i % l;
                i /= l;
            }
            h = 0;
            joint = 1.0;
            for (o = 0; o < ord; ++o) {
                joint *= condDistr[o][h][assign[o]];
                h = h * l + assign[o];
            }
            for (o = 0; o < ord; ++o) {
                assign[o] = abc.getComplementaryCode(assign[o]);
            }
            h = 0;
            for (o = 0; o < ord; ++o) {
                double[] dArray = result[o][h];
                int n = assign[ord - 1 - o];
                dArray[n] = dArray[n] + joint;
                h = h * l + assign[ord - 1 - o];
            }
            ++idx;
        }
        for (o = 1; o < ord; ++o) {
            for (h = 0; h < result[o].length; ++h) {
                joint = 0.0;
                for (i = 0; i < l; ++i) {
                    joint += result[o][h][i];
                }
                i = 0;
                while (i < l) {
                    double[] dArray = result[o][h];
                    int n = i++;
                    dArray[n] = dArray[n] / joint;
                }
            }
        }
        return result;
    }

    public StrandedLocatedSequenceAnnotationWithLength.Strand getStrand(Sequence seq, int startPos) {
        return this.getIndexOfMaximalComponentFor(seq, startPos) == 0 ? StrandedLocatedSequenceAnnotationWithLength.Strand.FORWARD : StrandedLocatedSequenceAnnotationWithLength.Strand.REVERSE;
    }

    private static class Container {
        private double[] weights;
        private double[][][][] distribution;

        private Container() {
        }

        private void setLength(int len) {
            this.weights = new double[len];
            this.distribution = new double[this.weights.length][][][];
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    public static enum InitMethod {
        INIT_FORWARD_STRAND,
        INIT_BACKWARD_STRAND,
        INIT_BOTH_STRANDS;

    }
}

