refactor, perf improvement

This commit is contained in:
Akash Mozumdar 2019-02-04 15:18:47 -05:00
parent c78747c228
commit e6805a2be3
7 changed files with 88 additions and 72 deletions

View File

@ -17,4 +17,4 @@ set(gui_src
add_executable(${PROJECT_NAME} WIN32 ${gui_src}) add_executable(${PROJECT_NAME} WIN32 ${gui_src})
target_link_libraries(${PROJECT_NAME} Qt5::Widgets winhttp) target_link_libraries(${PROJECT_NAME} Qt5::Widgets winhttp)
install_qt5_libs(${PROJECT_NAME}) # can be commented out for consecutive builds #install_qt5_libs(${PROJECT_NAME}) # can be commented out for consecutive builds

View File

@ -10,21 +10,16 @@ namespace
class ProcessRecord class ProcessRecord
{ {
public: public:
inline static Host::ProcessEventCallback OnConnect, OnDisconnect;
ProcessRecord(DWORD processId, HANDLE pipe) : ProcessRecord(DWORD processId, HANDLE pipe) :
processId(processId), processId(processId),
pipe(pipe), pipe(pipe),
mappedFile(OpenFileMappingW(FILE_MAP_READ, FALSE, (ITH_SECTION_ + std::to_wstring(processId)).c_str())), mappedFile(OpenFileMappingW(FILE_MAP_READ, FALSE, (ITH_SECTION_ + std::to_wstring(processId)).c_str())),
view(*(const TextHook(*)[MAX_HOOK])MapViewOfFile(mappedFile, FILE_MAP_READ, 0, 0, HOOK_SECTION_SIZE / 2)), // jichi 1/16/2015: Changed to half to hook section size view(*(const TextHook(*)[MAX_HOOK])MapViewOfFile(mappedFile, FILE_MAP_READ, 0, 0, HOOK_SECTION_SIZE / 2)), // jichi 1/16/2015: Changed to half to hook section size
viewMutex(ITH_HOOKMAN_MUTEX_ + std::to_wstring(processId)) viewMutex(ITH_HOOKMAN_MUTEX_ + std::to_wstring(processId))
{ {}
OnConnect(processId);
}
~ProcessRecord() ~ProcessRecord()
{ {
OnDisconnect(processId);
UnmapViewOfFile(view); UnmapViewOfFile(view);
} }
@ -59,15 +54,21 @@ namespace
{ {
return std::hash<int64_t>()(tp.processId + tp.addr) + std::hash<int64_t>()(tp.ctx + tp.ctx2); return std::hash<int64_t>()(tp.processId + tp.addr) + std::hash<int64_t>()(tp.ctx + tp.ctx2);
} }
ThreadSafe<std::unordered_map<ThreadParam, std::unique_ptr<TextThread>, Functor<HashThreadParam>>, std::recursive_mutex> textThreadsByParams; ThreadSafe<std::unordered_map<ThreadParam, TextThread, Functor<HashThreadParam>>, std::recursive_mutex> textThreadsByParams;
ThreadSafe<std::unordered_map<DWORD, ProcessRecord>, std::recursive_mutex> processRecordsByIds; ThreadSafe<std::unordered_map<DWORD, ProcessRecord>, std::recursive_mutex> processRecordsByIds;
Host::ProcessEventHandler OnConnect, OnDisconnect;
Host::ThreadEventHandler OnCreate, OnDestroy;
void RemoveThreads(std::function<bool(ThreadParam)> removeIf) void RemoveThreads(std::function<bool(ThreadParam)> removeIf)
{ {
std::vector<std::unique_ptr<TextThread>> removedThreads; // delay destruction until after lock is released std::vector<TextThread*> threadsToRemove;
auto[lock, textThreadsByParams] = ::textThreadsByParams.operator->(); std::for_each(textThreadsByParams->begin(), textThreadsByParams->end(), [&](auto& it) { if (removeIf(it.first)) threadsToRemove.push_back(&it.second); });
for (auto it = textThreadsByParams->begin(); it != textThreadsByParams->end(); removeIf(it->first) ? it = textThreadsByParams->erase(it) : ++it) for (auto thread : threadsToRemove)
if (removeIf(it->first)) removedThreads.emplace_back(std::move(it->second)); {
OnDestroy(*thread);
textThreadsByParams->erase(thread->tp);
}
} }
void CreatePipe() void CreatePipe()
@ -89,6 +90,7 @@ namespace
DWORD bytesRead, processId; DWORD bytesRead, processId;
ReadFile(hookPipe, &processId, sizeof(processId), &bytesRead, nullptr); ReadFile(hookPipe, &processId, sizeof(processId), &bytesRead, nullptr);
processRecordsByIds->try_emplace(processId, processId, hostPipe); processRecordsByIds->try_emplace(processId, processId, hostPipe);
OnConnect(processId);
CreatePipe(); CreatePipe();
@ -110,13 +112,18 @@ namespace
default: default:
{ {
auto tp = *(ThreadParam*)buffer; auto tp = *(ThreadParam*)buffer;
if (textThreadsByParams->count(tp) == 0) textThreadsByParams->insert({ tp, std::make_unique<TextThread>(tp, Host::GetHookParam(tp)) }); if (textThreadsByParams->count(tp) == 0)
textThreadsByParams->at(tp)->Push(buffer + sizeof(tp), bytesRead - sizeof(tp)); {
TextThread& created = textThreadsByParams->try_emplace(tp, tp, Host::GetHookParam(tp)).first->second;
OnCreate(created);
}
textThreadsByParams->find(tp)->second.Push(buffer + sizeof(tp), bytesRead - sizeof(tp));
} }
break; break;
} }
RemoveThreads([&](ThreadParam tp) { return tp.processId == processId; }); RemoveThreads([&](ThreadParam tp) { return tp.processId == processId; });
OnDisconnect(processId);
processRecordsByIds->erase(processId); processRecordsByIds->erase(processId);
}).detach(); }).detach();
} }
@ -124,22 +131,27 @@ namespace
namespace Host namespace Host
{ {
void Start(ProcessEventCallback OnConnect, ProcessEventCallback OnDisconnect, TextThread::EventCallback OnCreate, TextThread::EventCallback OnDestroy, TextThread::OutputCallback Output) void Start(ProcessEventHandler Connect, ProcessEventHandler Disconnect, ThreadEventHandler Create, ThreadEventHandler Destroy, TextThread::OutputCallback Output)
{ {
ProcessRecord::OnConnect = OnConnect; OnConnect = Connect;
ProcessRecord::OnDisconnect = OnDisconnect; OnDisconnect = Disconnect;
TextThread::OnCreate = OnCreate; OnCreate = [Create](TextThread& thread) { Create(thread); thread.Start(); };
TextThread::OnDestroy = OnDestroy; OnDestroy = [Destroy](TextThread& thread) { thread.Stop(); Destroy(thread); };
TextThread::Output = Output; TextThread::Output = Output;
processRecordsByIds->try_emplace(console.processId, console.processId, INVALID_HANDLE_VALUE); processRecordsByIds->try_emplace(console.processId, console.processId, INVALID_HANDLE_VALUE);
textThreadsByParams->insert({ console, std::make_unique<TextThread>(console, HookParam{}, CONSOLE) }); OnConnect(console.processId);
textThreadsByParams->insert({ Host::clipboard, std::make_unique<TextThread>(Host::clipboard, HookParam{}, CLIPBOARD) }); textThreadsByParams->try_emplace(console, console, HookParam{}, CONSOLE);
OnCreate(GetThread(console));
textThreadsByParams->try_emplace(clipboard, clipboard, HookParam{}, CLIPBOARD);
OnCreate(GetThread(clipboard));
CreatePipe(); CreatePipe();
SetWindowsHookExW(WH_GETMESSAGE, [](int statusCode, WPARAM wParam, LPARAM lParam) SetWindowsHookExW(WH_GETMESSAGE, [](int statusCode, WPARAM wParam, LPARAM lParam)
{ {
if (statusCode == HC_ACTION && wParam == PM_REMOVE && ((MSG*)lParam)->message == WM_CLIPBOARDUPDATE) if (statusCode == HC_ACTION && wParam == PM_REMOVE && ((MSG*)lParam)->message == WM_CLIPBOARDUPDATE)
if (auto text = Util::GetClipboardText()) Host::GetThread(Host::clipboard)->AddSentence(text.value()); if (auto text = Util::GetClipboardText()) GetThread(clipboard).AddSentence(text.value());
return CallNextHookEx(NULL, statusCode, wParam, lParam); return CallNextHookEx(NULL, statusCode, wParam, lParam);
}, NULL, GetCurrentThreadId()); }, NULL, GetCurrentThreadId());
} }
@ -200,13 +212,13 @@ namespace Host
return processRecordsByIds->at(tp.processId).GetHook(tp.addr).hp; return processRecordsByIds->at(tp.processId).GetHook(tp.addr).hp;
} }
TextThread* GetThread(ThreadParam tp) TextThread& GetThread(ThreadParam tp)
{ {
return textThreadsByParams->at(tp).get(); return textThreadsByParams->at(tp);
} }
void AddConsoleOutput(std::wstring text) void AddConsoleOutput(std::wstring text)
{ {
GetThread(console)->AddSentence(text); GetThread(console).AddSentence(text);
} }
} }

