/*
 * Decompiled with CFR 0.152.
 */
package projects.taleningner;

import de.jstacs.data.DataSet;
import de.jstacs.data.sequences.Sequence;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.differentiable.AbstractDifferentiableSequenceScore;
import de.jstacs.sequenceScores.differentiable.DifferentiableSequenceScore;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import java.text.NumberFormat;

public class DifferenceDiffSS
extends AbstractDifferentiableSequenceScore {
    private DifferentiableSequenceScore pos;
    private DifferentiableSequenceScore neg;

    public DifferenceDiffSS(DifferentiableSequenceScore pos, DifferentiableSequenceScore neg) {
        super(pos.getAlphabetContainer(), pos.getLength());
        if (!pos.getAlphabetContainer().checkConsistency(neg.getAlphabetContainer())) {
            throw new RuntimeException();
        }
        if (pos.getLength() != neg.getLength() && pos.getLength() > 0 && neg.getLength() > 0) {
            throw new RuntimeException();
        }
        this.pos = pos;
        this.neg = neg;
    }

    public DifferenceDiffSS(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public DifferenceDiffSS clone() throws CloneNotSupportedException {
        DifferenceDiffSS clone = (DifferenceDiffSS)super.clone();
        clone.pos = this.pos.clone();
        clone.neg = this.neg.clone();
        return clone;
    }

    @Override
    public void initializeFunction(int index, boolean freeParams, DataSet[] data, double[][] weights) throws Exception {
        if (data.length != 2) {
            throw new RuntimeException();
        }
        this.pos.initializeFunction(index, freeParams, data, weights);
        this.neg.initializeFunction(1 - index, freeParams, data, weights);
    }

    @Override
    public void initializeFunctionRandomly(boolean freeParams) throws Exception {
        this.pos.initializeFunctionRandomly(freeParams);
        this.neg.initializeFunctionRandomly(freeParams);
    }

    @Override
    public double getLogScoreAndPartialDerivation(Sequence seq, int start, IntList indices, DoubleList partialDer) {
        double score = this.pos.getLogScoreAndPartialDerivation(seq, start, indices, partialDer);
        int off = partialDer.length();
        IntList temp = new IntList();
        score -= this.neg.getLogScoreAndPartialDerivation(seq, start, temp, partialDer);
        partialDer.multiply(off, partialDer.length(), -1.0);
        off = this.pos.getNumberOfParameters();
        int i = 0;
        while (i < temp.length()) {
            indices.add(temp.get(i) + off);
            ++i;
        }
        return score;
    }

    @Override
    public int getNumberOfParameters() {
        return this.pos.getNumberOfParameters() + this.neg.getNumberOfParameters();
    }

    @Override
    public double[] getCurrentParameterValues() throws Exception {
        double[] temp = new double[this.getNumberOfParameters()];
        System.arraycopy(this.pos.getCurrentParameterValues(), 0, temp, 0, this.pos.getNumberOfParameters());
        System.arraycopy(this.neg.getCurrentParameterValues(), 0, temp, this.pos.getNumberOfParameters(), this.neg.getNumberOfParameters());
        return temp;
    }

    @Override
    public void setParameters(double[] params, int start) {
        this.pos.setParameters(params, start);
        this.neg.setParameters(params, start + this.pos.getNumberOfParameters());
    }

    @Override
    public String getInstanceName() {
        return String.valueOf(this.pos.getInstanceName()) + " - " + this.neg.getInstanceName();
    }

    @Override
    public double getLogScoreFor(Sequence seq, int start) {
        return this.pos.getLogScoreFor(seq, start) - this.neg.getLogScoreFor(seq, start);
    }

    @Override
    public boolean isInitialized() {
        return this.pos.isInitialized() && this.neg.isInitialized();
    }

    @Override
    public String toString(NumberFormat nf) {
        return String.valueOf(this.pos.toString(nf)) + "\n-\n" + this.neg.toString(nf);
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer sb = new StringBuffer();
        XMLParser.appendObjectWithTags(sb, this.pos, "pos");
        XMLParser.appendObjectWithTags(sb, this.neg, "neg");
        XMLParser.addTags(sb, "DifferenceDiffSS");
        return sb;
    }

    @Override
    protected void fromXML(StringBuffer xml) throws NonParsableException {
        xml = XMLParser.extractForTag(xml, "DifferenceDiffSS");
        this.pos = (DifferentiableSequenceScore)XMLParser.extractObjectForTags(xml, "pos");
        this.neg = (DifferentiableSequenceScore)XMLParser.extractObjectForTags(xml, "neg");
        this.alphabets = this.pos.getAlphabetContainer();
        this.length = this.pos.getLength();
    }
}

