diff dep/animia/src/fd/win32.cc @ 152:8700806c2cc2

dep/animia: awesome new breaking changes! I'm so tired
author Paper <mrpapersonic@gmail.com>
date Wed, 15 Nov 2023 02:34:59 -0500
parents 54744a48a7d7
children cdf79282d647
line wrap: on
line diff
--- a/dep/animia/src/fd/win32.cc	Tue Nov 14 16:31:21 2023 -0500
+++ b/dep/animia/src/fd/win32.cc	Wed Nov 15 02:34:59 2023 -0500
@@ -1,9 +1,11 @@
-/**
- * win32.cpp
- *  - provides support for Windows clients
- *
- **/
+/*
+** win32.cpp
+**  - provides support for Windows clients
+**
+*/
 #include "animia/fd/win32.h"
+#include "animia/util/win32.h"
+#include "animia.h"
 
 #include <stdexcept>
 #include <string>
@@ -26,10 +28,15 @@
    you to use *internal functions* that can't even be linked to, hence why we have to
    use GetProcAddress and such. What a mess. */
 
-#define SystemExtendedHandleInformation ((SYSTEM_INFORMATION_CLASS)0x40)
+/* SystemExtendedHandleInformation is only available in NT 5.1+ (XP and higher) and provides information for
+   32-bit PIDs, unlike SystemHandleInformation */
+constexpr SYSTEM_INFORMATION_CLASS SystemExtendedHandleInformation = static_cast<SYSTEM_INFORMATION_CLASS>(0x40);
+
+/* more constants not in winternl.h */
 constexpr NTSTATUS STATUS_INFO_LENGTH_MISMATCH = 0xC0000004UL;
-constexpr NTSTATUS STATUS_SUCCESS = 0x00000000UL;
 
