package de.jstacs.models.discrete.homogeneous;

import java.util.Random;

import de.jstacs.NonParsableException;
import de.jstacs.NotTrainedException;
import de.jstacs.WrongAlphabetException;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.data.Sample.ElementEnumerator;
import de.jstacs.data.sequences.IntSequence;
import de.jstacs.data.sequences.WrongSequenceTypeException;
import de.jstacs.io.XMLParser;
import de.jstacs.models.discrete.DGMParameterSet;
import de.jstacs.models.discrete.homogeneous.parameters.HomMMParameterSet;
import de.jtem.numericalMethods.calculus.specialFunctions.Gamma;

/**
 * This class implements homogeneous Markov models (hMM) of arbitrary order.
 * 
 * @author Jens Keilwagen
 * 
 * @see HomMMParameterSet
 */
public class HomogeneousMM extends HomogeneousModel
{
	private HomCondProb[] condProb;
	
	/**
	 * Creates a homogeneous Markov model from a parameter set.
	 * 
	 * @param params
	 *            the parameter set
	 * 
	 * @throws CloneNotSupportedException if the parameter set could not be cloned
	 * @throws IllegalArgumentException if the parameter set is not instantiated
	 * @throws NonParsableException if the parameter set is not parsable
	 */
	public HomogeneousMM( HomMMParameterSet params ) throws CloneNotSupportedException, IllegalArgumentException,
			NonParsableException
	{
		super( params );
	}

	/**
	 * Creates a homogeneous Markov model from a StringBuffer.
	 * 
	 * @param stringBuff
	 *            the StringBuffer
	 * 
	 * @throws NonParsableException
	 *             if the buffer is not parsable
	 */
	public HomogeneousMM( StringBuffer stringBuff ) throws NonParsableException
	{
		super( stringBuff );
	}

	public HomogeneousMM clone() throws CloneNotSupportedException
	{
		HomogeneousMM clone = (HomogeneousMM) super.clone();
		clone.condProb = clone.cloneHomProb( condProb );
		return clone;
	}
	
	protected Sequence getRandomSequence( Random r, int length ) throws WrongAlphabetException, WrongSequenceTypeException
	{
		int[] seq = new int[length];
		int j = 0, val = 0;
		for( j = 0; j < order && j < length; j++ )
		{
			seq[j] = chooseFromDistr( condProb[j], val, val + powers[1] - 1, r.nextDouble() );
			val = (val + seq[j])*powers[1];
		}
		while( j < length )
		{
			seq[j] = chooseFromDistr( condProb[order], val, val + powers[1] - 1, r.nextDouble() );
			val = ((val + seq[j++]) % powers[order]) * powers[1];
		}
		return new IntSequence( alphabets, seq );
	}

	public String getInstanceName()
	{
		return "hMM(" + getMaximalMarkovOrder() + ") " + (getESS()==0?"ML":"MAP");
	}

	public double getLogPriorTerm() throws Exception
	{
		if( !trained )
		{
			throw new NotTrainedException();
		}
		double p = 0, pot, ess = getESS();
		if( ess != 0 )
		{
			int counter1 = 0, counter2, anz1;
			for( ; counter1 < condProb.length; counter1++ )
			{
				anz1 = condProb[counter1].getNumberOfSpecificConstraints();
				pot = ess / (double) anz1;
				p += anz1 * ( Gamma.logOfGamma( powers[1] * pot ) / powers[1] - Gamma.logOfGamma( pot ) );
				for( counter2 = 0; counter2 < anz1; counter2++ )
				{
					p += pot * condProb[counter1].getLnFreq( counter2 );
				}
			}
		}
		return p;
	}
	
	protected double logProbFor( Sequence sequence, int startpos, int endpos )
	{
		if( endpos < startpos )
		{
			return 0;
		}
		int idx = sequence.discreteVal( startpos++ );
		double erg = condProb[0].getLnFreq( idx );
		for( int i = 1; i < order && startpos <= endpos; i++ )
		{
			idx = idx * powers[1] + sequence.discreteVal( startpos++ );
			erg += condProb[i].getLnFreq( idx );
		}
		while( startpos <= endpos )
		{
			idx = (idx % powers[order]) * powers[1] + sequence.discreteVal( startpos++ );
			erg += condProb[order].getLnFreq( idx );
		}
		return erg;
	}

	protected double probFor( Sequence sequence, int startpos, int endpos )
	{
		if( endpos < startpos )
		{
			return 0;
		}
		int idx = sequence.discreteVal( startpos++ );
		double erg = condProb[0].getFreq( idx );
		for( int i = 1; i < order && startpos <= endpos; i++ )
		{
			idx = idx * powers[1] + sequence.discreteVal( startpos++ );
			erg *= condProb[i].getFreq( idx );
		}
		while( startpos <= endpos )
		{
			idx = (idx % powers[order]) * powers[1] + sequence.discreteVal( startpos++ );
			erg *= condProb[order].getFreq( idx );
		}
		return erg;
	}


