package de.jstacs.models.discrete.homogeneous;

import java.util.Random;

import de.jstacs.NonParsableException;
import de.jstacs.NotTrainedException;
import de.jstacs.WrongAlphabetException;
import de.jstacs.data.EmptySampleException;
import de.jstacs.data.Sample;
import de.jstacs.data.Sequence;
import de.jstacs.data.sequences.WrongSequenceTypeException;
import de.jstacs.models.discrete.Constraint;
import de.jstacs.models.discrete.DGMParameterSet;
import de.jstacs.models.discrete.DiscreteGraphicalModel;
import de.jstacs.models.discrete.homogeneous.parameters.HomogeneousModelParameterSet;
import de.jstacs.results.NumericalResultSet;

/**
 * This class implements homogeneous models.
 * 
 * @author Jens Keilwagen
 * 
 * @see HomogeneousModelParameterSet
 */
public abstract class HomogeneousModel extends DiscreteGraphicalModel
{
	/**
	 * The powers of the alphabet length.
	 */
	protected int[] powers;

	/**
	 * The Markov order of the model.
	 */
	protected byte order;

	/**
	 * Creates a homogeneous model from a parameter set.
	 * 
	 * @param params
	 *            the parameter set
	 * 
	 * @throws CloneNotSupportedException if the parameter set could not be cloned
	 * @throws IllegalArgumentException if the parameter set is not instantiated
	 * @throws NonParsableException if the parameter set is not parsable
	 */
	public HomogeneousModel( HomogeneousModelParameterSet params ) throws CloneNotSupportedException,
			IllegalArgumentException, NonParsableException
	{
		super( params );
	}

	/**
	 * Creates a homogeneous model from a StringBuffer.
	 * 
	 * @param stringBuff
	 *            the StringBuffer
	 * 
	 * @throws NonParsableException
	 *             if the buffer is not parsable
	 */
	public HomogeneousModel( StringBuffer stringBuff ) throws NonParsableException
	{
		super( stringBuff );
	}

	/**
	 * Creates a sample of <code>no</code>sequences.
	 * 
	 * @param no
	 *            the number of sequences in the sample
	 * @param length
	 *            the length of all sequences or an array of length, than the sequence with index <code>i</code> has
	 *            length <code>length[i]</code>
	 * 
	 * @return the sample
	 * 
	 * @throws NotTrainedException
	 *             if the model was not trained
	 * @throws IllegalArgumentException
	 *             if the dimension of <code>length</code> is neither 1 nor <code>no</code>
	 * @throws EmptySampleException
	 *             if <code>no == 0</code>
	 * @throws WrongSequenceTypeException 
	 * @throws WrongAlphabetException 
	 */
	public final Sample emitSample( int no, int... length ) throws NotTrainedException, IllegalArgumentException,
			EmptySampleException, WrongAlphabetException, WrongSequenceTypeException
	{
		if( !trained )
		{
			throw new NotTrainedException();
		}
		Sequence[] seq = new Sequence[no];
		if( length.length == 1 )
		{
			for( int i = 0; i < no; i++ )
			{
				seq[i] = getRandomSequence( new Random(), length[0] );
			}
		}
		else if( length.length == no )
		{
			for( int i = 0; i < no; i++ )
			{
				seq[i] = getRandomSequence( new Random(), length[i] );
			}
		}
		else
		{
			throw new IllegalArgumentException( "The dimension of the array length is not correct." );
		}
		return new Sample( "sampled from " + getInstanceName(), seq );
	}

	/**
	 * This method creates a sequence from a trained model.
	 * 
	 * @param r
	 *            the random generator
	 * @param length
	 *            the length of the sequence
	 * 
	 * @return the sequence
	 * 
	 * @throws WrongSequenceTypeException 
	 * @throws WrongAlphabetException 
	 */
	protected abstract Sequence getRandomSequence( Random r, int length ) throws WrongAlphabetException, WrongSequenceTypeException;

	public byte getMaximalMarkovOrder()
	{
		return order;
	}

	public NumericalResultSet getNumericalCharacteristics() throws Exception
	{
		return null;
	}

	public final double getLogProbFor( Sequence sequence, int startpos, int endpos ) throws NotTrainedException,
			Exception
	{
		check( sequence, startpos, endpos );
		return logProbFor( sequence, startpos, endpos );
	}

