/*
 * fd/win32.cc: support for windows
 *
 * this file is noticably more complex than *nix, and that's because
 * there is no "simple" way to get the paths of a file. In fact, this thing requires
 * you to use *internal functions* that can't even be linked to, hence why we have to
 * use GetProcAddress and such. what a mess.
 *
 * Speaking of which, because this file uses internal functions of the OS, it is not
 * even guaranteed to work far into the future. however, just like with macOS, these
 * things have stayed the same since Vista so if Microsoft *really* wants compatibility
 * then they're pretty much forced to keeping this the same anyway.
 */
#include "animone/fd/win32.h"
#include "animone.h"
#include "animone/util/win32.h"

#include <stdexcept>
#include <string>
#include <unordered_map>
#include <vector>

#include <fileapi.h>
#include <handleapi.h>
#include <libloaderapi.h>
#include <ntdef.h>
#include <psapi.h>
#include <shlobj.h>
#include <stringapiset.h>
#include <tlhelp32.h>
#include <windows.h>
#include <winternl.h>

/* SystemExtendedHandleInformation is only available in NT 5.1+ (XP and higher) and provides information for
 * 32-bit PIDs, unlike SystemHandleInformation */
static constexpr SYSTEM_INFORMATION_CLASS SystemExtendedHandleInformation = static_cast<SYSTEM_INFORMATION_CLASS>(0x40);
static constexpr NTSTATUS STATUS_INFO_LENGTH_MISMATCH = 0xC0000004UL;

struct SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX {
	PVOID Object;
	ULONG_PTR UniqueProcessId;
	HANDLE HandleValue;
	ACCESS_MASK GrantedAccess;
	USHORT CreatorBackTraceIndex;
	USHORT ObjectTypeIndex;
	ULONG HandleAttributes;
	ULONG Reserved;
};

struct SYSTEM_HANDLE_INFORMATION_EX {
	ULONG_PTR NumberOfHandles;
	ULONG_PTR Reserved;
	SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX Handles[1];
};

namespace animone::internal::win32 {

class Ntdll {
public:
	Ntdll() {
		ntdll = ::GetModuleHandleW(L"ntdll.dll");
		nt_query_system_information = reinterpret_cast<decltype(::NtQuerySystemInformation)*>(
		    ::GetProcAddress(ntdll, "NtQuerySystemInformation"));
		nt_query_object = reinterpret_cast<decltype(::NtQueryObject)*>(::GetProcAddress(ntdll, "NtQueryObject"));
	}

	NTSTATUS QuerySystemInformation(SYSTEM_INFORMATION_CLASS cls, PVOID sysinfo, ULONG len,
	                                PULONG retlen){
		return nt_query_system_information(cls, sysinfo, len, retlen);
	}

