diff --git a/src/main.cpp b/src/main.cpp index 2eab1c2..622ea70 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,6 +1,7 @@ #include #include #include +#include #include #include #include "audioInput/WaveFileReader.h" @@ -69,6 +70,14 @@ ptree createXmlTree(const path& filePath, const map& phones return tree; } +// Tell TCLAP how to handle boost::optional +namespace TCLAP { + template<> + struct ArgTraits> { + typedef TCLAP::StringLike ValueCategory; + }; +} + int main(int argc, char *argv[]) { // Define command-line parameters const char argumentValueSeparator = ' '; @@ -76,6 +85,7 @@ int main(int argc, char *argv[]) { cmd.setExceptionHandling(false); cmd.setOutput(new NiceCmdLineOutput()); TCLAP::UnlabeledValueArg inputFileName("inputFile", "The input file. Must be a sound file in WAVE format.", true, "", "string", cmd); + TCLAP::ValueArg> dialog("d", "dialog", "The text of the dialog.", false, boost::optional(), "string", cmd); try { // Parse command line @@ -88,7 +98,10 @@ int main(int argc, char *argv[]) { map phones; { ProgressBar progressBar; - phones = detectPhones([&inputFileName]() { return createAudioStream(inputFileName.getValue()); }, progressBar); + phones = detectPhones( + [&inputFileName]() { return createAudioStream(inputFileName.getValue()); }, + dialog.getValue(), + progressBar); } std::cerr << "Done" << std::endl; diff --git a/src/phoneExtraction.cpp b/src/phoneExtraction.cpp index bf2bf70..e8ccfa9 100644 --- a/src/phoneExtraction.cpp +++ b/src/phoneExtraction.cpp @@ -7,6 +7,8 @@ #include "platformTools.h" #include "tools.h" #include +#include +#include extern "C" { #include @@ -17,6 +19,7 @@ extern "C" { } using std::runtime_error; +using std::invalid_argument; using std::unique_ptr; using std::shared_ptr; using std::string; @@ -24,6 +27,8 @@ using std::vector; using std::map; using boost::filesystem::path; using std::function; +using std::regex; +using std::regex_replace; unique_ptr to16kHzMono(unique_ptr stream) { // Downmix, if required @@ -33,7 +38,7 @@ unique_ptr to16kHzMono(unique_ptr stream) { // Downsample, if required if (stream->getFrameRate() < 16000) { - throw runtime_error("Audio sample rate must not be below 16kHz."); + throw invalid_argument("Audio sample rate must not be below 16kHz."); } if (stream->getFrameRate() != 16000) { stream.reset(new SampleRateConverter(std::move(stream), 16000)); @@ -133,7 +138,7 @@ void sphinxErrorCallback(void* user_data, err_lvl_t errorLevel, const char* form *errorString += message; } -vector recognizeWords(unique_ptr audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) { +vector recognizeWords(unique_ptr audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) { // Convert audio stream to the exact format PocketSphinx requires audioStream = to16kHzMono(std::move(audioStream)); @@ -153,19 +158,55 @@ vector recognizeWords(unique_ptr audioStream, ps_decoder_t if (error) throw runtime_error("Error ending utterance processing for word recognition."); // Collect words - vector result; + vector result; int32_t score; for (ps_seg_t* it = ps_seg_iter(&recognizer, &score); it; it = ps_seg_next(it)) { - // Get word const char* word = ps_seg_word(it); - s3wid_t wordId = dict_wordid(recognizer.dict, word); - - result.push_back(wordId); + result.push_back(word); } return result; } +// Splits dialog into words, doing minimal preprocessing. +// A robust solution should use TTS logic to cope with numbers, abbreviations, unknown words etc. +vector extractDialogWords(string dialog) { + // Convert to lower case + boost::algorithm::to_lower(dialog); + + // Insert silences where appropriate + dialog = regex_replace(dialog, regex("[,;.:!?] "), " "); + + // Remove all undesired characters + dialog = regex_replace(dialog, regex("[^a-z.'\\-0-9<>]"), " "); + + // Collapse whitespace + dialog = regex_replace(dialog, regex("\\s+"), " "); + + // Trim + boost::algorithm::trim(dialog); + + // Ugly hack: Remove trailing period + if (boost::algorithm::ends_with(dialog, ".")) { + dialog.pop_back(); + } + + // Split into words + vector result; + boost::algorithm::split(result, dialog, boost::is_space()); + return result; +} + +vector getWordIds(const vector& words, dict_t& dictionary) { + vector result; + for (const string& word : words) { + s3wid_t wordId = dict_wordid(&dictionary, word.c_str()); + if (wordId == BAD_S3WID) throw invalid_argument(fmt::format("Unknown word '{}'.", word)); + result.push_back(wordId); + } + return result; +} + map getPhoneAlignment(const vector& wordIds, unique_ptr audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) { // Create alignment list lambda_unique_ptr alignment( @@ -236,7 +277,11 @@ map getPhoneAlignment(const vector& wordIds, uniqu return result; } -map detectPhones(std::function(void)> createAudioStream, ProgressSink& progressSink) { +map detectPhones( + std::function(void)> createAudioStream, + boost::optional dialog, + ProgressSink& progressSink) +{ // Discard Pocketsphinx output err_set_logfp(nullptr); @@ -253,11 +298,16 @@ map detectPhones(std::function auto recognizer = createSpeechRecognizer(*config.get()); ProgressMerger progressMerger(progressSink); - ProgressSink& wordRecognitionProgressSink = progressMerger.addSink(1.0); + ProgressSink& wordRecognitionProgressSink = progressMerger.addSink(dialog ? 0.0 : 1.0); ProgressSink& alignmentProgressSink = progressMerger.addSink(0.5); - // Recognize words - vector wordIds = recognizeWords(createAudioStream(), *recognizer.get(), wordRecognitionProgressSink); + // Get words + vector words = dialog + ? extractDialogWords(*dialog) + : recognizeWords(createAudioStream(), *recognizer.get(), wordRecognitionProgressSink); + + // Look up words in dictionary + vector wordIds = getWordIds(words, *recognizer->dict); // Align the word's phones with speech map result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), alignmentProgressSink); diff --git a/src/phoneExtraction.h b/src/phoneExtraction.h index c2fee5f..a01f3b2 100644 --- a/src/phoneExtraction.h +++ b/src/phoneExtraction.h @@ -7,5 +7,9 @@ #include "Phone.h" #include "centiseconds.h" #include "progressBar.h" +#include -std::map detectPhones(std::function(void)> createAudioStream, ProgressSink& progressSink); +std::map detectPhones( + std::function(void)> createAudioStream, + boost::optional dialog, + ProgressSink& progressSink);