Re-written library code for parallel execution
The new implementation correctly re-throws exceptions on the calling thread instead of terminating the application.
This commit is contained in:
parent
5198ee9230
commit
b3b2366468
|
@ -222,7 +222,7 @@ set(SOURCE_FILES
|
|||
src/g2p.cpp src/g2p.h
|
||||
src/languageModels.cpp src/languageModels.h
|
||||
src/tupleHash.h
|
||||
src/ThreadPool.cpp src/ThreadPool.h
|
||||
src/parallel.h
|
||||
src/ObjectPool.h
|
||||
src/Lazy.h
|
||||
)
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
#include "ThreadPool.h"
|
||||
|
||||
int ThreadPool::getRecommendedThreadCount() {
|
||||
int coreCount = std::thread::hardware_concurrency();
|
||||
|
||||
// If the number of cores cannot be determined, use a reasonable default
|
||||
return coreCount != 0 ? coreCount : 4;
|
||||
}
|
||||
|
||||
ThreadPool::ThreadPool(int threadCount) :
|
||||
threadCount(threadCount),
|
||||
remainingJobCount(0),
|
||||
bailout(false) {
|
||||
for (int i = 0; i < threadCount; ++i) {
|
||||
threads.push_back(std::thread([&] {
|
||||
Task();
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
ThreadPool::~ThreadPool() {
|
||||
waitAll();
|
||||
|
||||
// Notify that we're done, and wake up any threads that are waiting for a new job
|
||||
bailout = true;
|
||||
jobAvailableCondition.notify_all();
|
||||
|
||||
for (auto& thread : threads) {
|
||||
if (thread.joinable()) {
|
||||
thread.join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ThreadPool::schedule(job_t job) {
|
||||
std::lock_guard<std::mutex> guard(queueMutex);
|
||||
jobQueue.emplace_back(job);
|
||||
++remainingJobCount;
|
||||
jobAvailableCondition.notify_one();
|
||||
}
|
||||
|
||||
void ThreadPool::waitAll() {
|
||||
if (remainingJobCount == 0) return;
|
||||
|
||||
std::unique_lock<std::mutex> lock(waitMutex);
|
||||
waitCondition.wait(lock, [&] {
|
||||
return remainingJobCount == 0;
|
||||
});
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
void ThreadPool::Task() {
|
||||
while (!bailout) {
|
||||
getNextJob()();
|
||||
--remainingJobCount;
|
||||
waitCondition.notify_one();
|
||||
}
|
||||
}
|
||||
|
||||
ThreadPool::job_t ThreadPool::getNextJob() {
|
||||
std::unique_lock<std::mutex> jobLock(queueMutex);
|
||||
|
||||
// Wait for a job if we don't have any
|
||||
jobAvailableCondition.wait(jobLock, [&] {
|
||||
return jobQueue.size() > 0 || bailout;
|
||||
});
|
||||
|
||||
if (bailout) {
|
||||
// Return a dummy job to keep remainingJobCount accurate
|
||||
++remainingJobCount;
|
||||
return [] {};
|
||||
}
|
||||
|
||||
// Get job from the queue
|
||||
auto result = jobQueue.front();
|
||||
jobQueue.pop_front();
|
||||
return result;
|
||||
}
|
|
@ -1,86 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <list>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include "progressBar.h"
|
||||
|
||||
// Thread pool based on https://github.com/nbsdx/ThreadPool, which is in the public domain.
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
using job_t = std::function<void(void)>;
|
||||
|
||||
static int getRecommendedThreadCount();
|
||||
|
||||
ThreadPool(int threadCount = getRecommendedThreadCount());
|
||||
|
||||
~ThreadPool();
|
||||
|
||||
// Gets the number of threads in this pool
|
||||
int getThreadCount() const {
|
||||
return threadCount;
|
||||
}
|
||||
|
||||
// Gets the number of jobs left in the queue
|
||||
int getRemainingJobCount() const {
|
||||
return remainingJobCount;
|
||||
}
|
||||
|
||||
// Adds a new job to the pool.
|
||||
// If there are no queued jobs, a thread is woken up to take the job.
|
||||
// If all threads are busy, the job is added to the end of the queue.
|
||||
void schedule(job_t job);
|
||||
|
||||
// Asynchronously runs a function for every element of a collection.
|
||||
template<typename TCollection>
|
||||
void schedule(
|
||||
TCollection& collection,
|
||||
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
|
||||
ProgressSink& progressSink,
|
||||
std::function<double(const typename TCollection::reference)> getElementProgressWeight = [](typename TCollection::reference) { return 1.0; });
|
||||
|
||||
// Blocks until all jobs have finshed executing
|
||||
void waitAll();
|
||||
|
||||
private:
|
||||
const int threadCount;
|
||||
std::vector<std::thread> threads;
|
||||
std::list<job_t> jobQueue;
|
||||
std::atomic_int remainingJobCount; // The number of queued or running jobs
|
||||
std::atomic_bool bailout;
|
||||
std::condition_variable jobAvailableCondition;
|
||||
std::condition_variable waitCondition;
|
||||
std::mutex waitMutex;
|
||||
std::mutex queueMutex;
|
||||
|
||||
// Takes the next job in the queue and run it.
|
||||
// Notify the main thread that a job has completed.
|
||||
void Task();
|
||||
|
||||
// Gets the next job; pop the first item in the queue,
|
||||
// otherwise wait for a signal from the main thread
|
||||
job_t getNextJob();
|
||||
};
|
||||
|
||||
template <typename TCollection>
|
||||
void ThreadPool::schedule(
|
||||
TCollection& collection,
|
||||
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
|
||||
ProgressSink& progressSink,
|
||||
std::function<double(const typename TCollection::reference)> getElementProgressWeight)
|
||||
{
|
||||
// Use shared pointer to keep progress merger alive throughout execution
|
||||
auto progressMerger = std::make_shared<ProgressMerger>(progressSink);
|
||||
|
||||
// Schedule all elements
|
||||
for (auto& element : collection) {
|
||||
ProgressSink& elementProgressSink = progressMerger->addSink(getElementProgressWeight(element));
|
||||
schedule([processElement, &element, &elementProgressSink, progressMerger /* Keep progressMerger alive! */] {
|
||||
processElement(element, elementProgressSink);
|
||||
});
|
||||
}
|
||||
}
|
|
@ -7,7 +7,7 @@
|
|||
#include <webrtc/common_audio/vad/include/webrtc_vad.h>
|
||||
#include "processing.h"
|
||||
#include <gsl_util.h>
|
||||
#include <ThreadPool.h>
|
||||
#include <parallel.h>
|
||||
#include "AudioStreamSegment.h"
|
||||
|
||||
using std::vector;
|
||||
|
@ -70,15 +70,14 @@ BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStre
|
|||
std::mutex activityMutex;
|
||||
|
||||
// Split audio into segments and perform parallel VAD
|
||||
ThreadPool threadPool;
|
||||
int segmentCount = threadPool.getThreadCount();
|
||||
int segmentCount = getProcessorCoreCount();
|
||||
centiseconds audioLength = audioStream->getTruncatedRange().getLength();
|
||||
vector<TimeRange> audioSegments;
|
||||
for (int i = 0; i < segmentCount; ++i) {
|
||||
TimeRange segmentRange = TimeRange(i * audioLength / segmentCount, (i + 1) * audioLength / segmentCount);
|
||||
audioSegments.push_back(segmentRange);
|
||||
}
|
||||
threadPool.schedule(audioSegments, [&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
|
||||
runParallel([&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
|
||||
unique_ptr<AudioStream> audioSegment = createSegment(audioStream->clone(false), segmentRange);
|
||||
BoundedTimeline<void> activitySegment = webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
|
||||
|
||||
|
@ -87,8 +86,7 @@ BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStre
|
|||
activityRange.getTimeRange().shift(segmentRange.getStart());
|
||||
activity.set(activityRange);
|
||||
}
|
||||
}, progressSink);
|
||||
threadPool.waitAll();
|
||||
}, audioSegments, segmentCount, progressSink);
|
||||
|
||||
// Fill small gaps in activity
|
||||
const centiseconds maxGap(5);
|
||||
|
|
|
@ -0,0 +1,96 @@
|
|||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include "ProgressBar.h"
|
||||
#include <boost/optional/optional.hpp>
|
||||
#include <gsl_util.h>
|
||||
|
||||
template<typename TCollection>
|
||||
void runParallel(
|
||||
std::function<void(typename TCollection::reference)> processElement,
|
||||
TCollection& collection,
|
||||
int maxThreadCount)
|
||||
{
|
||||
using future_type = std::future<void>;
|
||||
|
||||
std::mutex mutex;
|
||||
int currentThreadCount = 0;
|
||||
std::condition_variable elementFinished;
|
||||
future_type finishedElement;
|
||||
|
||||
// Before exiting, wait for all running tasks to finish, but don't re-throw exceptions.
|
||||
// This only applies if one task already failed with an exception.
|
||||
auto finishRunning = gsl::finally([&]{
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
elementFinished.wait(lock, [&] { return currentThreadCount == 0; });
|
||||
});
|
||||
|
||||
// Asyncronously run all elements
|
||||
for (auto it = collection.begin(); it != collection.end(); ++it) {
|
||||
// This variable will later hold the future, but can be value-captured right now
|
||||
auto future = std::make_shared<future_type>();
|
||||
|
||||
// Notifies that an element is done processing
|
||||
auto notifyElementDone = [&, future] {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
finishedElement = std::move(*future);
|
||||
--currentThreadCount;
|
||||
elementFinished.notify_one();
|
||||
};
|
||||
|
||||
// Processes the current element, then notifies
|
||||
auto wrapperFunction = [processElement, &element = *it, notifyElementDone]() {
|
||||
auto done = gsl::finally(notifyElementDone);
|
||||
processElement(element);
|
||||
};
|
||||
|
||||
// Asynchronously process element
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
*future = std::async(std::launch::async, wrapperFunction);
|
||||
++currentThreadCount;
|
||||
}
|
||||
|
||||
// Wait for threads to finish, if necessary
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
int targetThreadCount = it == collection.end() ? 0 : maxThreadCount - 1;
|
||||
while (currentThreadCount > targetThreadCount) {
|
||||
elementFinished.wait(lock);
|
||||
if (finishedElement.valid()) {
|
||||
// Re-throw any exception
|
||||
finishedElement.get();
|
||||
finishedElement = future_type();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template<typename TCollection>
|
||||
void runParallel(
|
||||
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
|
||||
TCollection& collection,
|
||||
int maxThreadCount,
|
||||
ProgressSink& progressSink,
|
||||
std::function<double(const typename TCollection::reference)> getElementProgressWeight = [](typename TCollection::reference) { return 1.0; })
|
||||
{
|
||||
// Create a collection of wrapper functions that take care of progress handling
|
||||
ProgressMerger progressMerger(progressSink);
|
||||
std::vector<std::function<void()>> functions;
|
||||
for (auto& element : collection) {
|
||||
auto& elementProgressSink = progressMerger.addSink(getElementProgressWeight(element));
|
||||
functions.push_back([&]() { processElement(element, elementProgressSink); });
|
||||
}
|
||||
|
||||
// Run wrapper function
|
||||
runParallel([&](std::function<void()> function) { function(); }, functions, maxThreadCount);
|
||||
}
|
||||
|
||||
inline int getProcessorCoreCount() {
|
||||
int coreCount = std::thread::hardware_concurrency();
|
||||
|
||||
// If the number of cores cannot be determined, use a reasonable default
|
||||
return coreCount != 0 ? coreCount : 4;
|
||||
}
|
|
@ -17,8 +17,7 @@
|
|||
#include "g2p.h"
|
||||
#include "ContinuousTimeline.h"
|
||||
#include "audio/processing.h"
|
||||
#include "ThreadPool.h"
|
||||
#include "ObjectPool.h"
|
||||
#include "parallel.h"
|
||||
|
||||
extern "C" {
|
||||
#include <pocketsphinx.h>
|
||||
|
@ -401,16 +400,14 @@ BoundedTimeline<Phone> detectPhones(
|
|||
// Determine how many parallel threads to use
|
||||
int threadCount = std::min({
|
||||
// Don't use more threads than there are CPU cores
|
||||
ThreadPool::getRecommendedThreadCount(),
|
||||
getProcessorCoreCount(),
|
||||
// Don't use more threads than there are utterances to be processed
|
||||
static_cast<int>(utterances.size()),
|
||||
// Don't waste time creating additional threads (and decoders!) if the recording is short
|
||||
static_cast<int>(duration_cast<std::chrono::seconds>(audioStream->getTruncatedRange().getLength()).count() / 10)
|
||||
});
|
||||
ThreadPool threadPool(threadCount);
|
||||
logging::debug("Speech recognition -- start");
|
||||
threadPool.schedule(utterances, processUtterance, dialogProgressSink, getUtteranceProgressWeight);
|
||||
threadPool.waitAll();
|
||||
runParallel(processUtterance, utterances, threadCount, dialogProgressSink, getUtteranceProgressWeight);
|
||||
logging::debug("Speech recognition -- end");
|
||||
|
||||
return result;
|
||||
|
|
Loading…
Reference in New Issue