	public final double getProbFor( Sequence sequence, int startpos, int endpos ) throws NotTrainedException, Exception
	{
		check( sequence, startpos, endpos );
		return probFor( sequence, startpos, endpos );
	}

	/**
	 * Trains the model on all given samples.
	 * 
	 * @param data
	 *            the data
	 * 
	 * @throws Exception
	 *             if something went wrong
	 */
	public void train( Sample[] data ) throws Exception
	{
		train( data, new double[data.length][] );
	}

	/**
	 * Trains the model using an array of weighted samples. The <code>weights[i]</code> are for <code>data[i]</code>.
	 * 
	 * @param data
	 *            the samples
	 * @param weights
	 *            the weights
	 * 
	 * @throws Exception
	 *             if something went wrong, furthermore <code>data.length</code> has to be <code>weights.length</code>
	 */
	public abstract void train( Sample[] data, double[][] weights ) throws Exception;

	protected void set( DGMParameterSet params, boolean trained ) throws CloneNotSupportedException,
			NonParsableException
	{
		super.set( params, trained );
		order = (Byte) params.getParameterAt( 2 ).getValue();
		powers = new int[Math.max( order + 1, 2 )];
		powers[0] = 1;
		powers[1] = (int) alphabets.getAlphabetLengthAt( 0 );
		for( int i = 1; i < powers.length; i++ )
		{
			powers[i] = powers[1] * powers[i - 1];
		}
	}

	/**
	 * Checks some constraints
	 * 
	 * @param sequence
	 *            the sequence
	 * @param startpos
	 *            the start position
	 * @param endpos
	 *            the end position
	 * 
	 * @throws NotTrainedException
	 *             if the model is not trained
	 * @throws IllegalArgumentException
	 *             if some arguments are wrong
	 */
	protected void check( Sequence sequence, int startpos, int endpos ) throws NotTrainedException,
			IllegalArgumentException
	{
		super.check( sequence, startpos, endpos );
		if( endpos >= sequence.getLength() )
		{
			throw new IllegalArgumentException( "This endposition is impossible. Try: endposistion < sequence.length" );
		}
	}

	/**
	 * Chooses a value in [0,<code>end-start</code>] according to the distribution encoded in the frequencies of <code>distr</code>
	 * between the indices <code>start</code> and <code>end</code>. 
	 * 
	 * <br><br>
	 * 
	 * The instance <code>distr</code> is not changed in the process.
	 * 
	 * @param distr the distribution 
	 * @param start the start index
	 * @param end the end index
	 * @param randNo a random number in [0,1]
	 * 
	 * @return the chosen value
	 * 
	 * @see Constraint#getFreq(int)
	 */
	protected final int chooseFromDistr( Constraint distr, int start, int end, double randNo )
	{
		int c = start;
		while( randNo > distr.getFreq( c ) && c <= end )
		{
			randNo -= distr.getFreq( c++ );
		}
		return c - start;
	}

	/**
	 * This method computes the logarithm of the probability of the given sequence in the given interval.
	 * The method is only used in {@link de.jstacs.models.Model#getLogProbFor(Sequence, int, int)} after the method
	 * {@link HomogeneousModel#check(Sequence, int, int)} has been invoked.
	 * 
	 * @param sequence the sequence
	 * @param startpos the start position
	 * @param endpos the end position
	 * 
	 * @return the logarithm of the probability for the given subsequence
	 * 
	 * @see HomogeneousModel#check(Sequence, int, int)
	 * @see de.jstacs.models.Model#getLogProbFor(Sequence, int, int)
	 */
	protected abstract double logProbFor( Sequence sequence, int startpos, int endpos );

	/**
	 * This method computes the probability of the given sequence in the given interval.
	 * The method is only used in {@link de.jstacs.models.Model#getProbFor(Sequence, int, int)} after the method
	 * {@link HomogeneousModel#check(Sequence, int, int)} has been invoked.
	 * 
	 * @param sequence the sequence
	 * @param startpos the start position
	 * @param endpos the end position
	 * 
	 * @return the probability for the given subsequence
	 * 
	 * @see HomogeneousModel#check(Sequence, int, int)
	 * @see de.jstacs.models.Model#getProbFor(Sequence, int, int)
	 */
	protected abstract double probFor( Sequence sequence, int startpos, int endpos );

