#include "stdafx.h"
#include <winbase.h>
#include <stdio.h>
#include "FunctionDataInfo.h"
#include <TlHelp32.h>
#include <vector>
#include <stack>

// variables to communicate with GUI application
bool currmodule_ = false;
bool system_ = false;
bool antidebug_ = false;
bool classnames_ = false;
bool modulenames_ = false;
bool varnames_ = false;
bool varvalues_ = false;
bool typenames_ = false;
bool inline_ = false;
bool consoleEnabled_ = false;
bool funcid_ = false;
bool exceptions_ = false;

char moduleFullName[bufSize];
char moduleName[bufSize];

struct Caller{
	char name[bufSize];
};

// we keep caller in our stack
stack<Caller> currentCaller;

std::unordered_multimap <FunctionID, FunctionDataInfo*> functions;

// Function removes information about function which returned
void eraseCaller(FunctionID funcId){
	std::pair<std::unordered_multimap <FunctionID, FunctionDataInfo*>::iterator, std::unordered_multimap <FunctionID, FunctionDataInfo*>::iterator> ret = functions.equal_range(funcId);
	std::unordered_multimap <FunctionID, FunctionDataInfo*>::iterator it;
	it = ret.first;
	functions.erase(it);
}



STDMETHODIMP CProfiler::Shutdown(){
	CloseHandle(hFile);
	return S_OK;

}

// Callback called when assembly loaded, logs asssembly name
STDMETHODIMP CProfiler::AssemblyLoadFinished(AssemblyID assemblyID, HRESULT hrStatus){

	ULONG length = 0;
	DWORD written = 0;
	AppDomainID appdomId;
	ModuleID moduleID;
	WCHAR  szName[bufSize];
	WCHAR  szModName[bufSize];
	char buffer[bufSize];
	char module[bufSize];
	char log[bufSize];

	strcpy(module, "Unknown");

	LPCBYTE  *ppBaseLoadAddress = NULL;
	AssemblyID *pAssemblyId = NULL;

	if (SUCCEEDED(hrStatus)){

		HRESULT hr = pInfo->GetAssemblyInfo(assemblyID, bufSize, &length, szName, &appdomId, &moduleID);

		if (SUCCEEDED(hr)) {

			HRESULT hrmod = pInfo->GetModuleInfo(moduleID, ppBaseLoadAddress, bufSize, &length, szModName, pAssemblyId);

			if (SUCCEEDED(hrmod)) wcstombs(module, szModName, bufSize - 40);
			wcstombs(buffer, szName, bufSize - 30);
			int offset = strlen(buffer) - 4;
			sprintf(log, "Loading Assembly: %s at module:%s", buffer, module);
			logData(log);


		}

	}
	return S_OK;
}




// Callback called when module loaded, logs module name
STDMETHODIMP CProfiler::ModuleLoadFinished(ModuleID moduleId, HRESULT  hrStatus) {

	LPCBYTE  *ppBaseLoadAddress = NULL;
	ULONG    cchName = 200;
	WCHAR  szName[bufSize];
	AssemblyID *pAssemblyId = NULL;
	char buffer[bufSize];
	console = GetStdHandle(STD_OUTPUT_HANDLE);
	memset(buffer, 0, bufSize);
	ULONG length = 0;
	DWORD written = 0;
	char answer[largeBufSize];

	HRESULT hr = pInfo->GetModuleInfo(moduleId, ppBaseLoadAddress, cchName, &length, szName, pAssemblyId);

	if (SUCCEEDED(hr)) {
		wcstombs(buffer, szName, bufSize - 30);
		sprintf(answer, "Loading Module: %s  BASE:%d", buffer, ppBaseLoadAddress);
		logData(answer);

		return S_OK;
	}

	return E_FAIL;

}


