/*
 * Decompiled with CFR 0.152.
 */
package projects.tals;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.alphabets.DiscreteAlphabet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.ReferenceSequenceAnnotation;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;
import java.text.NumberFormat;
import java.util.Arrays;

public class TALgetterMixture
extends AbstractDifferentiableStatisticalModel {
    private double ess;
    private boolean isInitialized = false;
    private double[] params;
    private double[] probs;
    private double[] HyperParams;
    private double[] HyperSum;
    private int p_anz;
    private boolean p_gesamte_seq;

    public TALgetterMixture(AlphabetContainer alphabetsRVD, int length, double ess, double[] priorImp) throws Exception {
        super(alphabetsRVD, 1);
        if (alphabetsRVD.getAlphabetLengthAt(0) > 0.0) {
            this.ess = ess;
            this.HyperParams = new double[(int)alphabetsRVD.getAlphabetLengthAt(0) + 1];
            this.HyperSum = new double[(int)alphabetsRVD.getAlphabetLengthAt(0) + 1];
            Arrays.fill(this.HyperSum, ess * (double)length / alphabetsRVD.getAlphabetLengthAt(0));
            int i = 0;
            while (i < priorImp.length) {
                this.HyperParams[i] = this.HyperSum[i] * priorImp[i];
                ++i;
            }
        } else {
            throw new Exception("Alphabet wrong");
        }
        this.HyperParams[this.HyperParams.length - 1] = 0.5 * this.HyperSum[this.HyperParams.length - 1];
        this.params = new double[(int)alphabetsRVD.getAlphabetLengthAt(0) + 1];
        this.probs = new double[(int)alphabetsRVD.getAlphabetLengthAt(0) + 1];
    }

    public TALgetterMixture(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public TALgetterMixture clone() throws CloneNotSupportedException {
        TALgetterMixture clone = (TALgetterMixture)super.clone();
        clone.params = (double[])this.params.clone();
        clone.probs = (double[])this.probs.clone();
        clone.HyperSum = (double[])this.HyperSum.clone();
        clone.HyperParams = (double[])this.HyperParams.clone();
        return clone;
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        int j = 0;
        while (j < this.getNumberOfParameters() - 1) {
            int n = start++;
            grad[n] = grad[n] + (this.HyperParams[j] - this.HyperSum[j] * this.probs[j]);
            ++j;
        }
    }

    @Override
    public double getESS() {
        return this.ess;
    }

    @Override
    public double getLogNormalizationConstant() {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double getLogPriorTerm() {
        double[] logBeta = new double[this.params.length];
        double logPrior = 0.0;
        int c = 0;
        while (c < this.params.length - 1) {
            logBeta[c] = Gamma.logOfGamma(this.HyperParams[c]) + Gamma.logOfGamma(this.HyperSum[c] - this.HyperParams[c]) - Gamma.logOfGamma(this.HyperSum[c]);
            logPrior += this.HyperParams[c] * Math.log(this.probs[c]) + (this.HyperSum[c] - this.HyperParams[c]) * Math.log1p(-this.probs[c]);
            logPrior -= logBeta[c];
            ++c;
        }
        return logPrior;
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        return 0;
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        return (double[])this.params.clone();
    }

    @Override
    public String getInstanceName() {
        return "TALgetterMixture";
    }

    public double getImportance(Sequence rvds, int pos) {
        return this.probs[rvds.discreteVal(pos)];
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        ReferenceSequenceAnnotation data_anno = (ReferenceSequenceAnnotation)seq.getSequenceAnnotationByType("reference", 0);
        Sequence rvd_seq = data_anno.getReferenceSequence();
        double erg = 0.0;
        int p_ab = seq.getLength() - this.p_anz;
        erg = start >= p_ab || this.p_gesamte_seq ? (this.p_gesamte_seq ? Math.log(this.probs[rvd_seq.discreteVal(start - 1)]) + (double)(start - 1) * Math.log(this.probs[this.probs.length - 1]) : Math.log(this.probs[rvd_seq.discreteVal(start - 1)]) + (double)(start - (seq.getLength() - 1 - this.p_anz)) * Math.log(this.probs[this.probs.length - 1])) : Math.log(this.probs[rvd_seq.discreteVal(start - 1)]);
        return erg;
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        ReferenceSequenceAnnotation data_anno = (ReferenceSequenceAnnotation)seq.getSequenceAnnotationByType("reference", 0);
        Sequence rvd_seq = data_anno.getReferenceSequence();
        double logScore = 0.0;
        int index_rvd = rvd_seq.discreteVal(start - 1);
        int p_ab = seq.getLength() - this.p_anz;
        if (start >= p_ab || this.p_gesamte_seq) {
            if (this.p_gesamte_seq) {
                logScore = Math.log(this.probs[index_rvd]) + (double)(start - 1) * Math.log(this.probs[this.probs.length - 1]);
                indices.add(this.probs.length - 1);
                partialDer.add((double)(start - 1) * (1.0 - this.probs[this.probs.length - 1]));
            } else {
                logScore = Math.log(this.probs[index_rvd]) + (double)(start - (seq.getLength() - 1 - this.p_anz)) * Math.log(this.probs[this.probs.length - 1]);
                indices.add(this.probs.length - 1);
                partialDer.add((double)(start - (seq.getLength() - 1 - this.p_anz)) * (1.0 - this.probs[this.probs.length - 1]));
            }
        } else {
            logScore = Math.log(this.probs[index_rvd]);
        }
        indices.add(index_rvd);
        partialDer.add(1.0 - this.probs[index_rvd]);
        return logScore;
    }

    @Override
    public int getNumberOfParameters() {
        return this.params.length;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        this.initializeFunctionRandomly(freeParams);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        int c = 0;
        while (c < this.params.length) {
            double temp = Math.random();
            this.params[c] = Math.log(temp);
            this.probs[c] = temp / (temp + 1.0);
            ++c;
        }
        this.isInitialized = true;
    }

    @Override
    public boolean isInitialized() {
        return this.isInitialized;
    }

    @Override
    public void setParameters(double[] params, int start) {
        int j = 0;
        while (j < this.getNumberOfParameters()) {
            this.params[j] = params[start];
            this.probs[j] = Math.exp(this.params[j]) / (1.0 + Math.exp(this.params[j]));
            ++j;
            ++start;
        }
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer();
        XMLParser.appendObjectWithTags(xml, this.alphabets, "alphabets");
        XMLParser.appendObjectWithTags(xml, this.length, "length");
        XMLParser.appendObjectWithTags(xml, this.ess, "ess");
        XMLParser.appendObjectWithTags(xml, this.HyperParams, "hyperParams");
        XMLParser.appendObjectWithTags(xml, this.HyperSum, "hyperSum");
        XMLParser.appendObjectWithTags(xml, this.isInitialized, "isInitialized");
        XMLParser.appendObjectWithTags(xml, this.p_anz, "panz");
        XMLParser.appendObjectWithTags(xml, this.p_gesamte_seq, "pgesseq");
        XMLParser.appendObjectWithTags(xml, this.params, "params");
        XMLParser.appendObjectWithTags(xml, this.probs, "probs");
        XMLParser.addTags(xml, "TALMSF");
        return xml;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        xml = XMLParser.extractForTag(xml, "TALMSF");
        this.alphabets = (AlphabetContainer)XMLParser.extractObjectForTags(xml, "alphabets");
        this.length = XMLParser.extractObjectForTags(xml, "length", Integer.TYPE);
        this.ess = XMLParser.extractObjectForTags(xml, "ess", Double.TYPE);
        this.HyperParams = (double[])XMLParser.extractObjectForTags(xml, "hyperParams");
        this.HyperSum = (double[])XMLParser.extractObjectForTags(xml, "hyperSum");
        this.isInitialized = XMLParser.extractObjectForTags(xml, "isInitialized", Boolean.TYPE);
        this.p_anz = XMLParser.extractObjectForTags(xml, "panz", Integer.TYPE);
        this.p_gesamte_seq = XMLParser.extractObjectForTags(xml, "pgesseq", Boolean.TYPE);
        this.params = (double[])XMLParser.extractObjectForTags(xml, "params");
        this.probs = (double[])XMLParser.extractObjectForTags(xml, "probs");
    }

    @Override
    public String toString(NumberFormat nf) {
        StringBuffer sb = new StringBuffer();
        int i = 0;
        while (i < this.probs.length - 1) {
            sb.append(String.valueOf(this.alphabets.getSymbol(0, i)) + "\t" + nf.format(this.probs[i]) + "\n");
            ++i;
        }
        return sb.toString();
    }

    public void addAndSet(AlphabetContainer con, String[] rvds) throws WrongAlphabetException {
        DiscreteAlphabet alph = (DiscreteAlphabet)con.getAlphabetAt(0);
        double[] nProbs = new double[(int)alph.length() + 1];
        double[] nParams = new double[(int)alph.length() + 1];
        Arrays.fill(nProbs, 1.0);
        Arrays.fill(nParams, Double.POSITIVE_INFINITY);
        System.arraycopy(this.probs, 0, nProbs, 0, this.probs.length - 1);
        nProbs[nProbs.length - 1] = this.probs[this.probs.length - 1];
        System.arraycopy(this.params, 0, nParams, 0, this.params.length - 1);
        nParams[nParams.length - 1] = this.params[this.params.length - 1];
        double[] HyperParams = new double[(int)con.getAlphabetLengthAt(0) + 1];
        double[] HyperSum = new double[(int)con.getAlphabetLengthAt(0) + 1];
        Arrays.fill(HyperParams, 1.0);
        Arrays.fill(HyperSum, 1.0);
        int i = 0;
        while (i < this.HyperParams.length - 1) {
            HyperParams[i] = this.HyperParams[i];
            HyperSum[i] = this.HyperSum[i];
            ++i;
        }
        HyperParams[HyperParams.length - 1] = this.HyperParams[this.HyperParams.length - 1];
        HyperSum[HyperSum.length - 1] = this.HyperSum[this.HyperSum.length - 1];
        i = 0;
        while (i < rvds.length) {
            int idx = alph.getCode(rvds[i]);
            nProbs[idx] = 1.0;
            nParams[idx] = Double.POSITIVE_INFINITY;
            ++i;
        }
        this.params = nParams;
        this.probs = nProbs;
        this.HyperParams = HyperParams;
        this.HyperSum = HyperSum;
        this.alphabets = con;
    }
}

