Re-written library code for parallel execution

The new implementation correctly re-throws exceptions on the calling thread
instead of terminating the application.
This commit is contained in:
Daniel Wolf 2016-07-27 21:15:00 +02:00
parent 5198ee9230
commit b3b2366468
6 changed files with 104 additions and 177 deletions

View File

@ -222,7 +222,7 @@ set(SOURCE_FILES
src/g2p.cpp src/g2p.h
src/languageModels.cpp src/languageModels.h
src/tupleHash.h
src/ThreadPool.cpp src/ThreadPool.h
src/parallel.h
src/ObjectPool.h
src/Lazy.h
)

View File

@ -1,78 +0,0 @@
#include "ThreadPool.h"
int ThreadPool::getRecommendedThreadCount() {
int coreCount = std::thread::hardware_concurrency();
// If the number of cores cannot be determined, use a reasonable default
return coreCount != 0 ? coreCount : 4;
}
ThreadPool::ThreadPool(int threadCount) :
threadCount(threadCount),
remainingJobCount(0),
bailout(false) {
for (int i = 0; i < threadCount; ++i) {
threads.push_back(std::thread([&] {
Task();
}));
}
}
ThreadPool::~ThreadPool() {
waitAll();
// Notify that we're done, and wake up any threads that are waiting for a new job
bailout = true;
jobAvailableCondition.notify_all();
for (auto& thread : threads) {
if (thread.joinable()) {
thread.join();
}
}
}
void ThreadPool::schedule(job_t job) {
std::lock_guard<std::mutex> guard(queueMutex);
jobQueue.emplace_back(job);
++remainingJobCount;
jobAvailableCondition.notify_one();
}
void ThreadPool::waitAll() {
if (remainingJobCount == 0) return;
std::unique_lock<std::mutex> lock(waitMutex);
waitCondition.wait(lock, [&] {
return remainingJobCount == 0;
});
lock.unlock();
}
void ThreadPool::Task() {
while (!bailout) {
getNextJob()();
--remainingJobCount;
waitCondition.notify_one();
}
}
ThreadPool::job_t ThreadPool::getNextJob() {
std::unique_lock<std::mutex> jobLock(queueMutex);
// Wait for a job if we don't have any
jobAvailableCondition.wait(jobLock, [&] {
return jobQueue.size() > 0 || bailout;
});
if (bailout) {
// Return a dummy job to keep remainingJobCount accurate
++remainingJobCount;
return [] {};
}
// Get job from the queue
auto result = jobQueue.front();
jobQueue.pop_front();
return result;
}

View File

@ -1,86 +0,0 @@
#pragma once
#include <atomic>
#include <thread>
#include <mutex>
#include <list>
#include <functional>
#include <vector>
#include "progressBar.h"
// Thread pool based on https://github.com/nbsdx/ThreadPool, which is in the public domain.
class ThreadPool {
public:
using job_t = std::function<void(void)>;
static int getRecommendedThreadCount();
ThreadPool(int threadCount = getRecommendedThreadCount());
~ThreadPool();
// Gets the number of threads in this pool
int getThreadCount() const {
return threadCount;
}
// Gets the number of jobs left in the queue
int getRemainingJobCount() const {
return remainingJobCount;
}
// Adds a new job to the pool.
// If there are no queued jobs, a thread is woken up to take the job.
// If all threads are busy, the job is added to the end of the queue.
void schedule(job_t job);
// Asynchronously runs a function for every element of a collection.
template<typename TCollection>
void schedule(
TCollection& collection,
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
ProgressSink& progressSink,
std::function<double(const typename TCollection::reference)> getElementProgressWeight = [](typename TCollection::reference) { return 1.0; });
// Blocks until all jobs have finshed executing
void waitAll();
private:
const int threadCount;
std::vector<std::thread> threads;
std::list<job_t> jobQueue;
std::atomic_int remainingJobCount; // The number of queued or running jobs
std::atomic_bool bailout;
std::condition_variable jobAvailableCondition;
std::condition_variable waitCondition;
std::mutex waitMutex;
std::mutex queueMutex;
// Takes the next job in the queue and run it.
// Notify the main thread that a job has completed.
void Task();
// Gets the next job; pop the first item in the queue,
// otherwise wait for a signal from the main thread
job_t getNextJob();
};
template <typename TCollection>
void ThreadPool::schedule(
TCollection& collection,
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
ProgressSink& progressSink,
std::function<double(const typename TCollection::reference)> getElementProgressWeight)
{
// Use shared pointer to keep progress merger alive throughout execution
auto progressMerger = std::make_shared<ProgressMerger>(progressSink);
// Schedule all elements
for (auto& element : collection) {
ProgressSink& elementProgressSink = progressMerger->addSink(getElementProgressWeight(element));
schedule([processElement, &element, &elementProgressSink, progressMerger /* Keep progressMerger alive! */] {
processElement(element, elementProgressSink);
});
}
}

View File

