/**
 * win32.cpp
 *  - provides support for Windows clients
 *
 **/
#include "animia/fd/win32.h"

#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>

#include <fileapi.h>
#include <handleapi.h>
#include <libloaderapi.h>
#include <ntdef.h>
#include <psapi.h>
#include <shlobj.h>
#include <stringapiset.h>
#include <tlhelp32.h>
#include <windows.h>
#include <winternl.h>

/* This file is noticably more complex than Unix and Linux, and that's because
   there is no "simple" way to get the paths of a file. In fact, this thing requires
   you to use *internal functions* that can't even be linked to, hence why we have to
   use GetProcAddress and such. What a mess. */

#define SystemExtendedHandleInformation ((SYSTEM_INFORMATION_CLASS)0x40)
constexpr NTSTATUS STATUS_INFO_LENGTH_MISMATCH = 0xC0000004UL;
constexpr NTSTATUS STATUS_SUCCESS = 0x00000000UL;

static unsigned short file_type_index = 0;

struct SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX {
		PVOID Object;
		ULONG_PTR UniqueProcessId;
		HANDLE HandleValue;
		ACCESS_MASK GrantedAccess;
		USHORT CreatorBackTraceIndex;
		USHORT ObjectTypeIndex;
		ULONG HandleAttributes;
		ULONG Reserved;
};

struct SYSTEM_HANDLE_INFORMATION_EX {
		ULONG_PTR NumberOfHandles;
		ULONG_PTR Reserved;
		SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX Handles[1];
};

namespace animia::internal::win32 {

static HANDLE DuplicateHandle(HANDLE process_handle, HANDLE handle) {
	HANDLE dup_handle = nullptr;
	const bool result =
	    ::DuplicateHandle(process_handle, handle, ::GetCurrentProcess(), &dup_handle, 0, false, DUPLICATE_SAME_ACCESS);
	return result ? dup_handle : nullptr;
}

static PVOID GetNTDLLAddress(LPCSTR proc_name) {
	return reinterpret_cast<PVOID>(::GetProcAddress(::GetModuleHandleA("ntdll.dll"), proc_name));
}

static NTSTATUS QuerySystemInformation(SYSTEM_INFORMATION_CLASS cls, PVOID sysinfo, ULONG len, PULONG retlen) {
	static const auto func =
	    reinterpret_cast<decltype(::NtQuerySystemInformation)*>(GetNTDLLAddress("NtQuerySystemInformation"));
	return func(cls, sysinfo, len, retlen);
}

static NTSTATUS QueryObject(HANDLE handle, OBJECT_INFORMATION_CLASS cls, PVOID objinf, ULONG objinflen, PULONG retlen) {
	static const auto func = reinterpret_cast<decltype(::NtQueryObject)*>(GetNTDLLAddress("NtQueryObject"));
	return func(handle, cls, objinf, objinflen, retlen);
}

static std::vector<SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX> GetSystemHandleInformation() {
	std::vector<SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX> res;
	/* we should really put a cap on this */
	ULONG cb = 1 << 19;
	NTSTATUS status = STATUS_SUCCESS;
	SYSTEM_HANDLE_INFORMATION_EX* info;

	do {
		status = STATUS_NO_MEMORY;

		if (!(info = (SYSTEM_HANDLE_INFORMATION_EX*)malloc(cb *= 2)))
			continue;

		res.reserve(cb);

		if (0 <= (status = QuerySystemInformation(SystemExtendedHandleInformation, info, cb, &cb))) {
			if (ULONG_PTR handles = info->NumberOfHandles) {
				SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX* entry = info->Handles;
				do {
					if (entry)
						res.push_back(*entry);
				} while (entry++, --handles);
			}
		}
		free(info);
	} while (status == STATUS_INFO_LENGTH_MISMATCH);

	return res;
}

static OBJECT_TYPE_INFORMATION QueryObjectTypeInfo(HANDLE handle) {
	OBJECT_TYPE_INFORMATION info;
	QueryObject(handle, ObjectTypeInformation, &info, sizeof(info), NULL);
	return info;
}

/* we're using UTF-8. originally, I had used just the ANSI versions of functions, but that
   sucks massive dick. this way we get unicode in the way every single other OS does it */
static std::string UnicodeStringToUtf8(std::wstring string) {
	unsigned long size = ::WideCharToMultiByte(CP_UTF8, 0, string.c_str(), string.length(), NULL, 0, NULL, NULL);
	std::string ret = std::string(size, '\0');
	::WideCharToMultiByte(CP_UTF8, 0, string.c_str(), string.length(), &ret.front(), ret.length(), NULL, NULL);
	return ret;
}

static std::string UnicodeStringToUtf8(UNICODE_STRING string) {
	unsigned long size = ::WideCharToMultiByte(CP_UTF8, 0, string.Buffer, string.Length, NULL, 0, NULL, NULL);
	std::string ret = std::string(size, '\0');
	::WideCharToMultiByte(CP_UTF8, 0, string.Buffer, string.Length, &ret.front(), ret.length(), NULL, NULL);
	return ret;
}

static std::wstring Utf8StringToUnicode(std::string string) {
	unsigned long size = ::MultiByteToWideChar(CP_UTF8, 0, string.c_str(), string.length(), NULL, 0);
	std::wstring ret = std::wstring(size, L'\0');
	::MultiByteToWideChar(CP_UTF8, 0, string.c_str(), string.length(), &ret.front(), ret.length());
	return ret;
}

static std::string GetHandleType(HANDLE handle) {
	OBJECT_TYPE_INFORMATION info = QueryObjectTypeInfo(handle);
	return UnicodeStringToUtf8(info.TypeName);
}

static std::string GetFinalPathNameByHandle(HANDLE handle) {
	std::wstring buffer;

	int result = ::GetFinalPathNameByHandleW(handle, NULL, 0, FILE_NAME_NORMALIZED | VOLUME_NAME_DOS);
	buffer.resize(result);
	::GetFinalPathNameByHandleW(handle, &buffer.front(), buffer.size(), FILE_NAME_NORMALIZED | VOLUME_NAME_DOS);
	buffer.resize(buffer.find('\0'));

	return UnicodeStringToUtf8(buffer);
}

static bool IsFileHandle(HANDLE handle, unsigned short object_type_index) {
	if (file_type_index)
		return object_type_index == file_type_index;
	else if (!handle)
		return true;
	else if (GetHandleType(handle) == "File") {
		file_type_index = object_type_index;
		return true;
	}
	return false;
}

static bool IsFileMaskOk(ACCESS_MASK access_mask) {
	if (!(access_mask & FILE_READ_DATA))
		return false;

	if ((access_mask & FILE_APPEND_DATA) || (access_mask & FILE_WRITE_EA) || (access_mask & FILE_WRITE_ATTRIBUTES))
		return false;

	return true;
}

static std::string GetSystemDirectory() {
	PWSTR path_wch;
	SHGetKnownFolderPath(FOLDERID_Windows, 0, NULL, &path_wch);
	std::wstring path_wstr(path_wch);
	CoTaskMemFree(path_wch);
	return UnicodeStringToUtf8(path_wstr);
}

static bool IsSystemFile(const std::string& path) {
	std::wstring path_w = Utf8StringToUnicode(path);
	CharUpperBuffW(&path_w.front(), path_w.length());
	std::wstring windir_w = Utf8StringToUnicode(GetSystemDirectory());
	CharUpperBuffW(&windir_w.front(), windir_w.length());
	return path_w.find(windir_w) == 4;
}

static bool IsFilePathOk(const std::string& path) {
	if (path.empty())
		return false;

	if (IsSystemFile(path))
		return false;

	const auto file_attributes = GetFileAttributesA(path.c_str());
	if ((file_attributes == INVALID_FILE_ATTRIBUTES) || (file_attributes & FILE_ATTRIBUTE_DIRECTORY))
		return false;

	return true;
}

bool Win32FdTools::GetAllPids(std::set<pid_t>& pids) {
	HANDLE hProcessSnap = ::CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);
	if (hProcessSnap == INVALID_HANDLE_VALUE)
		return false;

