Re-written library code for parallel execution
The new implementation correctly re-throws exceptions on the calling thread instead of terminating the application.
This commit is contained in:
parent
5198ee9230
commit
b3b2366468
|
@ -222,7 +222,7 @@ set(SOURCE_FILES
|
||||||
src/g2p.cpp src/g2p.h
|
src/g2p.cpp src/g2p.h
|
||||||
src/languageModels.cpp src/languageModels.h
|
src/languageModels.cpp src/languageModels.h
|
||||||
src/tupleHash.h
|
src/tupleHash.h
|
||||||
src/ThreadPool.cpp src/ThreadPool.h
|
src/parallel.h
|
||||||
src/ObjectPool.h
|
src/ObjectPool.h
|
||||||
src/Lazy.h
|
src/Lazy.h
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,78 +0,0 @@
|
||||||
#include "ThreadPool.h"
|
|
||||||
|
|
||||||
int ThreadPool::getRecommendedThreadCount() {
|
|
||||||
int coreCount = std::thread::hardware_concurrency();
|
|
||||||
|
|
||||||
// If the number of cores cannot be determined, use a reasonable default
|
|
||||||
return coreCount != 0 ? coreCount : 4;
|
|
||||||
}
|
|
||||||
|
|
||||||
ThreadPool::ThreadPool(int threadCount) :
|
|
||||||
threadCount(threadCount),
|
|
||||||
remainingJobCount(0),
|
|
||||||
bailout(false) {
|
|
||||||
for (int i = 0; i < threadCount; ++i) {
|
|
||||||
threads.push_back(std::thread([&] {
|
|
||||||
Task();
|
|
||||||
}));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ThreadPool::~ThreadPool() {
|
|
||||||
waitAll();
|
|
||||||
|
|
||||||
// Notify that we're done, and wake up any threads that are waiting for a new job
|
|
||||||
bailout = true;
|
|
||||||
jobAvailableCondition.notify_all();
|
|
||||||
|
|
||||||
for (auto& thread : threads) {
|
|
||||||
if (thread.joinable()) {
|
|
||||||
thread.join();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void ThreadPool::schedule(job_t job) {
|
|
||||||
std::lock_guard<std::mutex> guard(queueMutex);
|
|
||||||
jobQueue.emplace_back(job);
|
|
||||||
++remainingJobCount;
|
|
||||||
jobAvailableCondition.notify_one();
|
|
||||||
}
|
|
||||||
|
|
||||||
void ThreadPool::waitAll() {
|
|
||||||
if (remainingJobCount == 0) return;
|
|
||||||
|
|
||||||
std::unique_lock<std::mutex> lock(waitMutex);
|
|
||||||
waitCondition.wait(lock, [&] {
|
|
||||||
return remainingJobCount == 0;
|
|
||||||
});
|
|
||||||
lock.unlock();
|
|
||||||
}
|
|
||||||
|
|
||||||
void ThreadPool::Task() {
|
|
||||||
while (!bailout) {
|
|
||||||
getNextJob()();
|
|
||||||
--remainingJobCount;
|
|
||||||
waitCondition.notify_one();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ThreadPool::job_t ThreadPool::getNextJob() {
|
|
||||||
std::unique_lock<std::mutex> jobLock(queueMutex);
|
|
||||||
|
|
||||||
// Wait for a job if we don't have any
|
|
||||||
jobAvailableCondition.wait(jobLock, [&] {
|
|
||||||
return jobQueue.size() > 0 || bailout;
|
|
||||||
});
|
|
||||||
|
|
||||||
if (bailout) {
|
|
||||||
// Return a dummy job to keep remainingJobCount accurate
|
|
||||||
++remainingJobCount;
|
|
||||||
return [] {};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get job from the queue
|
|
||||||
auto result = jobQueue.front();
|
|
||||||
jobQueue.pop_front();
|
|
||||||
return result;
|
|
||||||
}
|
|
|
@ -1,86 +0,0 @@
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <atomic>
|
|
||||||
#include <thread>
|
|
||||||
#include <mutex>
|
|
||||||
#include <list>
|
|
||||||
#include <functional>
|
|
||||||
#include <vector>
|
|
||||||
#include "progressBar.h"
|
|
||||||
|
|
||||||
// Thread pool based on https://github.com/nbsdx/ThreadPool, which is in the public domain.
|
|
||||||
|
|
||||||
class ThreadPool {
|
|
||||||
public:
|
|
||||||
using job_t = std::function<void(void)>;
|
|
||||||
|
|
||||||
static int getRecommendedThreadCount();
|
|
||||||
|
|
||||||
ThreadPool(int threadCount = getRecommendedThreadCount());
|
|
||||||
|
|
||||||
~ThreadPool();
|
|
||||||
|
|
||||||
// Gets the number of threads in this pool
|
|
||||||
int getThreadCount() const {
|
|
||||||
return threadCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Gets the number of jobs left in the queue
|
|
||||||
int getRemainingJobCount() const {
|
|
||||||
return remainingJobCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adds a new job to the pool.
|
|
||||||
// If there are no queued jobs, a thread is woken up to take the job.
|
|
||||||
// If all threads are busy, the job is added to the end of the queue.
|
|
||||||
void schedule(job_t job);
|
|
||||||
|
|
||||||
// Asynchronously runs a function for every element of a collection.
|
|
||||||
template<typename TCollection>
|
|
||||||
void schedule(
|
|
||||||
TCollection& collection,
|
|
||||||
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
|
|
||||||
ProgressSink& progressSink,
|
|
||||||
std::function<double(const typename TCollection::reference)> getElementProgressWeight = [](typename TCollection::reference) { return 1.0; });
|
|
||||||
|
|
||||||
// Blocks until all jobs have finshed executing
|
|
||||||
void waitAll();
|
|
||||||
|
|
||||||
private:
|
|
||||||
const int threadCount;
|
|
||||||
std::vector<std::thread> threads;
|
|
||||||
std::list<job_t> jobQueue;
|
|
||||||
std::atomic_int remainingJobCount; // The number of queued or running jobs
|
|
||||||
std::atomic_bool bailout;
|
|
||||||
std::condition_variable jobAvailableCondition;
|
|
||||||
std::condition_variable waitCondition;
|
|
||||||
std::mutex waitMutex;
|
|
||||||
std::mutex queueMutex;
|
|
||||||
|
|
||||||
// Takes the next job in the queue and run it.
|
|
||||||
// Notify the main thread that a job has completed.
|
|
||||||
void Task();
|
|
||||||
|
|
||||||
// Gets the next job; pop the first item in the queue,
|
|
||||||
// otherwise wait for a signal from the main thread
|
|
||||||
job_t getNextJob();
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename TCollection>
|
|
||||||
void ThreadPool::schedule(
|
|
||||||
TCollection& collection,
|
|
||||||
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
|
|
||||||
ProgressSink& progressSink,
|
|
||||||
std::function<double(const typename TCollection::reference)> getElementProgressWeight)
|
|
||||||
{
|
|
||||||
// Use shared pointer to keep progress merger alive throughout execution
|
|
||||||
auto progressMerger = std::make_shared<ProgressMerger>(progressSink);
|
|
||||||
|
|
||||||
// Schedule all elements
|
|
||||||
for (auto& element : collection) {
|
|
||||||
ProgressSink& elementProgressSink = progressMerger->addSink(getElementProgressWeight(element));
|
|
||||||
schedule([processElement, &element, &elementProgressSink, progressMerger /* Keep progressMerger alive! */] {
|
|
||||||
processElement(element, elementProgressSink);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -7,7 +7,7 @@
|
||||||
#include <webrtc/common_audio/vad/include/webrtc_vad.h>
|
#include <webrtc/common_audio/vad/include/webrtc_vad.h>
|
||||||
#include "processing.h"
|
#include "processing.h"
|
||||||
#include <gsl_util.h>
|
#include <gsl_util.h>
|
||||||
#include <ThreadPool.h>
|
#include <parallel.h>
|
||||||
#include "AudioStreamSegment.h"
|
#include "AudioStreamSegment.h"
|
||||||
|
|
||||||
using std::vector;
|
using std::vector;
|
||||||
|
@ -70,15 +70,14 @@ BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStre
|
||||||
std::mutex activityMutex;
|
std::mutex activityMutex;
|
||||||
|
|
||||||
// Split audio into segments and perform parallel VAD
|
// Split audio into segments and perform parallel VAD
|
||||||
ThreadPool threadPool;
|
int segmentCount = getProcessorCoreCount();
|
||||||
int segmentCount = threadPool.getThreadCount();
|
|
||||||
centiseconds audioLength = audioStream->getTruncatedRange().getLength();
|
centiseconds audioLength = audioStream->getTruncatedRange().getLength();
|
||||||
vector<TimeRange> audioSegments;
|
vector<TimeRange> audioSegments;
|
||||||
for (int i = 0; i < segmentCount; ++i) {
|
for (int i = 0; i < segmentCount; ++i) {
|
||||||
TimeRange segmentRange = TimeRange(i * audioLength / segmentCount, (i + 1) * audioLength / segmentCount);
|
TimeRange segmentRange = TimeRange(i * audioLength / segmentCount, (i + 1) * audioLength / segmentCount);
|
||||||
audioSegments.push_back(segmentRange);
|
audioSegments.push_back(segmentRange);
|
||||||
}
|
}
|
||||||
threadPool.schedule(audioSegments, [&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
|
runParallel([&](const TimeRange& segmentRange, ProgressSink& segmentProgressSink) {
|
||||||
unique_ptr<AudioStream> audioSegment = createSegment(audioStream->clone(false), segmentRange);
|
unique_ptr<AudioStream> audioSegment = createSegment(audioStream->clone(false), segmentRange);
|
||||||
BoundedTimeline<void> activitySegment = webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
|
BoundedTimeline<void> activitySegment = webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
|
||||||
|
|
||||||
|
@ -87,8 +86,7 @@ BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStre
|
||||||
activityRange.getTimeRange().shift(segmentRange.getStart());
|
activityRange.getTimeRange().shift(segmentRange.getStart());
|
||||||
activity.set(activityRange);
|
activity.set(activityRange);
|
||||||
}
|
}
|
||||||
}, progressSink);
|
}, audioSegments, segmentCount, progressSink);
|
||||||
threadPool.waitAll();
|
|
||||||
|
|
||||||
// Fill small gaps in activity
|
// Fill small gaps in activity
|
||||||
const centiseconds maxGap(5);
|
const centiseconds maxGap(5);
|
||||||
|
|
|
@ -0,0 +1,96 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include "ProgressBar.h"
|
||||||
|
#include <boost/optional/optional.hpp>
|
||||||
|
#include <gsl_util.h>
|
||||||
|
|
||||||
|
template<typename TCollection>
|
||||||
|
void runParallel(
|
||||||
|
std::function<void(typename TCollection::reference)> processElement,
|
||||||
|
TCollection& collection,
|
||||||
|
int maxThreadCount)
|
||||||
|
{
|
||||||
|
using future_type = std::future<void>;
|
||||||
|
|
||||||
|
std::mutex mutex;
|
||||||
|
int currentThreadCount = 0;
|
||||||
|
std::condition_variable elementFinished;
|
||||||
|
future_type finishedElement;
|
||||||
|
|
||||||
|
// Before exiting, wait for all running tasks to finish, but don't re-throw exceptions.
|
||||||
|
// This only applies if one task already failed with an exception.
|
||||||
|
auto finishRunning = gsl::finally([&]{
|
||||||
|
std::unique_lock<std::mutex> lock(mutex);
|
||||||
|
elementFinished.wait(lock, [&] { return currentThreadCount == 0; });
|
||||||
|
});
|
||||||
|
|
||||||
|
// Asyncronously run all elements
|
||||||
|
for (auto it = collection.begin(); it != collection.end(); ++it) {
|
||||||
|
// This variable will later hold the future, but can be value-captured right now
|
||||||
|
auto future = std::make_shared<future_type>();
|
||||||
|
|
||||||
|
// Notifies that an element is done processing
|
||||||
|
auto notifyElementDone = [&, future] {
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
finishedElement = std::move(*future);
|
||||||
|
--currentThreadCount;
|
||||||
|
elementFinished.notify_one();
|
||||||
|
};
|
||||||
|
|
||||||
|
// Processes the current element, then notifies
|
||||||
|
auto wrapperFunction = [processElement, &element = *it, notifyElementDone]() {
|
||||||
|
auto done = gsl::finally(notifyElementDone);
|
||||||
|
processElement(element);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Asynchronously process element
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
*future = std::async(std::launch::async, wrapperFunction);
|
||||||
|
++currentThreadCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for threads to finish, if necessary
|
||||||
|
{
|
||||||
|
std::unique_lock<std::mutex> lock(mutex);
|
||||||
|
int targetThreadCount = it == collection.end() ? 0 : maxThreadCount - 1;
|
||||||
|
while (currentThreadCount > targetThreadCount) {
|
||||||
|
elementFinished.wait(lock);
|
||||||
|
if (finishedElement.valid()) {
|
||||||
|
// Re-throw any exception
|
||||||
|
finishedElement.get();
|
||||||
|
finishedElement = future_type();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename TCollection>
|
||||||
|
void runParallel(
|
||||||
|
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
|
||||||
|
TCollection& collection,
|
||||||
|
int maxThreadCount,
|
||||||
|
ProgressSink& progressSink,
|
||||||
|
std::function<double(const typename TCollection::reference)> getElementProgressWeight = [](typename TCollection::reference) { return 1.0; })
|
||||||
|
{
|
||||||
|
// Create a collection of wrapper functions that take care of progress handling
|
||||||
|
ProgressMerger progressMerger(progressSink);
|
||||||
|
std::vector<std::function<void()>> functions;
|
||||||
|
for (auto& element : collection) {
|
||||||
|
auto& elementProgressSink = progressMerger.addSink(getElementProgressWeight(element));
|
||||||
|
functions.push_back([&]() { processElement(element, elementProgressSink); });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run wrapper function
|
||||||
|
runParallel([&](std::function<void()> function) { function(); }, functions, maxThreadCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline int getProcessorCoreCount() {
|
||||||
|
int coreCount = std::thread::hardware_concurrency();
|
||||||
|
|
||||||
|
// If the number of cores cannot be determined, use a reasonable default
|
||||||
|
return coreCount != 0 ? coreCount : 4;
|
||||||
|
}
|
|
@ -17,8 +17,7 @@
|
||||||
#include "g2p.h"
|
#include "g2p.h"
|
||||||
#include "ContinuousTimeline.h"
|
#include "ContinuousTimeline.h"
|
||||||
#include "audio/processing.h"
|
#include "audio/processing.h"
|
||||||
#include "ThreadPool.h"
|
#include "parallel.h"
|
||||||
#include "ObjectPool.h"
|
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#include <pocketsphinx.h>
|
#include <pocketsphinx.h>
|
||||||
|
@ -401,16 +400,14 @@ BoundedTimeline<Phone> detectPhones(
|
||||||
// Determine how many parallel threads to use
|
// Determine how many parallel threads to use
|
||||||
int threadCount = std::min({
|
int threadCount = std::min({
|
||||||
// Don't use more threads than there are CPU cores
|
// Don't use more threads than there are CPU cores
|
||||||
ThreadPool::getRecommendedThreadCount(),
|
getProcessorCoreCount(),
|
||||||
// Don't use more threads than there are utterances to be processed
|
// Don't use more threads than there are utterances to be processed
|
||||||
static_cast<int>(utterances.size()),
|
static_cast<int>(utterances.size()),
|
||||||
// Don't waste time creating additional threads (and decoders!) if the recording is short
|
// Don't waste time creating additional threads (and decoders!) if the recording is short
|
||||||
static_cast<int>(duration_cast<std::chrono::seconds>(audioStream->getTruncatedRange().getLength()).count() / 10)
|
static_cast<int>(duration_cast<std::chrono::seconds>(audioStream->getTruncatedRange().getLength()).count() / 10)
|
||||||
});
|
});
|
||||||
ThreadPool threadPool(threadCount);
|
|
||||||
logging::debug("Speech recognition -- start");
|
logging::debug("Speech recognition -- start");
|
||||||
threadPool.schedule(utterances, processUtterance, dialogProgressSink, getUtteranceProgressWeight);
|
runParallel(processUtterance, utterances, threadCount, dialogProgressSink, getUtteranceProgressWeight);
|
||||||
threadPool.waitAll();
|
|
||||||
logging::debug("Speech recognition -- end");
|
logging::debug("Speech recognition -- end");
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|
Loading…
Reference in New Issue