fix race condition (i think)

This commit is contained in:
Akash Mozumdar 2019-01-06 02:57:52 -05:00
parent a9249111c0
commit f1e7b4dc70
3 changed files with 20 additions and 12 deletions

View File

@ -53,15 +53,17 @@ namespace
WinMutex viewMutex; WinMutex viewMutex;
}; };
ThreadSafe<std::unordered_map<ThreadParam, std::shared_ptr<TextThread>>> textThreadsByParams; ThreadSafe<std::unordered_map<ThreadParam, std::unique_ptr<TextThread>>, std::recursive_mutex> textThreadsByParams;
ThreadSafe<std::unordered_map<DWORD, ProcessRecord>> processRecordsByIds; ThreadSafe<std::unordered_map<DWORD, ProcessRecord>, std::recursive_mutex> processRecordsByIds;
ThreadParam CONSOLE{ 0, -1ULL, -1ULL, -1ULL }, CLIPBOARD{ 0, 0, -1ULL, -1ULL }; ThreadParam CONSOLE{ 0, -1ULL, -1ULL, -1ULL }, CLIPBOARD{ 0, 0, -1ULL, -1ULL };
void RemoveThreads(std::function<bool(ThreadParam)> removeIf) void RemoveThreads(std::function<bool(ThreadParam)> removeIf)
{ {
std::vector<std::unique_ptr<TextThread>> removedThreads;
auto[lock, textThreadsByParams] = ::textThreadsByParams.operator->(); auto[lock, textThreadsByParams] = ::textThreadsByParams.operator->();
for (auto it = textThreadsByParams->begin(); it != textThreadsByParams->end(); removeIf(it->first) ? it = textThreadsByParams->erase(it) : ++it); for (auto it = textThreadsByParams->begin(); it != textThreadsByParams->end(); removeIf(it->first) ? it = textThreadsByParams->erase(it) : ++it)
if (removeIf(it->first)) removedThreads.emplace_back(std::move(it->second));
} }
void CreatePipe() void CreatePipe()
@ -104,7 +106,7 @@ namespace
default: default:
{ {
auto tp = *(ThreadParam*)buffer; auto tp = *(ThreadParam*)buffer;
if (textThreadsByParams->count(tp) == 0) textThreadsByParams->insert({ tp, std::make_shared<TextThread>(tp, Host::GetHookParam(tp)) }); if (textThreadsByParams->count(tp) == 0) textThreadsByParams->insert({ tp, std::make_unique<TextThread>(tp, Host::GetHookParam(tp)) });
textThreadsByParams->at(tp)->Push(buffer + sizeof(tp), bytesRead - sizeof(tp)); textThreadsByParams->at(tp)->Push(buffer + sizeof(tp), bytesRead - sizeof(tp));
} }
break; break;
@ -136,8 +138,8 @@ namespace Host
TextThread::OnDestroy = OnDestroy; TextThread::OnDestroy = OnDestroy;
TextThread::Output = Output; TextThread::Output = Output;
processRecordsByIds->try_emplace(CONSOLE.processId, CONSOLE.processId, INVALID_HANDLE_VALUE); processRecordsByIds->try_emplace(CONSOLE.processId, CONSOLE.processId, INVALID_HANDLE_VALUE);
textThreadsByParams->insert({ CONSOLE, std::make_shared<TextThread>(CONSOLE, HookParam{}, L"Console") }); textThreadsByParams->insert({ CONSOLE, std::make_unique<TextThread>(CONSOLE, HookParam{}, L"Console") });
textThreadsByParams->insert({ CLIPBOARD, std::make_shared<TextThread>(CLIPBOARD, HookParam{}, L"Clipboard") }); textThreadsByParams->insert({ CLIPBOARD, std::make_unique<TextThread>(CLIPBOARD, HookParam{}, L"Clipboard") });
StartCapturingClipboard(); StartCapturingClipboard();
CreatePipe(); CreatePipe();
} }
@ -199,9 +201,9 @@ namespace Host
return processRecordsByIds->at(tp.processId).GetHook(tp.addr).hp; return processRecordsByIds->at(tp.processId).GetHook(tp.addr).hp;
} }
std::shared_ptr<TextThread> GetThread(ThreadParam tp) TextThread* GetThread(ThreadParam tp)
{ {
return textThreadsByParams->at(tp); return textThreadsByParams->at(tp).get();
} }
void AddConsoleOutput(std::wstring text) void AddConsoleOutput(std::wstring text)

View File

@ -14,7 +14,7 @@ namespace Host
HookParam GetHookParam(ThreadParam tp); HookParam GetHookParam(ThreadParam tp);
std::shared_ptr<TextThread> GetThread(ThreadParam tp); TextThread* GetThread(ThreadParam tp);
void AddConsoleOutput(std::wstring text); void AddConsoleOutput(std::wstring text);
inline int defaultCodepage = SHIFT_JIS; inline int defaultCodepage = SHIFT_JIS;

View File

@ -102,13 +102,19 @@ void MainWindow::ProcessConnected(DWORD processId)
void MainWindow::ProcessDisconnected(DWORD processId) void MainWindow::ProcessDisconnected(DWORD processId)
{ {
QMetaObject::invokeMethod(this, [this, processId] { ui->processCombo->removeItem(ui->processCombo->findText(QString::number(processId, 16).toUpper() + ":", Qt::MatchStartsWith)); }); QMetaObject::invokeMethod(this, [this, processId]
{
ui->processCombo->removeItem(ui->processCombo->findText(QString::number(processId, 16).toUpper() + ":", Qt::MatchStartsWith));
}, Qt::BlockingQueuedConnection);
} }
void MainWindow::ThreadAdded(TextThread* thread) void MainWindow::ThreadAdded(TextThread* thread)
{ {
QString ttString = TextThreadString(thread) + S(thread->name) + " (" + GenerateCode(thread->hp, thread->tp.processId) + ")"; QString ttString = TextThreadString(thread) + S(thread->name) + " (" + GenerateCode(thread->hp, thread->tp.processId) + ")";
QMetaObject::invokeMethod(this, [this, ttString] { ui->ttCombo->addItem(ttString); }); QMetaObject::invokeMethod(this, [this, ttString]
{
ui->ttCombo->addItem(ttString);
});
} }
void MainWindow::ThreadRemoved(TextThread* thread) void MainWindow::ThreadRemoved(TextThread* thread)
@ -123,7 +129,7 @@ void MainWindow::ThreadRemoved(TextThread* thread)
ViewThread(0); ViewThread(0);
} }
ui->ttCombo->removeItem(threadIndex); ui->ttCombo->removeItem(threadIndex);
}); }, Qt::BlockingQueuedConnection);
} }
bool MainWindow::SentenceReceived(TextThread* thread, std::wstring& sentence) bool MainWindow::SentenceReceived(TextThread* thread, std::wstring& sentence)