using shared_ptr to improve thread safety

This commit is contained in:
Akash Mozumdar 2018-10-31 01:20:44 -04:00
parent 8e45b35ebe
commit 1915008d00
6 changed files with 36 additions and 36 deletions

View File

@ -21,7 +21,7 @@ namespace
ThreadEventCallback OnCreate, OnRemove; ThreadEventCallback OnCreate, OnRemove;
ProcessEventCallback OnAttach, OnDetach; ProcessEventCallback OnAttach, OnDetach;
std::unordered_map<ThreadParam, TextThread*> textThreadsByParams; std::unordered_map<ThreadParam, std::shared_ptr<TextThread>> textThreadsByParams;
std::unordered_map<DWORD, ProcessRecord> processRecordsByIds; std::unordered_map<DWORD, ProcessRecord> processRecordsByIds;
std::recursive_mutex hostMutex; std::recursive_mutex hostMutex;
@ -35,7 +35,7 @@ namespace
if (textThreadsByParams[tp] == nullptr) if (textThreadsByParams[tp] == nullptr)
{ {
if (textThreadsByParams.size() > MAX_THREAD_COUNT) return Host::AddConsoleOutput(L"too many text threads: can't create more"); if (textThreadsByParams.size() > MAX_THREAD_COUNT) return Host::AddConsoleOutput(L"too many text threads: can't create more");
OnCreate(textThreadsByParams[tp] = new TextThread(tp)); OnCreate(textThreadsByParams[tp] = std::make_shared<TextThread>(tp));
} }
textThreadsByParams[tp]->AddText(text, len); textThreadsByParams[tp]->AddText(text, len);
} }
@ -43,15 +43,12 @@ namespace
void RemoveThreads(std::function<bool(ThreadParam)> removeIf) void RemoveThreads(std::function<bool(ThreadParam)> removeIf)
{ {
LOCK(hostMutex); LOCK(hostMutex);
std::vector<ThreadParam> removedThreads; for (auto it = textThreadsByParams.begin(); it != textThreadsByParams.end();)
for (auto[tp, thread] : textThreadsByParams) if (auto curr = it++; removeIf(curr->first))
if (removeIf(tp))
{ {
OnRemove(thread); OnRemove(curr->second);
//delete i.second; // Artikash 7/24/2018: FIXME: Qt GUI updates on another thread, so I can't delete this yet. textThreadsByParams.erase(curr->first);
removedThreads.push_back(tp);
} }
for (auto thread : removedThreads) textThreadsByParams.erase(thread);
} }
void RegisterProcess(DWORD pid, HANDLE hostPipe) void RegisterProcess(DWORD pid, HANDLE hostPipe)
@ -143,7 +140,7 @@ namespace Host
void Start(ProcessEventCallback onAttach, ProcessEventCallback onDetach, ThreadEventCallback onCreate, ThreadEventCallback onRemove, TextThread::OutputCallback output) void Start(ProcessEventCallback onAttach, ProcessEventCallback onDetach, ThreadEventCallback onCreate, ThreadEventCallback onRemove, TextThread::OutputCallback output)
{ {
OnAttach = onAttach; OnDetach = onDetach; OnCreate = onCreate; OnRemove = onRemove; TextThread::Output = output; OnAttach = onAttach; OnDetach = onDetach; OnCreate = onCreate; OnRemove = onRemove; TextThread::Output = output;
OnCreate(textThreadsByParams[CONSOLE] = new TextThread(CONSOLE)); OnCreate(textThreadsByParams[CONSOLE] = std::make_shared<TextThread>(CONSOLE));
StartPipe(); StartPipe();
} }
@ -152,9 +149,8 @@ namespace Host
// Artikash 7/25/2018: This is only called when Textractor is closed, at which point Windows should free everything itself...right? // Artikash 7/25/2018: This is only called when Textractor is closed, at which point Windows should free everything itself...right?
#ifdef _DEBUG // Check memory leaks #ifdef _DEBUG // Check memory leaks
LOCK(hostMutex); LOCK(hostMutex);
OnRemove = [](TextThread* textThread) { delete textThread; };
for (auto[pid, pr] : processRecordsByIds) UnregisterProcess(pid); for (auto[pid, pr] : processRecordsByIds) UnregisterProcess(pid);
delete textThreadsByParams[CONSOLE]; textThreadsByParams.clear();
#endif #endif
} }
@ -239,8 +235,6 @@ namespace Host
return ret; return ret;
} }
HookParam GetHookParam(ThreadParam tp) { return GetHookParam(tp.pid, tp.hook); }
std::wstring GetHookName(DWORD pid, uint64_t addr) std::wstring GetHookName(DWORD pid, uint64_t addr)
{ {
if (pid == 0) return L"Console"; if (pid == 0) return L"Console";
@ -260,7 +254,7 @@ namespace Host
return StringToWideString(buffer, CP_UTF8); return StringToWideString(buffer, CP_UTF8);
} }
TextThread* GetThread(ThreadParam tp) std::shared_ptr<TextThread> GetThread(ThreadParam tp)
{ {
LOCK(hostMutex); LOCK(hostMutex);
return textThreadsByParams[tp]; return textThreadsByParams[tp];

View File

@ -8,7 +8,7 @@
#include "textthread.h" #include "textthread.h"
typedef std::function<void(DWORD)> ProcessEventCallback; typedef std::function<void(DWORD)> ProcessEventCallback;
typedef std::function<void(TextThread*)> ThreadEventCallback; typedef std::function<void(std::shared_ptr<TextThread>)> ThreadEventCallback;
namespace Host namespace Host
{ {
@ -22,10 +22,11 @@ namespace Host
void RemoveHook(DWORD pid, uint64_t addr); void RemoveHook(DWORD pid, uint64_t addr);
HookParam GetHookParam(DWORD pid, uint64_t addr); HookParam GetHookParam(DWORD pid, uint64_t addr);
HookParam GetHookParam(ThreadParam tp); inline HookParam GetHookParam(ThreadParam tp) { return GetHookParam(tp.pid, tp.hook); }
std::wstring GetHookName(DWORD pid, uint64_t addr); std::wstring GetHookName(DWORD pid, uint64_t addr);
inline std::wstring GetHookName(ThreadParam tp) { return GetHookName(tp.pid, tp.hook); }
TextThread* GetThread(ThreadParam tp); std::shared_ptr<TextThread> GetThread(ThreadParam tp);
void AddConsoleOutput(std::wstring text); void AddConsoleOutput(std::wstring text);
} }

View File

@ -8,7 +8,7 @@
#include <regex> #include <regex>
#include <algorithm> #include <algorithm>
TextThread::TextThread(ThreadParam tp) : handle(threadCounter++), name(Host::GetHookName(tp.pid, tp.hook)), tp(tp), hp(Host::GetHookParam(tp)) {} TextThread::TextThread(ThreadParam tp) : handle(threadCounter++), name(Host::GetHookName(tp)), tp(tp), hp(Host::GetHookParam(tp)) {}
TextThread::~TextThread() TextThread::~TextThread()
{ {

View File

@ -35,6 +35,8 @@ MainWindow::MainWindow(QWidget *parent) :
if (settings.contains("Flush_Delay")) TextThread::flushDelay = settings.value("Flush_Delay").toInt(); if (settings.contains("Flush_Delay")) TextThread::flushDelay = settings.value("Flush_Delay").toInt();
if (settings.contains("Max_Buffer_Size")) TextThread::maxBufferSize = settings.value("Max_Buffer_Size").toInt(); if (settings.contains("Max_Buffer_Size")) TextThread::maxBufferSize = settings.value("Max_Buffer_Size").toInt();
qRegisterMetaType<std::shared_ptr<TextThread>>();
connect(this, &MainWindow::SigAddProcess, this, &MainWindow::AddProcess); connect(this, &MainWindow::SigAddProcess, this, &MainWindow::AddProcess);
connect(this, &MainWindow::SigRemoveProcess, this, &MainWindow::RemoveProcess); connect(this, &MainWindow::SigRemoveProcess, this, &MainWindow::RemoveProcess);
connect(this, &MainWindow::SigAddThread, this, &MainWindow::AddThread); connect(this, &MainWindow::SigAddThread, this, &MainWindow::AddThread);
@ -44,8 +46,8 @@ MainWindow::MainWindow(QWidget *parent) :
Host::Start( Host::Start(
[&](DWORD processId) { emit SigAddProcess(processId); }, [&](DWORD processId) { emit SigAddProcess(processId); },
[&](DWORD processId) { emit SigRemoveProcess(processId); }, [&](DWORD processId) { emit SigRemoveProcess(processId); },
[&](TextThread* thread) { emit SigAddThread(thread); }, [&](std::shared_ptr<TextThread> thread) { emit SigAddThread(thread); },
[&](TextThread* thread) { emit SigRemoveThread(thread); }, [&](std::shared_ptr<TextThread> thread) { emit SigRemoveThread(thread); },
[&](TextThread* thread, std::wstring& output) { return ProcessThreadOutput(thread, output); } [&](TextThread* thread, std::wstring& output) { return ProcessThreadOutput(thread, output); }
); );
Host::AddConsoleOutput(L"Textractor beta v3.3.2 by Artikash\r\nSource code and more information available under GPLv3 at https://github.com/Artikash/Textractor"); Host::AddConsoleOutput(L"Textractor beta v3.3.2 by Artikash\r\nSource code and more information available under GPLv3 at https://github.com/Artikash/Textractor");
@ -82,32 +84,31 @@ void MainWindow::RemoveProcess(unsigned processId)
processCombo->removeItem(processCombo->findText(QString::number(processId, 16).toUpper() + ":", Qt::MatchStartsWith)); processCombo->removeItem(processCombo->findText(QString::number(processId, 16).toUpper() + ":", Qt::MatchStartsWith));
} }
void MainWindow::AddThread(TextThread* thread) void MainWindow::AddThread(std::shared_ptr<TextThread> thread)
{ {
ttCombo->addItem( ttCombo->addItem(
TextThreadString(thread) + TextThreadString(thread.get()) +
QString::fromStdWString(thread->name) + QString::fromStdWString(thread->name) +
" (" + " (" +
GenerateCode(Host::GetHookParam(thread->tp), thread->tp.pid) + GenerateCode(thread->hp, thread->tp.pid) +
")" ")"
); );
} }
void MainWindow::RemoveThread(TextThread* thread) void MainWindow::RemoveThread(std::shared_ptr<TextThread> thread)
{ {
int threadIndex = ttCombo->findText(TextThreadString(thread), Qt::MatchStartsWith); int threadIndex = ttCombo->findText(TextThreadString(thread.get()), Qt::MatchStartsWith);
if (threadIndex == ttCombo->currentIndex()) if (threadIndex == ttCombo->currentIndex())
{ {
ttCombo->setCurrentIndex(0); ttCombo->setCurrentIndex(0);
on_ttCombo_activated(0); on_ttCombo_activated(0);
} }
ttCombo->removeItem(threadIndex); ttCombo->removeItem(threadIndex);
delete thread;
} }
void MainWindow::ThreadOutput(TextThread* thread, QString output) void MainWindow::ThreadOutput(QString threadString, QString output)
{ {
if (ttCombo->currentText().startsWith(TextThreadString(thread))) if (ttCombo->currentText().startsWith(threadString))
{ {
textOutput->moveCursor(QTextCursor::End); textOutput->moveCursor(QTextCursor::End);
textOutput->insertPlainText(output); textOutput->insertPlainText(output);
@ -120,7 +121,7 @@ bool MainWindow::ProcessThreadOutput(TextThread* thread, std::wstring& output)
if (Extension::DispatchSentence(output, GetInfoForExtensions(thread))) if (Extension::DispatchSentence(output, GetInfoForExtensions(thread)))
{ {
output += L"\r\n"; output += L"\r\n";
emit SigThreadOutput(thread, QString::fromStdWString(output)); emit SigThreadOutput(TextThreadString(thread), QString::fromStdWString(output));
return true; return true;
} }
return false; return false;

View File

@ -13,6 +13,8 @@ namespace Ui
class MainWindow; class MainWindow;
} }
Q_DECLARE_METATYPE(std::shared_ptr<TextThread>);
class MainWindow : public QMainWindow class MainWindow : public QMainWindow
{ {
Q_OBJECT Q_OBJECT
@ -24,16 +26,16 @@ public:
signals: signals:
void SigAddProcess(unsigned processId); void SigAddProcess(unsigned processId);
void SigRemoveProcess(unsigned processId); void SigRemoveProcess(unsigned processId);
void SigAddThread(TextThread* thread); void SigAddThread(std::shared_ptr<TextThread>);
void SigRemoveThread(TextThread* thread); void SigRemoveThread(std::shared_ptr<TextThread>);
void SigThreadOutput(TextThread* thread, QString output); void SigThreadOutput(QString threadString, QString output);
private slots: private slots:
void AddProcess(unsigned processId); void AddProcess(unsigned processId);
void RemoveProcess(unsigned processId); void RemoveProcess(unsigned processId);
void AddThread(TextThread* thread); void AddThread(std::shared_ptr<TextThread> thread);
void RemoveThread(TextThread* thread); void RemoveThread(std::shared_ptr<TextThread> thread);
void ThreadOutput(TextThread* thread, QString output); void ThreadOutput(QString threadString, QString output); // this function doesn't take TextThread* because it might be destroyed on pipe thread
void on_attachButton_clicked(); void on_attachButton_clicked();
void on_detachButton_clicked(); void on_detachButton_clicked();
void on_ttCombo_activated(int index); void on_ttCombo_activated(int index);

View File

@ -6,6 +6,8 @@
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include <functional> #include <functional>
#include <algorithm>
#include <memory>
#include <optional> #include <optional>
#include <thread> #include <thread>
#include <mutex> #include <mutex>