	NTSTATUS QueryObject(HANDLE handle, OBJECT_INFORMATION_CLASS cls, PVOID objinf, ULONG objinflen, PULONG retlen) {
		return nt_query_object(handle, cls, objinf, objinflen, retlen);
	}

private:
	HMODULE ntdll;
	decltype(::NtQuerySystemInformation)* nt_query_system_information;
	decltype(::NtQueryObject)* nt_query_object;
};

Ntdll ntdll;

static HANDLE DuplicateHandle(HANDLE process_handle, HANDLE handle) {
	HANDLE dup_handle = nullptr;
	const bool result =
	    ::DuplicateHandle(process_handle, handle, ::GetCurrentProcess(), &dup_handle, 0, false, DUPLICATE_SAME_ACCESS);
	return result ? dup_handle : nullptr;
}

static std::vector<SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX> GetSystemHandleInformation() {
	/* we should really put a cap on this */
	ULONG cb = 1 << 19;
	NTSTATUS status = STATUS_NO_MEMORY;
	std::unique_ptr<SYSTEM_HANDLE_INFORMATION_EX> info;

	do {
		info.reset(reinterpret_cast<SYSTEM_HANDLE_INFORMATION_EX*>(malloc(cb *= 2)));
		if (!info)
			continue;

		status = ntdll.QuerySystemInformation(SystemExtendedHandleInformation, info.get(), cb, &cb);
	} while (status == STATUS_INFO_LENGTH_MISMATCH);

	if (!NT_SUCCESS(status))
		return {};

	std::vector<SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX> res;

	ULONG_PTR handles = info->NumberOfHandles;
	if (!handles)
		return {};

	res.reserve(handles);

	SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX* entry = info->Handles;
	do {
		if (entry)
			res.push_back(*(entry++));
	} while (--handles);

	return res;
}

static std::wstring GetHandleType(HANDLE handle) {
	OBJECT_TYPE_INFORMATION info = {0};
	ntdll.QueryObject(handle, ObjectTypeInformation, &info, sizeof(info), NULL);
	return std::wstring(info.TypeName.Buffer, info.TypeName.Length);
}

static std::wstring GetFinalPathNameByHandle(HANDLE handle) {
	std::wstring buffer;

	DWORD size = ::GetFinalPathNameByHandleW(handle, nullptr, 0, FILE_NAME_NORMALIZED | VOLUME_NAME_DOS);
	buffer.resize(size);
	::GetFinalPathNameByHandleW(handle, &buffer.front(), buffer.size(), FILE_NAME_NORMALIZED | VOLUME_NAME_DOS);

	return buffer;
}

/* ------------------------------------------------------------------- */

static bool GetSystemDirectory(std::wstring& str) {
	PWSTR path_wch;

	if (FAILED(::SHGetFolderPathW(NULL, CSIDL_WINDOWS, NULL, SHGFP_TYPE_CURRENT, &path_wch)))
		return false;

	str.assign(path_wch);

	::CoTaskMemFree(path_wch);
	return true;
}

static bool IsSystemDirectory(const std::string& path) {
	return IsSystemDirectory(ToWstring(path));
}

static bool IsSystemDirectory(std::wstring path) {
	std::wstring windir;
	if (!GetSystemDirectory(windir))
		return false;

	::CharUpperBuffW(&path.front(), path.length());
	::CharUpperBuffW(&windir.front(), windir.length());

	// XXX wtf is 4?
	return path.find(windir) == 4;
}

static bool IsFileHandle(HANDLE handle, unsigned short object_type_index) {
	/* this is filled in at runtime because it's not guaranteed to be (and isn't)
	 * constant between different versions of Windows */
	static std::optional<unsigned short> file_type_index;

	if (file_type_index.has_value()) {
		return object_type_index == file_type_index.value();
	} else if (!handle) {
		/* XXX what? */
		return true;
	} else if (GetHandleType(handle) == L"File") {
		file_type_index.reset(object_type_index);
		return true;
	}

	return false;
}

static bool IsFileMaskOk(ACCESS_MASK access_mask) {
	if (!(access_mask & FILE_READ_DATA))
		return false;

	if ((access_mask & FILE_APPEND_DATA) || (access_mask & FILE_WRITE_EA) || (access_mask & FILE_WRITE_ATTRIBUTES))
		return false;

	return true;
}

static bool IsFilePathOk(const std::wstring& path) {
	if (path.empty())
		return false;

	if (IsSystemDirectory(path))
		return false;

	const auto file_attributes = GetFileAttributesW(path.c_str());
	if ((file_attributes == INVALID_FILE_ATTRIBUTES) || (file_attributes & FILE_ATTRIBUTE_DIRECTORY))
		return false;

	return true;
}

/* ------------------------------------------------------------------- */

static std::string GetProcessPath(DWORD process_id) {
	// If we try to open a SYSTEM process, this function fails and the last error
	// code is ERROR_ACCESS_DENIED.
	//
	// Note that if we requested PROCESS_QUERY_INFORMATION access right instead
	// of PROCESS_QUERY_LIMITED_INFORMATION, this function would fail when used
	// to open an elevated process.
	Handle process_handle(::OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, FALSE, process_id));

