Multithreaded recognition: refactoring and fixes
* Decoders are correctly released after use * Determining optimal thread count for multithreading
This commit is contained in:
parent
f13449f810
commit
2d314f4bc7
|
@ -11,6 +11,10 @@ public:
|
|||
using typename Timeline<T>::iterator;
|
||||
using Timeline<T>::end;
|
||||
|
||||
BoundedTimeline() :
|
||||
range(TimeRange::zero())
|
||||
{}
|
||||
|
||||
explicit BoundedTimeline(TimeRange range) :
|
||||
range(range)
|
||||
{}
|
||||
|
|
|
@ -32,7 +32,7 @@ ThreadPool::~ThreadPool() {
|
|||
}
|
||||
}
|
||||
|
||||
void ThreadPool::addJob(job_t job) {
|
||||
void ThreadPool::schedule(job_t job) {
|
||||
std::lock_guard<std::mutex> guard(queueMutex);
|
||||
jobQueue.emplace_back(job);
|
||||
++remainingJobCount;
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
#include <list>
|
||||
#include <functional>
|
||||
#include <vector>
|
||||
#include "progressBar.h"
|
||||
|
||||
// Thread pool based on https://github.com/nbsdx/ThreadPool, which is in the public domain.
|
||||
|
||||
|
@ -32,7 +33,15 @@ public:
|
|||
// Adds a new job to the pool.
|
||||
// If there are no queued jobs, a thread is woken up to take the job.
|
||||
// If all threads are busy, the job is added to the end of the queue.
|
||||
void addJob(job_t job);
|
||||
void schedule(job_t job);
|
||||
|
||||
// Asynchronously runs a function for every element of a collection.
|
||||
template<typename TCollection>
|
||||
void schedule(
|
||||
TCollection& collection,
|
||||
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
|
||||
ProgressSink& progressSink,
|
||||
std::function<double(const typename TCollection::reference)> getElementProgressWeight = [](typename TCollection::reference) { return 1.0; });
|
||||
|
||||
// Blocks until all jobs have finshed executing
|
||||
void waitAll();
|
||||
|
@ -57,3 +66,21 @@ private:
|
|||
job_t getNextJob();
|
||||
};
|
||||
|
||||
template <typename TCollection>
|
||||
void ThreadPool::schedule(
|
||||
TCollection& collection,
|
||||
std::function<void(typename TCollection::reference, ProgressSink&)> processElement,
|
||||
ProgressSink& progressSink,
|
||||
std::function<double(const typename TCollection::reference)> getElementProgressWeight)
|
||||
{
|
||||
// Use shared pointer to keep progress merger alive throughout execution
|
||||
auto progressMerger = std::make_shared<ProgressMerger>(progressSink);
|
||||
|
||||
// Schedule all elements
|
||||
for (auto& element : collection) {
|
||||
ProgressSink& elementProgressSink = progressMerger->addSink(getElementProgressWeight(element));
|
||||
schedule([processElement, &element, &elementProgressSink, progressMerger /* Keep progressMerger alive! */] {
|
||||
processElement(element, elementProgressSink);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -52,8 +52,9 @@ public:
|
|||
using reverse_iterator = typename set_type::reverse_iterator;
|
||||
using size_type = size_t;
|
||||
using value_type = Timed<T>;
|
||||
using reference = const value_type&;
|
||||
|
||||
class reference {
|
||||
class ReferenceWrapper {
|
||||
public:
|
||||
operator boost::optional<const T&>() const {
|
||||
auto optional = timeline.get(time);
|
||||
|
@ -71,7 +72,7 @@ public:
|
|||
return optional->getValue();
|
||||
}
|
||||
|
||||
reference& operator=(boost::optional<const T&> value) {
|
||||
ReferenceWrapper& operator=(boost::optional<const T&> value) {
|
||||
if (value) {
|
||||
timeline.set(time, time + time_type(1), *value);
|
||||
} else {
|
||||
|
@ -83,7 +84,7 @@ public:
|
|||
private:
|
||||
friend class Timeline;
|
||||
|
||||
reference(Timeline& timeline, time_type time) :
|
||||
ReferenceWrapper(Timeline& timeline, time_type time) :
|
||||
timeline(timeline),
|
||||
time(time)
|
||||
{}
|
||||
|
@ -232,13 +233,13 @@ public:
|
|||
return set(Timed<void>(start, end));
|
||||
}
|
||||
|
||||
reference operator[](time_type time) {
|
||||
return reference(*this, time);
|
||||
ReferenceWrapper operator[](time_type time) {
|
||||
return ReferenceWrapper(*this, time);
|
||||
}
|
||||
|
||||
// ReSharper disable once CppConstValueFunctionReturnType
|
||||
const reference operator[](time_type time) const {
|
||||
return reference(*this, time);
|
||||
const ReferenceWrapper operator[](time_type time) const {
|
||||
return ReferenceWrapper(*this, time);
|
||||
}
|
||||
|
||||
virtual void shift(time_type offset) {
|
||||
|
|
|
@ -59,7 +59,7 @@ BoundedTimeline<void> detectVoiceActivity(std::unique_ptr<AudioStream> audioStre
|
|||
for (int i = 0; i < segmentCount; ++i) {
|
||||
TimeRange segmentRange = TimeRange(i * audioLength / segmentCount, (i + 1) * audioLength / segmentCount);
|
||||
ProgressSink& segmentProgressSink = progressMerger.addSink(1.0);
|
||||
threadPool.addJob([segmentRange, &audioStream, &segmentProgressSink, &activityMutex, &activity] {
|
||||
threadPool.schedule([segmentRange, &audioStream, &segmentProgressSink, &activityMutex, &activity] {
|
||||
std::unique_ptr<AudioStream> audioSegment = createSegment(audioStream->clone(false), segmentRange);
|
||||
BoundedTimeline<void> activitySegment = webRtcDetectVoiceActivity(*audioSegment, segmentProgressSink);
|
||||
|
||||
|
|
|
@ -42,6 +42,7 @@ using std::regex_replace;
|
|||
using std::chrono::duration;
|
||||
using boost::optional;
|
||||
using std::u32string;
|
||||
using std::chrono::duration_cast;
|
||||
|
||||
constexpr int sphinxSampleRate = 16000;
|
||||
|
||||
|
@ -134,7 +135,7 @@ s3wid_t getWordId(const string& word, dict_t& dictionary) {
|
|||
return wordId;
|
||||
}
|
||||
|
||||
optional<BoundedTimeline<Phone>> getPhoneAlignment(
|
||||
optional<Timeline<Phone>> getPhoneAlignment(
|
||||
const vector<s3wid_t>& wordIds,
|
||||
unique_ptr<AudioStream> audioStream,
|
||||
ps_decoder_t& decoder,
|
||||
|
@ -193,7 +194,7 @@ optional<BoundedTimeline<Phone>> getPhoneAlignment(
|
|||
|
||||
// Extract phones with timestamps
|
||||
char** phoneNames = decoder.dict->mdef->ciname;
|
||||
BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
|
||||
Timeline<Phone> result;
|
||||
for (ps_alignment_iter_t* it = ps_alignment_phones(alignment.get()); it; it = ps_alignment_iter_next(it)) {
|
||||
// Get phone
|
||||
ps_alignment_entry_t* phoneEntry = ps_alignment_iter_get(it);
|
||||
|
@ -276,82 +277,128 @@ lambda_unique_ptr<ps_decoder_t> createDecoder(optional<u32string> dialog) {
|
|||
return decoder;
|
||||
}
|
||||
|
||||
Timeline<Phone> utteranceToPhones(
|
||||
AudioStream& audioStream,
|
||||
TimeRange utterance,
|
||||
ps_decoder_t& decoder,
|
||||
ProgressSink& utteranceProgressSink)
|
||||
{
|
||||
ProgressMerger utteranceProgressMerger(utteranceProgressSink);
|
||||
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0);
|
||||
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5);
|
||||
|
||||
auto streamSegment = createSegment(audioStream.clone(true), utterance);
|
||||
|
||||
// Get words
|
||||
BoundedTimeline<string> words = recognizeWords(streamSegment->clone(true), decoder, wordRecognitionProgressSink);
|
||||
for (Timed<string> timedWord : words) {
|
||||
timedWord.getTimeRange().shift(utterance.getStart());
|
||||
logging::logTimedEvent("word", timedWord);
|
||||
}
|
||||
|
||||
// Look up words in dictionary
|
||||
vector<s3wid_t> wordIds;
|
||||
for (const auto& timedWord : words) {
|
||||
wordIds.push_back(getWordId(timedWord.getValue(), *decoder.dict));
|
||||
}
|
||||
if (wordIds.empty()) return Timeline<Phone>();
|
||||
|
||||
// Align the words' phones with speech
|
||||
Timeline<Phone> segmentPhones = getPhoneAlignment(wordIds, std::move(streamSegment), decoder, alignmentProgressSink)
|
||||
.value_or(ContinuousTimeline<Phone>(streamSegment->getTruncatedRange(), Phone::Unknown));
|
||||
segmentPhones.shift(utterance.getStart());
|
||||
for (const auto& timedPhone : segmentPhones) {
|
||||
logging::logTimedEvent("phone", timedPhone);
|
||||
}
|
||||
|
||||
return segmentPhones;
|
||||
}
|
||||
|
||||
BoundedTimeline<Phone> detectPhones(
|
||||
unique_ptr<AudioStream> audioStream,
|
||||
optional<u32string> dialog,
|
||||
ProgressSink& progressSink)
|
||||
{
|
||||
ProgressMerger totalProgressMerger(progressSink);
|
||||
ProgressSink& voiceActivationProgressSink = totalProgressMerger.addSink(1.0);
|
||||
ProgressSink& dialogProgressSink = totalProgressMerger.addSink(15);
|
||||
|
||||
// Make sure audio stream has no DC offset
|
||||
audioStream = removeDCOffset(std::move(audioStream));
|
||||
|
||||
// Split audio into utterances
|
||||
BoundedTimeline<void> utterances;
|
||||
try {
|
||||
utterances = detectVoiceActivity(audioStream->clone(true), voiceActivationProgressSink);
|
||||
}
|
||||
catch (...) {
|
||||
std::throw_with_nested(runtime_error("Error detecting segments of speech."));
|
||||
}
|
||||
|
||||
// Discard Pocketsphinx output
|
||||
err_set_logfp(nullptr);
|
||||
|
||||
// Redirect Pocketsphinx output to log
|
||||
err_set_callback(sphinxLogCallback, nullptr);
|
||||
|
||||
// Make sure audio stream has no DC offset
|
||||
audioStream = removeDCOffset(std::move(audioStream));
|
||||
// Prepare pool of decoders
|
||||
std::stack<lambda_unique_ptr<ps_decoder_t>> decoderPool;
|
||||
std::mutex decoderPoolMutex;
|
||||
auto getDecoder = [&] {
|
||||
std::lock_guard<std::mutex> lock(decoderPoolMutex);
|
||||
if (decoderPool.empty()) {
|
||||
decoderPool.push(createDecoder(dialog));
|
||||
}
|
||||
auto decoder = std::move(decoderPool.top());
|
||||
decoderPool.pop();
|
||||
return std::move(decoder);
|
||||
};
|
||||
auto returnDecoder = [&](lambda_unique_ptr<ps_decoder_t> decoder) {
|
||||
std::lock_guard<std::mutex> lock(decoderPoolMutex);
|
||||
decoderPool.push(std::move(decoder));
|
||||
};
|
||||
|
||||
ProgressMerger totalProgressMerger(progressSink);
|
||||
ProgressSink& voiceActivationProgressSink = totalProgressMerger.addSink(1.0);
|
||||
ProgressSink& dialogProgressSink = totalProgressMerger.addSink(15);
|
||||
BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
|
||||
std::mutex resultMutex, audioStreamMutex;
|
||||
auto processUtterance = [&](Timed<void> timedUtterance, ProgressSink& utteranceProgressSink) {
|
||||
logging::logTimedEvent("utterance", timedUtterance.getTimeRange(), string(""));
|
||||
|
||||
// Detect phones for utterance
|
||||
auto decoder = getDecoder();
|
||||
std::unique_ptr<AudioStream> audioStreamCopy;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(audioStreamMutex);
|
||||
audioStreamCopy = audioStream->clone(true);
|
||||
}
|
||||
Timeline<Phone> phones =
|
||||
utteranceToPhones(*audioStreamCopy, timedUtterance.getTimeRange(), *decoder, utteranceProgressSink);
|
||||
returnDecoder(std::move(decoder));
|
||||
|
||||
// Copy phones to result timeline
|
||||
std::lock_guard<std::mutex> lock(resultMutex);
|
||||
for (const auto& timedPhone : phones) {
|
||||
result.set(timedPhone);
|
||||
}
|
||||
};
|
||||
|
||||
auto getUtteranceProgressWeight = [](const Timed<void> timedUtterance) {
|
||||
return timedUtterance.getTimeRange().getLength().count();
|
||||
};
|
||||
|
||||
// Perform speech recognition
|
||||
try {
|
||||
// Split audio into utterances
|
||||
BoundedTimeline<void> utterances = detectVoiceActivity(audioStream->clone(true), voiceActivationProgressSink);
|
||||
|
||||
// For progress reporting: weigh utterances by length
|
||||
ProgressMerger dialogProgressMerger(dialogProgressSink);
|
||||
vector<ProgressSink*> utteranceProgressSinks;
|
||||
for (const auto& timedUtterance : utterances) {
|
||||
utteranceProgressSinks.push_back(&dialogProgressMerger.addSink(timedUtterance.getTimeRange().getLength().count()));
|
||||
}
|
||||
auto utteranceProgressSinkIt = utteranceProgressSinks.begin();
|
||||
|
||||
BoundedTimeline<Phone> result(audioStream->getTruncatedRange());
|
||||
std::mutex resultMutex;
|
||||
// Determine how many parallel threads to use
|
||||
int threadCount = std::min({
|
||||
// Don't use more threads than there are CPU cores
|
||||
ThreadPool::getRecommendedThreadCount(),
|
||||
// Don't use more threads than there are utterances to be processed
|
||||
static_cast<int>(utterances.size()),
|
||||
// Don't waste time creating additional threads (and decoders!) if the recording is short
|
||||
static_cast<int>(duration_cast<std::chrono::seconds>(audioStream->getTruncatedRange().getLength()).count() / 10)
|
||||
});
|
||||
ThreadPool threadPool(threadCount);
|
||||
logging::debug("Speech recognition -- start");
|
||||
ObjectPool<ps_decoder_t> decoderPool([&dialog] { return createDecoder(dialog).release(); });
|
||||
ThreadPool threadPool;
|
||||
for (const auto& timedUtterance : utterances) {
|
||||
threadPool.addJob([&] {
|
||||
auto decoder = decoderPool.acquire();
|
||||
ProgressMerger utteranceProgressMerger(**utteranceProgressSinkIt++);
|
||||
ProgressSink& wordRecognitionProgressSink = utteranceProgressMerger.addSink(1.0);
|
||||
ProgressSink& alignmentProgressSink = utteranceProgressMerger.addSink(0.5);
|
||||
|
||||
const TimeRange timeRange = timedUtterance.getTimeRange();
|
||||
logging::logTimedEvent("utterance", timeRange, string(""));
|
||||
|
||||
auto streamSegment = createSegment(audioStream->clone(true), timeRange);
|
||||
|
||||
// Get words
|
||||
BoundedTimeline<string> words = recognizeWords(streamSegment->clone(true), *decoder.get(), wordRecognitionProgressSink);
|
||||
for (Timed<string> timedWord : words) {
|
||||
timedWord.getTimeRange().shift(timedUtterance.getStart());
|
||||
logging::logTimedEvent("word", timedWord);
|
||||
}
|
||||
|
||||
// Look up words in dictionary
|
||||
vector<s3wid_t> wordIds;
|
||||
for (const auto& timedWord : words) {
|
||||
wordIds.push_back(getWordId(timedWord.getValue(), *decoder->dict));
|
||||
}
|
||||
if (wordIds.empty()) return;
|
||||
|
||||
// Align the words' phones with speech
|
||||
BoundedTimeline<Phone> segmentPhones = getPhoneAlignment(wordIds, std::move(streamSegment), *decoder.get(), alignmentProgressSink)
|
||||
.value_or(ContinuousTimeline<Phone>(streamSegment->getTruncatedRange(), Phone::Unknown));
|
||||
segmentPhones.shift(timedUtterance.getStart());
|
||||
for (const auto& timedPhone : segmentPhones) {
|
||||
logging::logTimedEvent("phone", timedPhone);
|
||||
}
|
||||
|
||||
// Fill result
|
||||
std::lock_guard<std::mutex> lock(resultMutex);
|
||||
for (const auto& timedPhone : segmentPhones) {
|
||||
result.set(timedPhone);
|
||||
}
|
||||
});
|
||||
}
|
||||
threadPool.schedule(utterances, processUtterance, dialogProgressSink, getUtteranceProgressWeight);
|
||||
threadPool.waitAll();
|
||||
logging::debug("Speech recognition -- end");
|
||||
|
||||
|
|
Loading…
Reference in New Issue