/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.mixture;

import de.jstacs.NotTrainedException;
import de.jstacs.algorithms.optimization.termination.TerminationCondition;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.sampling.BurnInTest;
import de.jstacs.sequenceScores.statisticalModels.trainable.TrainableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.trainable.mixture.AbstractMixtureTrainSM;
import de.jstacs.utils.random.MRGParams;
import de.jstacs.utils.random.MultivariateRandomGenerator;
import java.util.Random;

public class StrandTrainSM
extends AbstractMixtureTrainSM {
    protected StrandTrainSM(TrainableStatisticalModel model, int starts, boolean estimateComponentProbs, double[] componentHyperParams, double forwardStrandProb, AbstractMixtureTrainSM.Algorithm algorithm, double alpha, TerminationCondition tc, AbstractMixtureTrainSM.Parameterization parametrization, int initialIteration, int stationaryIteration, BurnInTest burnInTest) throws CloneNotSupportedException, IllegalArgumentException, WrongAlphabetException {
        super(model.getLength(), new TrainableStatisticalModel[]{model}, null, 2, starts, estimateComponentProbs, componentHyperParams, new double[]{forwardStrandProb, 1.0 - forwardStrandProb}, algorithm, alpha, tc, parametrization, initialIteration, stationaryIteration, burnInTest);
        if (!this.alphabets.isReverseComplementable()) {
            throw new WrongAlphabetException("The given model uses an AlphabetContainer that can not be used for building a reverse complement.");
        }
    }

    public StrandTrainSM(TrainableStatisticalModel model, int starts, double[] componentHyperParams, double alpha, TerminationCondition tc, AbstractMixtureTrainSM.Parameterization parametrization) throws CloneNotSupportedException, IllegalArgumentException, WrongAlphabetException {
        this(model, starts, true, componentHyperParams, 0.5, AbstractMixtureTrainSM.Algorithm.EM, alpha, tc, parametrization, 0, 0, null);
    }

    public StrandTrainSM(TrainableStatisticalModel model, int starts, double forwardStrandProb, double alpha, TerminationCondition tc, AbstractMixtureTrainSM.Parameterization parametrization) throws CloneNotSupportedException, IllegalArgumentException, WrongAlphabetException {
        this(model, starts, false, null, forwardStrandProb, AbstractMixtureTrainSM.Algorithm.EM, alpha, tc, parametrization, 0, 0, null);
    }

    public StrandTrainSM(TrainableStatisticalModel model, int starts, double[] componentHyperParams, int initialIteration, int stationaryIteration, BurnInTest burnInTest) throws CloneNotSupportedException, IllegalArgumentException, WrongAlphabetException {
        this(model, starts, true, componentHyperParams, 0.5, AbstractMixtureTrainSM.Algorithm.GIBBS_SAMPLING, 0.0, null, AbstractMixtureTrainSM.Parameterization.LAMBDA, initialIteration, stationaryIteration, burnInTest);
    }

    public StrandTrainSM(TrainableStatisticalModel model, int starts, double forwardStrandProb, int initialIteration, int stationaryIteration, BurnInTest burnInTest) throws CloneNotSupportedException, IllegalArgumentException, WrongAlphabetException {
        this(model, starts, false, null, forwardStrandProb, AbstractMixtureTrainSM.Algorithm.GIBBS_SAMPLING, 0.0, null, AbstractMixtureTrainSM.Parameterization.LAMBDA, initialIteration, stationaryIteration, burnInTest);
    }

    public StrandTrainSM(StringBuffer stringBuff) throws NonParsableException {
        super(stringBuff);
    }

    @Override
    public void setTrainData(DataSet s) throws Exception {
        int n = s.getNumberOfElements();
        Sequence[] seq = new Sequence[2 * n];
        for (int i = 0; i < n; ++i) {
            seq[2 * i] = s.getElementAt(i);
            seq[2 * i + 1] = seq[2 * i].reverseComplement();
        }
        this.sample = new DataSet[]{new DataSet("sample of both strands from " + s.getAnnotation(), seq)};
    }

    @Override
    protected double[][] doFirstIteration(double[] dataWeights, MultivariateRandomGenerator m, MRGParams[] params) throws Exception {
        int d = this.sample[0].getNumberOfElements();
        double[][] seqweights = this.createSeqWeightsArray();
        double[] w = new double[2];
        this.initWithPrior(w);
        d /= 2;
        if (dataWeights == null) {
            for (int counter1 = 0; counter1 < d; ++counter1) {
                m.generate(seqweights[0], 2 * counter1, 2, params[counter1]);
                for (int counter2 = 0; counter2 < 2; ++counter2) {
                    int n = counter2;
                    w[n] = w[n] + seqweights[0][2 * counter1 + counter2];
                }
            }
        } else {
            double[] help = new double[2];
            for (int counter1 = 0; counter1 < d; ++counter1) {
                help = m.generate(2, params[counter1]);
                for (int counter2 = 0; counter2 < 2; ++counter2) {
                    seqweights[0][counter2 + 2 * counter1] = dataWeights[counter1] * help[counter2];
                    int n = counter2;
                    w[n] = w[n] + seqweights[0][counter2 + 2 * counter1];
                }
            }
        }
        this.getNewParameters(0, seqweights, w);
        return seqweights;
    }

    @Override
    protected double getNewWeights(double[] dataWeights, double[] w, double[][] seqweights) throws Exception {
        double L = 0.0;
        double currentWeight = 1.0;
        int counter1 = 0;
        int counter2 = 0;
        this.initWithPrior(w);
        double[] help = new double[2];
        while (counter1 < seqweights[0].length) {
            if (dataWeights != null) {
                currentWeight = dataWeights[counter2++];
            }
            help[0] = this.model[0].getLogProbFor(this.sample[0].getElementAt(counter1)) + this.logWeights[0];
            help[1] = this.model[0].getLogProbFor(this.sample[0].getElementAt(counter1 + 1)) + this.logWeights[1];
            L += this.modifyWeights(help) * currentWeight;
            seqweights[0][counter1] = help[0] * currentWeight;
            w[0] = w[0] + seqweights[0][counter1++];
            seqweights[0][counter1] = help[1] * currentWeight;
            w[1] = w[1] + seqweights[0][counter1++];
        }
        return L;
    }

    @Override
    public String toString() {
        StringBuffer sb = new StringBuffer(this.model.length * 100000);
        sb.append("Strand model with parameter estimation by " + this.getNameOfAlgorithm() + ": \n");
        sb.append("number of starts:\t" + this.starts + "\n");
        switch (this.algorithm) {
            case EM: {
                sb.append(this.weights[0] + "\tforward strand\n");
                sb.append(this.weights[1] + "\tbackward strand\n\n");
                sb.append(((Object)this.model[0]).toString());
                break;
            }
            case GIBBS_SAMPLING: {
                sb.append("burn in test              :\t" + this.burnInTest.getInstanceName() + "\n");
                sb.append("length of stationary phase:\t" + this.stationaryIteration + "\n");
                sb.append("strand model component:" + this.model[0].getInstanceName() + "\n");
                break;
            }
            default: {
                throw new IllegalArgumentException("The type of algorithm is unknown.");
            }
        }
        return sb.toString();
    }

    @Override
    protected Sequence[] emitDataSetUsingCurrentParameterSet(int n, int ... lengths) throws NotTrainedException, Exception {
        DataSet nr = this.model[0].emitDataSet(n, lengths);
        Random r = new Random();
        Sequence[] seq = new Sequence[nr.getNumberOfElements()];
        for (int i = 0; i < seq.length; ++i) {
            seq[i] = r.nextDouble() < this.weights[0] ? nr.getElementAt(i) : nr.getElementAt(i).reverseComplement();
        }
        return seq;
    }

    @Override
    protected double getLogProbUsingCurrentParameterSetFor(int component, Sequence s, int start, int end) throws Exception {
        switch (component) {
            case 0: {
                return this.logWeights[0] + this.model[0].getLogProbFor(s, start, end);
            }
            case 1: {
                return this.logWeights[1] + this.model[0].getLogProbFor(s.reverseComplement(), s.getLength() - end - 1, s.getLength() - start - 1);
            }
        }
        throw new IndexOutOfBoundsException("component has to be in [0,1]; 0 = forward strand, 1 = backward strand");
    }
}

