diff dep/animia/src/win32.cpp @ 62:4c6dd5999b39

*: update 1. updated animia 2. use widestrings for filesystem on Windows
author Paper <mrpapersonic@gmail.com>
date Sun, 01 Oct 2023 06:16:06 -0400
parents 6ff7aabeb9d7
children eab9e623eb84
line wrap: on
line diff
--- a/dep/animia/src/win32.cpp	Fri Sep 29 15:52:31 2023 -0400
+++ b/dep/animia/src/win32.cpp	Sun Oct 01 06:16:06 2023 -0400
@@ -1,21 +1,29 @@
+/**
+ * win32.cpp
+ *  - provides support for Windows clients
+ *
+ **/
 #include "win32.h"
-#include <windows.h>
-#include <winternl.h>
+#include <fileapi.h>
+#include <handleapi.h>
+#include <iostream>
 #include <libloaderapi.h>
 #include <ntdef.h>
 #include <psapi.h>
-#include <tlhelp32.h>
-#include <fileapi.h>
-#include <handleapi.h>
-#include <vector>
-#include <iostream>
+#include <shlobj.h>
+#include <stdexcept>
 #include <string>
+#include <stringapiset.h>
+#include <tlhelp32.h>
 #include <unordered_map>
-#include <stdexcept>
-#include <locale>
-#include <codecvt>
+#include <vector>
+#include <windows.h>
+#include <winternl.h>
+
 /* This file is noticably more complex than Unix and Linux, and that's because
-   there is no "simple" way to get the paths of a file. */
+   there is no "simple" way to get the paths of a file. In fact, this thing requires
+   you to use *internal functions* that can't even be linked to, hence why we have to
+   use GetProcAddress and such. What a mess. */
 
 #define SystemExtendedHandleInformation ((SYSTEM_INFORMATION_CLASS)0x40)
 constexpr NTSTATUS STATUS_INFO_LENGTH_MISMATCH = 0xC0000004UL;
@@ -24,74 +32,40 @@
 static unsigned short file_type_index = 0;
 
 struct SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX {
-	PVOID Object;
-	ULONG_PTR UniqueProcessId;
-	HANDLE HandleValue;
-	ACCESS_MASK GrantedAccess;
-	USHORT CreatorBackTraceIndex;
-	USHORT ObjectTypeIndex;
-	ULONG HandleAttributes;
-	ULONG Reserved;
+		PVOID Object;
+		ULONG_PTR UniqueProcessId;
+		HANDLE HandleValue;
+		ACCESS_MASK GrantedAccess;
+		USHORT CreatorBackTraceIndex;
+		USHORT ObjectTypeIndex;
+		ULONG HandleAttributes;
+		ULONG Reserved;
 };
 
 struct SYSTEM_HANDLE_INFORMATION_EX {
-	ULONG_PTR NumberOfHandles;
-	ULONG_PTR Reserved;
-	SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX Handles[1];
+		ULONG_PTR NumberOfHandles;
+		ULONG_PTR Reserved;
+		SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX Handles[1];
 };
 
 namespace Animia::Windows {
 
-std::vector<int> get_all_pids() {
-	std::vector<int> ret;
-    HANDLE hProcessSnap = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);
-    PROCESSENTRY32 pe32;
-	pe32.dwSize = sizeof(PROCESSENTRY32);
-
-    if (hProcessSnap == INVALID_HANDLE_VALUE)
-        return std::vector<int>();
-
-	if (!Process32First(hProcessSnap, &pe32))
-		return std::vector<int>();
-
-	ret.push_back(pe32.th32ProcessID);
-	while (Process32Next(hProcessSnap, &pe32)) {
-		ret.push_back(pe32.th32ProcessID);
-	}
-	// clean the snapshot object
-	CloseHandle(hProcessSnap);
-
-    return ret;
-}
-
-std::string get_process_name(int pid) {
-	HANDLE handle = OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, FALSE, pid);
-	if (!handle)
-		return "";
-
-	std::string ret(MAX_PATH, '\0');
-	if (!GetModuleBaseNameA(handle, 0, &ret.front(), ret.size()))
-		throw std::runtime_error("GetModuleBaseNameA failed: " + std::to_string(GetLastError()));
-	CloseHandle(handle);
-
-	return ret;
-}
-
 /* All of this BS is required on Windows. Why? */
 
 HANDLE DuplicateHandle(HANDLE process_handle, HANDLE handle) {
 	HANDLE dup_handle = nullptr;
-	const bool result = ::DuplicateHandle(process_handle, handle,
-		::GetCurrentProcess(), &dup_handle, 0, false, DUPLICATE_SAME_ACCESS);
+	const bool result =
+	    ::DuplicateHandle(process_handle, handle, ::GetCurrentProcess(), &dup_handle, 0, false, DUPLICATE_SAME_ACCESS);
 	return result ? dup_handle : nullptr;
 }
 
 PVOID GetNTDLLAddress(LPCSTR proc_name) {
-	return reinterpret_cast<PVOID>(GetProcAddress(GetModuleHandleA("ntdll.dll"), proc_name));
+	return reinterpret_cast<PVOID>(::GetProcAddress(::GetModuleHandleA("ntdll.dll"), proc_name));
 }
 
 NTSTATUS QuerySystemInformation(SYSTEM_INFORMATION_CLASS cls, PVOID sysinfo, ULONG len, PULONG retlen) {
-	static const auto func = reinterpret_cast<decltype(::NtQuerySystemInformation)*>(GetNTDLLAddress("NtQuerySystemInformation"));
+	static const auto func =
+	    reinterpret_cast<decltype(::NtQuerySystemInformation)*>(GetNTDLLAddress("NtQuerySystemInformation"));
 	return func(cls, sysinfo, len, retlen);
 }
 