View File

@ -5,8 +5,9 @@
namespace Host namespace Host
{ {
using ProcessEventCallback = std::function<void(DWORD)>; using ProcessEventHandler = std::function<void(DWORD)>;
void Start(ProcessEventCallback OnConnect, ProcessEventCallback OnDisconnect, TextThread::EventCallback OnCreate, TextThread::EventCallback OnDestroy, TextThread::OutputCallback Output); using ThreadEventHandler = std::function<void(TextThread&)>;
void Start(ProcessEventHandler Connect, ProcessEventHandler Disconnect, ThreadEventHandler Create, ThreadEventHandler Destroy, TextThread::OutputCallback Output);
bool InjectProcess(DWORD processId, DWORD timeout = 5000); bool InjectProcess(DWORD processId, DWORD timeout = 5000);
void DetachProcess(DWORD processId); void DetachProcess(DWORD processId);
@ -14,7 +15,7 @@ namespace Host
HookParam GetHookParam(ThreadParam tp); HookParam GetHookParam(ThreadParam tp);
TextThread* GetThread(ThreadParam tp); TextThread& GetThread(ThreadParam tp);
void AddConsoleOutput(std::wstring text); void AddConsoleOutput(std::wstring text);
inline int defaultCodepage = SHIFT_JIS; inline int defaultCodepage = SHIFT_JIS;

View File

@ -9,14 +9,16 @@ TextThread::TextThread(ThreadParam tp, HookParam hp, std::optional<std::wstring>
name(name.value_or(Util::StringToWideString(hp.name).value())), name(name.value_or(Util::StringToWideString(hp.name).value())),
tp(tp), tp(tp),
hp(hp) hp(hp)
{}
void TextThread::Start()
{ {
CreateTimerQueueTimer(&timer, NULL, [](void* This, BOOLEAN) { ((TextThread*)This)->Flush(); }, this, 10, 10, WT_EXECUTELONGFUNCTION); CreateTimerQueueTimer(&timer, NULL, [](void* This, BOOLEAN) { ((TextThread*)This)->Flush(); }, this, 10, 10, WT_EXECUTELONGFUNCTION);
OnCreate(this);
} }
TextThread::~TextThread() void TextThread::Stop()
{ {
OnDestroy(this); timer = NULL;
} }
void TextThread::AddSentence(const std::wstring& sentence) void TextThread::AddSentence(const std::wstring& sentence)
@ -40,7 +42,7 @@ void TextThread::Push(const BYTE* data, int len)
lastPushTime = GetTickCount(); lastPushTime = GetTickCount();
if (std::all_of(buffer.begin(), buffer.end(), [&](wchar_t c) { return repeatingChars.count(c) > 0; })) buffer.clear(); if (std::all_of(buffer.begin(), buffer.end(), [&](wchar_t c) { return repeatingChars.count(c) > 0; })) buffer.clear();
if (Util::RemoveRepetition(buffer)) // repetition detected, which means the entire sentence has already been received if (Util::RemoveRepetition(buffer)) // sentence repetition detected, which means the entire sentence has already been received
{ {
repeatingChars = std::unordered_set(buffer.begin(), buffer.end()); repeatingChars = std::unordered_set(buffer.begin(), buffer.end());
AddSentence(buffer); AddSentence(buffer);
@ -53,11 +55,13 @@ void TextThread::Flush()
std::vector<std::wstring> sentences; std::vector<std::wstring> sentences;
queuedSentences->swap(sentences); queuedSentences->swap(sentences);
for (auto& sentence : sentences) for (auto& sentence : sentences)
if (Output(this, sentence)) storage->append(sentence); if (Output(*this, sentence)) storage->append(sentence);
std::scoped_lock lock(bufferMutex); std::scoped_lock lock(bufferMutex);
if (buffer.empty()) return; if (buffer.empty()) return;
if (buffer.size() < maxBufferSize && GetTickCount() - lastPushTime < flushDelay) return; if (buffer.size() > maxBufferSize || GetTickCount() - lastPushTime > flushDelay)
{
AddSentence(buffer); AddSentence(buffer);
buffer.clear(); buffer.clear();
}
} }

View File

@ -6,17 +6,16 @@
class TextThread class TextThread
{ {
public: public:
using EventCallback = std::function<void(TextThread*)>; using OutputCallback = std::function<bool(TextThread&, std::wstring&)>;
using OutputCallback = std::function<bool(TextThread*, std::wstring&)>;
inline static EventCallback OnCreate, OnDestroy;
inline static OutputCallback Output; inline static OutputCallback Output;
inline static int flushDelay = 400; // flush every 400ms by default inline static int flushDelay = 400; // flush every 400ms by default
inline static int maxBufferSize = 1000; inline static int maxBufferSize = 1000;
TextThread(ThreadParam tp, HookParam hp, std::optional<std::wstring> name = {}); TextThread(ThreadParam tp, HookParam hp, std::optional<std::wstring> name = {});
~TextThread();
void Start();
void Stop();
void AddSentence(const std::wstring& sentence); void AddSentence(const std::wstring& sentence);
void Push(const BYTE* data, int len); void Push(const BYTE* data, int len);
@ -38,5 +37,5 @@ private:
DWORD lastPushTime = 0; DWORD lastPushTime = 0;
ThreadSafe<std::vector<std::wstring>> queuedSentences; ThreadSafe<std::vector<std::wstring>> queuedSentences;
struct TimerDeleter { void operator()(HANDLE h) { DeleteTimerQueueTimer(NULL, h, INVALID_HANDLE_VALUE); } }; struct TimerDeleter { void operator()(HANDLE h) { DeleteTimerQueueTimer(NULL, h, INVALID_HANDLE_VALUE); } };
AutoHandle<TimerDeleter> timer = NULL; // this needs to be last so it's destructed first AutoHandle<TimerDeleter> timer = NULL;
}; };

View File

@ -49,9 +49,9 @@ MainWindow::MainWindow(QWidget *parent) :
Host::Start( Host::Start(
[this](DWORD processId) { ProcessConnected(processId); }, [this](DWORD processId) { ProcessConnected(processId); },
[this](DWORD processId) { ProcessDisconnected(processId); }, [this](DWORD processId) { ProcessDisconnected(processId); },
[this](TextThread* thread) { ThreadAdded(thread); }, [this](TextThread& thread) { ThreadAdded(thread); },
[this](TextThread* thread) { ThreadRemoved(thread); }, [this](TextThread& thread) { ThreadRemoved(thread); },
[this](TextThread* thread, std::wstring& output) { return SentenceReceived(thread, output); } [this](TextThread& thread, std::wstring& output) { return SentenceReceived(thread, output); }
); );
Host::AddConsoleOutput(ABOUT); Host::AddConsoleOutput(ABOUT);
@ -104,7 +104,7 @@ void MainWindow::ProcessConnected(DWORD processId)
auto hookList = std::find_if(allProcesses.rbegin(), allProcesses.rend(), [&](QString hookList) { return hookList.contains(process); }); auto hookList = std::find_if(allProcesses.rbegin(), allProcesses.rend(), [&](QString hookList) { return hookList.contains(process); });
if (hookList != allProcesses.rend()) if (hookList != allProcesses.rend())
for (auto hookInfo : hookList->split(" , ")) for (auto hookInfo : hookList->split(" , "))
if (auto hp = Util::ParseCode(S(hookInfo))) QMetaObject::invokeMethod(this, [processId, hp] { Host::InsertHook(processId, hp.value()); }); if (auto hp = Util::ParseCode(S(hookInfo))) Host::InsertHook(processId, hp.value());
else swscanf_s(S(hookInfo).c_str(), L"|%I64d:%I64d:%[^\n]", &savedThreadCtx.first, &savedThreadCtx.second, savedThreadCode, ARRAYSIZE(savedThreadCode)); else swscanf_s(S(hookInfo).c_str(), L"|%I64d:%I64d:%[^\n]", &savedThreadCtx.first, &savedThreadCtx.second, savedThreadCode, ARRAYSIZE(savedThreadCode));
} }
@ -116,11 +116,11 @@ void MainWindow::ProcessDisconnected(DWORD processId)
}, Qt::BlockingQueuedConnection); }, Qt::BlockingQueuedConnection);
} }
void MainWindow::ThreadAdded(TextThread* thread) void MainWindow::ThreadAdded(TextThread& thread)
{ {
std::wstring threadCode = Util::GenerateCode(thread->hp, thread->tp.processId); std::wstring threadCode = Util::GenerateCode(thread.hp, thread.tp.processId);
QString ttString = TextThreadString(thread) + S(thread->name) + " (" + S(threadCode) + ")"; QString ttString = TextThreadString(thread) + S(thread.name) + " (" + S(threadCode) + ")";
bool savedMatch = savedThreadCtx.first == thread->tp.ctx && savedThreadCtx.second == thread->tp.ctx2 && savedThreadCode == threadCode; bool savedMatch = savedThreadCtx.first == thread.tp.ctx && savedThreadCtx.second == thread.tp.ctx2 && savedThreadCode == threadCode;
if (savedMatch) savedThreadCtx.first = savedThreadCtx.second = savedThreadCode[0] = 0; if (savedMatch) savedThreadCtx.first = savedThreadCtx.second = savedThreadCode[0] = 0;
QMetaObject::invokeMethod(this, [this, ttString, savedMatch] QMetaObject::invokeMethod(this, [this, ttString, savedMatch]
{ {
@ -129,7 +129,7 @@ void MainWindow::ThreadAdded(TextThread* thread)
}); });
} }
void MainWindow::ThreadRemoved(TextThread* thread) void MainWindow::ThreadRemoved(TextThread& thread)
{ {
QString ttString = TextThreadString(thread); QString ttString = TextThreadString(thread);
QMetaObject::invokeMethod(this, [this, ttString] QMetaObject::invokeMethod(this, [this, ttString]
@ -140,7 +140,7 @@ void MainWindow::ThreadRemoved(TextThread* thread)
}, Qt::BlockingQueuedConnection); }, Qt::BlockingQueuedConnection);
} }
bool MainWindow::SentenceReceived(TextThread* thread, std::wstring& sentence) bool MainWindow::SentenceReceived(TextThread& thread, std::wstring& sentence)
{ {
if (DispatchSentenceToExtensions(sentence, GetMiscInfo(thread))) if (DispatchSentenceToExtensions(sentence, GetMiscInfo(thread)))
{ {
@ -160,14 +160,14 @@ bool MainWindow::SentenceReceived(TextThread* thread, std::wstring& sentence)
return false; return false;
} }
QString MainWindow::TextThreadString(TextThread* thread) QString MainWindow::TextThreadString(TextThread& thread)
{ {
return QString("%1:%2:%3:%4:%5: ").arg( return QString("%1:%2:%3:%4:%5: ").arg(
QString::number(thread->handle, 16), QString::number(thread.handle, 16),
QString::number(thread->tp.processId, 16), QString::number(thread.tp.processId, 16),
QString::number(thread->tp.addr, 16), QString::number(thread.tp.addr, 16),
QString::number(thread->tp.ctx, 16), QString::number(thread.tp.ctx, 16),
QString::number(thread->tp.ctx2, 16) QString::number(thread.tp.ctx2, 16)
).toUpper(); ).toUpper();
} }
@ -182,16 +182,16 @@ DWORD MainWindow::GetSelectedProcessId()
return ui->processCombo->currentText().split(":")[0].toULong(nullptr, 16); return ui->processCombo->currentText().split(":")[0].toULong(nullptr, 16);
} }
std::unordered_map<const char*, int64_t> MainWindow::GetMiscInfo(TextThread* thread) std::unordered_map<const char*, int64_t> MainWindow::GetMiscInfo(TextThread& thread)
{ {
return return
{ {
{ "current select", ui->ttCombo->currentText().startsWith(TextThreadString(thread)) }, { "current select", ui->ttCombo->currentText().startsWith(TextThreadString(thread)) },
{ "text number", thread->handle }, { "text number", thread.handle },
{ "process id", thread->tp.processId }, { "process id", thread.tp.processId },
{ "hook address", thread->tp.addr }, { "hook address", thread.tp.addr },
{ "text handle", thread->handle }, { "text handle", thread.handle },
{ "text name", (int64_t)thread->name.c_str() } { "text name", (int64_t)thread.name.c_str() }
}; };
} }
@ -280,9 +280,9 @@ void MainWindow::SaveHooks()
} }
} }
auto hookInfo = QStringList() << S(processName.value()) << hookCodes.values(); auto hookInfo = QStringList() << S(processName.value()) << hookCodes.values();
TextThread* current = Host::GetThread(ParseTextThreadString(ui->ttCombo->currentText())); TextThread& current = Host::GetThread(ParseTextThreadString(ui->ttCombo->currentText()));
if (current->tp.processId == GetSelectedProcessId()) if (current.tp.processId == GetSelectedProcessId())
hookInfo << QString("|%1:%2:%3").arg(current->tp.ctx).arg(current->tp.ctx2).arg(S(Util::GenerateCode(Host::GetHookParam(current->tp), current->tp.processId))); hookInfo << QString("|%1:%2:%3").arg(current.tp.ctx).arg(current.tp.ctx2).arg(S(Util::GenerateCode(Host::GetHookParam(current.tp), current.tp.processId)));
QTextFile(HOOK_SAVE_FILE, QIODevice::WriteOnly | QIODevice::Append).write((hookInfo.join(" , ") + "\n").toUtf8()); QTextFile(HOOK_SAVE_FILE, QIODevice::WriteOnly | QIODevice::Append).write((hookInfo.join(" , ") + "\n").toUtf8());
} }
} }
@ -328,6 +328,6 @@ void MainWindow::Extensions()
void MainWindow::ViewThread(int index) void MainWindow::ViewThread(int index)
{ {
ui->ttCombo->setCurrentIndex(index); ui->ttCombo->setCurrentIndex(index);
ui->textOutput->setPlainText(S(Host::GetThread(ParseTextThreadString(ui->ttCombo->itemText(index)))->storage->c_str())); ui->textOutput->setPlainText(S(Host::GetThread(ParseTextThreadString(ui->ttCombo->itemText(index))).storage->c_str()));
ui->textOutput->moveCursor(QTextCursor::End); ui->textOutput->moveCursor(QTextCursor::End);
} }

View File

@ -20,13 +20,13 @@ private:
void closeEvent(QCloseEvent*) override; void closeEvent(QCloseEvent*) override;
void ProcessConnected(DWORD processId); void ProcessConnected(DWORD processId);
void ProcessDisconnected(DWORD processId); void ProcessDisconnected(DWORD processId);
void ThreadAdded(TextThread* thread); void ThreadAdded(TextThread& thread);
void ThreadRemoved(TextThread* thread); void ThreadRemoved(TextThread& thread);
bool SentenceReceived(TextThread* thread, std::wstring& sentence); bool SentenceReceived(TextThread& thread, std::wstring& sentence);
QString TextThreadString(TextThread* thread); QString TextThreadString(TextThread& thread);
ThreadParam ParseTextThreadString(QString ttString); ThreadParam ParseTextThreadString(QString ttString);
DWORD GetSelectedProcessId(); DWORD GetSelectedProcessId();
std::unordered_map<const char*, int64_t> GetMiscInfo(TextThread* thread); std::unordered_map<const char*, int64_t> GetMiscInfo(TextThread& thread);
void AttachProcess(); void AttachProcess();
void LaunchProcess(); void LaunchProcess();
void DetachProcess(); void DetachProcess();