/*
 * This file is part of Jstacs.
 * 
 * Jstacs is free software: you can redistribute it and/or modify it under the
 * terms of the GNU General Public License as published by the Free Software
 * Foundation, either version 3 of the License, or (at your option) any later
 * version.
 * 
 * Jstacs is distributed in the hope that it will be useful, but WITHOUT ANY
 * WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
 * A PARTICULAR PURPOSE. See the GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License along with
 * Jstacs. If not, see <http://www.gnu.org/licenses/>.
 * 
 * For more information on Jstacs, visit http://www.jstacs.de
 */

package de.jstacs.scoringFunctions.mix;

import java.util.Arrays;

import de.jstacs.NonParsableException;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.scoringFunctions.NormalizableScoringFunction;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;
import de.jstacs.utils.Normalisation;
import de.jstacs.utils.random.DirichletMRG;
import de.jstacs.utils.random.DirichletMRGParams;

/**
 * This class implements a real mixture model.
 * 
 * @author Jens Keilwagen
 */
public class MixtureScoringFunction extends AbstractMixtureScoringFunction {

	/**
	 * This constructor creates a new {@link MixtureScoringFunction}.
	 * 
	 * @param starts
	 *            the number of starts that should be done in an optimization
	 * @param plugIn
	 *            indicates whether the initial parameters for an optimization
	 *            should be related to the data or randomly drawn
	 * @param component
	 *            the {@link de.jstacs.scoringFunctions.ScoringFunction}s
	 * 
	 * @throws CloneNotSupportedException
	 *             if an element of <code>component</code> could not be cloned
	 */
	public MixtureScoringFunction( int starts, boolean plugIn, NormalizableScoringFunction... component ) throws CloneNotSupportedException {
		super( component[0].getLength(), starts, component.length, true, plugIn, component );
		for( int i = 0; i < component.length; i++ ) {
			if( length != component[i].getLength() ) {
				throw new IllegalArgumentException( "The length of component " + i + " is not " + length + "." );
			}
			if( !alphabets.checkConsistency( component[i].getAlphabetContainer() ) ) {
				throw new IllegalArgumentException( "The AlphabetContainer of component " + i + " is not suitable." );
			}
		}
		computeLogGammaSum();
	}

