/*
 * This file is part of Jstacs.
 * 
 * Jstacs is free software: you can redistribute it and/or modify it under the
 * terms of the GNU General Public License as published by the Free Software
 * Foundation, either version 3 of the License, or (at your option) any later
 * version.
 * 
 * Jstacs is distributed in the hope that it will be useful, but WITHOUT ANY
 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
 * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License along with
 * Jstacs. If not, see <http://www.gnu.org/licenses/>.
 * 
 * For more information on Jstacs, visit http://www.jstacs.de
 */

package de.jstacs.classifier.scoringFunctionBased;

import java.io.OutputStream;

import de.jstacs.NonParsableException;
import de.jstacs.NotTrainedException;
import de.jstacs.algorithms.optimization.ConstantStartDistance;
import de.jstacs.algorithms.optimization.DifferentiableFunction;
import de.jstacs.algorithms.optimization.NegativeDifferentiableFunction;
import de.jstacs.algorithms.optimization.Optimizer;
import de.jstacs.classifier.AbstractScoreBasedClassifier;
import de.jstacs.classifier.ClassDimensionException;
import de.jstacs.classifier.scoringFunctionBased.OptimizableFunction.KindOfParameter;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.data.Sample.WeightedSampleFactory;
import de.jstacs.data.Sample.WeightedSampleFactory.SortOperation;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.XMLParser;
import de.jstacs.results.CategoricalResult;
import de.jstacs.results.NumericalResult;
import de.jstacs.results.NumericalResultSet;
import de.jstacs.scoringFunctions.ScoringFunction;
import de.jstacs.utils.SafeOutputStream;

/**
 * 
 * @author Jens Keilwagen, Jan Grau
 */
public abstract class ScoreClassifier extends AbstractScoreBasedClassifier {

	/**
	 * The internally used scoring functions.
	 */
	protected ScoringFunction[] score;

	/**
	 * The parameter set for the classifier.
	 */
	protected ScoreClassifierParameterSet params;

	/**
	 * This boolean indicates whether the classifier has been optimized with the
	 * method {@link de.jstacs.classifier.AbstractClassifier#train(Sample[])} or the weighted
	 * version.
	 */
	protected boolean hasBeenOptimized;

	private double lastScore;

	/**
	 * This stream is used for comments, ... while the training, ... .
	 */
	protected SafeOutputStream sostream;

	/**
	 * The default constructor.
	 * 
	 * @param params
	 *            the parameter set for the classifier
	 * @param score
	 *            the {@link ScoringFunction}s for the classes
	 * 
	 * @throws CloneNotSupportedException
	 *             if at least one {@link ScoringFunction} could not be cloned
	 */
	public ScoreClassifier( ScoreClassifierParameterSet params, ScoringFunction... score ) throws CloneNotSupportedException {
		super( params.getAlphabet(), params.getLength(), score.length );
		int i = 0, l, len = getLength();
		AlphabetContainer con = getAlphabetContainer();
		while( i < score.length ) {
			l = score[i].getLength();
			if( ( l == 0 || l == len ) && con.checkConsistency( score[i].getAlphabetContainer() ) ) {
				//everything is okay
				i++;
			} else {
				throw new IllegalArgumentException( "Please check the length and the AlphabetContainer of the ScoringFunction with index " + i
													+ "." );
			}
		}
		this.score = ArrayHandler.clone( score );
		hasBeenOptimized = false;
		lastScore = Double.NaN;
		set( (ScoreClassifierParameterSet)params.clone() );
	}

	/**
	 * This is the constructor for the interface {@link de.jstacs.Storable}.
	 * 
	 * @param xml
	 *            the XML representation
	 * 
	 * @throws NonParsableException
	 *             if the representation could not be parsed.
	 */
	public ScoreClassifier( StringBuffer xml ) throws NonParsableException {
		super( xml );
	}

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.AbstractScoreBasedClassifier#clone()
	 */
	@Override
	public ScoreClassifier clone() throws CloneNotSupportedException {
		ScoreClassifier clone = (ScoreClassifier)super.clone();
		clone.params = (ScoreClassifierParameterSet)params.clone();
		clone.score = ArrayHandler.clone( score );
		clone.setOutputStream( this.sostream.doesNothing() ? null : SafeOutputStream.DEFAULT_STREAM );
		return clone;
	}

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.AbstractClassifier#getInstanceName()
	 */
	@Override
	public String getInstanceName() {
		return getClass().getSimpleName();
	}

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.AbstractClassifier#getClassifierAnnotation()
	 */
	@Override
	public CategoricalResult[] getClassifierAnnotation() {
		CategoricalResult[] res = new CategoricalResult[score.length + 1];
		res[0] = new CategoricalResult( "classifier", "a <b>short</b> description of the classifier", getInstanceName() );
		int i = 0;
		while( i < score.length ) {
			res[i + 1] = new CategoricalResult( "class info " + i, "some information about the class", score[i++].getInstanceName() );
		}
		return res;
	}

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.AbstractClassifier#getNumericalCharacteristics()
	 */
	@Override
	public NumericalResultSet getNumericalCharacteristics() throws Exception {

		NumericalResult[] pars = new NumericalResult[score.length + ( hasBeenOptimized ? 1 : 0 )];
		if( hasBeenOptimized ) {
			pars[0] = new NumericalResult( "Last score", "The final score after the optimization", lastScore );
		}
		for( int i = 0; i < score.length; i++ ) {
			pars[i + ( hasBeenOptimized ? 1 : 0 )] = new NumericalResult( "Number of parameters " + ( i + 1 ),
					"The number of parameters for scoring function " + ( i + 1 ) + ", -1 indicates unknown number of parameters.",
					score[i].getNumberOfParameters() );
		}
		return new NumericalResultSet( pars );
	}

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.AbstractClassifier#isTrained()
	 */
	@Override
	public boolean isTrained() {
		int i = 0;
		while( i < score.length && score[i].isInitialized() ) {
			i++;
		}
		return i == score.length;
	}

