/*
 * This file is part of Jstacs.
 *
 * Jstacs is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * Jstacs is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with Jstacs.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * For more information on Jstacs, visit http://www.jstacs.de
 */

package de.jstacs.scoringFunctions;

import de.jstacs.NonParsableException;
import de.jstacs.Storable;
import de.jstacs.data.AlphabetContainer;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.io.ArrayHandler;
import de.jstacs.io.XMLParser;
import de.jstacs.utils.DoubleList;
import de.jstacs.utils.IntList;

/**
 * This class enables the user to model parts of the sequence independent of each other. The first part of the sequence
 * is modeled by the first NormalizableScoringFunction and has the length of the first NormalizableScoringFunction, the
 * second part starts directly after the first part, is modeled by the second ... .
 * 
 * @author Jens Keilwagen
 */
public class IndependentProductScoringFunction extends AbstractNormalizableScoringFunction
{
	private NormalizableScoringFunction[] score;

	private int[] start, partialLength, params;
	private boolean[] isVariable;

	private IntList partIList;

	private final static AlphabetContainer getAlphabetContainer( NormalizableScoringFunction[] functions, int length[] )
			throws IllegalArgumentException
	{
		AlphabetContainer[] cons = new AlphabetContainer[functions.length];
		int[] lengths = new int[functions.length];
		for( int i = 0; i < functions.length; i++ )
		{
			cons[i] = functions[i].getAlphabetContainer();
			lengths[i] = length[i];
		}
		return new AlphabetContainer( cons, lengths );
	}

	private final static int sum( int[] length ) throws IllegalArgumentException
	{
		int res = 0, i = 0;
		while( i < length.length && length[i] > 0 )
		{
			res += length[i++];
		}
		if( i != length.length )
		{
			throw new IllegalArgumentException( "The length with index " + i + " is 0." );
		}
		return res;
	}
	
	private final static int[] getLengthArray( NormalizableScoringFunction... functions ) throws IllegalArgumentException
	{
		int i = 0;
		int[] res = new int[functions.length]; 
		while( i < functions.length && functions[i].getLength() > 0 )
		{
			res[i] = functions[i].getLength();
			i++;
		}
		if( i != functions.length )
		{
			throw new IllegalArgumentException( "The NormalizableScoringFunction with index " + i
					+ " has a length 0." );
		}
		return res;
	}

	/**
	 * This constructor creates an instance of a given series of independent NormalizableScoringFunctions.
	 * The length that is modeled by each component is determined by {@link ScoringFunction#getLength()}.
	 * So the length should not be 0.
	 * 
	 * @param functions the components
	 * 
	 * @throws CloneNotSupportedException if at least one component could not be cloned
	 * @throws IllegalArgumentException if at least one component has length 0 or the components do not have the same ess
	 * 
	 * @see IndependentProductScoringFunction#IndependentProductScoringFunction(NormalizableScoringFunction[], int[])
	 */
	public IndependentProductScoringFunction( NormalizableScoringFunction... functions )
			throws CloneNotSupportedException, IllegalArgumentException
	{
		this( functions, getLengthArray( functions ) );
	}

	/**
	 * This constructor creates an instance of a given series of independent NormalizableScoringFunctions and lengths.
	 * 
	 * @param functions the components
	 * @param length the length of each component 
	 * 
	 * @throws CloneNotSupportedException if at least one component could not be cloned
	 * @throws IllegalArgumentException if the lengths and the components are not matching or the components do not have the same ess
	 * 
	 * @see IndependentProductScoringFunction#IndependentProductScoringFunction(NormalizableScoringFunction[], int[])
	 */
	public IndependentProductScoringFunction( NormalizableScoringFunction[] functions, int[] length )
	throws CloneNotSupportedException
	{
		super( getAlphabetContainer( functions, length ), sum( length ) );
		score = ArrayHandler.clone( functions );
		setStartsAndLengths( length );
		setParamsStarts();
		// test ESS
		double ess = score[0].getEss();
		int i = 1;
		while( i < score.length && score[i].getEss() == ess )
		{
			i++;
		}
		if( i != score.length )
		{
			throw new IllegalArgumentException(
					"All NormalizableScoringFunction have to use the same ESS. Violated at index " + i + "." );
		}
	}
	
