////////////////////////////////////////////////////////////////////////////////
//
// EnforceFilter.cpp 
//
// Author: Oleg Starodumov (www.debuginfo.com)
//


////////////////////////////////////////////////////////////////////////////////
//
// This example demonstrates how to ensure that nobody else can overwrite 
// our custom filter for unhandler exceptions. This is achieved by patching 
// SetUnhandledExceptionFilter function. 
// 
//


////////////////////////////////////////////////////////////////////////////////
// Include files 
//

#include <windows.h>
#include <tchar.h>
#include <crtdbg.h>
#include <stdio.h>


////////////////////////////////////////////////////////////////////////////////
// Function declarations
//

	// Custom filters 
	// 
	// Each of these filters prints a message with the name of the filter function, 
	// which allows us to identify the filter that was registered at the moment 
	// when the exception has been thrown
	// 
LONG __stdcall FirstFilter( EXCEPTION_POINTERS* pep ); 
LONG __stdcall SecondFilter( EXCEPTION_POINTERS* pep ); 
LONG __stdcall ThirdFilter( EXCEPTION_POINTERS* pep ); 

	// EnforceFilter function 
	// 
	// If bEnforce is "true", the function overwrites the beginning of 
	// SetUnhandledExceptionFilter function with a patch which rejects  
	// all subsequent attempts to register a filter.
	// If bEnforce is "false", the original functionality of 
	// SetUnhandledExceptionFilter is restored. 
	// 
bool EnforceFilter( bool bEnforce );

	// WriteMemory function 
	// 
	// This function writes the specified sequence of bytes from 
	// the source memory into the target memory. In addition, the function 
	// modifies virtual memory protection attributes of the target memory page 
	// to make sure that it is writeable.
	// 
bool WriteMemory( BYTE* pTarget, const BYTE* pSource, DWORD Size );

	// Test function 
void FaultyFunc(); 


////////////////////////////////////////////////////////////////////////////////
// Global variables 
//

	// Patch for SetUnhandledExceptionFilter 
const BYTE PatchBytes[5] = { 0x33, 0xC0, 0xC2, 0x04, 0x00 };

	// Original bytes at the beginning of SetUnhandledExceptionFilter 
BYTE OriginalBytes[5] = {0};


////////////////////////////////////////////////////////////////////////////////
// main() function
//

int main() 
{
	// Register the first filter 

	SetUnhandledExceptionFilter( FirstFilter ); 


	// Patch the beginning of SetUnhandledExceptionFunction. 
	// It will ensure that nobody else can register its own filter 
	// for unhandled exceptions and overwrite our filter (FirstFilter)

	if( !EnforceFilter( true ) )
	{
		_tprintf( _T("EnforceFilter(true) failed.\n") );
		return 0;
	}


	// Register other filters (since SetUnhandledExceptionFilter is patched, 
	// registration will have no effect, and FirstFilter will remain 
	// the currently registered filter 

	SetUnhandledExceptionFilter( SecondFilter ); 
	SetUnhandledExceptionFilter( ThirdFilter ); 


	// Load msvcrt.dll to make it register its own custom filter 
	// for unhandled exceptions (_CxxUnhandledExceptionFilter)
	// 

	HMODULE hLib = LoadLibrary( _T("msvcrt.dll") );

	if( hLib == NULL )
		_tprintf( _T("LoadLibrary(msvcrt.dll) failed. Error: %u\n"), GetLastError() );


	// Simulate an exception 

		// Note: If FaultyFunc throws an unhandled C++ exception, 
		// and msvcrt!_CxxUnhandledExceptionFilter is the currently registered 
		// filter, the application will terminate with CRT runtime error. 
		// This will not happen if we used EnforceFilter(true) to ensure 
		// that our filter is always registered.

	FaultyFunc();


	// Restore the functionality of SetUnhandledExceptionFilter 

	if( !EnforceFilter( false ) )
	{
		_tprintf( _T("EnforceFilter(false) failed.\n") );
		return 0;
	}


	// Free msvcrt.dll

	if( hLib != NULL )
		if( !FreeLibrary( hLib ) )
			_tprintf( _T("FreeLibrary(msvcrt.dll) failed. Error: %u\n"), GetLastError() );

	return 0; 

}