@ -7,7 +7,7 @@
#include <webrtc/common_audio/vad/include/webrtc_vad.h>
#include "processing.h"
#include <gsl_util.h>
#include <ThreadPool.h>
#include <parallel.h>
#include "AudioStreamSegment.h"
using std::vector;
@ -70,15 +70,14 @@ BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStre
std::mutex activityMutex;
// Split audio into segments and perform parallel VAD
ThreadPool threadPool;
int segmentCount = threadPool.getThreadCount();
int segmentCount = getProcessorCoreCount();
centiseconds audioLength = audioStream->getTruncatedRange().getLength();
vector<TimeRange> audioSegments;
for (int i = 0; i < segmentCount; ++i) {
TimeRange segmentRange = TimeRange(i * audioLength / segmentCount, (i + 1) * audioLength / segmentCount);
audioSegments.push_back(segmentRange);
}
threadPool.schedule(audioSegments, [&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
runParallel([&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
unique_ptr<AudioStream> audioSegment = createSegment(audioStream->clone(false), segmentRange);
BoundedTimeline<void> activitySegment = webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
@ -87,8 +86,7 @@ BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStre
activityRange.getTimeRange().shift(segmentRange.getStart());
activity.set(activityRange);
}
}, progressSink);
threadPool.waitAll();
}, audioSegments, segmentCount, progressSink);
// Fill small gaps in activity
const centiseconds maxGap(5);

96
src/parallel.h Normal file
View File

@ -0,0 +1,96 @@
#pragma once
#include <functional>
#include "ProgressBar.h"
#include <boost/optional/optional.hpp>
#include <gsl_util.h>
template<typename TCollection>
void runParallel(
std::function<void(typename TCollection::reference)> processElement,
TCollection& collection,
int maxThreadCount)
{
using future_type = std::future<void>;
std::mutex mutex;
int currentThreadCount = 0;
std::condition_variable elementFinished;
future_type finishedElement;
// Before exiting, wait for all running tasks to finish, but don't re-throw exceptions.
// This only applies if one task already failed with an exception.
auto finishRunning = gsl::finally([&]{
std::unique_lock<std::mutex> lock(mutex);
elementFinished.wait(lock, [&] { return currentThreadCount == 0; });
});
// Asyncronously run all elements
for (auto it = collection.begin(); it != collection.end(); ++it) {
// This variable will later hold the future, but can be value-captured right now
auto future = std::make_shared<future_type>();
// Notifies that an element is done processing
auto notifyElementDone = [&, future] {
std::lock_guard<std::mutex> lock(mutex);
finishedElement = std::move(*future);
--currentThreadCount;
elementFinished.notify_one();
};
// Processes the current element, then notifies
auto wrapperFunction = [processElement, &element = *it, notifyElementDone]() {
auto done = gsl::finally(notifyElementDone);
processElement(element);
};
// Asynchronously process element
{
std::lock_guard<std::mutex> lock(mutex);
*future = std::async(std::launch::async, wrapperFunction);
++currentThreadCount;
}
// Wait for threads to finish, if necessary
{
std::unique_lock<std::mutex> lock(mutex);
int targetThreadCount = it == collection.end() ? 0 : maxThreadCount - 1;
while (currentThreadCount > targetThreadCount) {
elementFinished.wait(lock);
if (finishedElement.valid()) {
// Re-throw any exception
finishedElement.get();
finishedElement = future_type();
}
}
}
}
}
template<typename TCollection>
void runParallel(
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
TCollection& collection,
int maxThreadCount,
ProgressSink& progressSink,
std::function<double(const typename TCollection::reference)> getElementProgressWeight = [](typename TCollection::reference) { return 1.0; })
{
// Create a collection of wrapper functions that take care of progress handling
ProgressMerger progressMerger(progressSink);
std::vector<std::function<void()>> functions;
for (auto& element : collection) {
auto& elementProgressSink = progressMerger.addSink(getElementProgressWeight(element));
functions.push_back([&]() { processElement(element, elementProgressSink); });
}
// Run wrapper function
runParallel([&](std::function<void()> function) { function(); }, functions, maxThreadCount);
}
inline int getProcessorCoreCount() {
int coreCount = std::thread::hardware_concurrency();
// If the number of cores cannot be determined, use a reasonable default
return coreCount != 0 ? coreCount : 4;
}

View File

@ -17,8 +17,7 @@
#include "g2p.h"
#include "ContinuousTimeline.h"
#include "audio/processing.h"
#include "ThreadPool.h"
#include "ObjectPool.h"
#include "parallel.h"
extern "C" {
#include <pocketsphinx.h>
@ -401,16 +400,14 @@ BoundedTimeline<Phone> detectPhones(
// Determine how many parallel threads to use
int threadCount = std::min({
// Don't use more threads than there are CPU cores
ThreadPool::getRecommendedThreadCount(),
getProcessorCoreCount(),
// Don't use more threads than there are utterances to be processed
static_cast<int>(utterances.size()),
// Don't waste time creating additional threads (and decoders!) if the recording is short
static_cast<int>(duration_cast<std::chrono::seconds>(audioStream->getTruncatedRange().getLength()).count() / 10)
});
ThreadPool threadPool(threadCount);
logging::debug("Speech recognition -- start");
threadPool.schedule(utterances, processUtterance, dialogProgressSink, getUtteranceProgressWeight);
threadPool.waitAll();
runParallel(processUtterance, utterances, threadCount, dialogProgressSink, getUtteranceProgressWeight);
logging::debug("Speech recognition -- end");
return result;