	/**
	 * This is the constructor for {@link Storable}.
	 * 
	 * @param source the xml representation
	 * 
	 * @throws NonParsableException if the representation could not be parsed.
	 */
	public IndependentProductScoringFunction( StringBuffer source ) throws NonParsableException
	{
		super( source );
	}

	private void setStartsAndLengths( int[] length ) throws IllegalArgumentException
	{
		int oldStart = 0;
		start = new int[score.length];
		partialLength = new int[score.length];
		isVariable = new boolean[score.length];
		for( int i = 0; i < score.length; i++ )
		{
			start[i] = oldStart;
			partialLength[i] = length[i];
			isVariable[i] = score[i] instanceof VariableLengthScoringFunction;
			if( !isVariable[i] && score[i].getLength() != partialLength[i] )
			{
				throw new IllegalArgumentException( "Could not use length " + partialLength[i] + " for component " + i + "." );
			}
			oldStart += length[i];
		}
		partIList = new IntList();
	}

	private void setParamsStarts()
	{
		params = new int[score.length + 1];
		for( int n, i = 0; i < score.length; i++ )
		{
			n = score[i].getNumberOfParameters();
			if( n == UNKNOWN )
			{
				params = null;
				break;
			}
			else
			{
				params[i + 1] = params[i] + n;
			}
		}
	}

	public IndependentProductScoringFunction clone() throws CloneNotSupportedException
	{
		IndependentProductScoringFunction clone = (IndependentProductScoringFunction) super.clone();
		clone.score = ArrayHandler.clone( score );
		clone.setStartsAndLengths( partialLength );
		clone.setParamsStarts();
		return clone;
	}

	private int getSFIndex( int index )
	{
		int i = 1;
		while( i < params.length && index >= params[i] )
		{
			i++;
		}
		return i - 1;
	}

	public int getSizeOfEventSpaceForRandomVariablesOfParameter( int index )
	{
		int i = getSFIndex( index );
		return score[i].getSizeOfEventSpaceForRandomVariablesOfParameter( index - params[i] );
	}

	public double getNormalizationConstant()
	{
		double norm = 1;
		for( int i = 0; i < score.length; i++ )
		{
			norm *= score[i].getNormalizationConstant();
		}
		return norm;
	}

	public double getPartialNormalizationConstant( int parameterIndex ) throws Exception
	{
		int i = 0, j = getSFIndex( parameterIndex );
		double partNorm = 1;
		for( ; i < score.length; i++ )
		{
			if( i == j )
			{
				partNorm *= score[i].getPartialNormalizationConstant( parameterIndex - params[i] );
			}
			else
			{
				partNorm *= score[i].getNormalizationConstant();
			}
		}
		return partNorm;

	}

	public double getEss()
	{
		return score[0].getEss();
	}

	public void initializeFunction( int index, boolean freeParams, Sample[] data, double[][] weights ) throws Exception
	{
		Sample[] part = new Sample[data.length];
		for( int j, i = 0; i < score.length; i++ )
		{
			for( j = 0; j < data.length; j++ )
			{
				part[j] = data[j].getInfixSample( start[i], partialLength[i] );
			}
			score[i].initializeFunction( index, freeParams, part, weights );
		}
		setParamsStarts();
	}

	protected void fromXML( StringBuffer rep ) throws NonParsableException
	{
		StringBuffer xml = XMLParser.extractForTag( rep, getInstanceName() );
		alphabets = (AlphabetContainer) XMLParser.extractStorableForTag( xml, "AlphabetContainer" );
		length = XMLParser.extractIntForTag( xml, "length" );
		score = (NormalizableScoringFunction[]) ArrayHandler.cast( XMLParser.extractStorableArrayForTag( xml, "ScoringFunctions" ) );
		setStartsAndLengths( XMLParser.extractIntArrayForTag( xml, "partialLength" ) );
		setParamsStarts();
	}

