/*
 * Decompiled with CFR 0.152.
 */
package de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.shared;

import de.jstacs.NotTrainedException;
import de.jstacs.algorithms.graphs.tensor.SymmetricTensor;
import de.jstacs.algorithms.optimization.termination.TerminationCondition;
import de.jstacs.data.WrongAlphabetException;
import de.jstacs.io.NonParsableException;
import de.jstacs.io.XMLParser;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.DiscreteGraphicalTrainSM;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.FSDAGTrainSM;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.StructureLearner;
import de.jstacs.sequenceScores.statisticalModels.trainable.discrete.inhomogeneous.parameters.IDGTrainSMParameterSet;
import de.jstacs.sequenceScores.statisticalModels.trainable.mixture.AbstractMixtureTrainSM;
import de.jstacs.sequenceScores.statisticalModels.trainable.mixture.MixtureTrainSM;

public class SharedStructureMixture
extends MixtureTrainSM {
    private static final long serialVersionUID = -9019277262448497915L;
    private StructureLearner sl;
    private StructureLearner.ModelType modelType;
    private byte order;
    private StructureLearner.LearningType method;
    private static final String XML_TAG = "SharedStructureMixture";

    private static double[] getClassHyperParams(DiscreteGraphicalTrainSM[] m) {
        double[] erg = new double[m.length];
        for (int i = 0; i < erg.length; ++i) {
            erg[i] = m[i].getESS();
        }
        return erg;
    }

    public SharedStructureMixture(FSDAGTrainSM[] m, StructureLearner.ModelType model, byte order, int starts, double alpha, TerminationCondition tc) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException {
        this(m, model, order, starts, true, null, alpha, tc);
    }

    public SharedStructureMixture(FSDAGTrainSM[] m, StructureLearner.ModelType model, byte order, int starts, double[] weights, double alpha, TerminationCondition tc) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException {
        this(m, model, order, starts, false, weights, alpha, tc);
    }

    protected SharedStructureMixture(FSDAGTrainSM[] m, StructureLearner.ModelType model, byte order, int starts, boolean estimateComponentProbs, double[] weights, double alpha, TerminationCondition tc) throws IllegalArgumentException, WrongAlphabetException, CloneNotSupportedException {
        super(m[0].getLength(), m, starts, estimateComponentProbs, SharedStructureMixture.getClassHyperParams(m), weights, AbstractMixtureTrainSM.Algorithm.EM, alpha, tc, AbstractMixtureTrainSM.Parameterization.LAMBDA, 0, 0, null);
        this.sl = new StructureLearner(m[0].getAlphabetContainer(), this.getLength());
        this.modelType = model;
        if (order < 0) {
            throw new IllegalArgumentException("The value of order has to be non-negative.");
        }
        this.order = order;
        this.method = StructureLearner.LearningType.ML_OR_MAP;
    }

    public SharedStructureMixture(StringBuffer xml) throws NonParsableException {
        super(xml);
    }

    @Override
    public SharedStructureMixture clone() throws CloneNotSupportedException {
        SharedStructureMixture clone = (SharedStructureMixture)super.clone();
        clone.sl = new StructureLearner(this.alphabets, this.length);
        return clone;
    }

    public String getStructure() throws NotTrainedException {
        return ((FSDAGTrainSM)this.model[0]).getStructure();
    }

    @Override
    public String getInstanceName() {
        return "ssMixModel(" + this.dimension + " " + IDGTrainSMParameterSet.getModelInstanceName(this.modelType, this.order, this.method, ((FSDAGTrainSM)this.model[0]).getESS()) + ")";
    }

    @Override
    public StringBuffer toXML() {
        StringBuffer xml = new StringBuffer(100000);
        XMLParser.appendObjectWithTags(xml, (Object)this.modelType, "model");
        XMLParser.appendObjectWithTags(xml, this.order, "order");
        XMLParser.appendObjectWithTags(xml, (Object)this.method, "method");
        xml.append(super.toXML());
        XMLParser.addTags(xml, XML_TAG);
        return xml;
    }

    @Override
    protected void fromXML(StringBuffer representation) throws NonParsableException {
        StringBuffer xml = XMLParser.extractForTag(representation, XML_TAG);
        this.modelType = XMLParser.extractObjectForTags(xml, "model", StructureLearner.ModelType.class);
        this.order = XMLParser.extractObjectForTags(xml, "order", Byte.TYPE);
        this.method = XMLParser.extractObjectForTags(xml, "method", StructureLearner.LearningType.class);
        this.sl = new StructureLearner(this.getAlphabetContainer(), this.getLength());
        super.fromXML(xml);
    }

    @Override
    protected void getNewParameters(int iteration, double[][] seqWeights, double[] w) throws Exception {
        SymmetricTensor[] parts = new SymmetricTensor[this.dimension];
        double[] x = new double[this.dimension];
        for (int i = 0; i < this.dimension; ++i) {
            this.sl.setESS(((FSDAGTrainSM)this.model[i]).getESS());
            parts[i] = this.sl.getTensor(this.sample[0], seqWeights[i], this.order, this.method);
            x[i] = 1.0;
        }
        FSDAGTrainSM.train(this.model, StructureLearner.getStructure(new SymmetricTensor(parts, x), this.modelType, this.order), seqWeights, this.sample[0]);
        this.getNewComponentProbs(w);
    }
}