	/**
	 * This method returns <code>true</code> if the classifier has been
	 * optimized by a <code>train</code>-method.
	 * 
	 * @return <code>true</code> if the classifier has been optimized by a
	 *         <code>train</code>-method
	 */
	public boolean hasBeenOptimized() {
		return hasBeenOptimized;
	}

	/**
	 * Sets the {@link OutputStream} that is used e.g. for writing information
	 * while training. It is possible to set <code>o=null</code>, than nothing
	 * will be written.
	 * 
	 * @param o
	 *            the {@link OutputStream}
	 */
	public void setOutputStream( OutputStream o ) {
		sostream = new SafeOutputStream( o );
	}

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.AbstractClassifier#train(de.jstacs.data.Sample[], double[][])
	 */
	@Override
	public void train( Sample[] data, double[][] weights ) throws Exception {
		hasBeenOptimized = false;
		// check
		if( weights != null && data.length != weights.length ) {
			throw new IllegalArgumentException( "data and weights do not match" );
		}
		if( score.length != data.length ) {
			throw new ClassDimensionException();
		}
		if( weights == null ) {
			weights = new double[data.length][];
		}
		WeightedSampleFactory wsf;
		Sample[] reduced = new Sample[data.length];
		double[][] newWeights = new double[data.length][];
		AlphabetContainer abc = getAlphabetContainer();
		for( int l = getLength(), i = 0; i < score.length; i++ ) {
			if( weights[i] != null && data[i].getNumberOfElements() != weights[i].length ) {
				throw new IllegalArgumentException( "At least for one sample: The dimension of the sample and the weight do not match." );
			}
			if( !abc.checkConsistency( data[i].getAlphabetContainer() ) ) {
				throw new IllegalArgumentException( "At least one sample is not defined over the correct alphabets." );
			}
			if( data[i].getElementLength() != l ) {
				// throw new IllegalArgumentException( "At least one sample has not the correct length." );
				wsf = new WeightedSampleFactory( SortOperation.NO_SORT, data[i], weights[i], l );
			} else {
				wsf = new WeightedSampleFactory( SortOperation.NO_SORT, data[i], weights[i] );
			}
			reduced[i] = wsf.getSample();
			newWeights[i] = wsf.getWeights();
		}
		lastScore = doOptimization( reduced, newWeights );
	}

	/**
	 * This method does the optimization of the <code>train</code>-method
	 * 
	 * @param reduced
	 *            the samples
	 * @param newWeights
	 *            the weights
	 * 
	 * @return the value of the optimization
	 * 
	 * @throws Exception
	 *             if something went wrong while the optimization
	 */
	protected double doOptimization( Sample[] reduced, double[][] newWeights ) throws Exception {
		// train
		byte algo = (Byte)params.getParameterAt( 0 ).getValue();
		double eps = (Double)params.getParameterAt( 1 ).getValue(), linEps = (Double)params.getParameterAt( 2 ).getValue(), startDist = (Double)params.getParameterAt( 3 )
				.getValue();
		KindOfParameter plugIn = (KindOfParameter)params.getParameterAt( 5 ).getValue();
		double[] best = null;
		OptimizableFunction f = getFunction( reduced, newWeights );

		DifferentiableFunction g = new NegativeDifferentiableFunction( f );
		double max = Double.NEGATIVE_INFINITY, current;

		int iterations = f.getNumberOfStarts();
		double[] start;
		sostream.writeln( getInstanceName() );
		ScoringFunction[] bestSF = null, secure;
		if( iterations > 1 ) {
			secure = ArrayHandler.clone( score );
		} else {
			secure = null;
		}
		for( int i = 0; i < iterations; ) {
			// create structure
			createStructure( reduced, newWeights );
			f.reset( score );
			if( i == 0 ) {
				sostream.writeln( "optimizing " + f.getDimensionOfScope() + " parameters" );
			}
			sostream.writeln( "start " + ++i + " :" );
			start = f.getParameters( plugIn );
			Optimizer.optimize( algo,
					g,
					start,
					Optimizer.TerminationCondition.SMALL_DIFFERENCE_OF_FUNCTION_EVALUATIONS,
					eps,
					linEps,
					new ConstantStartDistance( startDist ),
					sostream );
			current = f.evaluateFunction( start );
			if( sostream.doesNothing() && iterations > 1 ) {
				System.out.println( "start " + i + ": " + current );
			}
			if( current > max ) {
				bestSF = score;
				score = ArrayHandler.clone( secure );
				best = start;
				max = current;
				System.gc();
			}
		}
		sostream.writeln( "best = " + max );

		score = bestSF;
		setClassWeights( false, f.getClassParams( best ) );
		hasBeenOptimized = true;
		return max;
	}