	PROCESSENTRY32 pe32;
	pe32.dwSize = sizeof(PROCESSENTRY32);

	if (!::Process32First(hProcessSnap, &pe32))
		return false;

	pids.insert(pe32.th32ProcessID);

	while (::Process32Next(hProcessSnap, &pe32))
		pids.insert(pe32.th32ProcessID);

	::CloseHandle(hProcessSnap);

	return true;
}

bool Win32FdTools::GetProcessName(pid_t pid, std::string& result) {
	unsigned long ret_size = 0; // size given by GetModuleBaseNameW
	Handle handle(::OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, FALSE, pid));
	if (handle.get() == INVALID_HANDLE_VALUE)
		return false;

	/* agh... */
	std::wstring ret(256, L'\0');
	for (; ret.length() < 32768; ret.resize(ret.length() * 2)) {
		if (!(ret_size = ::GetModuleBaseNameW(handle.get(), 0, &ret.front(), ret.length()))) {
			return false;
		} else if (ret.length() > ret_size) {
			ret.resize(ret.find(L'\0'));
			result = UnicodeStringToUtf8(ret);
			break;
		}
	}

	return true;
}

/* this could be changed to being a callback, but... I'm too lazy right now :) */
bool Win32FdTools::EnumerateOpenFiles(const std::set<pid_t>& pids, std::vector<OpenFile>& files) {
	std::unordered_map<pid_t, Handle> proc_handles;

	for (const pid_t& pid : pids) {
		const HANDLE handle = ::OpenProcess(PROCESS_DUP_HANDLE, false, pid);
		if (handle != INVALID_HANDLE_VALUE)
			proc_handles[pid] = Handle(handle);
	}

	if (proc_handles.empty())
		return false;

	std::vector<SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX> info = GetSystemHandleInformation();

	for (const auto& h : info) {
		const pid_t pid = h.UniqueProcessId;
		if (!pids.count(pid))
			continue;

		if (!IsFileHandle(nullptr, h.ObjectTypeIndex))
			continue;

		if (!IsFileMaskOk(h.GrantedAccess))
			continue;

		Handle handle(DuplicateHandle(proc_handles[pid].get(), h.HandleValue));
		if (handle.get() == INVALID_HANDLE_VALUE)
			continue;

		if (GetFileType(handle.get()) != FILE_TYPE_DISK)
			continue;

		const std::string path = GetFinalPathNameByHandle(handle.get());
		if (!IsFilePathOk(path))
			continue;

		OpenFile file;
		file.pid = pid;
		file.path = path;

		files.push_back(file);
	}

	return true;
}

} // namespace animia::internal::win32
