From b3b236646863811579f8ab49c4310d7bb7fbfc33 Mon Sep 17 00:00:00 2001 From: Daniel Wolf Date: Wed, 27 Jul 2016 21:15:00 +0200 Subject: [PATCH] Re-written library code for parallel execution The new implementation correctly re-throws exceptions on the calling thread instead of terminating the application. --- CMakeLists.txt | 2 +- src/ThreadPool.cpp | 78 ---------------------- src/ThreadPool.h | 86 ------------------------- src/audio/voiceActivityDetection.cpp | 10 ++- src/parallel.h | 96 ++++++++++++++++++++++++++++ src/phoneExtraction.cpp | 9 +-- 6 files changed, 104 insertions(+), 177 deletions(-) delete mode 100644 src/ThreadPool.cpp delete mode 100644 src/ThreadPool.h create mode 100644 src/parallel.h diff --git a/CMakeLists.txt b/CMakeLists.txt index f7daf70..2b9390b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -222,7 +222,7 @@ set(SOURCE_FILES src/g2p.cpp src/g2p.h src/languageModels.cpp src/languageModels.h src/tupleHash.h - src/ThreadPool.cpp src/ThreadPool.h + src/parallel.h src/ObjectPool.h src/Lazy.h ) diff --git a/src/ThreadPool.cpp b/src/ThreadPool.cpp deleted file mode 100644 index 159fb8d..0000000 --- a/src/ThreadPool.cpp +++ /dev/null @@ -1,78 +0,0 @@ -#include "ThreadPool.h" - -int ThreadPool::getRecommendedThreadCount() { - int coreCount = std::thread::hardware_concurrency(); - - // If the number of cores cannot be determined, use a reasonable default - return coreCount != 0 ? coreCount : 4; -} - -ThreadPool::ThreadPool(int threadCount) : - threadCount(threadCount), - remainingJobCount(0), - bailout(false) { - for (int i = 0; i < threadCount; ++i) { - threads.push_back(std::thread([&] { - Task(); - })); - } -} - -ThreadPool::~ThreadPool() { - waitAll(); - - // Notify that we're done, and wake up any threads that are waiting for a new job - bailout = true; - jobAvailableCondition.notify_all(); - - for (auto& thread : threads) { - if (thread.joinable()) { - thread.join(); - } - } -} - -void ThreadPool::schedule(job_t job) { - std::lock_guard guard(queueMutex); - jobQueue.emplace_back(job); - ++remainingJobCount; - jobAvailableCondition.notify_one(); -} - -void ThreadPool::waitAll() { - if (remainingJobCount == 0) return; - - std::unique_lock lock(waitMutex); - waitCondition.wait(lock, [&] { - return remainingJobCount == 0; - }); - lock.unlock(); -} - -void ThreadPool::Task() { - while (!bailout) { - getNextJob()(); - --remainingJobCount; - waitCondition.notify_one(); - } -} - -ThreadPool::job_t ThreadPool::getNextJob() { - std::unique_lock jobLock(queueMutex); - - // Wait for a job if we don't have any - jobAvailableCondition.wait(jobLock, [&] { - return jobQueue.size() > 0 || bailout; - }); - - if (bailout) { - // Return a dummy job to keep remainingJobCount accurate - ++remainingJobCount; - return [] {}; - } - - // Get job from the queue - auto result = jobQueue.front(); - jobQueue.pop_front(); - return result; -} diff --git a/src/ThreadPool.h b/src/ThreadPool.h deleted file mode 100644 index f7ea6b6..0000000 --- a/src/ThreadPool.h +++ /dev/null @@ -1,86 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include "progressBar.h" - -// Thread pool based on https://github.com/nbsdx/ThreadPool, which is in the public domain. - -class ThreadPool { -public: - using job_t = std::function; - - static int getRecommendedThreadCount(); - - ThreadPool(int threadCount = getRecommendedThreadCount()); - - ~ThreadPool(); - - // Gets the number of threads in this pool - int getThreadCount() const { - return threadCount; - } - - // Gets the number of jobs left in the queue - int getRemainingJobCount() const { - return remainingJobCount; - } - - // Adds a new job to the pool. - // If there are no queued jobs, a thread is woken up to take the job. - // If all threads are busy, the job is added to the end of the queue. - void schedule(job_t job); - - // Asynchronously runs a function for every element of a collection. - template - void schedule( - TCollection& collection, - std::function processElement, - ProgressSink& progressSink, - std::function getElementProgressWeight = [](typename TCollection::reference) { return 1.0; }); - - // Blocks until all jobs have finshed executing - void waitAll(); - -private: - const int threadCount; - std::vector threads; - std::list jobQueue; - std::atomic_int remainingJobCount; // The number of queued or running jobs - std::atomic_bool bailout; - std::condition_variable jobAvailableCondition; - std::condition_variable waitCondition; - std::mutex waitMutex; - std::mutex queueMutex; - - // Takes the next job in the queue and run it. - // Notify the main thread that a job has completed. - void Task(); - - // Gets the next job; pop the first item in the queue, - // otherwise wait for a signal from the main thread - job_t getNextJob(); -}; - -template -void ThreadPool::schedule( - TCollection& collection, - std::function processElement, - ProgressSink& progressSink, - std::function getElementProgressWeight) -{ - // Use shared pointer to keep progress merger alive throughout execution - auto progressMerger = std::make_shared(progressSink); - - // Schedule all elements - for (auto& element : collection) { - ProgressSink& elementProgressSink = progressMerger->addSink(getElementProgressWeight(element)); - schedule([processElement, &element, &elementProgressSink, progressMerger /* Keep progressMerger alive! */] { - processElement(element, elementProgressSink); - }); - } -} diff --git a/src/audio/voiceActivityDetection.cpp b/src/audio/voiceActivityDetection.cpp index 1eb6c0e..b5d9905 100644 --- a/src/audio/voiceActivityDetection.cpp +++ b/src/audio/voiceActivityDetection.cpp @@ -7,7 +7,7 @@ #include #include "processing.h" #include -#include +#include #include "AudioStreamSegment.h" using std::vector; @@ -70,15 +70,14 @@ BoundedTimeline detectVoiceActivity(std::unique_ptr audioStre std::mutex activityMutex; // Split audio into segments and perform parallel VAD - ThreadPool threadPool; - int segmentCount = threadPool.getThreadCount(); + int segmentCount = getProcessorCoreCount(); centiseconds audioLength = audioStream->getTruncatedRange().getLength(); vector audioSegments; for (int i = 0; i < segmentCount; ++i) { TimeRange segmentRange = TimeRange(i * audioLength / segmentCount, (i + 1) * audioLength / segmentCount); audioSegments.push_back(segmentRange); } - threadPool.schedule(audioSegments, [&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) { + runParallel([&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) { unique_ptr audioSegment = createSegment(audioStream->clone(false), segmentRange); BoundedTimeline activitySegment = webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink); @@ -87,8 +86,7 @@ BoundedTimeline detectVoiceActivity(std::unique_ptr audioStre activityRange.getTimeRange().shift(segmentRange.getStart()); activity.set(activityRange); } - }, progressSink); - threadPool.waitAll(); + }, audioSegments, segmentCount, progressSink); // Fill small gaps in activity const centiseconds maxGap(5); diff --git a/src/parallel.h b/src/parallel.h new file mode 100644 index 0000000..4a6daa8 --- /dev/null +++ b/src/parallel.h @@ -0,0 +1,96 @@ +#pragma once + +#include +#include "ProgressBar.h" +#include +#include + +template +void runParallel( + std::function processElement, + TCollection& collection, + int maxThreadCount) +{ + using future_type = std::future; + + std::mutex mutex; + int currentThreadCount = 0; + std::condition_variable elementFinished; + future_type finishedElement; + + // Before exiting, wait for all running tasks to finish, but don't re-throw exceptions. + // This only applies if one task already failed with an exception. + auto finishRunning = gsl::finally([&]{ + std::unique_lock lock(mutex); + elementFinished.wait(lock, [&] { return currentThreadCount == 0; }); + }); + + // Asyncronously run all elements + for (auto it = collection.begin(); it != collection.end(); ++it) { + // This variable will later hold the future, but can be value-captured right now + auto future = std::make_shared(); + + // Notifies that an element is done processing + auto notifyElementDone = [&, future] { + std::lock_guard lock(mutex); + finishedElement = std::move(*future); + --currentThreadCount; + elementFinished.notify_one(); + }; + + // Processes the current element, then notifies + auto wrapperFunction = [processElement, &element = *it, notifyElementDone]() { + auto done = gsl::finally(notifyElementDone); + processElement(element); + }; + + // Asynchronously process element + { + std::lock_guard lock(mutex); + *future = std::async(std::launch::async, wrapperFunction); + ++currentThreadCount; + } + + // Wait for threads to finish, if necessary + { + std::unique_lock lock(mutex); + int targetThreadCount = it == collection.end() ? 0 : maxThreadCount - 1; + while (currentThreadCount > targetThreadCount) { + elementFinished.wait(lock); + if (finishedElement.valid()) { + // Re-throw any exception + finishedElement.get(); + finishedElement = future_type(); + } + } + } + } + +} + +template +void runParallel( + std::function processElement, + TCollection& collection, + int maxThreadCount, + ProgressSink& progressSink, + std::function getElementProgressWeight = [](typename TCollection::reference) { return 1.0; }) +{ + // Create a collection of wrapper functions that take care of progress handling + ProgressMerger progressMerger(progressSink); + std::vector> functions; + for (auto& element : collection) { + auto& elementProgressSink = progressMerger.addSink(getElementProgressWeight(element)); + functions.push_back([&]() { processElement(element, elementProgressSink); }); + } + + // Run wrapper function + runParallel([&](std::function function) { function(); }, functions, maxThreadCount); +} + +inline int getProcessorCoreCount() { + int coreCount = std::thread::hardware_concurrency(); + + // If the number of cores cannot be determined, use a reasonable default + return coreCount != 0 ? coreCount : 4; +} diff --git a/src/phoneExtraction.cpp b/src/phoneExtraction.cpp index 504adc2..43ecce6 100644 --- a/src/phoneExtraction.cpp +++ b/src/phoneExtraction.cpp @@ -17,8 +17,7 @@ #include "g2p.h" #include "ContinuousTimeline.h" #include "audio/processing.h" -#include "ThreadPool.h" -#include "ObjectPool.h" +#include "parallel.h" extern "C" { #include @@ -401,16 +400,14 @@ BoundedTimeline detectPhones( // Determine how many parallel threads to use int threadCount = std::min({ // Don't use more threads than there are CPU cores - ThreadPool::getRecommendedThreadCount(), + getProcessorCoreCount(), // Don't use more threads than there are utterances to be processed static_cast(utterances.size()), // Don't waste time creating additional threads (and decoders!) if the recording is short static_cast(duration_cast(audioStream->getTruncatedRange().getLength()).count() / 10) }); - ThreadPool threadPool(threadCount); logging::debug("Speech recognition -- start"); - threadPool.schedule(utterances, processUtterance, dialogProgressSink, getUtteranceProgressWeight); - threadPool.waitAll(); + runParallel(processUtterance, utterances, threadCount, dialogProgressSink, getUtteranceProgressWeight); logging::debug("Speech recognition -- end"); return result;