// proxy.h
// NOTE: This is a generic file. Actual unit tests are located in
//       unit_tests.cpp.
// By Jack Toole for CS 225 spring 2011

#ifndef MONAD_PROXY_H
#define MONAD_PROXY_H

#include <iostream>
#include <limits>
#include <map>
#include <string>
#include <vector>
#include <utility>

#include "pipestream.h"
#include "monad_shared.h"

#define NO_MP_PART -1
#include "_mp_part_number.h"
#define MP_PART(x) (MP_PART_NUMBER == (x) || MP_PART_NUMBER == NO_MP_PART)

namespace proxy
{
	using namespace monad_shared;

	class RunTests;
	typedef bool (*output_check)(const std::string &, const std::string &);

	extern std::vector<unit_test> * global_tests;
	typedef std::map<std::string, output_check, util::ci_less> output_check_map;
	extern output_check_map * global_output_checks;

	class add_unit_test
	{
		public:
		add_unit_test(const char * name, unit_test::function func,
		              int32_t points_in_part, int32_t points_in_total, long timeout,
		              bool is_valgrind);

		private:
		void lazy_init_global_tests();
		int32_t get_points(int32_t points_in_total, int32_t points_in_part);
	};

	class add_output_check
	{
		public:
		add_output_check(const char * name, output_check func);
	};

	enum mode_t
	{
		SINGLE_TEST,
		MP_PART_TESTS,
		ALL_TESTS
	};

	struct RunTimeEnvironment
	{
		public:
		//!!const int itimer_number0;
		//!!const int itimer_number1;
		const int timeout_signum0;
		const int timeout_signum1;
		const size_t max_output_length;
		const char * single_test_passed_string;
		std::vector<unit_test> * heap_tests;
		output_check_map * output_checks;
		int32_t cleanup_globals();

		RunTimeEnvironment(std::vector<unit_test> *& init_tests,
		                   output_check_map *& init_output_checks);
		
		bool is_timeout_signal(int8_t signal_number)
		{
			return signal_number == timeout_signum0 ||
			       signal_number == timeout_signum1;
		}

		private:
		RunTimeEnvironment(const RunTimeEnvironment & other);
		RunTimeEnvironment & operator=(RunTimeEnvironment & other);
	};

	class RunTests
	{
		private:
		RunTimeEnvironment & environment;
		mode_t mode;
		const char * test_arg;
		int8_t mp_part;

		public:
		RunTests(int argc, char ** argv, RunTimeEnvironment & env);
		int execute();
		private:
		void redirect_glibc_to_stderr();
		void process_args(int argc, char ** argv);

		protected:
		int32_t execute_by_mode();
		int32_t run_single_test(const char * testname);
		int32_t run_single_test(unit_test & curr_test);
		void    handle_single_test_output(const std::string & output);
		void    output_single_test_passfail(const unit_test & curr_test);
		
		int32_t run_all_tests();
		int32_t get_sum_points();
		int32_t get_max_testname_length();
		int32_t get_max_points_length();
		void output_detailed_info_if_any_failed(int32_t score);
		void output_detailed_tests_info(int32_t score);

		bool execute_test(unit_test & test, bool enable_valgrind_call);

		private:
		RunTests(const RunTests & other);
		RunTests & operator=(const RunTests & other);
	};

	template <typename F>
	bool fork_execute(F & executor);
	
	class test_execution
	{
		private:
		util::pipestream fmsg_pipe; // For error messages
		util::pipestream cout_pipe; // For stdout/stderr
		util::pipestream nums_pipe; // for numbers: time, valgrind
		unit_test & test;
		RunTimeEnvironment & environment;
		bool do_valgrind;

		public:
		test_execution(unit_test & _test, RunTimeEnvironment & env, bool enable_valgrind_call);
		void before();
		void parent();
		void child();
		void after_success(int8_t return_code);
		void after_failure(int8_t signal_number);
		
		private:
		void child_test();
		void child_valgrind();
		void after_test_success();
		void after_valgrind_success(int8_t return_code);
		void start_timeout();
		long end_timeout();
		static bool prof_timeout_enabled();
		
		private:
		test_execution(const test_execution & other);
		test_execution & operator=(const test_execution & other);
	};


	const char * get_valgrind_string(int32_t flags);
	int32_t get_valgrind_flags(bool test_failed);
	int32_t bitflags(unsigned long a,     unsigned long b = 0, unsigned long c = 0,
					 unsigned long d = 0, unsigned long e = 0);
	bool bitflag(int32_t flags, int32_t num);

} // namespace proxy