	/**
	 * Creates the structure that will be used in the optimization.
	 * 
	 * @param data
	 *            the data
	 * @param weights
	 *            the weights of the data
	 * 
	 * @throws Exception
	 *             if something went wrong
	 */
	protected void createStructure( Sample[] data, double[][] weights ) throws Exception {
		boolean freeParams = params.useOnlyFreeParameter();
		for( int i = 0; i < score.length; i++ ) {
			score[i].initializeFunction( i, freeParams, data, weights );
		}
	}

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.AbstractScoreBasedClassifier#extractFurtherClassifierInfosFromXML(java.lang.StringBuffer)
	 */
	@Override
	protected void extractFurtherClassifierInfosFromXML( StringBuffer xml ) throws NonParsableException {
		super.extractFurtherClassifierInfosFromXML( xml );
		set( (ScoreClassifierParameterSet)XMLParser.extractStorableForTag( xml, "params" ) );
		hasBeenOptimized = XMLParser.extractBooleanForTag( xml, "hasBeenOptimized" );
		lastScore = XMLParser.extractDoubleForTag( xml, "lastScore" );
		score = (ScoringFunction[])ArrayHandler.cast( XMLParser.extractStorableArrayForTag( xml, "score" ) );
	}

	/**
	 * Returns the function that should be optimized.
	 * 
	 * @param data
	 *            the samples
	 * @param weights
	 *            the weights of the sequences of the samples
	 * 
	 * @return the function that should be optimized
	 * 
	 * @throws Exception
	 *             if something went wrong
	 */
	protected abstract OptimizableFunction getFunction( Sample[] data, double[][] weights ) throws Exception;

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.AbstractScoreBasedClassifier#getFurtherClassifierInfos()
	 */
	@Override
	protected StringBuffer getFurtherClassifierInfos() {
		StringBuffer xml = super.getFurtherClassifierInfos();
		XMLParser.appendStorableWithTags( xml, params, "params" );
		XMLParser.appendBooleanWithTags( xml, hasBeenOptimized, "hasBeenOptimized" );
		XMLParser.appendDoubleWithTags( xml, lastScore, "lastScore" );
		XMLParser.appendStorableArrayWithTags( xml, score, "score" );
		return xml;
	}

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.AbstractScoreBasedClassifier#getScore(de.jstacs.data.Sequence, int, boolean)
	 */
	@Override
	protected double getScore( Sequence seq, int i, boolean check ) throws IllegalArgumentException, NotTrainedException, Exception {
		if( check ) {
			check( seq );
		}
		return getClassWeight( i ) + score[i].getLogScore( seq, 0 );
	}

	/**
	 * Returns the score that was computed in the last optimization of the
	 * parameters.
	 * 
	 * @return score from the last parameter optimization
	 */
	public double getLastScore() {
		return lastScore;
	}

	/**
	 * Returns the internally used {@link ScoringFunction} with index
	 * <code>i</code>.
	 * 
	 * @param i
	 *            the internal index of the {@link ScoringFunction}
	 * 
	 * @return the internally used {@link ScoringFunction} with index
	 *         <code>i</code>.
	 * 
	 * @throws CloneNotSupportedException
	 *             if the {@link ScoringFunction} could not be cloned
	 */
	public ScoringFunction getScoringFunction( int i ) throws CloneNotSupportedException {
		return score[i].clone();
	}

	/**
	 * Returns all internally used {@link ScoringFunction}s in the internal
	 * order.
	 * 
	 * @return the internally used S{@link ScoringFunction}s in the internal
	 *         order.
	 * 
	 * @throws CloneNotSupportedException
	 *             if a {@link ScoringFunction} could not be cloned
	 */
	public ScoringFunction[] getScoringFunctions() throws CloneNotSupportedException {
		return ArrayHandler.clone( score );
	}

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.AbstractClassifier#getXMLTag()
	 */
	@Override
	protected abstract String getXMLTag();

	private void set( ScoreClassifierParameterSet params ) {
		this.params = params;
		setOutputStream( SafeOutputStream.DEFAULT_STREAM );
	}
}
