/*
 * This file is part of Jstacs.
 * 
 * Jstacs is free software: you can redistribute it and/or modify it under the
 * terms of the GNU General Public License as published by the Free Software
 * Foundation, either version 3 of the License, or (at your option) any later
 * version.
 * 
 * Jstacs is distributed in the hope that it will be useful, but WITHOUT ANY
 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
 * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License along with
 * Jstacs. If not, see <http://www.gnu.org/licenses/>.
 * 
 * For more information on Jstacs, visit http://www.jstacs.de
 */

package de.jstacs.classifier.scoringFunctionBased;

import de.jstacs.WrongAlphabetException;
import de.jstacs.algorithms.optimization.DimensionException;
import de.jstacs.data.Sample;
import de.jstacs.scoringFunctions.ScoringFunction;

/**
 * This class extends {@link OptimizableFunction} and implements some common
 * methods.
 * 
 * @author Jens Keilwagen
 */
public abstract class AbstractOptimizableFunction extends OptimizableFunction {

	/**
	 * These shortcuts indicate the beginning of a new part in the parameter
	 * vector.
	 */
	protected int[] shortcut;

	/**
	 * The data that is used to evaluate this function.
	 */
	protected Sample[] data;

	/**
	 * The weights for the data.
	 * 
	 * @see AbstractOptimizableFunction#data
	 */
	protected double[][] weights;

	/**
	 * The class parameters.
	 */
	protected double[] clazz;

	/**
	 * The logarithm of the class parameters.
	 * 
	 * @see AbstractOptimizableFunction#clazz
	 */
	protected double[] logClazz;

	/**
	 * The sums of the weighted data per class and additional the total weight
	 * sum.
	 * 
	 * @see AbstractOptimizableFunction#data
	 * @see AbstractOptimizableFunction#weights
	 */
	protected double[] sum;

	/**
	 * The number of different classes.
	 */
	protected int cl;

	/**
	 * Whether a normalization should be done or not.
	 */
	protected boolean norm;

	/**
	 * Whether only the free parameters or all should be used.
	 */
	protected boolean freeParams;

	/**
	 * The constructor creates an instance using the given weighted data.
	 * 
	 * @param data
	 *            the data
	 * @param weights
	 *            the weights
	 * @param norm
	 *            the switch for using the normalization (division by the number
	 *            of sequences)
	 * @param freeParams
	 *            the switch for using only the free parameters
	 * 
	 * @throws IllegalArgumentException
	 *             if the number of classes is not correct
	 * @throws WrongAlphabetException
	 *             if different alphabets are used
	 */
	protected AbstractOptimizableFunction( Sample[] data, double[][] weights, boolean norm, boolean freeParams )
																												throws IllegalArgumentException,
																												WrongAlphabetException {
		this.norm = norm;
		this.freeParams = freeParams;
		shortcut = new int[data.length + 1];
		cl = data.length;
		if( cl < 1 || cl != weights.length ) {
			throw new IllegalArgumentException( "The number of classes is not correct. Check the length of the data and weights array." );
		}
		if( freeParams ) {
			shortcut[0] = cl - 1;
		} else {
			shortcut[0] = cl;
		}
		this.data = data;
		this.weights = weights;
		logClazz = new double[cl];
		clazz = new double[cl];
		sum = new double[cl + 1];
		sum[cl] = 0;
		int i = 0, j;
		for( ; i < cl; i++ ) {
			sum[i] = 0;
			for( j = 0; j < weights[i].length; j++ ) {
				sum[i] += weights[i][j];
			}
			sum[cl] += sum[i];
		}
	}

	/* (non-Javadoc)
	 * @see de.jstacs.algorithms.optimization.Function#getDimensionOfScope()
	 */
	public final int getDimensionOfScope() {
		return shortcut[cl];
	}

	/**
	 * This method enables the user to get the parameters without creating a new
	 * array.
	 * 
	 * @param kind
	 *            the kind of the class parameters to be returned in
	 *            <code>erg</code>
	 * @param erg
	 *            the array for the start parameters
	 * 
	 * @throws Exception
	 *             if the array is <code>null</code> or does not have the
	 *             correct length
	 * 
	 * @see OptimizableFunction#getParameters(KindOfParameter)
	 */
	public abstract void getParameters( KindOfParameter kind, double[] erg ) throws Exception;

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.scoringFunctionBased.OptimizableFunction#getParameters(KindOfParameter)
	 */
	@Override
	public final double[] getParameters( KindOfParameter kind ) throws Exception {
		double[] temp = new double[getDimensionOfScope()];
		getParameters( kind, temp );
		return temp;
	}

	/**
	 * Checks the dimension and sets the class parameters.
	 */
	@Override
	public void setParams( double[] params ) throws DimensionException {
		if( params == null || params.length != shortcut[cl] ) {
			if( params != null ) {
				throw new DimensionException( params.length, shortcut[cl] );
			} else {
				throw new DimensionException( 0, shortcut[cl] );
			}
		}
		for( int counter1 = 0; counter1 < shortcut[0]; counter1++ ) {
			logClazz[counter1] = params[counter1];
			clazz[counter1] = Math.exp( params[counter1] );
		}
	}

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.scoringFunctionBased.OptimizableFunction#getClassParams(double[])
	 */
	@Override
	public final double[] getClassParams( double[] params ) {
		double[] res = new double[cl];
		System.arraycopy( params, 0, res, 0, shortcut[0] );
		if( freeParams ) {
			res[shortcut[0]] = 0;
		}
		return res;
	}

	/**
	 * Returns the number of recommended starts.
	 * 
	 * @param score
	 *            the underlying scoring functions
	 * 
	 * @return the number of recommended starts
	 * 
	 * @see OptimizableFunction#getNumberOfStarts()
	 */
	protected final int getNumberOfStarts( ScoringFunction[] score ) {
		int starts = score[0].getNumberOfRecommendedStarts();
		for( int i = 1; i < score.length; i++ ) {
			starts = Math.max( starts, score[i].getNumberOfRecommendedStarts() );
		}
		return starts;
	}

	/* (non-Javadoc)
	 * @see de.jstacs.classifier.scoringFunctionBased.OptimizableFunction#addTermToClassParameter(int, double)
	 */
	@Override
	public final void addTermToClassParameter( int classIndex, double term ) {
		if( classIndex < 0 || classIndex >= cl ) {
			throw new IndexOutOfBoundsException( "check the class index" );
		}
		if( freeParams && classIndex == cl - 1 ) {
			// this parameter is not free so we have to change all other class parameters
			for( int i = 0; i < shortcut[0]; i++ ) {
				logClazz[i] -= term;
				clazz[i] = Math.exp( logClazz[i] );
			}
		} else {
			logClazz[classIndex] += term;
			clazz[classIndex] = Math.exp( logClazz[classIndex] );
		}
	}
}
