/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.models.mixture.motif;

import de.jstacs.NonParsableException;
import de.jstacs.WrongAlphabetException;
import de.jstacs.algorithms.optimization.termination.TerminationCondition;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.models.Model;
import de.jstacs.models.mixture.AbstractMixtureModel;
import de.jstacs.models.mixture.StrandModel;
import de.jstacs.models.mixture.gibbssampling.BurnInTest;
import de.jstacs.models.mixture.motif.HiddenMotifMixture;
import de.jstacs.models.mixture.motif.positionprior.PositionPrior;
import de.jstacs.models.mixture.motif.positionprior.UniformPositionPrior;
import de.jstacs.motifDiscovery.MotifDiscoverer;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.MRGParams;
import de.jstacs.utils.random.MultivariateRandomGenerator;
import java.util.Arrays;
import java.util.LinkedList;

public class SingleHiddenMotifMixture
extends HiddenMotifMixture {
    private int[] refBgSample;
    private boolean trainOnlyMotifModel;
    protected byte bgMaxMarkovOrder;

    protected SingleHiddenMotifMixture(Model motif, Model bg, boolean trainOnlyMotifModel, int starts, double[] componentHyperParams, double[] weights, PositionPrior posPrior, AbstractMixtureModel.Algorithm algorithm, double alpha, TerminationCondition tc, AbstractMixtureModel.Parameterization parametrization, int initialIteration, int stationaryIteration, BurnInTest burnInTest) throws CloneNotSupportedException, IllegalArgumentException, WrongAlphabetException {
        super(new Model[]{motif, bg}, new boolean[]{true, !trainOnlyMotifModel}, 2, starts, weights == null, componentHyperParams, weights, posPrior == null ? new UniformPositionPrior() : posPrior, algorithm, alpha, tc, parametrization, initialIteration, stationaryIteration, burnInTest);
        this.bgMaxMarkovOrder = this.model[1].getMaximalMarkovOrder();
        this.trainOnlyMotifModel = trainOnlyMotifModel;
    }

    public SingleHiddenMotifMixture(Model motif, Model bg, boolean trainOnlyMotifModel, int starts, double[] componentHyperParams, PositionPrior posPrior, double alpha, TerminationCondition tc, AbstractMixtureModel.Parameterization parametrization) throws CloneNotSupportedException, IllegalArgumentException, WrongAlphabetException {
        this(motif, bg, trainOnlyMotifModel, starts, componentHyperParams, null, posPrior, AbstractMixtureModel.Algorithm.EM, alpha, tc, parametrization, 0, 0, null);
    }

    public SingleHiddenMotifMixture(Model motif, Model bg, boolean trainOnlyMotifModel, int starts, double motifProb, PositionPrior posPrior, double alpha, TerminationCondition tc, AbstractMixtureModel.Parameterization parametrization) throws CloneNotSupportedException, IllegalArgumentException, WrongAlphabetException {
        this(motif, bg, trainOnlyMotifModel, starts, null, new double[]{motifProb, 1.0 - motifProb}, posPrior, AbstractMixtureModel.Algorithm.EM, alpha, tc, parametrization, 0, 0, null);
    }

    public SingleHiddenMotifMixture(StringBuffer xml) throws NonParsableException {
        super(xml);
        int i;
        this.bgMaxMarkovOrder = this.model[1].getMaximalMarkovOrder();
        for (i = 1; i < this.model.length && !this.optimizeModel[i]; ++i) {
        }
        this.trainOnlyMotifModel = i == this.model.length;
    }

    @Override
    protected void setTrainData(Sample data) throws Exception {
        LinkedList<Sequence> fg = new LinkedList<Sequence>();
        LinkedList<Sequence> bg = new LinkedList<Sequence>();
        int i = 0;
        int motifLength = this.getMotifLength(0);
        boolean rev = this.model[0] instanceof StrandModel;
        this.refBgSample = new int[data.getNumberOfElements() + 1];
        while (i < data.getNumberOfElements()) {
            Sequence s = data.getElementAt(i);
            if (rev) {
                s.reverseComplement();
            }
            if (!this.trainOnlyMotifModel) {
                bg.add(s);
            }
            int l = s.getLength() - motifLength;
            for (int start = 0; start <= l; ++start) {
                fg.add(s.getSubSequence(start, motifLength));
                if (this.trainOnlyMotifModel) continue;
                bg.add(s.getSubSequence(0, start));
                bg.add(s.getSubSequence(start + motifLength));
            }
            int n = this.trainOnlyMotifModel ? ++i : bg.size();
            this.refBgSample[i] = n;
        }
        Sequence[] empty = new Sequence[]{};
        this.sample = new Sample[]{new Sample("possible motifs", fg.toArray(empty)), data};
        if (bg.size() != 0) {
            this.sample[1] = new Sample("possible background", bg.toArray(empty));
        }
    }

    @Override
    protected double[][] createSeqWeightsArray() {
        if (this.trainOnlyMotifModel) {
            return new double[][]{new double[this.sample[0].getNumberOfElements()]};
        }
        return new double[][]{new double[this.sample[0].getNumberOfElements()], new double[this.sample[1].getNumberOfElements()]};
    }

    @Override
    protected double[][] doFirstIteration(double[] dataWeights, MultivariateRandomGenerator m, MRGParams[] params) throws Exception {
        int i;
        int fgStart = 0;
        int bgStart = 0;
        int l = this.refBgSample.length - 1;
        int ml = this.getMotifLength(0) - 1;
        double[][] seqweights = this.createSeqWeightsArray();
        double[] w = new double[2];
        this.initWithPrior(w);
        if (!this.estimateComponentProbs && this.weights[0] == 1.0) {
            for (i = 0; i < l; ++i) {
                double d;
                int len = this.sample[1].getElementAt(this.refBgSample[i]).getLength() - ml;
                m.generate(seqweights[0], fgStart, len, params[i]);
                double d2 = d = dataWeights == null ? 1.0 : dataWeights[i];
                if (this.trainOnlyMotifModel) {
                    len = fgStart + len;
                    while (fgStart < len) {
                        double[] dArray = seqweights[0];
                        int n = fgStart++;
                        dArray[n] = dArray[n] * d;
                    }
                    continue;
                }
                seqweights[1][bgStart++] = 0.0;
                len = fgStart + len;
                while (fgStart < len) {
                    double[] dArray = seqweights[0];
                    int n = fgStart;
                    dArray[n] = dArray[n] * d;
                    int n2 = bgStart++;
                    int n3 = bgStart++;
                    double d3 = seqweights[0][fgStart];
                    seqweights[1][n3] = d3;
                    seqweights[1][n2] = d3;
                    ++fgStart;
                }
            }
        } else {
            --ml;
            while (i < l) {
                int j;
                int len = this.sample[1].getElementAt(this.refBgSample[i]).getLength() - ml;
                double[] helpArray = m.generate(len, params[i]);
                double d = dataWeights == null ? 1.0 : dataWeights[i];
                w[0] = w[0] + (1.0 - helpArray[0]) * d;
                w[1] = w[1] + helpArray[0] * d;
                if (this.trainOnlyMotifModel) {
                    j = 1;
                    while (j < helpArray.length) {
                        seqweights[0][fgStart] = helpArray[j] * d;
                        ++j;
                        ++fgStart;
                    }
                } else {
                    seqweights[1][bgStart++] = helpArray[0] * d;
                    j = 1;
                    while (j < helpArray.length) {
                        seqweights[0][fgStart] = helpArray[j] * d;
                        int n = bgStart++;
                        int n4 = bgStart++;
                        double d4 = seqweights[0][fgStart];
                        seqweights[1][n4] = d4;
                        seqweights[1][n] = d4;
                        ++j;
                        ++fgStart;
                    }
                }
                ++i;
            }
        }
        this.getNewParameters(0, seqweights, w);
        return seqweights;
    }

    @Override
    protected double getNewWeights(double[] dataWeights, double[] w, double[][] seqweights) throws Exception {
        double ll = 0.0;
        double currentWeight = 1.0;
        int seqIndex = 0;
        int motifLength = this.getMotifLength(0);
        int w0Index = 0;
        int w1Index = 0;
        double[] help = new double[2];
        this.initWithPrior(w);
        while (w0Index < this.sample[0].getNumberOfElements()) {
            Sequence seq = this.sample[1].getElementAt(this.refBgSample[seqIndex]);
            if (dataWeights != null) {
                currentWeight = dataWeights[seqIndex];
            }
            int l = seq.getLength();
            double logPSeq = this.model[1].getLogProbFor(seq, 0, l - 1);
            int end = l - motifLength;
            int b1 = -this.bgMaxMarkovOrder;
            int b2 = motifLength + this.bgMaxMarkovOrder - 1;
            int start = w0Index;
            int j = 0;
            while (j <= end) {
                int s = Math.max(b1, 0);
                int e = Math.min(b2, l - 1);
                seqweights[0][w0Index] = this.posPrior.getLogPriorForPositions(l, j) + this.model[0].getLogProbFor(this.sample[0].getElementAt(w0Index), 0, motifLength - 1) - this.model[1].getLogProbFor(seq, s, e) + this.model[1].getLogProbFor(seq, s, j - 1) + this.model[1].getLogProbFor(seq, j + motifLength, e);
                ++j;
                ++b1;
                ++b2;
                ++w0Index;
            }
            ll += currentWeight * (logPSeq + this.modify(help, seqweights[0], start, w0Index));
            help[0] = help[0] * currentWeight;
            w[0] = w[0] + help[0];
            w[1] = w[1] + currentWeight * help[1];
            if (this.trainOnlyMotifModel) {
                while (start < w0Index) {
                    double[] dArray = seqweights[0];
                    int n = start++;
                    dArray[n] = dArray[n] * help[0];
                }
            } else {
                seqweights[1][w1Index++] = currentWeight * help[1];
                while (start < w0Index) {
                    double[] dArray = seqweights[0];
                    int n = start;
                    dArray[n] = dArray[n] * help[0];
                    int n2 = w1Index++;
                    int n3 = w1Index++;
                    double d = seqweights[0][start];
                    seqweights[1][n3] = d;
                    seqweights[1][n2] = d;
                    ++start;
                }
            }
            ++seqIndex;
        }
        return ll;
    }

    protected double modify(double[] containsMotif, double[] startpos, int start, int end) {
        switch (this.algorithm) {
            case EM: {
                containsMotif[0] = this.logWeights[0] + Normalisation.logSumNormalisation(startpos, start, end, startpos, start);
                containsMotif[1] = this.logWeights[1];
                return Normalisation.logSumNormalisation(containsMotif, 0, 2, containsMotif, 0);
            }
            case GIBBS_SAMPLING: {
                throw new IllegalArgumentException("Gibbs Sampling currently not implemented.");
            }
        }
        throw new IllegalArgumentException("The type of algorithm is unknown.");
    }

    @Override
    protected double getLogProbUsingCurrentParameterSetFor(int component, Sequence seq, int start, int end) throws Exception {
        switch (component) {
            case 0: {
                int current = 0;
                int l = end - start + 1;
                int motifLength = this.getMotifLength(0);
                int m = l - motifLength;
                int b1 = -this.bgMaxMarkovOrder;
                int b2 = motifLength + this.bgMaxMarkovOrder - 1;
                double all = this.model[1].getLogProbFor(seq, start, end);
                double res = Double.NEGATIVE_INFINITY;
                while (current <= m) {
                    int s = Math.max(b1, 0);
                    int e = Math.min(b2, l - 1);
                    res = Normalisation.getLogSum(res, this.posPrior.getLogPriorForPositions(l, current) + this.model[0].getLogProbFor(seq, start + current) - this.model[1].getLogProbFor(seq, start + s, start + e) + this.model[1].getLogProbFor(seq, start + s, start + current - 1) + this.model[1].getLogProbFor(seq, start + current + motifLength, start + e));
                    ++current;
                    ++b1;
                    ++b2;
                }
                return all + this.logWeights[0] + res;
            }
            case 1: {
                return this.logWeights[1] + this.model[1].getLogProbFor(seq, start, end);
            }
        }
        throw new IndexOutOfBoundsException("This model has only two components (0=motif, 1=no motif).");
    }

    @Override
    public double[] getProfileOfScoresFor(int component, int motif, Sequence sequence, int startpos, MotifDiscoverer.KindOfProfile kind) throws Exception {
        if (component == 0 && motif == 0) {
            int motifLength = this.getMotifLength(motif);
            int l = sequence.getLength() - startpos;
            int len = l - motifLength + 1;
            switch (this.algorithm) {
                case EM: {
                    double d = kind == MotifDiscoverer.KindOfProfile.UNNORMALIZED_JOINT ? this.logWeights[component] : 0.0;
                    double[] weights = new double[len];
                    int b1 = -this.bgMaxMarkovOrder;
                    int b2 = motifLength + this.bgMaxMarkovOrder - 1;
                    int current = 0;
                    while (current < len) {
                        int s = Math.max(b1, 0);
                        int e = Math.min(b2, l - 1);
                        weights[current] = d + this.posPrior.getLogPriorForPositions(l, current + startpos) + this.model[0].getLogProbFor(sequence, current + startpos) - this.model[1].getLogProbFor(sequence, s, e) + this.model[1].getLogProbFor(sequence, s, current + startpos - 1) + this.model[1].getLogProbFor(sequence, current + startpos + motifLength, e);
                        ++current;
                        ++b1;
                        ++b2;
                    }
                    if (kind == MotifDiscoverer.KindOfProfile.NORMALIZED_CONDITIONAL) {
                        d = Normalisation.getLogSum(0, weights.length, weights);
                        current = 0;
                        while (current < len) {
                            int n = current++;
                            weights[n] = weights[n] - d;
                            ++b1;
                            ++b2;
                        }
                    }
                    return weights;
                }
                case GIBBS_SAMPLING: {
                    throw new IllegalArgumentException("Gibbs Sampling currently not implemented.");
                }
            }
            throw new IllegalArgumentException("The type of algorithm is unknown.");
        }
        throw new IndexOutOfBoundsException();
    }

    @Override
    public int getMinimalSequenceLength() {
        if (this.estimateComponentProbs || this.weights[1] != 0.0) {
            return 0;
        }
        return this.model[0].getLength();
    }

    @Override
    public int getMotifLength(int motif) {
        if (motif == 0) {
            return this.model[0].getLength();
        }
        throw new IndexOutOfBoundsException();
    }

    @Override
    public int getNumberOfMotifs() {
        return 1;
    }

    @Override
    public int getNumberOfMotifsInComponent(int component) {
        switch (component) {
            case 0: {
                return 1;
            }
            case 1: {
                return 0;
            }
        }
        throw new IndexOutOfBoundsException();
    }

    @Override
    public double[] getStrandProbabilitiesFor(int component, int motif, Sequence sequence, int startpos) throws Exception {
        if (component == 0 && motif == 0) {
            if (this.model[0] instanceof StrandModel) {
                Sequence help = sequence.getSubSequence(startpos, this.model[0].getLength());
                double[] logProbs = new double[]{((StrandModel)this.model[0]).getLogProbFor(0, help), ((StrandModel)this.model[0]).getLogProbFor(1, help)};
                Normalisation.logSumNormalisation(logProbs);
                return logProbs;
            }
            return new double[]{1.0, 0.0};
        }
        throw new IndexOutOfBoundsException();
    }

    @Override
    public int getGlobalIndexOfMotifInComponent(int component, int motif) {
        if (component == 0 && motif == 0) {
            return 0;
        }
        throw new IndexOutOfBoundsException();
    }

    @Override
    public void trainBgModel(Sample data, double[] weights) throws Exception {
        this.model[1].train(data, weights);
    }

    private void estimateShiftedParameters(int shift, double[][] originalWeights, double[][] newWeights) throws Exception {
        for (int i = 0; i < newWeights.length; ++i) {
            Arrays.fill(newWeights[i], 0.0);
        }
        int motifLength = this.model[0].getLength();
        int i = 0;
        int seqIndex = 0;
        int bgIndex = 0;
        while (i < this.sample[0].getNumberOfElements()) {
            int start = i;
            int end = start + this.sample[1].getElementAt(this.refBgSample[seqIndex++]).getLength() - motifLength;
            int j = 0;
            while (i <= end) {
                double[] strandProbs = this.getStrandProbabilitiesFor(0, 0, this.sample[0].getElementAt(i), 0);
                double[] dArray = newWeights[0];
                int n = this.getIndexForCircularShift(start, end, i - shift);
                dArray[n] = dArray[n] + strandProbs[0] * originalWeights[0][i];
                double[] dArray2 = newWeights[0];
                int n2 = this.getIndexForCircularShift(start, end, i + shift);
                dArray2[n2] = dArray2[n2] + strandProbs[1] * originalWeights[0][i];
                ++j;
                ++i;
            }
            if (this.trainOnlyMotifModel) continue;
            newWeights[1][bgIndex] = originalWeights[1][bgIndex];
            ++bgIndex;
            while (start < end) {
                int n = bgIndex++;
                int n3 = bgIndex++;
                double d = newWeights[0][start];
                newWeights[1][n3] = d;
                newWeights[1][n] = d;
                ++start;
            }
        }
        for (i = 0; i < newWeights.length; ++i) {
            this.getNewParametersForModel(i, 1, i, newWeights[i]);
        }
    }

    private int getIndexForCircularShift(int start, int end, int proposal) {
        if (proposal < start) {
            proposal = end - (start - proposal) - 1;
        }
        if (proposal > end) {
            proposal = start + (proposal - end) - 1;
        }
        return proposal;
    }

    @Override
    protected double iterate(int start, double[] dataWeights, MultivariateRandomGenerator m, MRGParams[] params) throws Exception {
        this.sostream.writeln("========== start: " + start + " ==========");
        switch (this.algorithm) {
            case EM: {
                int shift;
                this.seqWeights = this.doFirstIteration(dataWeights, m, params);
                double[][] help = null;
                Cloneable[] backup = null;
                int ml = this.model[0].getLength() / 2;
                do {
                    this.best = this.continueIterations(dataWeights, this.seqWeights);
                    if (help == null) {
                        help = this.createSeqWeightsArray();
                    }
                    backup = (Model[])ArrayHandler.clone((Cloneable[])this.model);
                    shift = 0;
                    double best = Double.NEGATIVE_INFINITY;
                    for (int i = -ml; i <= ml; ++i) {
                        this.estimateShiftedParameters(i, this.seqWeights, help);
                        double current = this.continueIterations(dataWeights, help, 0, 0);
                        if (best < current) {
                            shift = i;
                            best = current;
                        }
                        this.sostream.writeln(i + "\t" + current + "\t" + shift);
                        this.model = (Model[])ArrayHandler.clone((Cloneable[])backup);
                    }
                    if (shift == 0) continue;
                    this.estimateShiftedParameters(shift, this.seqWeights, help);
                } while (shift != 0);
                break;
            }
            case GIBBS_SAMPLING: {
                this.extendSampling(start);
                this.burnInTest.setCurrentSamplingIndex(start);
                this.seqWeights = this.doFirstIteration(dataWeights, m, params);
                this.samplingStopped();
                this.continueIterations(dataWeights, this.seqWeights, this.initialIteration, start);
                break;
            }
            default: {
                throw new IllegalArgumentException("The type of algorithm is unknown.");
            }
        }
        this.algorithmHasBeenRun = true;
        return this.best;
    }
}