+/* this is filled in at runtime because it's not guaranteed to be (and isn't)
+   constant between different versions of Windows */
 static unsigned short file_type_index = 0;
 
 struct SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX {
@@ -77,19 +84,23 @@
 	std::vector<SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX> res;
 	/* we should really put a cap on this */
 	ULONG cb = 1 << 19;
-	NTSTATUS status = STATUS_SUCCESS;
-	SYSTEM_HANDLE_INFORMATION_EX* info;
 
-	do {
+	for (NTSTATUS status = STATUS_INFO_LENGTH_MISMATCH; status == STATUS_INFO_LENGTH_MISMATCH; ) {
+		/* why are we doing this? */
 		status = STATUS_NO_MEMORY;
 
-		if (!(info = (SYSTEM_HANDLE_INFORMATION_EX*)malloc(cb *= 2)))
+		SYSTEM_HANDLE_INFORMATION_EX* info = (SYSTEM_HANDLE_INFORMATION_EX*)malloc(cb *= 2);
+		if (!info)
 			continue;
 
 		res.reserve(cb);
 
-		if (0 <= (status = QuerySystemInformation(SystemExtendedHandleInformation, info, cb, &cb))) {
-			if (ULONG_PTR handles = info->NumberOfHandles) {
+		status = QuerySystemInformation(SystemExtendedHandleInformation, info, cb, &cb);
+		if (0 <= status) {
+			ULONG_PTR handles = info->NumberOfHandles;
+			if (handles) {
+				res.reserve(res.size() + handles);
+
 				SYSTEM_HANDLE_TABLE_ENTRY_INFO_EX* entry = info->Handles;
 				do {
 					if (entry)
@@ -97,8 +108,9 @@
 				} while (entry++, --handles);
 			}
 		}
+
 		free(info);
-	} while (status == STATUS_INFO_LENGTH_MISMATCH);
+	}
 
 	return res;
 }
@@ -109,32 +121,9 @@
 	return info;
 }
 
-/* we're using UTF-8. originally, I had used just the ANSI versions of functions, but that
-   sucks massive dick. this way we get unicode in the way every single other OS does it */
-static std::string UnicodeStringToUtf8(std::wstring string) {
-	unsigned long size = ::WideCharToMultiByte(CP_UTF8, 0, string.c_str(), string.length(), NULL, 0, NULL, NULL);
-	std::string ret = std::string(size, '\0');
-	::WideCharToMultiByte(CP_UTF8, 0, string.c_str(), string.length(), &ret.front(), ret.length(), NULL, NULL);
-	return ret;
-}
-
-static std::string UnicodeStringToUtf8(UNICODE_STRING string) {
-	unsigned long size = ::WideCharToMultiByte(CP_UTF8, 0, string.Buffer, string.Length, NULL, 0, NULL, NULL);
-	std::string ret = std::string(size, '\0');
-	::WideCharToMultiByte(CP_UTF8, 0, string.Buffer, string.Length, &ret.front(), ret.length(), NULL, NULL);
-	return ret;
-}
-
-static std::wstring Utf8StringToUnicode(std::string string) {
-	unsigned long size = ::MultiByteToWideChar(CP_UTF8, 0, string.c_str(), string.length(), NULL, 0);
-	std::wstring ret = std::wstring(size, L'\0');
-	::MultiByteToWideChar(CP_UTF8, 0, string.c_str(), string.length(), &ret.front(), ret.length());
-	return ret;
-}
-
 static std::string GetHandleType(HANDLE handle) {
 	OBJECT_TYPE_INFORMATION info = QueryObjectTypeInfo(handle);
-	return UnicodeStringToUtf8(info.TypeName);
+	return ToUtf8String(info.TypeName);
 }
 
 static std::string GetFinalPathNameByHandle(HANDLE handle) {
@@ -145,7 +134,7 @@
 	::GetFinalPathNameByHandleW(handle, &buffer.front(), buffer.size(), FILE_NAME_NORMALIZED | VOLUME_NAME_DOS);
 	buffer.resize(buffer.find('\0'));
 
-	return UnicodeStringToUtf8(buffer);
+	return ToUtf8String(buffer);
 }
 
 static bool IsFileHandle(HANDLE handle, unsigned short object_type_index) {
@@ -170,27 +159,11 @@
 	return true;
 }
 
-static std::string GetSystemDirectory() {
-	PWSTR path_wch;
-	SHGetKnownFolderPath(FOLDERID_Windows, 0, NULL, &path_wch);
-	std::wstring path_wstr(path_wch);
-	CoTaskMemFree(path_wch);
-	return UnicodeStringToUtf8(path_wstr);
-}
-
-static bool IsSystemFile(const std::string& path) {
-	std::wstring path_w = Utf8StringToUnicode(path);
-	CharUpperBuffW(&path_w.front(), path_w.length());
-	std::wstring windir_w = Utf8StringToUnicode(GetSystemDirectory());
-	CharUpperBuffW(&windir_w.front(), windir_w.length());
-	return path_w.find(windir_w) == 4;
-}
-
 static bool IsFilePathOk(const std::string& path) {
 	if (path.empty())
 		return false;
 
-	if (IsSystemFile(path))
+	if (IsSystemDirectory(path))
 		return false;
 
 	const auto file_attributes = GetFileAttributesA(path.c_str());
@@ -200,7 +173,7 @@
 	return true;
 }
 
-bool Win32FdTools::GetAllPids(std::set<pid_t>& pids) {
+bool Win32FdTools::EnumerateOpenProcesses(process_proc_t process_proc) {
 	HANDLE hProcessSnap = ::CreateToolhelp32Snapshot(TH32CS_SNAPPROCESS, 0);
 	if (hProcessSnap == INVALID_HANDLE_VALUE)
 		return false;
@@ -211,39 +184,23 @@
 	if (!::Process32First(hProcessSnap, &pe32))
 		return false;
 
-	pids.insert(pe32.th32ProcessID);
+	if (!process_proc({pe32.th32ProcessID, pe32.szExeFile}))
+		return false;
 
 	while (::Process32Next(hProcessSnap, &pe32))
-		pids.insert(pe32.th32ProcessID);
+		if (!process_proc({pe32.th32ProcessID, pe32.szExeFile}))
+			return false;
 
 	::CloseHandle(hProcessSnap);
 
 	return true;
 }
 
-bool Win32FdTools::GetProcessName(pid_t pid, std::string& result) {
-	unsigned long ret_size = 0; // size given by GetModuleBaseNameW
-	Handle handle(::OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, FALSE, pid));
-	if (handle.get() == INVALID_HANDLE_VALUE)
+/* this could be changed to being a callback, but... I'm too lazy right now :) */
+bool Win32FdTools::EnumerateOpenFiles(const std::set<pid_t>& pids, open_file_proc_t open_file_proc) {
+	if (!open_file_proc)
 		return false;
 
-	/* agh... */
-	std::wstring ret(256, L'\0');
-	for (; ret.length() < 32768; ret.resize(ret.length() * 2)) {
-		if (!(ret_size = ::GetModuleBaseNameW(handle.get(), 0, &ret.front(), ret.length()))) {
-			return false;
-		} else if (ret.length() > ret_size) {
-			ret.resize(ret.find(L'\0'));
-			result = UnicodeStringToUtf8(ret);
-			break;
-		}
-	}
-
-	return true;
-}
-
-/* this could be changed to being a callback, but... I'm too lazy right now :) */
-bool Win32FdTools::EnumerateOpenFiles(const std::set<pid_t>& pids, std::vector<OpenFile>& files) {
 	std::unordered_map<pid_t, Handle> proc_handles;
 
 	for (const pid_t& pid : pids) {
@@ -279,7 +236,8 @@
 		if (!IsFilePathOk(path))
 			continue;
 
-		files.push_back({pid, path});
+		if (!open_file_proc({pid, path}))
+			return false;
 	}
 
 	return true;