/*
 * This file is part of Jstacs.
 * 
 * Jstacs is free software: you can redistribute it and/or modify it under the
 * terms of the GNU General Public License as published by the Free Software
 * Foundation, either version 3 of the License, or (at your option) any later
 * version.
 * 
 * Jstacs is distributed in the hope that it will be useful, but WITHOUT ANY
 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
 * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License along with
 * Jstacs. If not, see <http://www.gnu.org/licenses/>.
 * 
 * For more information on Jstacs, visit http://www.jstacs.de
 */

package de.jstacs.classifier.scoringFunctionBased.cll;

import java.util.Arrays;

import de.jstacs.WrongAlphabetException;
import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.algorithms.optimization.EvaluationException;
import de.jstacs.classifier.scoringFunctionBased.AbstractOptimizableFunction;
import de.jstacs.classifier.scoringFunctionBased.logPrior.DoesNothingLogPrior;
import de.jstacs.classifier.scoringFunctionBased.logPrior.LogPrior;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.scoringFunctions.ScoringFunction;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;

/**
 * This class implements the normalized log conditional likelihood. It can be
 * used to maximize parameters.
 * 
 * @author Jens Keilwagen
 */
public class NormConditionalLogLikelihood extends AbstractOptimizableFunction {

	private ScoringFunction[] score;

	private double[] helpArray;

	private DoubleList[] dList;

	private IntList[] iList;

	private LogPrior prior;

	/**
	 * The constructor creates an instance of the
	 * {@link NormConditionalLogLikelihood}.
	 * 
	 * @param score
	 *            the {@link ScoringFunction}s
	 * @param data
	 *            the data
	 * @param weights
	 *            the weights
	 * @param norm
	 *            the switch for using the normalization (division by the number
	 *            of sequences)
	 * @param freeParams
	 *            the switch for using only the free parameters or all
	 *            parameters in a {@link ScoringFunction}
	 * 
	 * @throws IllegalArgumentException
	 *             if the number of classes is not correct
	 * @throws WrongAlphabetException
	 *             if different alphabets are used
	 * 
	 * @see NormConditionalLogLikelihood#NormConditionalLogLikelihood(ScoringFunction[],
	 *      Sample[], double[][], LogPrior, boolean, boolean)
	 */
	public NormConditionalLogLikelihood( ScoringFunction[] score, Sample[] data, double[][] weights, boolean norm, boolean freeParams )
																																		throws IllegalArgumentException,
																																		WrongAlphabetException {
		this( score, data, weights, null, norm, freeParams );
	}

	/**
	 * The constructor creates an instance of the
	 * {@link NormConditionalLogLikelihood} using the given prior.
	 * 
	 * @param score
	 *            the {@link ScoringFunction}s
	 * @param data
	 *            the data
	 * @param weights
	 *            the weights
	 * @param prior
	 *            the prior
	 * @param norm
	 *            the switch for using the normalization (division by the number
	 *            of sequences)
	 * @param freeParams
	 *            the switch for using only the free parameters or all
	 *            parameters in a {@link ScoringFunction}
	 * 
	 * @throws IllegalArgumentException
	 *             if the number of classes is not correct
	 * @throws WrongAlphabetException
	 *             if different alphabets are used
	 * 
	 * @see AbstractOptimizableFunction#AbstractOptimizableFunction(Sample[],
	 *      double[][], boolean, boolean)
	 */
	public NormConditionalLogLikelihood( ScoringFunction[] score, Sample[] data, double[][] weights, LogPrior prior, boolean norm,
											boolean freeParams ) throws IllegalArgumentException, WrongAlphabetException {
		super( data, weights, norm, freeParams );
		if( cl < 2 ) {
			throw new IllegalArgumentException( "The number of classes is not correct. Has to be at least 2." );
		}
		if( cl < score.length ) {
			throw new IllegalArgumentException( "The number of classes is not correct. Check the length of the score array." );
		}
		this.prior = ( prior == null ) ? DoesNothingLogPrior.defaultInstance : prior;
		helpArray = new double[cl];
		dList = new DoubleList[cl];
		iList = new IntList[cl];
		this.score = score;
		for( int i = 0; i < cl; i++ ) {
			dList[i] = new DoubleList();
			iList[i] = new IntList();
		}
	}

