/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.differentiable;

import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.MultiDimensionalSequence;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import java.text.NumberFormat;
import java.util.Arrays;

public class MultiDimensionalSequenceWrapperDiffSS
extends AbstractDifferentiableSequenceScore {
    private DifferentiableSequenceScore function;
    private IntList iList;
    private DoubleList dList;
    private double[] gradient;
    private static final String XML_TAG = MultiDimensionalSequenceWrapperDiffSS.class.getSimpleName();

    public MultiDimensionalSequenceWrapperDiffSS(DifferentiableSequenceScore function) throws IllegalArgumentException, CloneNotSupportedException {
        super(function.getAlphabetContainer(), function.getLength());
        this.function = function.clone();
    }

    public MultiDimensionalSequenceWrapperDiffSS(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        xml = XMLParser.extractForTag(xml, XML_TAG);
        this.function = XMLParser.extractObjectForTags(xml, "function", DifferentiableSequenceScore.class);
        this.alphabets = this.function.getAlphabetContainer();
        this.length = this.function.getLength();
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer();
        XMLParser.appendObjectWithTags(xml, this.function, "function");
        XMLParser.addTags(xml, XML_TAG);
        return xml;
    }

    @Override
    public MultiDimensionalSequenceWrapperDiffSS clone() throws CloneNotSupportedException {
        MultiDimensionalSequenceWrapperDiffSS clone = (MultiDimensionalSequenceWrapperDiffSS)super.clone();
        clone.function = this.function.clone();
        if (this.gradient != null) {
            clone.gradient = (double[])this.gradient.clone();
            clone.iList = this.iList.clone();
            clone.dList = this.dList.clone();
        }
        return clone;
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        return this.function.getCurrentParameterValues();
    }

    @Override
    public String getInstanceName() {
        return "multidimensional wrapper of " + this.function.getInstanceName();
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        double res;
        if (seq instanceof MultiDimensionalSequence) {
            MultiDimensionalSequence mdSeq = (MultiDimensionalSequence)seq;
            int n = mdSeq.getNumberOfSequences();
            res = 0.0;
            int i = 0;
            while (i < n) {
                res += this.function.getLogScoreFor(mdSeq.getSequence(i), start);
                ++i;
            }
            res /= (double)n;
        } else {
            res = this.function.getLogScoreFor(seq, start);
        }
        return res;
    }

    private void init() {
        this.gradient = new double[this.function.getNumberOfParameters()];
        this.iList = new IntList();
        this.dList = new DoubleList();
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        double res;
        if (seq instanceof MultiDimensionalSequence) {
            MultiDimensionalSequence mdSeq = (MultiDimensionalSequence)seq;
            int n = mdSeq.getNumberOfSequences();
            res = 0.0;
            if (this.gradient == null) {
                this.init();
            }
            Arrays.fill(this.gradient, 0.0);
            int i = 0;
            while (i < n) {
                this.iList.clear();
                this.dList.clear();
                res += this.function.getLogScoreAndPartialDerivation(mdSeq.getSequence(i), start, this.iList, this.dList);
                int j = 0;
                while (j < this.iList.length()) {
                    int n2 = this.iList.get(j);
                    this.gradient[n2] = this.gradient[n2] + this.dList.get(j);
                    ++j;
                }
                ++i;
            }
            res /= (double)n;
            i = 0;
            while (i < this.gradient.length) {
                if (this.gradient[i] != 0.0) {
                    indices.add(i);
                    partialDer.add(this.gradient[i] / (double)n);
                }
                ++i;
            }
        } else {
            res = this.function.getLogScoreAndPartialDerivation(seq, start, indices, partialDer);
        }
        return res;
    }

    @Override
    public int getNumberOfParameters() {
        return this.function.getNumberOfParameters();
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        this.function.initializeFunction(index, freeParams, data, weights);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        this.function.initializeFunctionRandomly(freeParams);
    }

    @Override
    public boolean isInitialized() {
        return this.function.isInitialized();
    }

    @Override
    public void setParameters(double[] params, int start) {
        this.function.setParameters(params, start);
    }

    @Override
    public String toString(NumberFormat nf) {
        return "wrapper of " + this.function.getInstanceName() + ":\n" + this.function.toString(nf);
    }
}