	if (!process_handle)
		return std::wstring();

	std::wstring buffer(MAX_PATH, L'\0');
	DWORD buf_size = buffer.length();

	// Note that this function requires Windows Vista or above. You may use
	// GetProcessImageFileName or GetModuleFileNameEx on earlier versions.
	if (!::QueryFullProcessImageNameW(process_handle.get(), 0, &buffer.front(), &buf_size))
		return std::wstring();

	buffer.resize(buf_size);
	return ToUtf8String(buffer);
}

static std::string GetFilenameFromPath(const std::string& path) {
	const auto pos = path.find_last_of(L"/\\");
	return pos != std::wstring::npos ? path.substr(pos + 1) : path;
}

static bool VerifyProcessPath(const std::string& path) {
	return !path.empty() && !IsSystemDirectory(path);
}

static bool VerifyProcessFilename(const std::string& name) {
	static const std::set<std::string> invalid_names = {
	    // System files
	    "explorer.exe",   // Windows Explorer
	    "taskeng.exe",    // Task Scheduler Engine
	    "taskhost.exe",   // Host Process for Windows Tasks
	    "taskhostex.exe", // Host Process for Windows Tasks
	    "taskmgr.exe",    // Task Manager
	    "services.exe",   // Service Control Manager
	};

	if (name.empty())
		return false;

	for (const auto& invalid_name : invalid_names)
		if (util::EqualStrings(name, invalid_name))
			return false;

	return true;
}

/* ------------------------------------------------------------------- */
/* extern functions */

bool GetProcessName(pid_t pid, std::string& name) {
	std::string path = GetProcessPath(pid);
	if (path.empty() || !VerifyProcessPath(path))
		return false;

	name = GetFilenameFromPath(path);
	if (!VerifyProcessFilename(name))
		return false;

	return true;
}

bool EnumerateOpenProcesses(process_proc_t process_proc) {
	Handle process_snap(::CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0));
	if (process_snap.get() == INVALID_HANDLE_VALUE)
		return false;

	PROCESSENTRY32 pe32;
	pe32.dwSize = sizeof(PROCESSENTRY32);

	if (!::Process32First(process_snap.get(), &pe32))
		return false;

	do {
		std::string name;
		if (!GetProcessName(pe32.th32ProcessID, name))
			continue;

		if (!process_proc({.platform = ExecutablePlatform::Win32, .pid = pe32.th32ProcessID, .comm = name}))
			return false;
	} while (::Process32Next(process_snap.get(), &pe32));

	return true;
}

bool EnumerateOpenFiles(const std::set<pid_t>& pids, open_file_proc_t open_file_proc) {
	if (!open_file_proc)
		return false;

	std::unordered_map<pid_t, Handle> proc_handles;

	for (const pid_t& pid : pids) {
		const HANDLE handle = ::OpenProcess(PROCESS_DUP_HANDLE, false, pid);
		if (handle != INVALID_HANDLE_VALUE)
			proc_handles[pid] = Handle(handle);
	}

	if (proc_handles.empty())
		return false;

	std::vector<SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX> info = GetSystemHandleInformation();

	for (const auto& h : info) {
		const pid_t pid = h.UniqueProcessId;
		if (!pids.count(pid))
			continue;

		if (!IsFileHandle(nullptr, h.ObjectTypeIndex))
			continue;

		if (!IsFileMaskOk(h.GrantedAccess))
			continue;

		Handle handle(DuplicateHandle(proc_handles[pid].get(), h.HandleValue));
		if (handle.get() == INVALID_HANDLE_VALUE)
			continue;

		if (::GetFileType(handle.get()) != FILE_TYPE_DISK)
			continue;

		const std::wstring path = GetFinalPathNameByHandle(handle.get());
		if (!IsFilePathOk(path))
			continue;

		if (!open_file_proc({pid, ToUtf8String(path)}))
			return false;
	}

	return true;
}

} // namespace animone::internal::win32