	/**
	 * This is the constructor for the interface {@link de.jstacs.Storable}.
	 * Creates a new {@link MixtureScoringFunction} out of a
	 * {@link StringBuffer}.
	 * 
	 * @param xml
	 *            the XML representation as {@link StringBuffer}
	 * 
	 * @throws NonParsableException
	 *             if the XML representation could not be parsed
	 */
	public MixtureScoringFunction( StringBuffer xml ) throws NonParsableException {
		super( xml );
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @seede.jstacs.scoringFunctions.mix.AbstractMixtureScoringFunction#
	 * getNormalizationConstantForComponent(int)
	 */
	@Override
	protected double getNormalizationConstantForComponent( int i ) {
		return function[i].getNormalizationConstant();
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @seede.jstacs.scoringFunctions.NormalizableScoringFunction#
	 * getPartialNormalizationConstant(int)
	 */
	public double getPartialNormalizationConstant( int parameterIndex ) throws Exception {
		if( isNormalized ) {
			return 0;
		} else {
			if( norm < 0 ) {
				precomputeNorm();
			}
			int[] ind = getIndices( parameterIndex );
			if( ind[0] == function.length ) {
				return partNorm[ind[1]];
			} else {
				return hiddenPotential[ind[0]] * function[ind[0]].getPartialNormalizationConstant( ind[1] );
			}
		}
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @seede.jstacs.scoringFunctions.mix.AbstractMixtureScoringFunction#
	 * getHyperparameterForHiddenParameter(int)
	 */
	@Override
	public double getHyperparameterForHiddenParameter( int index ) {
		return function[index].getEss();
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see de.jstacs.scoringFunctions.NormalizableScoringFunction#getEss()
	 */
	public double getEss() {
		double ess = 0;
		for( int i = 0; i < function.length; i++ ) {
			ess += function[i].getEss();
		}
		return ess;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @seede.jstacs.scoringFunctions.mix.AbstractMixtureScoringFunction#
	 * initializeUsingPlugIn(int, boolean, de.jstacs.data.Sample[], double[][])
	 */
	@Override
	protected void initializeUsingPlugIn( int index, boolean freeParams, Sample[] data, double[][] weights ) throws Exception {
		Arrays.fill( hiddenParameter, 0 );
		double[] myWeights = weights[index];
		double[][] newWeights = new double[function.length][myWeights.length];
		int i = 0, j = 0;
		double[] h = new double[this.getNumberOfComponents()];
		if( getEss() == 0 ) {
			Arrays.fill( h, 1 );
		} else {
			for( ; j < h.length; j++ ) {
				h[j] = getHyperparameterForHiddenParameter( j );
			}
		}
		DirichletMRGParams param = new DirichletMRGParams( h );
		double[] p = new double[h.length];
		while( i < myWeights.length ) {
			DirichletMRG.DEFAULT_INSTANCE.generate( p, 0, p.length, param );
			for( j = 0; j < p.length; j++ ) {
				newWeights[j][i] = myWeights[i] * p[j];
				hiddenParameter[j] += newWeights[j][i];
			}
			i++;
		}
		for( i = 0; i < function.length; i++ ) {
			weights[index] = newWeights[i];
			function[i].initializeFunction( index, freeParams, data, weights );
		}
		weights[index] = myWeights;
		computeHiddenParameter( hiddenParameter );
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see de.jstacs.scoringFunctions.ScoringFunction#getInstanceName()
	 */
	public String getInstanceName() {
		String erg = "mixture(" + function[0].getInstanceName();
		for( int i = 1; i < function.length; i++ ) {
			erg += ", " + function[i].getInstanceName();
		}
		return erg + ")";
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @seede.jstacs.scoringFunctions.mix.AbstractMixtureScoringFunction#
	 * fillComponentScores(de.jstacs.data.Sequence, int)
	 */
	@Override
	protected void fillComponentScores( Sequence seq, int start ) {
		for( int i = 0; i < function.length; i++ ) {
			componentScore[i] = logHiddenPotential[i] + function[i].getLogScore( seq, start );
		}
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * de.jstacs.scoringFunctions.ScoringFunction#getLogScoreAndPartialDerivation
	 * (de.jstacs.data.Sequence, int, de.jstacs.utils.IntList,
	 * de.jstacs.utils.DoubleList)
	 */
	public double getLogScoreAndPartialDerivation( Sequence seq, int start, IntList indices, DoubleList partialDer ) {
		int i = 0, j = 0, k = paramRef.length - 1;
		k = paramRef[k] - paramRef[k - 1];
		for( ; i < function.length; i++ ) {
			iList[i].clear();
			dList[i].clear();
			componentScore[i] = logHiddenPotential[i] + function[i].getLogScoreAndPartialDerivation( seq, start, iList[i], dList[i] );
		}
		double logScore = Normalisation.logSumNormalisation( componentScore, 0, function.length, componentScore, 0 );
		for( i = 0; i < function.length; i++ ) {
			for( j = 0; j < iList[i].length(); j++ ) {
				indices.add( paramRef[i] + iList[i].get( j ) );
				partialDer.add( componentScore[i] * dList[i].get( j ) );
			}
		}
		for( j = 0; j < k; j++ ) {
			indices.add( paramRef[i] + j );
			partialDer.add( componentScore[j] - ( isNormalized ? hiddenPotential[j] : 0 ) );
		}
		return logScore;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see java.lang.Object#toString()
	 */
	@Override
	public String toString() {
		if( norm < 0 ) {
			precomputeNorm();
		}
		StringBuffer erg = new StringBuffer( function.length * 1000 );
		for( int i = 0; i < function.length; i++ ) {
			erg.append( "p(" + i
						+ ") = "
						+ ( isNormalized ? hiddenPotential[i] : ( partNorm[i] / norm ) )
						+ "\n"
						+ function[i].toString()
						+ "\n" );
		}
		return erg.toString();
	}
}
