/*
    Copyright (C) 2011 Brendon J. Brewer
    This file is part of DNest, the Diffusive Nested Sampler.

    DNest is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    DNest is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with DNest.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "Level.h"
#include "Model.h"
#include <cassert>
#include <cmath>

using namespace std;
namespace DNest
{

	Level::Level(double logX, double logl, double tieBreaker)
	:logX(logX), accepts(0), tries(0), visits(0), exceeds(0)
	{
		cutoff.logl = logl;
		cutoff.tieBreaker = tieBreaker;
	}

	Level::Level(double logX, const LikelihoodType& cutoff)
	:logX(logX)
	,cutoff(cutoff)
	,accepts(0), tries(0), visits(0), exceeds(0)
	{

	}

	LikelihoodType Level::getCutoff() const
	{
		return cutoff;
	}

	double Level::getLogX() const
	{
		return logX;
	}

	unsigned long Level::getTries() const
	{
		return tries;
	}

	void Level::recalculateLogX(vector<Level>& levels, int regularisation)
	{
		assert(levels.size() > 0);

		levels[0].logX = 0.0;
		for(size_t i=1; i<levels.size(); i++)
		{
			levels[i].logX = levels[i-1].logX + log((double)(levels[i-1].exceeds + exp(-1.0)*regularisation)/(double)(levels[i-1].visits + regularisation));
		}

	}

	void Level::renormaliseVisits(vector<Level>& levels, int regularisation, bool doExceeds)
	{
		for(size_t i=0; i<levels.size(); i++)
		{
			if(i != levels.size() - 1)
			{
				levels[i].accepts = ((double)(levels[i].accepts+1)/(double)(levels[i].tries+1))*regularisation;
				if(doExceeds)
					levels[i].exceeds = ((double)(levels[i].exceeds+1)/(double)(levels[i].visits+1))*regularisation;
				levels[i].tries = regularisation;
				if(doExceeds)
					levels[i].visits = regularisation;
			}
			else
			{
				levels[i].accepts = 0;
				levels[i].exceeds = 0;
				levels[i].tries = 0;
				levels[i].visits = 0;
			}
		}
	}

	void Level::incrementVisits(bool incrementExceeds)
	{
		visits++;
		if(incrementExceeds)
			exceeds++;
	}

	void Level::incrementTries(bool accepted)
	{
		tries++;
		if(accepted)
			accepts++;
	}

	bool operator < (const Level& level, const Model* model)
	{
		return level.cutoff < model->logl;
	}

	ostream& operator << (ostream& out, const Level& level)
	{
		out<<level.logX<<' '<<level.cutoff.logl<<' '<<level.cutoff.tieBreaker<<' ';
		out<<level.accepts<<' '<<level.tries<<' '<<level.exceeds<<' '<<level.visits<<' ';
		return out;
	}
}

