/*
 * This file is part of Jstacs.
 *
 * Jstacs is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * Jstacs is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Jstacs.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * For more information on Jstacs, visit http://www.jstacs.de
 */

package de.jstacs.models.discrete.inhomogeneous.shared;

import de.jstacs.NonParsableException;
import de.jstacs.algorithms.graphs.tensor.SymmetricTensor;
import de.jstacs.classifier.ClassDimensionException;
import de.jstacs.classifier.modelBased.ModelBasedClassifier;
import de.jstacs.data.Sample;
import de.jstacs.io.XMLParser;
import de.jstacs.models.AbstractModel;
import de.jstacs.models.discrete.inhomogeneous.FSDAGModel;
import de.jstacs.models.discrete.inhomogeneous.StructureLearner;
import de.jstacs.models.discrete.inhomogeneous.StructureLearner.LearningType;
import de.jstacs.models.discrete.inhomogeneous.StructureLearner.ModelType;
import de.jstacs.models.discrete.inhomogeneous.parameters.BayesianNetworkModelParameterSet;
import de.jstacs.results.CategoricalResult;

/**
 * This class enables you to learn the structure on all classes together. A special case is for instance a TAN.
 * 
 * @author Jens Keilwagen
 */
public class SharedStructureClassifier extends ModelBasedClassifier
{
	private ModelType model;
	private byte order;
	private LearningType method;
	private StructureLearner sl;
	
	/**
	 * The main constructor.
	 * 
	 * @param length the sequence length
	 * @param model the model type
	 * @param order the model order
	 * @param method the learning method
	 * @param models the class models
	 *  
	 * @throws IllegalArgumentException if order is below 0
	 * @throws CloneNotSupportedException if at least one model could not be cloned
	 * @throws ClassDimensionException if the class dimension is wrong (below 2)
	 */
	public SharedStructureClassifier( int length, ModelType model, byte order, LearningType method, FSDAGModel... models ) throws IllegalArgumentException,
			CloneNotSupportedException, ClassDimensionException
	{
		super( true, (AbstractModel[]) models );
		this.model = model;
		if( order < 0 )
		{
			throw new IllegalArgumentException( "The value of order has to be non-negative." );
		}
		this.order = order;
		this.method = method;
		sl = new StructureLearner( getAlphabetContainer(), length );
	}

	/**
	 * The constructor for the {@link de.jstacs.Storable} interface.
	 * 
	 * @param xml
	 *            the StringBuffer
	 * 
	 * @throws NonParsableException
	 *             if the StringBuffer is not parsable
	 */
	public SharedStructureClassifier( StringBuffer xml ) throws NonParsableException
	{
		super( xml );
	}
	
	public SharedStructureClassifier clone() throws CloneNotSupportedException
	{
		SharedStructureClassifier clone = (SharedStructureClassifier) super.clone();
		clone.sl = new StructureLearner( getAlphabetContainer(), getLength() );
		return clone;
	}

	public void train( Sample[] data, double[][] weights ) throws IllegalArgumentException, Exception
	{
		int dimension = models.length;
		SymmetricTensor[] parts = new SymmetricTensor[dimension];
		double[] w = new double[dimension];
		for( int i = 0; i < dimension; i++ )
		{
			sl.setESS( ((FSDAGModel)models[i]).getESS() );
			parts[i] = sl.getTensor( data[i], weights[i], order, method );
			w[i] = 1d;
		}
		FSDAGModel.train( models, StructureLearner.getStructure( new SymmetricTensor( parts, w ), model, order ), weights, data );
	}
	
	public String getInstanceName()
	{
		return "shared-structure classifier";
	}
	
	protected void extractFurtherClassifierInfosFromXML( StringBuffer xml ) throws NonParsableException
	{
		super.extractFurtherClassifierInfosFromXML( xml );
		model = XMLParser.extractEnumForTag( xml, "model" );
		order = XMLParser.extractByteForTag( xml, "order" );
		method = XMLParser.extractEnumForTag( xml, "method" );
		sl = new StructureLearner( getAlphabetContainer(), getLength() );
	}
	
	protected StringBuffer getFurtherClassifierInfos( )
	{
		StringBuffer xml = super.getFurtherClassifierInfos( );
		XMLParser.appendEnumWithTags( xml, model, "model" );
		XMLParser.appendByteWithTags( xml, order, "order" );
		XMLParser.appendEnumWithTags( xml, method, "method" );
		return xml;
	}
	
	public CategoricalResult[] getClassifierAnnotation()
	{
		CategoricalResult[] res = new CategoricalResult[models.length+1];
		res[0] = new CategoricalResult( "classifier", "a <b>short</b> description of the classifier", getInstanceName() );
		int i = 0;
		while( i < models.length )
		{
			res[i + 1] = new CategoricalResult( "class info " + i, "some information about the class", BayesianNetworkModelParameterSet.getModelInstanceName( model, order, method, ((FSDAGModel)models[i++]).getESS() ) );
		}
		return res;
	}
}
