/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous;

import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.InhConstraint;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.SequenceIterator;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.FastDirichletMRGParams;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;

public class MEMConstraint
extends InhConstraint {
    private double[] expLambda;
    private double[] lambda;
    private int[] corrected_positions;
    private static String XML_TAG = "MEMConstraint";

    private static int[] isSorted(int[] pos) throws IllegalArgumentException {
        int i;
        for (i = 1; i < pos.length && pos[i - 1] < pos[i]; ++i) {
        }
        if (i < pos.length) {
            throw new IllegalArgumentException("The position array is not unique.");
        }
        return pos;
    }

    public MEMConstraint(int[] pos, int[] alphabetLength) throws IllegalArgumentException {
        super(MEMConstraint.isSorted(pos), alphabetLength);
        this.expLambda = new double[this.counts.length];
        Arrays.fill(this.expLambda, 1.0);
        this.lambda = new double[this.counts.length];
        this.corrected_positions = this.usedPositions;
    }

    public MEMConstraint(int[] pos, int[] alphabetLength, int[] corrected_positions) throws IllegalArgumentException {
        super(MEMConstraint.isSorted(pos), alphabetLength);
        if (pos.length != corrected_positions.length) {
            throw new IllegalArgumentException("The length of pos and corrected_positions is not equal.");
        }
        this.expLambda = new double[this.counts.length];
        this.lambda = new double[this.counts.length];
        this.corrected_positions = (int[])corrected_positions.clone();
    }

    public MEMConstraint(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public MEMConstraint clone() throws CloneNotSupportedException {
        MEMConstraint clone = (MEMConstraint)super.clone();
        clone.expLambda = (double[])this.expLambda.clone();
        clone.lambda = (double[])this.lambda.clone();
        clone.corrected_positions = this.corrected_positions == this.usedPositions ? clone.usedPositions : (int[])this.corrected_positions.clone();
        return clone;
    }

    @Override
    public void estimate(double ess) {
        this.estimateUnConditional(0, this.freq.length, ess / (double)this.freq.length, true);
    }

    public void draw(double ess) {
        FastDirichletMRGParams alpha = new FastDirichletMRGParams(ess / (double)this.freq.length);
        DirichletMRG.DEFAULT_INSTANCE.generate(this.expLambda, 0, this.expLambda.length, alpha);
        for (int i = 0; i < this.expLambda.length; ++i) {
            this.lambda[i] = Math.log(this.expLambda[i]);
        }
    }

    public int getCorrectedPosition(int index) {
        return this.corrected_positions[index];
    }

    public double getExpLambda(int index) {
        return this.expLambda[index];
    }

    public double getLambda(int index) {
        return this.lambda[index];
    }

    public void multiplyExpLambdaWith(int index, double val) {
        int n = index;
        this.expLambda[n] = this.expLambda[n] * val;
        int n2 = index;
        this.lambda[n2] = this.lambda[n2] + Math.log(val);
    }

    @Override
    public void reset() {
        super.reset();
        Arrays.fill(this.expLambda, 1.0);
        Arrays.fill(this.lambda, 0.0);
    }

    public int satisfiesSpecificConstraint(SequenceIterator sequence) {
        int erg = 0;
        for (int i = 0; i < this.corrected_positions.length; ++i) {
            erg += this.offset[i] * sequence.seq[this.corrected_positions[i]];
        }
        return erg;
    }

    @Override
    public double getFreq(int index) {
        return this.freq[index];
    }

    public void setExpLambda(int index, double val) {
        this.expLambda[index] = val;
        this.lambda[index] = Math.log(val);
    }

    public void setLambda(int index, double val) {
        this.expLambda[index] = Math.exp(val);
        this.lambda[index] = val;
    }

    @Override
    public String toString() {
        String erg = "" + this.usedPositions[0];
        for (int i = 1; i < this.usedPositions.length; ++i) {
            erg = erg + ", " + this.usedPositions[i];
        }
        return erg;
    }

    @Override
    protected void appendAdditionalInfo(StringBuffer xml) {
        super.appendAdditionalInfo(xml);
        XMLParser.appendObjectWithTags(xml, this.lambda, "lambda");
        if (this.corrected_positions != this.usedPositions) {
            StringBuffer b = new StringBuffer(500);
            XMLParser.appendObjectWithTags(b, this.corrected_positions, "corrected_positions");
            XMLParser.addTags(b, "corrected");
            xml.append(b);
        }
    }

    @Override
    protected String getXMLTag() {
        return XML_TAG;
    }

    @Override
    protected void extractAdditionalInfo(StringBuffer xml) throws NonParsableException {
        super.extractAdditionalInfo(xml);
        this.lambda = XMLParser.extractObjectForTags(xml, "lambda", double[].class);
        this.expLambda = new double[this.lambda.length];
        for (int i = 0; i < this.lambda.length; ++i) {
            this.expLambda[i] = Math.exp(this.lambda[i]);
        }
        StringBuffer corrected = XMLParser.extractForTag(xml, "corrected");
        this.corrected_positions = corrected == null ? this.usedPositions : XMLParser.extractObjectForTags(corrected, "corrected_positions", int[].class);
    }

    public int comparePosition(int offset, MEMConstraint constr) {
        int matching = 0;
        int u2 = 0;
        for (int u1 = 0; u1 < this.usedPositions.length; ++u1) {
            while (u2 < constr.usedPositions.length && constr.usedPositions[u2] - offset < this.usedPositions[u1]) {
                ++u2;
            }
            if (u2 >= constr.usedPositions.length || constr.usedPositions[u2] - offset != this.usedPositions[u1]) continue;
            ++matching;
        }
        return matching;
    }

    public void addParameters(int offset, IntList list, MEMConstraint[] constraint, double[] params, int[] start) {
        int idx;
        int n;
        int u;
        HashSet<Integer> hash = new HashSet<Integer>();
        for (u = 0; u < this.usedPositions.length; ++u) {
            hash.add(this.usedPositions[u]);
        }
        for (n = 0; n < list.length(); ++n) {
            idx = list.get(n);
            for (u = 0; u < constraint[idx].usedPositions.length; ++u) {
                hash.add(constraint[idx].usedPositions[u] - offset);
            }
        }
        int[] pos = new int[hash.size()];
        Iterator it = hash.iterator();
        u = 0;
        while (it.hasNext()) {
            pos[u] = (Integer)it.next();
            ++u;
        }
        Arrays.sort(pos);
        Arrays.fill(this.expLambda, Double.NEGATIVE_INFINITY);
        int alphLength = this.lambda.length / this.offset[0] - 1;
        int[] assignment = new int[pos.length + 1];
        while (assignment[pos.length] == 0) {
            int i;
            int index;
            double p = 0.0;
            for (n = 0; n < list.length(); ++n) {
                idx = list.get(n);
                index = 0;
                i = 0;
                for (u = 0; u < pos.length && i < constraint[idx].usedPositions.length; ++u) {
                    if (pos[u] != constraint[idx].usedPositions[i] - offset) continue;
                    index += constraint[idx].offset[i] * assignment[u];
                    ++i;
                }
                p += params[start[idx] + index];
            }
            index = 0;
            i = 0;
            for (u = 0; u < pos.length && i < this.usedPositions.length; ++u) {
                if (pos[u] != this.usedPositions[i]) continue;
                index += this.offset[i] * assignment[u];
                ++i;
            }
            this.expLambda[index] = Normalisation.getLogSum(this.expLambda[index], p);
            for (u = 0; u < pos.length && assignment[u] == alphLength; ++u) {
                assignment[u] = 0;
            }
            int n2 = u;
            assignment[n2] = assignment[n2] + 1;
        }
        for (u = 0; u < this.lambda.length; ++u) {
            int n3 = u;
            this.lambda[n3] = this.lambda[n3] + this.expLambda[u];
            this.expLambda[u] = Math.exp(this.lambda[u]);
        }
    }
}