////////////////////////////////////////////////////////////////////////////////
// Custom filters for unhandled exceptions
// 
// Each of these filters calls the previously registered filter 
//

LONG __stdcall FirstFilter( EXCEPTION_POINTERS* pep ) 
{
	_tprintf( _T("FirstFilter()...\n") );

	return EXCEPTION_EXECUTE_HANDLER; 
}

LONG __stdcall SecondFilter( EXCEPTION_POINTERS* pep ) 
{
	_tprintf( _T("SecondFilter()...\n") );

	return EXCEPTION_EXECUTE_HANDLER; 
}

LONG __stdcall ThirdFilter( EXCEPTION_POINTERS* pep ) 
{
	_tprintf( _T("ThirdFilter()...\n") );

	return EXCEPTION_EXECUTE_HANDLER; 
}


////////////////////////////////////////////////////////////////////////////////
// EnforceFilter function 
// 

bool EnforceFilter( bool bEnforce )
{
	DWORD ErrCode = 0;

	
	// Obtain the address of SetUnhandledExceptionFilter 

	HMODULE hLib = GetModuleHandle( _T("kernel32.dll") );

	if( hLib == NULL )
	{
		ErrCode = GetLastError();
		_ASSERTE( !_T("GetModuleHandle(kernel32.dll) failed.") );
		return false;
	}

	BYTE* pTarget = (BYTE*)GetProcAddress( hLib, "SetUnhandledExceptionFilter" );

	if( pTarget == 0 )
	{
		ErrCode = GetLastError();
		_ASSERTE( !_T("GetProcAddress(SetUnhandledExceptionFilter) failed.") );
		return false;
	}

	if( IsBadReadPtr( pTarget, sizeof(OriginalBytes) ) )
	{
		_ASSERTE( !_T("Target is unreadable.") );
		return false;
	}


	if( bEnforce )
	{
		// Save the original contents of SetUnhandledExceptionFilter 

		memcpy( OriginalBytes, pTarget, sizeof(OriginalBytes) );


		// Patch SetUnhandledExceptionFilter 

		if( !WriteMemory( pTarget, PatchBytes, sizeof(PatchBytes) ) )
			return false;

	}
	else
	{
		// Restore the original behavior of SetUnhandledExceptionFilter 

		if( !WriteMemory( pTarget, OriginalBytes, sizeof(OriginalBytes) ) )
			return false;

	}


	// Success 

	return true;

}


////////////////////////////////////////////////////////////////////////////////
// WriteMemory function 
// 

bool WriteMemory( BYTE* pTarget, const BYTE* pSource, DWORD Size )
{
	DWORD ErrCode = 0;


	// Check parameters 

	if( pTarget == 0 )
	{
		_ASSERTE( !_T("Target address is null.") );
		return false;
	}

	if( pSource == 0 )
	{
		_ASSERTE( !_T("Source address is null.") );
		return false;
	}

	if( Size == 0 )
	{
		_ASSERTE( !_T("Source size is null.") );
		return false;
	}

	if( IsBadReadPtr( pSource, Size ) )
	{
		_ASSERTE( !_T("Source is unreadable.") );
		return false;
	}


	// Modify protection attributes of the target memory page 

	DWORD OldProtect = 0;

	if( !VirtualProtect( pTarget, Size, PAGE_EXECUTE_READWRITE, &OldProtect ) )
	{
		ErrCode = GetLastError();
		_ASSERTE( !_T("VirtualProtect() failed.") );
		return false;
	}


	// Write memory 

	memcpy( pTarget, pSource, Size );


	// Restore memory protection attributes of the target memory page 

	DWORD Temp = 0;

	if( !VirtualProtect( pTarget, Size, OldProtect, &Temp ) )
	{
		ErrCode = GetLastError();
		_ASSERTE( !_T("VirtualProtect() failed.") );
		return false;
	}


	// Success 

	return true;

}


////////////////////////////////////////////////////////////////////////////////
// Faulty function
//

void FaultyFunc()
{
	_tprintf( _T("We will crash now...\n") );

	throw 1;
}

