diff --git a/src/audio/voiceActivityDetection.cpp b/src/audio/voiceActivityDetection.cpp index b795ed6..c03983b 100644 --- a/src/audio/voiceActivityDetection.cpp +++ b/src/audio/voiceActivityDetection.cpp @@ -66,7 +66,7 @@ BoundedTimeline webRtcDetectVoiceActivity(const AudioClip& audioClip, Prog return activity; } -BoundedTimeline detectVoiceActivity(const AudioClip& inputAudioClip, ProgressSink& progressSink) { +BoundedTimeline detectVoiceActivity(const AudioClip& inputAudioClip, int maxThreadCount, ProgressSink& progressSink) { // Prepare audio for VAD const unique_ptr audioClip = inputAudioClip.clone() | resample(16000) | removeDCOffset(); @@ -74,7 +74,7 @@ BoundedTimeline detectVoiceActivity(const AudioClip& inputAudioClip, Progr std::mutex activityMutex; // Split audio into segments and perform parallel VAD - int segmentCount = getProcessorCoreCount(); + const int segmentCount = maxThreadCount; centiseconds audioLength = audioClip->getTruncatedRange().getLength(); vector audioSegments; for (int i = 0; i < segmentCount; ++i) { diff --git a/src/audio/voiceActivityDetection.h b/src/audio/voiceActivityDetection.h index 381df23..df0ffd9 100644 --- a/src/audio/voiceActivityDetection.h +++ b/src/audio/voiceActivityDetection.h @@ -3,4 +3,4 @@ #include #include -BoundedTimeline detectVoiceActivity(const AudioClip& audioClip, ProgressSink& progressSink); +BoundedTimeline detectVoiceActivity(const AudioClip& audioClip, int maxThreadCount, ProgressSink& progressSink); diff --git a/src/main.cpp b/src/main.cpp index d7afc41..0f7045a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -16,6 +16,7 @@ #include "stringTools.h" #include #include +#include "parallel.h" using std::exception; using std::string; @@ -115,6 +116,7 @@ int main(int argc, char *argv[]) { tclap::ValuesConstraint logLevelConstraint(logLevels); tclap::ValueArg logLevel("", "logLevel", "The minimum log level to log", false, logging::Level::Debug, &logLevelConstraint, cmd); tclap::ValueArg logFileName("", "logFile", "The log file path.", false, string(), "string", cmd); + tclap::ValueArg maxThreadCount("", "threads", "The maximum number of worker threads to use.", false, getProcessorCoreCount(), "number", cmd); tclap::ValueArg dialogFile("d", "dialogFile", "A file containing the text of the dialog.", false, string(), "string", cmd); auto exportFormats = vector(ExportFormatConverter::get().getValues()); tclap::ValuesConstraint exportFormatConstraint(exportFormats); @@ -132,6 +134,9 @@ int main(int argc, char *argv[]) { // Parse command line cmd.parse(argc, argv); + if (maxThreadCount.getValue() < 1) { + throw std::runtime_error("Thread count must be 1 or higher."); + } // Set up log file if (logFileName.isSet()) { @@ -151,6 +156,7 @@ int main(int argc, char *argv[]) { phones = detectPhones( *createAudioClip(inputFileName.getValue()), dialogFile.isSet() ? readTextFile(path(dialogFile.getValue())) : boost::optional(), + maxThreadCount.getValue(), progressBar); } std::cerr << "Done" << std::endl; diff --git a/src/phoneExtraction.cpp b/src/phoneExtraction.cpp index bf48b7d..a40dadd 100644 --- a/src/phoneExtraction.cpp +++ b/src/phoneExtraction.cpp @@ -359,6 +359,7 @@ Timeline getUnknownSounds(const Timeline& utterances, const Timeline BoundedTimeline detectPhones( const AudioClip& inputAudioClip, optional dialog, + int maxThreadCount, ProgressSink& progressSink) { ProgressMerger totalProgressMerger(progressSink); @@ -371,7 +372,7 @@ BoundedTimeline detectPhones( // Split audio into utterances BoundedTimeline utterances; try { - utterances = detectVoiceActivity(*audioClip, voiceActivationProgressSink); + utterances = detectVoiceActivity(*audioClip, maxThreadCount, voiceActivationProgressSink); } catch (...) { std::throw_with_nested(runtime_error("Error detecting segments of speech.")); @@ -437,8 +438,7 @@ BoundedTimeline detectPhones( try { // Determine how many parallel threads to use int threadCount = std::min({ - // Don't use more threads than there are CPU cores - getProcessorCoreCount(), + maxThreadCount, // Don't use more threads than there are utterances to be processed static_cast(utterances.size()), // Don't waste time creating additional threads (and decoders!) if the recording is short diff --git a/src/phoneExtraction.h b/src/phoneExtraction.h index f4b9f91..1f5d14d 100644 --- a/src/phoneExtraction.h +++ b/src/phoneExtraction.h @@ -8,4 +8,5 @@ BoundedTimeline detectPhones( const AudioClip& audioClip, boost::optional dialog, + int maxThreadCount, ProgressSink& progressSink);