/*
 * Decompiled with CFR 0.152.
 */
package projects.tals;

import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.DataSet;
import de.jstacs.data.DiscreteSequenceEnumerator;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.data.alphabets.Alphabet;
import de.jstacs.data.alphabets.DiscreteAlphabet;
import de.jstacs.data.alphabets.DoubleSymbolException;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.data.sequences.annotation.ReferenceSequenceAnnotation;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.differentiable.AbstractDifferentiableStatisticalModel;
import de.jstacs.sequenceScores.statisticalModels.differentiable.homogeneous.HomogeneousMMDiffSM;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import java.text.NumberFormat;
import java.util.Arrays;

public class TALgetterRVDDependentComponent
extends AbstractDifferentiableStatisticalModel {
    protected AlphabetContainer alphabetsRVD;
    protected double ess;
    protected int priorLength;
    protected HomogeneousMMDiffSM[] hmm_c;
    private boolean isInitialized = false;

    public TALgetterRVDDependentComponent(AlphabetContainer alphabets, AlphabetContainer alphabetsRVD, int length, double ess, double[] priorImp, double[][] priorPrefs) throws Exception {
        super(alphabets, 1);
        if (alphabets.getAlphabetLengthAt(0) > 0.0 && alphabetsRVD.getAlphabetLengthAt(0) > 0.0) {
            this.alphabetsRVD = alphabetsRVD;
            this.ess = ess;
            this.priorLength = length;
            double norm = 0.0;
            int i = 0;
            while (i < priorImp.length) {
                norm += priorImp[i];
                ++i;
            }
            this.hmm_c = new HomogeneousMMDiffSM[this.getNumberOfSymbols(alphabetsRVD)];
            double[][][] hypi = new double[this.hmm_c.length][1][(int)alphabets.getAlphabetLengthAt(0)];
            double[] priorImSum = new double[this.hmm_c.length];
            int c = 0;
            while ((double)c < alphabetsRVD.getAlphabetLengthAt(0)) {
                int i2 = 0;
                while (i2 < hypi[this.getMappedIndex(alphabetsRVD, c)][0].length) {
                    double[] dArray = hypi[this.getMappedIndex(alphabetsRVD, c)][0];
                    int n = i2;
                    dArray[n] = dArray[n] + priorPrefs[c][i2] * ess * priorImp[c] / norm * (double)length;
                    ++i2;
                }
                int n = this.getMappedIndex(alphabetsRVD, c);
                priorImSum[n] = priorImSum[n] + priorImp[c];
                ++c;
            }
            c = 0;
            while (c < this.hmm_c.length) {
                this.hmm_c[c] = new HomogeneousMMDiffSM(alphabets, 0, ess * priorImSum[c] / norm, hypi[c], true, true, 1);
                ++c;
            }
        } else {
            throw new Exception("Alphabet ist nicht OK!");
        }
    }

    public TALgetterRVDDependentComponent(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public TALgetterRVDDependentComponent clone() throws CloneNotSupportedException {
        TALgetterRVDDependentComponent clone = (TALgetterRVDDependentComponent)super.clone();
        clone.hmm_c = (HomogeneousMMDiffSM[])ArrayHandler.clone((Cloneable[])this.hmm_c);
        return clone;
    }

    protected int getNumberOfSymbols(AlphabetContainer con) {
        return (int)con.getAlphabetLengthAt(0);
    }

    protected int getMappedIndex(AlphabetContainer con, int original) {
        return original;
    }

    protected int getIndex(Sequence seq, int pos) {
        return this.getMappedIndex(seq.getAlphabetContainer(), seq.discreteVal(pos));
    }

    @Override
    public void addGradientOfLogPriorTerm(double[] grad, int start) throws Exception {
        int c = 0;
        while (c < this.hmm_c.length) {
            this.hmm_c[c].addGradientOfLogPriorTerm(grad, start);
            start += this.hmm_c[c].getNumberOfParameters();
            ++c;
        }
    }

    @Override
    public double getESS() {
        return this.ess;
    }

    @Override
    public double getInitialClassParam(double classProb) {
        return 0.0;
    }

    @Override
    public double getLogNormalizationConstant() {
        return 0.0;
    }

    @Override
    public double getLogPartialNormalizationConstant(int parameterIndex) throws Exception {
        return Double.NEGATIVE_INFINITY;
    }

    @Override
    public double getLogPriorTerm() {
        double logPrior = 0.0;
        int c = 0;
        while (c < this.hmm_c.length) {
            logPrior += this.hmm_c[c].getLogPriorTerm();
            ++c;
        }
        return logPrior;
    }

    @Override
    public int getSizeOfEventSpaceForRandomVariablesOfParameter(int index) {
        return 0;
    }

    @Override
    public boolean isNormalized() {
        return true;
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        double[] params = new double[this.getNumberOfParameters()];
        int off = 0;
        int c = 0;
        while (c < this.hmm_c.length) {
            System.arraycopy(this.hmm_c[c].getCurrentParameterValues(), 0, params, off, this.hmm_c[c].getNumberOfParameters());
            off += this.hmm_c[c].getNumberOfParameters();
            ++c;
        }
        return params;
    }

    @Override
    public String getInstanceName() {
        return "TALgetterRVDDependentComponent";
    }

    @Override
    public int getLength() {
        return 1;
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        ReferenceSequenceAnnotation data_anno = (ReferenceSequenceAnnotation)seq.getSequenceAnnotationByType("reference", 0);
        Sequence rvd_seq = data_anno.getReferenceSequence();
        double erg = 0.0;
        erg = this.hmm_c[this.getIndex(rvd_seq, start - 1)].getLogScoreFor(seq, start, start);
        return erg;
    }

    public double[] getSpecificities(Sequence rvds, int pos) {
        int rvd = this.getIndex(rvds, pos);
        double[] spec = new double[(int)this.hmm_c[rvd].getAlphabetContainer().getAlphabetLengthAt(0)];
        DiscreteSequenceEnumerator dse = new DiscreteSequenceEnumerator(this.hmm_c[rvd].getAlphabetContainer(), 1, false);
        int i = 0;
        while (dse.hasMoreElements()) {
            spec[i] = this.hmm_c[rvd].getLogScoreFor((Sequence)dse.nextElement());
            ++i;
        }
        Normalisation.logSumNormalisation(spec);
        return spec;
    }

    public AlphabetContainer addAndSet(String[] rvds, double[][] specs) throws WrongAlphabetException, IllegalArgumentException, DoubleSymbolException {
        int j;
        DiscreteAlphabet alph = (DiscreteAlphabet)this.alphabetsRVD.getAlphabetAt(0);
        int[] map = new int[rvds.length];
        Arrays.fill(map, -1);
        int n = 0;
        int i = 0;
        while (i < rvds.length) {
            if (alph.isSymbol(rvds[i])) {
                map[i] = alph.getCode(rvds[i]);
            } else {
                ++n;
            }
            ++i;
        }
        if (n > 0) {
            String[] newAlph = new String[(int)alph.length() + n];
            int i2 = 0;
            while ((double)i2 < alph.length()) {
                newAlph[i2] = alph.getSymbolAt(i2);
                ++i2;
            }
            HomogeneousMMDiffSM[] temp_c = new HomogeneousMMDiffSM[newAlph.length];
            System.arraycopy(this.hmm_c, 0, temp_c, 0, this.hmm_c.length);
            j = 0;
            while (j < map.length) {
                if (map[j] == -1) {
                    newAlph[i2] = rvds[j];
                    temp_c[i2] = new HomogeneousMMDiffSM(this.alphabets, 0, this.ess / (double)temp_c.length, this.priorLength);
                    temp_c[i2].initializeFunctionRandomly(false);
                    ++i2;
                }
                ++j;
            }
            this.hmm_c = temp_c;
            alph = new DiscreteAlphabet(true, newAlph);
            this.alphabetsRVD = new AlphabetContainer((Alphabet)alph);
        }
        i = 0;
        while (i < rvds.length) {
            int idx = alph.getCode(rvds[i]);
            double[] temp = (double[])specs[i].clone();
            j = 0;
            while (j < temp.length) {
                temp[j] = Math.log(temp[j]);
                ++j;
            }
            this.hmm_c[idx].setParameters(temp, 0);
            ++i;
        }
        return this.alphabetsRVD;
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        ReferenceSequenceAnnotation data_anno = (ReferenceSequenceAnnotation)seq.getSequenceAnnotationByType("reference", 0);
        Sequence rvd_seq = data_anno.getReferenceSequence();
        int alph_length = (int)this.alphabets.getAlphabetLengthAt(0);
        double logScore = 0.0;
        int index_rvd = this.getIndex(rvd_seq, start - 1);
        IntList temp = new IntList();
        logScore = this.hmm_c[index_rvd].getLogScoreAndPartialDerivation(seq, start, start, temp, partialDer);
        int i = 0;
        while (i < temp.length()) {
            indices.add(temp.get(i) + alph_length * index_rvd);
            ++i;
        }
        return logScore;
    }

    @Override
    public int getNumberOfParameters() {
        return this.hmm_c[0].getNumberOfParameters() * this.hmm_c.length;
    }

    @Override
    public int getNumberOfRecommendedStarts() {
        return 1;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        this.initializeFunctionRandomly(freeParams);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        int c = 0;
        while (c < this.hmm_c.length) {
            this.hmm_c[c].initializeFunctionRandomly(freeParams);
            ++c;
        }
        this.isInitialized = true;
    }

    @Override
    public boolean isInitialized() {
        return this.isInitialized;
    }

    @Override
    public void setParameters(double[] params, int start) {
        int c = 0;
        while (c < this.hmm_c.length) {
            this.hmm_c[c].setParameters(params, start);
            start += this.hmm_c[c].getNumberOfParameters();
            ++c;
        }
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer();
        XMLParser.appendObjectWithTags(xml, this.alphabetsRVD, "alphabetsRVD");
        XMLParser.appendObjectWithTags(xml, this.alphabets, "alphabets");
        XMLParser.appendObjectWithTags(xml, this.length, "length");
        XMLParser.appendObjectWithTags(xml, this.ess, "ess");
        XMLParser.appendObjectWithTags(xml, this.hmm_c, "hmmc");
        XMLParser.appendObjectWithTags(xml, this.isInitialized, "isInitialized");
        XMLParser.addTags(xml, "TALANSF");
        return xml;
    }

    @Override
    public String toString(NumberFormat nf) {
        StringBuffer sb = new StringBuffer();
        int i = 0;
        while (i < this.hmm_c.length) {
            sb.append(String.valueOf(this.alphabetsRVD.getSymbol(0, i)) + "\t");
            sb.append(String.valueOf(this.hmm_c[i].toString(nf)) + "\n");
            ++i;
        }
        return sb.toString();
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        xml = XMLParser.extractForTag(xml, "TALANSF");
        this.alphabetsRVD = (AlphabetContainer)XMLParser.extractObjectForTags(xml, "alphabetsRVD");
        this.alphabets = (AlphabetContainer)XMLParser.extractObjectForTags(xml, "alphabets");
        this.length = XMLParser.extractObjectForTags(xml, "length", Integer.TYPE);
        this.ess = XMLParser.extractObjectForTags(xml, "ess", Double.TYPE);
        this.hmm_c = (HomogeneousMMDiffSM[])XMLParser.extractObjectForTags(xml, "hmmc");
        this.isInitialized = XMLParser.extractObjectForTags(xml, "isInitialized", Boolean.TYPE);
    }
}