// Callback called when function leaved, gets information about function which where previously stored in Callback FunctionEnter2
void __stdcall FunctionLeave2Impl(FunctionID funcId, UINT_PTR clientData, COR_PRF_FRAME_INFO frameInfo, COR_PRF_FUNCTION_ARGUMENT_RANGE  *retvalRange){

	char answer[bufSize];
	char temp[bufSize];
	DWORD written = 0;
	bool param = false;
	char buffer1[bufSize];

	unordered_multimap<FunctionID, FunctionDataInfo*>::const_iterator it = functions.find(funcId);

	if (funcid_)  sprintf(buffer1, "%d Leaving Function ERROR", funcId);
	else sprintf(buffer1, "Leaving Function ERROR");

	if (!functions.empty() && it == functions.end()) { logData(buffer1, false); return; }

	Caller caller = currentCaller.top();
	currentCaller.pop();

	FunctionDataInfo* finfo = it->second;

	if (finfo->pReturnInfo == NULL)  {
		eraseCaller(funcId);
		finfo->release();
		delete finfo;
		return;
	}

	// filter MainModule and exception
	if (currmodule_ && (strstr(caller.name, ".exe") == NULL) || (exceptions_ && strstr(finfo->className, "Exception") != NULL)) {
		eraseCaller(funcId);
		finfo->release();
		delete finfo;
		return;
	}

	// filter MCLRlib
	if (system_ && strcmp(finfo->moduleName, "CommonLanguageRuntimeLibrary") == 0) {
		eraseCaller(funcId);
		finfo->release();
		delete finfo;
		return;

	}

	// 
	if (funcid_)  sprintf(buffer1, "%d Leaving Function ", funcId);
	else sprintf(buffer1, "Leaving Function ");

	strcpy(answer, buffer1);
	if (finfo->isModuleNameEnabled()) { strcat(answer, finfo->moduleName); strcat(answer, "->"); }
	if (finfo->isClassNameEnabled()) { strcat(answer, finfo->className); strcat(answer, ":"); }
	strcat(answer, finfo->methodName);
	strcat(answer, " (ReturnValue:");

	// parse return value type
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "String") == 0) param = finfo->parseString((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "byte[]") == 0) param = finfo->parseByteArray((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "int[]") == 0) param = finfo->parseIntArray((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "char[]") == 0) param = finfo->parseCharArray((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "int*") == 0) param = finfo->parseIntPointer((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "int") == 0) param = finfo->parseInt((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "byte") == 0) param = finfo->parseByte((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "char") == 0) param = finfo->parseChar((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "bool") == 0) param = finfo->parseBool((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "char*") == 0) param = finfo->parseCharPointer((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "uint") == 0) param = finfo->parseUint((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "int8") == 0) param = finfo->parseInt8((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "short") == 0) param = finfo->parseShort((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "ushort") == 0) param = finfo->parseUshort((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "float") == 0) param = finfo->parseFloat((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "double") == 0) param = finfo->parseDouble((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "int64") == 0) param = finfo->parseInt64((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "uint64") == 0) param = finfo->parseUint64((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "uint*") == 0) param = finfo->parseUintPointer((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "double*") == 0) param = finfo->parseDoublePointer((*retvalRange).startAddress, temp);
	if (!param && strcmp(finfo->pReturnInfo->valueTypeName, "void") == 0) { strcpy(temp, ""); param = true; }
	if (param) {

		strcat(answer, finfo->pReturnInfo->valueTypeName);
		if (finfo->isVarValueEnabled()){
			if (strcmp(finfo->pReturnInfo->valueTypeName, "void") != 0) strcat(answer, "=");
			strcat(answer, temp);
		}
	}

	// remove caller (this function is returning, nothing will be called from this function)
	eraseCaller(funcId);
	finfo->release();
	delete finfo;

	strcat(answer, ")");
	logData(answer, consoleEnabled_);

	return;
}


// Callback called when entering to function, gets information about funciton: name, params, values and stores it in the structure
void __stdcall FunctionEnter2Impl(FunctionID funcId, UINT_PTR clientData, COR_PRF_FRAME_INFO frameInfo, COR_PRF_FUNCTION_ARGUMENT_INFO  *argumentInfo) {

	int bwritten = 0;
	LPDWORD written = (LPDWORD)&bwritten;
	HRESULT hr = S_OK;

	char buffer1[bufSize];
	char methodName[bufSize];
	char methodModuleName[bufSize];
	char buffer[bufSize];
	char answer[largeBufSize];

	memset(methodName, 0, sizeof(methodName));
	memset(buffer, 0, sizeof(buffer));
	memset(answer, 0, sizeof(largeBufSize));



	FunctionDataInfo *finfo = new FunctionDataInfo(funcId, clientData, frameInfo, argumentInfo);

	// Save function info to hash map
	functions.insert({ funcId, finfo });


	finfo->setClassName(classnames_);
	finfo->setModuleName(modulenames_);
	finfo->setParamName(varnames_);
	finfo->setTypeName(typenames_);
	finfo->setVarName(varnames_);
	finfo->setVarValue(varvalues_);

	if (finfo == NULL) {

		logData("Error function data");
		return;
	}


	if (funcid_)  sprintf(buffer1, "%d Entering ", funcId);
	else sprintf(buffer1, "Entering ");

	strcat(answer, buffer1);


	hr = finfo->getModuleName(methodModuleName);

	if (SUCCEEDED(hr)){

		if (finfo->isModuleNameEnabled()) {
			strcat(answer, methodModuleName);
			strcat(answer, "->");
		}

	}
	else { logData("Error retreiving module name"); }


	hr = finfo->getMethodNameAndClassType(methodName);


	if (!SUCCEEDED(hr)) { logData("Error retreiving method name and class type"); }

	if (finfo->isClassNameEnabled()) {
		hr = finfo->getClassName(buffer);


		if (SUCCEEDED(hr)){

			strcat(answer, buffer);
			strcat(answer, ":");

		}
		else { logData("Error retreiving class name"); }

	}


	strcat(answer, methodName);


	Caller caller;
	strcpy(caller.name, "");
	Caller nextCaller;

	// get my caller
	if (!currentCaller.empty()) caller = currentCaller.top();

	// save next caller (myself)
	strcpy(nextCaller.name, methodModuleName);
	currentCaller.push(nextCaller);


	// filter MainModule and Exception
	if (currmodule_ && (strstr(caller.name, ".exe") == NULL) || (exceptions_ && strstr(finfo->className, "Exception") != NULL)) { return; }

	// filter MCLRlib
	if (system_ && strcmp(finfo->moduleName, "CommonLanguageRuntimeLibrary") == 0) return;

	// parse function arguments
	hr = finfo->getArguments(buffer);

	if (SUCCEEDED(hr)){

		strcat(answer, buffer);
	}
	else { logData("Error retreiving method arguments"); }

	// log information about entered function
	logData(answer, consoleEnabled_);

	return;
}