	public String toString()
	{
		String erg = "description: " + getDescription();
		if( trained && getMaximalMarkovOrder() == 0 )
		{
			erg += "\n";
			for( int i = 0; i < condProb[0].getNumberOfSpecificConstraints(); i++ )
			{
				erg += "\t" + alphabets.getSymbol( condProb[0].getPosition( 0 ), i )
						+ ": " + condProb[0].getFreq( i );
			}
		}
		return erg;
	}
	
	public void train( Sample data, double[] weights ) throws Exception
	{
		train( new Sample[]{ data }, new double[][]{ weights } );
	}

	public void train( Sample[] data, double[][] weights ) throws Exception
	{
		// check
		if( data.length != weights.length )
		{
			throw new IllegalArgumentException( "The constraint data.length == weights.length is not fulfilled." );
		}

		int i = 0;

		// reset container of counter
		while( i < condProb.length )
		{
			condProb[i++].reset();
		}

		// count
		for( i = 0; i < data.length; i++ )
		{
			if( data[i] != null )
			{
				countHomogeneous( data[i], weights[i] );
			}
		}
		// estimate
		estimate();
	}

	protected StringBuffer getFurtherModelInfos()
	{
		if( condProb != null )
		{
			StringBuffer xml = new StringBuffer( 1000 );
			int l = condProb.length;
			StringBuffer source = new StringBuffer( 25 + l * 500 );
			XMLParser.appendIntWithTags( source, l, "length" );
			for( int j = 0; j < l; j++ )
			{
				XMLParser.appendStringWithTags( source, condProb[j].toXML().toString(), "pos val=\"" + j + "\"", "pos" );
			}
			XMLParser.addTags( source, "condProb" );
			xml.append( source );
			return xml;
		}
		else
		{
			return null;
		}
	}

	private static final String XML_TAG = "HomogeneousMarkovModel";

	protected String getXMLTag()
	{
		return XML_TAG;
	}

	protected void set( DGMParameterSet params, boolean trained ) throws CloneNotSupportedException,
			NonParsableException
	{
		super.set( params, trained );
		if( !trained )
		{
			byte j, i = 0;
			int k;
			condProb = new HomCondProb[order+1];
			int[] pos;
			for( ; i < condProb.length; i++ )
			{
				pos = new int[i + 1];
				for( j = 0; j <= i; j++ )
				{
					pos[j] = j;
				}
				k = powers[i]*powers[1];
				condProb[i] = new HomCondProb( pos, k );
			}
		}
	}

	protected void setFurtherModelInfos( StringBuffer xml ) throws NonParsableException
	{
		if( trained )
		{
			StringBuffer help = XMLParser.extractForTag( xml, "condProb" );
			condProb = new HomCondProb[XMLParser.extractIntForTag( help, "length" )];
			for( int j = 0; j < condProb.length; j++ )
			{
				condProb[j] = new HomCondProb( XMLParser.extractForTag( help, "pos val=\"" + j + "\"", "pos" ) );
			}
		}
	}

	/**
	 * Counts homogeneously.
	 * 
	 * @param data
	 *            the sample
	 * @param weights
	 *            the weight or <code>null</code>
	 * 
	 * @throws IllegalArgumentException
	 *             if the weights do not have the right dimension
	 * @throws WrongAlphabetException
	 *             if the alphabets do not match
	 */
	private void countHomogeneous( Sample data, double[] weights ) throws WrongAlphabetException
	{
		int d = data.getNumberOfElements(), counter1, lengthCounter, l;
		Sequence seq;

		// check some constraints
		if( weights != null && d != weights.length )
		{
			throw new IllegalArgumentException( "The weights are not suitable for the data (wrong dimension)." );
		}
		if( !alphabets.checkConsistency( data.getAlphabetContainer() ) )
		{
			throw new WrongAlphabetException( "The alphabets of the model and the Sample are not suitable." );
		}

		// fill the constraints with the absolute frequency in the data
		int idx;
		ElementEnumerator ei = new ElementEnumerator( data );
		double w = 1;
		for( counter1 = 0; counter1 < d; counter1++ )
		{
			seq = ei.nextElement();
			l = seq.getLength();
			idx = 0;
			if( weights != null )
			{
				w = weights[counter1];
			}
			for( lengthCounter = 0; lengthCounter < order && lengthCounter < l; lengthCounter++ )
			{
				idx = idx * powers[1] + seq.discreteVal( lengthCounter );
				condProb[lengthCounter].add( idx, w );
			}
			while( lengthCounter < l )
			{
				idx = (idx % powers[order]) * powers[1] + seq.discreteVal( lengthCounter++ );
				condProb[order].add( idx, w );
			}
		}
	}

	private void estimate()
	{
		double ess = getESS();
		for( int i = 0; i < condProb.length; i++ )
		{
			condProb[i].estimate( ess );
		}
		trained = true;
	}
}
