/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.discrete.homogeneous;

import de.jstacs.NotTrainedException;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.EmptyDataSetException;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.alphabets.DiscreteAlphabet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.WrongSequenceTypeException;
import de.jstacs.io.NonParsableException;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.Constraint;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.DGTrainSMParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.DiscreteGraphicalTrainSM;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.homogeneous.parameters.HomogeneousTrainSMParameterSet;
import java.util.Random;

public abstract class HomogeneousTrainSM
extends DiscreteGraphicalTrainSM {
    protected int[] powers;
    protected byte order;

    public HomogeneousTrainSM(HomogeneousTrainSMParameterSet params) throws CloneNotSupportedException, IllegalArgumentException, NonParsableException {
        super(params);
    }

    public HomogeneousTrainSM(StringBuffer stringBuff) throws NonParsableException {
        super(stringBuff);
    }

    @Override
    public final DataSet emitDataSet(int no, int ... length) throws NotTrainedException, IllegalArgumentException, EmptyDataSetException, WrongAlphabetException, WrongSequenceTypeException {
        if (!this.trained) {
            throw new NotTrainedException();
        }
        Sequence[] seq = new Sequence[no];
        if (length.length == 1) {
            int i = 0;
            while (i < no) {
                seq[i] = this.getRandomSequence(new Random(), length[0]);
                ++i;
            }
        } else if (length.length == no) {
            int i = 0;
            while (i < no) {
                seq[i] = this.getRandomSequence(new Random(), length[i]);
                ++i;
            }
        } else {
            throw new IllegalArgumentException("The dimension of the array length is not correct.");
        }
        return new DataSet("sampled from " + this.getInstanceName(), seq);
    }

    protected abstract Sequence getRandomSequence(Random var1, int var2) throws WrongAlphabetException, WrongSequenceTypeException;

    @Override
    public byte getMaximalMarkovOrder() {
        return this.order;
    }

    @Override
    public NumericalResultSet getNumericalCharacteristics() throws Exception {
        return null;
    }

    @Override
    public final double getLogProbFor(Sequence sequence, int startpos, int endpos) throws NotTrainedException, Exception {
        this.check(sequence, startpos, endpos);
        return this.logProbFor(sequence, startpos, endpos);
    }

    public void train(DataSet[] data) throws Exception {
        this.train(data, new double[data.length][]);
    }

    public abstract void train(DataSet[] var1, double[][] var2) throws Exception;

    @Override
    protected void set(DGTrainSMParameterSet params, boolean trained) throws CloneNotSupportedException, NonParsableException {
        super.set(params, trained);
        this.order = (Byte)params.getParameterAt(2).getValue();
        this.powers = new int[Math.max(this.order + 1, 2)];
        this.powers[0] = 1;
        this.powers[1] = (int)this.alphabets.getAlphabetLengthAt(0);
        int i = 1;
        while (i < this.powers.length) {
            this.powers[i] = this.powers[1] * this.powers[i - 1];
            ++i;
        }
    }

    @Override
    protected void check(Sequence sequence, int startpos, int endpos) throws NotTrainedException, IllegalArgumentException {
        super.check(sequence, startpos, endpos);
        if (endpos >= sequence.getLength()) {
            throw new IllegalArgumentException("This endposition is impossible. Try: endposistion < sequence.length");
        }
    }

    protected final int chooseFromDistr(Constraint distr, int start, int end, double randNo) {
        int c = start;
        while (randNo > distr.getFreq(c) && c <= end) {
            randNo -= distr.getFreq(c++);
        }
        return c - start;
    }

    protected abstract double logProbFor(Sequence var1, int var2, int var3);

    protected HomCondProb[] cloneHomProb(HomCondProb[] p) {
        HomCondProb[] condProb = new HomCondProb[p.length];
        int i = 0;
        while (i < condProb.length) {
            condProb[i] = new HomCondProb(p[i]);
            ++i;
        }
        return condProb;
    }

    protected class HomCondProb
    extends Constraint {
        private double[] lnFreq;
        private static final String XML_TAG = "HomCondProb";

        public HomCondProb(int[] pos, int n) {
            super(pos, n);
        }

        public HomCondProb(StringBuffer xml) throws NonParsableException {
            super(xml);
        }

        public HomCondProb(HomCondProb old) {
            this(old.usedPositions, old.freq.length);
            System.arraycopy(old.freq, 0, this.freq, 0, this.freq.length);
            if (old.lnFreq != null) {
                this.lnFreq(0, this.freq.length);
            }
        }

        @Override
        public void estimate(double ess) {
            double pc = ess / (double)this.getNumberOfSpecificConstraints();
            if (this.usedPositions.length == 1) {
                this.estimateUnConditional(0, this.freq.length, pc, false);
            } else {
                int counter1 = 0;
                while (counter1 < this.freq.length) {
                    this.estimateUnConditional(counter1, counter1 + HomogeneousTrainSM.this.powers[1], pc, false);
                    counter1 += HomogeneousTrainSM.this.powers[1];
                }
            }
        }

        public double getLnFreq(int index) {
            return this.lnFreq[index];
        }

        @Override
        public int satisfiesSpecificConstraint(Sequence seq, int start) {
            int erg = 0;
            int counter = 0;
            int p = this.usedPositions.length - 1;
            while (counter < this.usedPositions.length) {
                erg += HomogeneousTrainSM.this.powers[p] * seq.discreteVal(start + this.usedPositions[counter]);
                ++counter;
                --p;
            }
            return erg;
        }

        @Override
        public String toString() {
            String erg = "";
            int i = 1;
            int l = this.usedPositions.length - 1;
            if (l > 0) {
                erg = String.valueOf(erg) + this.usedPositions[0];
                while (i < l) {
                    erg = String.valueOf(erg) + ", " + this.usedPositions[i++];
                }
                erg = String.valueOf(erg) + " -> ";
            }
            return String.valueOf(erg) + this.usedPositions[l];
        }

        public final void addAll(Sequence seq, double weight, int start, int prevIndex) {
            int l = seq.getLength();
            while (start < l) {
                int n = prevIndex = prevIndex % HomogeneousTrainSM.this.powers[HomogeneousTrainSM.this.order] * HomogeneousTrainSM.this.powers[1] + seq.discreteVal(start++);
                this.counts[n] = this.counts[n] + weight;
            }
        }

        @Override
        protected void appendAdditionalInfo(StringBuffer xml) {
        }

        @Override
        protected String getXMLTag() {
            return XML_TAG;
        }

        @Override
        protected void estimateUnConditional(int start, int end, double pc, boolean exceptionWhenNoData) {
            super.estimateUnConditional(start, end, pc, exceptionWhenNoData);
            this.lnFreq(start, end);
        }

        private void lnFreq(int start, int end) {
            if (this.lnFreq == null) {
                this.lnFreq = new double[this.freq.length];
            }
            int i = start;
            while (i < end) {
                this.lnFreq[i] = Math.log(this.freq[i]);
                ++i;
            }
        }

        @Override
        protected void extractAdditionalInfo(StringBuffer xml) throws NonParsableException {
            this.lnFreq(0, this.freq.length);
        }

        @Override
        public String getDescription(AlphabetContainer con, int i) {
            String res = null;
            int j = 0;
            while (j < this.usedPositions.length) {
                DiscreteAlphabet d = (DiscreteAlphabet)con.getAlphabetAt(this.usedPositions[j]);
                String s = "X_" + this.usedPositions[j] + "=" + d.getSymbolAt(i / HomogeneousTrainSM.this.powers[this.usedPositions.length - 1 - j]);
                res = res == null ? s : String.valueOf(s) + ", " + res;
                i %= HomogeneousTrainSM.this.powers[this.usedPositions.length - 1 - j];
                ++j;
            }
            res = res.replaceFirst(", ", " | ");
            return "P(" + res + ")";
        }
    }
}