	public String getInstanceName()
	{
		return getClass().getSimpleName();
	}

	public double[] getCurrentParameterValues() throws Exception
	{
		int numPars = this.getNumberOfParameters();
		double[] pars = new double[numPars], help;
		for( int j, k = 0, i = 0; i < score.length; i++ )
		{
			help = score[i].getCurrentParameterValues();
			for( j = 0; j < help.length; j++ )
			{
				pars[k++] = help[j];
			}
		}
		return pars;
	}

	public double getLogScore( Sequence seq, int start )
	{
		double s = 0;
		for( int i = 0; i < score.length; i++ )
		{
			if( isVariable[i] )
			{
				s += ((VariableLengthScoringFunction) score[i]).getLogScore( seq, start + this.start[i], partialLength[i] );
			}
			else
			{
				s += score[i].getLogScore( seq, start + this.start[i] );
			}
		}
		return s;
	}

	public double getLogScoreAndPartialDerivation( Sequence seq, int start, IntList indices, DoubleList partialDer )
	{
		double s = 0;
		for( int j, i = 0; i < score.length; i++ )
		{
			partIList.clear();
			if( isVariable[i] )
			{
				s += ((VariableLengthScoringFunction) score[i]).getLogScoreAndPartialDerivation( seq, start + this.start[i], partialLength[i], partIList, partialDer );
			}
			else
			{
				s += score[i].getLogScoreAndPartialDerivation( seq, start + this.start[i], partIList, partialDer );
			}
			for( j = 0; j < partIList.length(); j++ )
			{
				indices.add( partIList.get( j ) + this.params[i] );
			}
		}
		return s;
	}

	public int getNumberOfParameters()
	{
		if( params == null )
		{
			return UNKNOWN;
		}
		else
		{
			return params[score.length];
		}
	}

	public int getNumberOfRecommendedStarts()
	{
		int max = score[0].getNumberOfRecommendedStarts();
		for( int i = 1; i < score.length; i++ )
		{
			max = Math.max( max, score[i].getNumberOfRecommendedStarts() );
		}
		return max;
	}

	public void setParameters( double[] params, int start )
	{
		for( int i = 0; i < score.length; i++ )
		{
			score[i].setParameters( params, start + this.params[i] );
		}
	}

	public StringBuffer toXML()
	{
		StringBuffer xml = new StringBuffer( 10000 );
		XMLParser.appendStorableWithTags( xml, alphabets, "AlphabetContainer" );
		XMLParser.appendIntWithTags( xml, length, "length" );
		XMLParser.appendStorableArrayWithTags( xml, score, "ScoringFunctions" );
		XMLParser.appendIntArrayWithTags( xml, partialLength, "partialLength" );
		XMLParser.addTags( xml, getInstanceName() );
		return xml;
	}

	public String toString()
	{
		StringBuffer sb = new StringBuffer( 100000 );
		for( int i = 0; i < score.length; i++ )
		{
			sb.append( "sequence part beginning at " + start[i] + " with length " + partialLength[i] + "\n" );
			sb.append( score[i].toString() + "\n" );
		}
		return sb.toString();
	}

	public double getLogPriorTerm()
	{
		double val = 0;
		for( int i = 0; i < score.length; i++ )
		{
			val += score[i].getLogPriorTerm();
		}
		return val;
	}

	public void addGradientOfLogPriorTerm( double[] grad, int start ) throws Exception
	{
		for( int i = 0; i < score.length; i++ )
		{
			score[i].addGradientOfLogPriorTerm( grad, start + params[i] );
		}
	}

	public boolean isInitialized()
	{
		int i = 0;
		while( i < score.length && score[i].isInitialized() )
		{
			i++;
		}
		return i == score.length;
	}
	
	public void initializeFunctionRandomly( boolean freeParams ) throws Exception
	{
		for( int i = 0; i < score.length; i++ )
		{
			score[i].initializeFunctionRandomly( freeParams );
		}
	}
}