using std::cout;
using std::cerr;
using std::endl;

#define UNIT_TEST(func,pointsInPart,pointsInTotal,timeout)             \
	monad_shared::unit_test::return_type                               \
	func(monad_shared::unit_test & this_test);                         \
	proxy::add_unit_test                                               \
		func##_adder(#func, func, pointsInPart,                        \
		             pointsInTotal, timeout, false);                   \
	monad_shared::unit_test::return_type                               \
	func(monad_shared::unit_test & this_test)

#define VALGRIND_TEST(func,pointsInPart,pointsInTotal,timeout)         \
	monad_shared::unit_test::return_type                               \
	func(monad_shared::unit_test & this_test);                         \
	proxy::add_unit_test                                               \
		func##_adder(#func, func, pointsInPart,                        \
		             pointsInTotal, timeout, true);                    \
	monad_shared::unit_test::return_type                               \
	func(monad_shared::unit_test & this_test)

#define HELPER_TEST(func, ...)                                         \
	monad_shared::unit_test::return_type                               \
	func(monad_shared::unit_test & this_test, __VA_ARGS__)

#define CALL_HELPER(func, ...)                                         \
	do {                                                               \
		monad_shared::unit_test::return_type helperval =               \
			func(this_test, __VA_ARGS__);                              \
		if (helperval != monad_shared::unit_test::pass_string)         \
			FAIL(helperval);                                           \
	} while (0)

#define OUTPUT_CHECK(func)                                                              \
	bool output_check_##func(const std::string & output, const std::string & expected); \
	proxy::add_output_check                                                             \
		output_check_##func##_adder(#func, output_check_##func);                        \
	bool output_check_##func(const std::string & output, const std::string & expected)

#define STRINGIFY1(p)   #p
#define STR(p)          STRINGIFY1(p)

#define FAIL(error)     return std::string(__FILE__ ":" STR(__LINE__) ": ") + (error)

#define PASS            return monad_shared::unit_test::pass_string;

#define ASSERT(expr)    if (!(expr))  \
                            FAIL("Assertion (" #expr ") failed")

#define ASSERT_OUTPUT(checkFunc, str)  \
	*this_test.checkstream << #checkFunc << str;

enum proxy_runtime_t
{
	CONSTANT_TIME = 0,
	N_TIME,
	NLOGN_TIME,
	NROOTN_TIME,
	N2_TIME,
	INFINITE_TIME,
	TIME_COUNT
};

namespace proxy
{
	extern double runtime_ratio[TIME_COUNT];
	extern const char * runtime_str[TIME_COUNT];
	template <typename Generator, typename Timer>
	clock_t timeIterations(Generator gen, Timer timeFunctor, size_t input_size);
}

#define ASSERT_TIME(gen, functor, expectedTime)                                           \
	do {                                                                                  \
		clock_t diff200 = proxy::timeIterations(gen, functor, 200);                       \
		clock_t diff100 = proxy::timeIterations(gen, functor, 100);                       \
		double ratio = static_cast<double>(diff200)/static_cast<double>(diff100);         \
		double diffFromExpected = abs(ratio - proxy::runtime_ratio[expectedTime]);        \
		double diffFromWrong    = abs(ratio - proxy::runtime_ratio[expectedTime + 1]);    \
		if (diffFromWrong < diffFromExpected)                                             \
			FAIL(string("Runtime was larger than ") + proxy::runtime_str[expectedTime]);  \
	} while(0)

namespace proxy {

template <typename Generator, typename Timer>
clock_t timeIterations(Generator gen, Timer timeFunctor, size_t input_size)
{
	const size_t num_iterations = 1000;
	std::vector<typename Generator::result_type> inputs;
	for (size_t i = 0; i < num_iterations; i++)
		inputs.push_back(gen(input_size));

	clock_t starttime = clock();
	for (size_t i = 0; i < num_iterations; i++)
		timeFunctor(inputs[i]);
	clock_t endtime = clock();
	return endtime - starttime;
}

inline int32_t bitflags(unsigned long a, unsigned long b, unsigned long c,
                        unsigned long d, unsigned long e)
{
	return ((int)(a != 0))        | (((int)(b != 0)) << 1) |
           (((int)(c != 0)) << 2) | (((int)(d != 0)) << 3) |
           (((int)(e != 0)) << 4) ;
}

inline bool bitflag(int32_t flags, int32_t num)
{
	return (flags & (1 << num)) != 0;
}

}
#endif