// called when function  does not return. but jumps to another function
void __stdcall FunctionTailCall2Impl(FunctionID funcId, UINT_PTR clientData, COR_PRF_FRAME_INFO frameInfo) {

	FunctionDataInfo *finfo;
	finfo = new FunctionDataInfo(funcId, clientData, frameInfo, NULL);

	// Add to hash map
	functions.insert({ funcId, finfo });


}

void __declspec(naked)  FunctionEnter2_(FunctionID funcId, UINT_PTR clientData, COR_PRF_FRAME_INFO func, COR_PRF_FUNCTION_ARGUMENT_INFO  *argumentInfo){

	      // function is defined as naked, so we must create epilog and prolog for function

	__asm{
		    push ebp		// save ebp
			mov ebp, esp	// save stack base
			pushad			// save registers 
			mov eax, [ebp + 0x14] // argumentInfo
			push eax
			mov eax, [ebp + 0x10]	// func
			push eax
			mov eax, [ebp + 0xC]	// client data
			push eax
			mov eax, [ebp + 0x8]	// funcId
			push eax
			call FunctionEnter2Impl
			popad	// recover saved registers
			pop ebp // recover ebp
			ret 0x10	// return 
	}
}




void __declspec(naked)  FunctionLeave2_(FunctionID funcId, UINT_PTR clientData, COR_PRF_FRAME_INFO func, COR_PRF_FUNCTION_ARGUMENT_RANGE  *retvalRange){

	

	__asm{
		    push ebp		
			mov ebp, esp
			pushad		
			mov eax, [ebp + 0x14]	// retvalRange
			push eax
			mov eax, [ebp + 0x10]	// func
			push eax
			mov eax, [ebp + 0xC]	// client data
			push eax
			mov eax, [ebp + 0x8]	// funcId
			push eax
			call FunctionLeave2Impl
			popad		
			pop ebp		
			ret 0x10	
	}
}


void __declspec(naked)  FunctionTailcall2_(FunctionID funcId, UINT_PTR clientData, COR_PRF_FRAME_INFO frameInfo) {

	__asm{
		push ebp		
			mov ebp, esp
			pushad		
			mov eax, [ebp + 0x14]	// retvalRange
			push eax
			mov eax, [ebp + 0x10]	// func
			push eax
			mov eax, [ebp + 0xC]	// client data
			push eax
			mov eax, [ebp + 0x8]	// funcId
			push eax
			call FunctionTailCall2Impl
			popad		
			pop ebp		
			ret 0x10	
	}


}


/* this method is called by CLR before starting application wich will be profiled
   here is the place where to setup everything before profiling.	
 */
