/*
    Copyright (C) 2011 Brendon J. Brewer
    This file is part of DNest, the Diffusive Nested Sampler.

    DNest is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    DNest is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with DNest.  If not, see <http://www.gnu.org/licenses/>.
*/

#include "DNestSampler.h"
#include "RandomNumberGenerator.h"
#include <cmath>
#include <iostream>
#include <fstream>
#include <algorithm>
#include "Utils.h"

using namespace std;
namespace DNest
{

	const string DNestSampler::sampleFile = "sample.txt";
	const string DNestSampler::sampleInfoFile = "sample_info.txt";
	const string DNestSampler::levelsFile = "levels.txt";
	const string DNestSampler::optionsFile = "OPTIONS";

	DNestSampler::DNestSampler(const Model* exampleModel)
	:options(optionsFile.c_str())
	,theModels(options.numParticles)
	,theIndices(options.numParticles, 0)
	,theLevels(1, Level(0.0, -1E300, 0.0))
	,initialised(false)
	{
		for(unsigned int i=0; i<options.numParticles; i++)
			theModels[i] = exampleModel->factory();
		step = 0;

		// Make empty files
		fstream fout1(sampleFile.c_str(), ios::out);
		fstream fout2(sampleInfoFile.c_str(), ios::out);
		fstream fout3(levelsFile.c_str(), ios::out);
		fout1.close(); fout2.close(); fout3.close();
	}

	DNestSampler::~DNestSampler()
	{
		for(unsigned int i=0; i<options.numParticles; i++)
			delete theModels[i];
	}

	void DNestSampler::initialise()
	{
		for(unsigned int i=0; i<options.numParticles; i++)
		{
			theModels[i]->fromPrior();
			loglKeep.push_back(theModels[i]->getLogLikelihood());
		}
		initialised = true;
	}

	void DNestSampler::run()
	{
		if(!initialised)
			initialise();
		while(true)
			update();
	}

	void DNestSampler::update()
	{
		int which = randInt(options.numParticles);

		if(randomU() <= 0.5)
		{
			updateModel(which);
			updateIndex(which);
		}
		else
		{
			updateIndex(which);
			updateModel(which);
		}

		// Accumulate visits, exceeds
		int theIndex = theIndices[which];
		while(theIndex < (int)theLevels.size() - 1)
		{
			bool exceeds = theLevels[theIndex+1] < theModels[which];
			theLevels[theIndex].incrementVisits(exceeds);
			if(!exceeds)
				break;
			theIndex++;
		}

		// Accumulate loglikelihoods for the purposes
		// of creating new levels. Create a new level
		// if applicable.
		if(theLevels.size() < options.maxNumLevels)
		{
			// Accumulate loglikelihoods
			if(theLevels.back() < theModels[which])
				loglKeep.push_back(theModels[which]->getLogLikelihood());
			if(loglKeep.size() >= options.newLevelInterval)
			{
				addLevel();
				if(options.deleteParticles)
					deleteModel();
			}
		}

		if(step%options.saveInterval == 0 && step > 0)
		{
			Level::recalculateLogX(theLevels, options.regularisation);
			saveLevels();
			saveSample(which);
		}

		step++;
	}

	void DNestSampler::deleteModel()
	{
		top:
		for(size_t i=0; i<theModels.size(); i++)
			if((int)theLevels.size() - 1 - theIndices[i] > (5*options.backTrackLength+1) && options.numParticles > 1)
			{
				delete theModels[i];
				theModels.erase(theModels.begin() + i);
				theIndices.erase(theIndices.begin() + i);		
				options.numParticles--;
				cout<<"# Deleted a particle. "<<options.numParticles<<" remaining."<<endl;
				goto top;
			}

	}

	void DNestSampler::updateModel(int which)
	{
		bool accepted = theModels[which]->update(theLevels[theIndices[which]].getCutoff());
		theLevels[theIndices[which]].incrementTries(accepted);
	}

	void DNestSampler::updateIndex(int which)
	{
		int proposedIndex = theIndices[which] + (int)round(pow(10.0, 2.0*randomU())*randn());
		if(proposedIndex >= 0 && proposedIndex < (int)theLevels.size())
		{
			double logp = theLevels[theIndices[which]].getLogX() - theLevels[proposedIndex].getLogX() + logPush(proposedIndex) - logPush(theIndices[which]);

			// Enforce uniform exploration, if all levels created
			if(theLevels.size() == options.maxNumLevels)
				logp += options.enforceUniformityStrength*log((double)(theLevels[theIndices[which]].getTries() + 1)/(double)(theLevels[proposedIndex].getTries() + 1));

			if(logp > 0)
				logp = 0;
			if(randomU() <= exp(logp) && theLevels[proposedIndex] < theModels[which])
				theIndices[which] = proposedIndex;
		}

	}

	void DNestSampler::addLevel()
	{
		if(theLevels.size() >= options.maxNumLevels)
			return;

		sort(loglKeep.begin(), loglKeep.end());
		int ii = (int)floor(0.63212*loglKeep.size());
		theLevels.push_back(Level(theLevels.back().getLogX() - 1.0, loglKeep[ii]));
		cout<<"# Creating level "<<((int)theLevels.size()-1)<<" with logl = "<<loglKeep[ii].logl<<"."<<endl;
		loglKeep.erase(loglKeep.begin(), loglKeep.begin() + ii + 1);

		if(theLevels.size() == options.maxNumLevels)
		{
			loglKeep.clear();
			Level::renormaliseVisits(theLevels, options.regularisation, options.deleteParticles);
		}

		Level::recalculateLogX(theLevels, options.regularisation);
		saveLevels();
	}

	double DNestSampler::logPush(int index) const
	{
		if(theLevels.size() == options.maxNumLevels)
			return 0;

		int distance = (int)theLevels.size() - 1 - index;
		double result = -distance/options.backTrackLength;

		if(!options.deleteParticles)
			if(result <= -5.0)
				result = -5.0;

		return result;
	}

	void DNestSampler::saveLevels() const
	{
		fstream fout(levelsFile.c_str(), ios::out);
		for(size_t i=0; i<theLevels.size(); i++)
			fout<<theLevels[i]<<endl;
		fout.close();
	}

	void DNestSampler::saveSample(int which) const
	{
		static unsigned int count = 0;
		count++;
		fstream fout(sampleFile.c_str(), ios::out | ios::app);
		cout<<"# Saving a sample. N = "<<count<<"."<<endl;
		theModels[which]->print(fout);
		fout<<endl;
		fout.close();

		fstream fout2(sampleInfoFile.c_str(), ios::out | ios::app);
		fout2<<theIndices[which]<<' '<<theModels[which]->getLogLikelihood().logl<<' '<<theModels[which]->getLogLikelihood().tieBreaker<<endl;
		fout2.close();

		if(count >= options.maxNumSamples && options.maxNumSamples != 0)
			exit(0);
	}
}