	/**
	 * Clones the given conditional probabilities.
	 * 
	 * @param p
	 *            the original conditional probabilities
	 *            
	 * @return an array of clones
	 */
	protected HomCondProb[] cloneHomProb( HomCondProb[] p )
	{
		HomCondProb[] condProb = new HomCondProb[p.length];
		for( int i = 0; i < condProb.length; i++ )
		{
			condProb[i] = new HomCondProb( p[i] );
		}
		return condProb;
	}

	/**
	 * This class handles the (conditional) homogeneous probabilities in a fast way.
	 * 
	 * @author Jens Keilwagen
	 */
	protected class HomCondProb extends Constraint
	{
		private double[] lnFreq;

		/**
		 * The main constructor. Checks that each position is used maximally once.
		 * 
		 * @param pos
		 *            the used positions (will be cloned), has to be non-negative
		 * @param n
		 *            the number of specific constraints
		 */
		public HomCondProb( int[] pos, int n )
		{
			super( pos, n );
		}

		/**
		 * Creates an instance from a StringBuffer (see {@link de.jstacs.Storable}).
		 * 
		 * @param xml
		 *            the StringBuffer
		 * 
		 * @throws NonParsableException
		 *             if the buffer is not parsable
		 */
		public HomCondProb( StringBuffer xml ) throws NonParsableException
		{
			super( xml );
		}

		/**
		 * This constructor is used for cloning instances, since any instance is an inner instance of a {@link HomogeneousModel}.
		 * 
		 * @param old the old instance to be cloned
		 */
		public HomCondProb( HomCondProb old )
		{
			this( old.usedPositions, old.freq.length );
			System.arraycopy( old.freq, 0, freq, 0, freq.length );
			if( old.lnFreq != null )
			{
				lnFreq( 0, freq.length );
			}
		}

		public void estimate( double ess )
		{
			double pc = ess / (double) getNumberOfSpecificConstraints();
			if( usedPositions.length == 1 )
			{
				estimateUnConditional( 0, freq.length, pc, false );
			}
			else
			{
				// conditional
				for( int counter1 = 0; counter1 < freq.length; counter1 += powers[1] )
				{
					estimateUnConditional( counter1, counter1 + powers[1], pc, false );
				}
			}
		}

		/**
		 * Returns the logarithmic frequency.
		 * 
		 * @param index
		 *            the index
		 * 
		 * @return the logarithmic frequency
		 */
		public double getLnFreq( int index )
		{
			return lnFreq[index];
		}

		public int satisfiesSpecificConstraint( Sequence seq, int start )
		{
			int erg = 0, counter = 0, p = usedPositions.length - 1;
			for( ; counter < usedPositions.length; counter++, p-- )
			{
				erg += powers[p] * seq.discreteVal( start + usedPositions[counter] );
			}
			return erg;
		}

		public String toString()
		{
			String erg = "";
			int i = 1, l = usedPositions.length - 1;
			if( l > 0 )
			{
				erg += usedPositions[0];
				while( i < l )
				{
					erg += ", " + usedPositions[i++];
				}
				erg += " -> ";
			}
			return erg + usedPositions[l];
		}

		protected void appendAdditionalInfo( StringBuffer xml )
		{
		}

		private static final String XML_TAG = "HomCondProb";

		protected String getXMLTag()
		{
			return XML_TAG;
		}

		protected void estimateUnConditional( int start, int end, double pc, boolean exceptionWhenNoData )
		{
			super.estimateUnConditional( start, end, pc, exceptionWhenNoData );
			lnFreq( start, end );
		}

		/**
		 * Computes the logarithm for all frequencies.
		 */
		private void lnFreq( int start, int end )
		{
			if( lnFreq == null )
			{
				lnFreq = new double[freq.length];
			}
			for( int i = start; i < end; i++ )
			{
				lnFreq[i] = Math.log( freq[i] );
			}
		}

		protected void extractAdditionalInfo( StringBuffer xml ) throws NonParsableException
		{
			lnFreq( 0, freq.length );
		}
	}
}