	/* (non-Javadoc)
	 * @see de.jstacs.algorithms.optimization.DifferentiableFunction#evaluateGradientOfFunction(double[])
	 */
	@Override
	public double[] evaluateGradientOfFunction( double[] x ) throws DimensionException, EvaluationException {
		setParams( x );
		double[] grad = new double[shortcut[cl]];
		double weight;
		int counter1, counter2, counter3 = 0, counter4 = 0;
		Sequence s;
		//comments are old version
		for( ; counter3 < cl; counter3++ ) {
			for( counter2 = 0; counter2 < data[counter3].getNumberOfElements(); counter2++ ) {
				s = data[counter3].getElementAt( counter2 );
				weight = weights[counter3][counter2];
				//l = 0;
				for( counter1 = 0; counter1 < cl; counter1++ ) {
					iList[counter1].clear();
					dList[counter1].clear();

					//helpArray[counter1] = Math.exp( score[counter1].getLogScoreAndPartialDerivation( s, iList[counter1], dList[counter1] ) );
					//l += clazz[counter1] * helpArray[counter1];

					helpArray[counter1] = logClazz[counter1] + score[counter1].getLogScoreAndPartialDerivation( s,
													0,
													iList[counter1],
													dList[counter1] );
				}

				Normalisation.logSumNormalisation( helpArray, 0, helpArray.length, helpArray, 0 );

				for( counter1 = 0; counter1 < shortcut[0]; counter1++ ) {
					if( counter1 != counter3 ) {
						//grad[counter1] -= weight * clazz[counter1] * helpArray[counter1] / l;
						grad[counter1] -= weight * helpArray[counter1];
					} else {
						//grad[counter1] += weight * (1 - clazz[counter1] * helpArray[counter1] / l);
						grad[counter1] += weight * ( 1 - helpArray[counter1] );
					}
				}
				for( counter1 = 0; counter1 < cl; counter1++ ) {
					if( counter1 != counter3 ) {
						for( counter4 = 0; counter4 < iList[counter1].length(); counter4++ ) {
							//grad[shortcut[counter1] + iList[counter1].get( counter4 )] -= weight * dList[counter1].get( counter4 ) * clazz[counter1] / l;
							grad[shortcut[counter1] + iList[counter1].get( counter4 )] -= weight * dList[counter1].get( counter4 )
																							* helpArray[counter1];
						}
					} else {
						for( counter4 = 0; counter4 < iList[counter1].length(); counter4++ ) {
							//grad[shortcut[counter1] + iList[counter1].get( counter4 )] += weight * dList[counter1].get( counter4 ) * (1d / helpArray[counter1] - clazz[counter1] / l);
							grad[shortcut[counter1] + iList[counter1].get( counter4 )] += weight * dList[counter1].get( counter4 )
																							* ( 1d - helpArray[counter1] );
						}
					}
				}
			}
		}

		// prior
		prior.addGradientFor( x, grad );

		// normalization
		if( norm ) {
			for( counter1 = 0; counter1 < grad.length; counter1++ ) {
				grad[counter1] /= sum[cl];
			}
		}

		return grad;
	}

	/* (non-Javadoc)
	 * @see de.jstacs.algorithms.optimization.Function#evaluateFunction(double[])
	 */
	public double evaluateFunction( double[] x ) throws DimensionException, EvaluationException {
		setParams( x );

		double cll = 0, pr;
		int counter1, counter2, counter3 = 0;

		Sequence s;
		for( ; counter3 < cl; counter3++ ) {
			for( counter2 = 0; counter2 < data[counter3].getNumberOfElements(); counter2++ ) {
				s = data[counter3].getElementAt( counter2 );
				for( counter1 = 0; counter1 < cl; counter1++ ) {
					// class weight + class score
					helpArray[counter1] = logClazz[counter1] + score[counter1].getLogScore( s, 0 );
				}
				cll += weights[counter3][counter2] * ( helpArray[counter3] - Normalisation.getLogSum( helpArray ) );
			}
		}

		pr = prior.evaluateFunction( x );
		//System.out.println( (cll/sum[cl]) + " + " + (pr/sum[cl]) );

		if( Double.isNaN( cll + pr ) ) {
			System.out.println( "params " + Arrays.toString( x ) );
			System.out.flush();
			throw new EvaluationException( "Evaluating the function gives: " + cll + " + " + pr );
		} else if( norm ) {
			// normalization
			return ( cll + pr ) / sum[cl];
		} else {
			return cll + pr;
		}
	}

	/**
	 * This is used as a kind of very small pseudocount to avoid some problems
	 * while initialization.
	 */
	private static final double EPS = 1E-6;

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.scoringFunctionBased.AbstractOptimizableFunction#getParameters(de.jstacs.classifier.scoringFunctionBased.OptimizableFunction.KindOfParameter, double[])
	 */
	@Override
	public void getParameters( KindOfParameter kind, double[] erg ) throws Exception {
		double discount = 0;
		switch( kind ) {
			case PLUGIN:
				discount = Math.log( freeParams ? ( sum[cl - 1] + EPS ) : ( sum[cl] + cl * EPS ) );
				for( int i = 0; i < shortcut[0]; i++ ) {
					erg[i] = Math.log( sum[i] + EPS ) - discount;
				}
			case LAST:
				for( int i = 0; i < shortcut[0]; i++ ) {
					erg[i] = logClazz[i];
				}
				break;
			case ZEROS:
				break;
			default:
				throw new IllegalArgumentException( "Unknown kind of parameter" );
		}
		for( int i = 0; i < cl; i++ ) {
			System.arraycopy( score[i].getCurrentParameterValues(), 0, erg, shortcut[i], score[i].getNumberOfParameters() );
		}
	}

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.scoringFunctionBased.AbstractOptimizableFunction#setParams(double[])
	 */
	@Override
	public void setParams( double[] params ) throws DimensionException {
		super.setParams( params );
		for( int counter1 = 0; counter1 < cl; counter1++ ) {
			score[counter1].setParameters( params, shortcut[counter1] );
		}
	}

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.scoringFunctionBased.OptimizableFunction#getNumberOfStarts()
	 */
	@Override
	public int getNumberOfStarts() {
		return getNumberOfStarts( score );
	}

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.scoringFunctionBased.OptimizableFunction#reset(de.jstacs.scoringFunctions.ScoringFunction[])
	 */
	@Override
	public void reset( ScoringFunction[] funs ) throws Exception {
		if( funs.length != cl ) {
			throw new IllegalArgumentException( "Could not reset." );
		}
		for( int i = 0; i < cl; i++ ) {
			score[i] = funs[i];
			shortcut[i + 1] = shortcut[i] + score[i].getNumberOfParameters();
		}
		if( prior != null ) {
			prior.set( freeParams, score );
		}
	}
}