@@ -112,11 +86,14 @@
 		if (!(info = (SYSTEM_HANDLE_INFORMATION_EX*)malloc(cb *= 2)))
 			continue;
 
+		res.reserve(cb);
+
 		if (0 <= (status = QuerySystemInformation(SystemExtendedHandleInformation, info, cb, &cb))) {
 			if (ULONG_PTR handles = info->NumberOfHandles) {
 				SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX* entry = info->Handles;
 				do {
-					if (entry) res.push_back(*entry);
+					if (entry)
+						res.push_back(*entry);
 				} while (entry++, --handles);
 			}
 		}
@@ -132,51 +109,43 @@
 	return info;
 }
 
-std::string UnicodeStringToStdString(UNICODE_STRING string) {
-	ANSI_STRING result;
-	static const auto uc_to_ansi = reinterpret_cast<decltype(::RtlUnicodeStringToAnsiString)*>(GetNTDLLAddress("RtlUnicodeStringToAnsiString"));
-	uc_to_ansi(&result, &string, TRUE);
-	std::string ret = std::string(result.Buffer, result.Length);
-	static const auto free_ansi = reinterpret_cast<decltype(::RtlFreeAnsiString)*>(GetNTDLLAddress("RtlFreeAnsiString"));
-	free_ansi(&result);
+/* we're using UTF-8. originally, I had used just the ANSI versions of functions, but that
+   sucks massive dick. this way we get unicode in the way every single other OS does it */
+std::string UnicodeStringToUtf8(std::wstring string) {
+	unsigned long size = ::WideCharToMultiByte(CP_UTF8, 0, string.c_str(), string.length(), NULL, 0, NULL, NULL);
+	std::string ret = std::string(size, '\0');
+	::WideCharToMultiByte(CP_UTF8, 0, string.c_str(), string.length(), &ret.front(), ret.length(), NULL, NULL);
+	return ret;
+}
+
+std::string UnicodeStringToUtf8(UNICODE_STRING string) {
+	unsigned long size = ::WideCharToMultiByte(CP_UTF8, 0, string.Buffer, string.Length, NULL, 0, NULL, NULL);
+	std::string ret = std::string(size, '\0');
+	::WideCharToMultiByte(CP_UTF8, 0, string.Buffer, string.Length, &ret.front(), ret.length(), NULL, NULL);
+	return ret;
+}
+
+std::wstring Utf8StringToUnicode(std::string string) {
+	unsigned long size = ::MultiByteToWideChar(CP_UTF8, 0, string.c_str(), string.length(), NULL, 0);
+	std::wstring ret = std::wstring(size, L'\0');
+	::MultiByteToWideChar(CP_UTF8, 0, string.c_str(), string.length(), &ret.front(), ret.length());
 	return ret;
 }
 
 std::string GetHandleType(HANDLE handle) {
 	OBJECT_TYPE_INFORMATION info = QueryObjectTypeInfo(handle);
-	return UnicodeStringToStdString(info.TypeName);
+	return UnicodeStringToUtf8(info.TypeName);
 }
 
-/* GetFinalPathNameByHandleA literally just doesn't work */
 std::string GetFinalPathNameByHandle(HANDLE handle) {
 	std::wstring buffer;
 
 	int result = ::GetFinalPathNameByHandleW(handle, NULL, 0, FILE_NAME_NORMALIZED | VOLUME_NAME_DOS);
 	buffer.resize(result);
 	::GetFinalPathNameByHandleW(handle, &buffer.front(), buffer.size(), FILE_NAME_NORMALIZED | VOLUME_NAME_DOS);
-
-	std::wstring_convert<std::codecvt_utf8<wchar_t>, wchar_t> converter;
-
-	return converter.to_bytes(buffer);
-}
-
-std::string GetSystemDirectory() {
-	std::string windir = std::string(MAX_PATH, '\0');
-	::GetWindowsDirectoryA(&windir.front(), windir.length());
-	return "\\\\?\\" + windir;
-}
+	buffer.resize(buffer.find('\0'));
 
-/* This function is useless. I'm not exactly sure why, but whenever I try to compare the two
-   values, they both come up as different. I'm assuming it's just some Unicode BS I can't be bothered
-   to deal with. */
-bool IsSystemDirectory(const std::string& path) {
-	std::string path_l = path;
-	CharUpperBuffA(&path_l.front(), path_l.length());
-
-	std::string windir = GetSystemDirectory();
-	CharUpperBuffA(&windir.front(), windir.length());
-
-	return path_l.rfind(windir, 0) != std::string::npos;
+	return UnicodeStringToUtf8(buffer);
 }
 
 bool IsFileHandle(HANDLE handle, unsigned short object_type_index) {
@@ -192,37 +161,123 @@
 }
 
 bool IsFileMaskOk(ACCESS_MASK access_mask) {
-	/* this filters out any file handles that, legitimately,
-	   do not make sense (for what we're using it for)
-	
-	   shoutout to erengy for having these in Anisthesia */
-
 	if (!(access_mask & FILE_READ_DATA))
 		return false;
 
-	if ((access_mask & FILE_APPEND_DATA) ||
-	    (access_mask & FILE_WRITE_EA) ||
-	    (access_mask & FILE_WRITE_ATTRIBUTES))
+	if ((access_mask & FILE_APPEND_DATA) || (access_mask & FILE_WRITE_EA) || (access_mask & FILE_WRITE_ATTRIBUTES))
 		return false;
 
 	return true;
 }
 
 bool IsFilePathOk(const std::string& path) {
-	if (path.empty() || IsSystemDirectory(path))
+	if (path.empty())
 		return false;
 
 	const auto file_attributes = GetFileAttributesA(path.c_str());
-	if ((file_attributes == INVALID_FILE_ATTRIBUTES) ||
-	    (file_attributes & FILE_ATTRIBUTE_DIRECTORY))
+	if ((file_attributes == INVALID_FILE_ATTRIBUTES) || (file_attributes & FILE_ATTRIBUTE_DIRECTORY))
 		return false;
 
 	return true;
 }
 
+std::string GetSystemDirectory() {
+	PWSTR path_wch;
+	SHGetKnownFolderPath(FOLDERID_Windows, 0, NULL, &path_wch);
+	std::wstring path_wstr(path_wch);
+	CoTaskMemFree(path_wch);
+	return UnicodeStringToUtf8(path_wstr);
+}
+
+bool IsSystemFile(const std::string& path) {
+	std::wstring path_w = Utf8StringToUnicode(path);
+	CharUpperBuffW(&path_w.front(), path_w.length());
+	std::wstring windir_w = Utf8StringToUnicode(GetSystemDirectory());
+	CharUpperBuffW(&windir_w.front(), windir_w.length());
+	return path_w.find(windir_w) == 4;
+}
+
+std::vector<int> get_all_pids() {
+	std::vector<int> ret;
+	HANDLE hProcessSnap = CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);
+	PROCESSENTRY32 pe32;
+	pe32.dwSize = sizeof(PROCESSENTRY32);
+
+	if (hProcessSnap == INVALID_HANDLE_VALUE)
+		return std::vector<int>();
+
+	if (!Process32First(hProcessSnap, &pe32))
+		return std::vector<int>();
+
+	ret.push_back(pe32.th32ProcessID);
+	while (Process32Next(hProcessSnap, &pe32)) {
+		ret.push_back(pe32.th32ProcessID);
+	}
+	// clean the snapshot object
+	CloseHandle(hProcessSnap);
+
+	return ret;
+}
+
+std::string get_process_name(int pid) {
+	unsigned long size = 256, ret_size = 0;
+	HANDLE handle = ::OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, FALSE, pid);
+	if (!handle)
+		return "";
+
+	std::wstring ret(size, L'\0');
+	while (size < 32768) {
+		ret.resize(size, L'\0');
+
+		if (!(ret_size = ::GetModuleBaseNameW(handle, 0, &ret.front(), ret.length())))
+			ret = L"";
+		else if (size > ret_size)
+			ret.resize(ret.find('\0'));
+
+		size *= 2;
+	}
+
+	CloseHandle(handle);
+
+	return UnicodeStringToUtf8(ret);
+}
+
 std::vector<std::string> get_open_files(int pid) {
-	std::unordered_map<int, std::vector<std::string>> map = get_all_open_files();
-	return map[pid];
+	std::vector<std::string> ret;
+	std::vector<SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX> info = GetSystemHandleInformation();
+	for (auto& h : info) {
+		if (h.UniqueProcessId != pid)
+			continue;
+
+		if (!IsFileHandle(nullptr, h.ObjectTypeIndex))
+			continue;
+		if (!IsFileMaskOk(h.GrantedAccess))
+			continue;
+
+		const HANDLE proc = ::OpenProcess(PROCESS_DUP_HANDLE, false, pid);
+		HANDLE handle = DuplicateHandle(proc, h.HandleValue);
+		if (!handle)
+			continue;
+
+		if (GetFileType(handle) != FILE_TYPE_DISK)
+			continue;
+
+		std::string path = GetFinalPathNameByHandle(handle);
+		if (!IsFilePathOk(path))
+			continue;
+
+		ret.push_back(path);
+	}
+	return ret;
+}
+
+std::vector<std::string> filter_system_files(const std::vector<std::string>& source) {
+	std::vector<std::string> ret;
+	for (const std::string& s : source) {
+		if (!IsSystemFile(s))
+			ret.push_back(s);
+	}
+	return ret;
 }
 
 std::unordered_map<int, std::vector<std::string>> get_all_open_files() {
@@ -255,4 +310,4 @@
 	return map;
 }
 
-}
+} // namespace Animia::Windows