STDMETHODIMP CProfiler::Initialize(IUnknown *pICorProfilerInfoUnk) {



	WCHAR data[5];
	// Log File
	LPCWSTR Filename = L"Output.txt";
	WCHAR wideModule[largeBufSize];
	memset(wideModule, 0, largeBufSize);
	memset(moduleFullName, 0, bufSize);
	memset(moduleName, 0, bufSize);

	// Output Log File
	hFile = CreateFile(Filename, GENERIC_WRITE, 0, NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);



	// Get Environment variables set by GUI application
	if (GetEnvironmentVariable(L"MODULE_", wideModule, largeBufSize)) {

		wcstombs(moduleFullName, wideModule, bufSize);
		char* pointer = strrchr(moduleFullName, '\\');
		strcpy(moduleName, pointer + 1);
	}

	// delete environment variables after checking
	if (GetEnvironmentVariable(L"ANTIDEBUG_", data, 5)) {

		if (wcscmp(data, L"1") == 0) antidebug_ = true;

	}

	// class names in log
	if (GetEnvironmentVariable(L"CLASSNAMES_", data, 5)) {

		if (wcscmp(data, L"1") == 0) classnames_ = true;

	}

	// only functions called from .exe modules
	if (GetEnvironmentVariable(L"CURRMODULE_", data, 5)) {

		if (wcscmp(data, L"1") == 0) currmodule_ = true;

	}

	if (GetEnvironmentVariable(L"SYSTEM_", data, 5)) {

		if (wcscmp(data, L"1") == 0) system_ = true;

	}

	// names of argument
	if (GetEnvironmentVariable(L"VARNAMES_", data, 5)) {

		if (wcscmp(data, L"1") == 0) varnames_ = true;

	}

	// argument values
	if (GetEnvironmentVariable(L"VARVALUES_", data, 5)) {

		if (wcscmp(data, L"1") == 0) varvalues_ = true;

	}

	// mofule names in log
	if (GetEnvironmentVariable(L"MODULENAMES_", data, 5)) {

		if (wcscmp(data, L"1") == 0) modulenames_ = true;

	}

	// type names in log
	if (GetEnvironmentVariable(L"TYPENAMES_", data, 5)) {

		if (wcscmp(data, L"1") == 0) typenames_ = true;

	}

	// enable inlining
	if (GetEnvironmentVariable(L"INLINE_", data, 5)) {

		if (wcscmp(data, L"1") == 0) inline_ = true;

	}

	// log to console
	if (GetEnvironmentVariable(L"CONSOLE_", data, 5)) {

		if (wcscmp(data, L"1") == 0) consoleEnabled_ = true;

	}

	// filter exceptions
	if (GetEnvironmentVariable(L"EXCEPTIONS_", data, 5)) {

		if (wcscmp(data, L"1") == 0) exceptions_ = true;

	}

	// Function Id in log
	if (GetEnvironmentVariable(L"FUNCID_", data, 5)) {

		if (wcscmp(data, L"1") == 0) funcid_ = true;

	}


	// get ICorProfilerInfo2 interface
	HRESULT hr = pICorProfilerInfoUnk->QueryInterface(IID_ICorProfilerInfo2, (void**)&pInfo);

	AllocConsole();

	freopen("CONIN$", "r", stdin);
	freopen("CONOUT$", "w", stdout);
	freopen("CONOUT$", "w", stderr);


	if (!SUCCEEDED(hr)) {
		int x;
		printf("ERROR Interface \r\n");
		scanf("%d", &x);
		return E_FAIL;

	}


	// set event mask telling CLR about which events profiler will be informed  

	DWORD events = COR_PRF_MONITOR_ASSEMBLY_LOADS | COR_PRF_MONITOR_MODULE_LOADS | COR_PRF_MONITOR_ENTERLEAVE | COR_PRF_ENABLE_FUNCTION_RETVAL | COR_PRF_ENABLE_FUNCTION_ARGS | COR_PRF_ENABLE_FRAME_INFO;

	// if inline enabled, set it
	if (inline_ == false) events |= COR_PRF_DISABLE_INLINING;

	// set mask
	pInfo->SetEventMask(events);

	// set callbacks which will be called when event occures
	pInfo->SetEnterLeaveFunctionHooks2((FunctionEnter2*)&FunctionEnter2_, (FunctionLeave2*)&FunctionLeave2_, (FunctionTailcall2*)&FunctionTailcall2_);


	// destroy environment variables - antidebug

	if (antidebug_) {
		SetEnvironmentVariable(L"COR_PROFILER", NULL);
		SetEnvironmentVariable(L"COR_ENABLE_PROFILING", NULL);
		SetEnvironmentVariable(L"SYSTEM_", NULL);
		SetEnvironmentVariable(L"TYPENAMES_", NULL);
		SetEnvironmentVariable(L"CLASSNAMES_", NULL);
		SetEnvironmentVariable(L"MODULENAMES_", NULL);
		SetEnvironmentVariable(L"VARNAMES_", NULL);
		SetEnvironmentVariable(L"VARVALUES_", NULL);
		SetEnvironmentVariable(L"ANTIDEBUG_", NULL);
		SetEnvironmentVariable(L"INLINE_", NULL);
		SetEnvironmentVariable(L"CONSOLE_", NULL);
		SetEnvironmentVariable(L"MODULE_", NULL);
		SetEnvironmentVariable(L"CURRMODULE_", NULL);
		SetEnvironmentVariable(L"FUNCID_", NULL);
		SetEnvironmentVariable(L"EXCEPTIONS_", NULL);
		SetEnvironmentVariable(L"complus_profapi_profilercompatibilitysetting", NULL);

	}

	return  S_OK;
}