diff --git a/CMakeLists.txt b/CMakeLists.txt index 7a589d7..8c6aeb4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,7 +64,7 @@ target_compile_options(sphinxbase PRIVATE ${disableWarningsFlags}) set_target_properties(sphinxbase PROPERTIES FOLDER lib) # ... PocketSphinx -include_directories(SYSTEM "lib/pocketsphinx-5prealpha-2015-08-05/include") +include_directories(SYSTEM "lib/pocketsphinx-5prealpha-2015-08-05/include" "lib/pocketsphinx-5prealpha-2015-08-05/src/libpocketsphinx") FILE(GLOB pocketSphinxFiles "lib/pocketsphinx-5prealpha-2015-08-05/src/libpocketsphinx/*.c") add_library(pocketSphinx ${pocketSphinxFiles}) target_link_libraries(pocketSphinx sphinxbase) @@ -130,8 +130,8 @@ set(CPACK_GENERATOR ZIP) # Copy resource files at build time; install them at package time include(tools.cmake) set(modelDir "${CMAKE_SOURCE_DIR}/lib/pocketsphinx-5prealpha-2015-08-05/model") -copy_and_install("${modelDir}/en-us/en-us-phone.lm.bin" "res/sphinx") -copy_and_install("${modelDir}/en-us/en-us/*" "res/sphinx/acoustic_model") +copy_and_install("${modelDir}/en-us/*" "res/sphinx") +copy_and_install("${modelDir}/en-us/en-us/*" "res/sphinx/acoustic-model") install( TARGETS rhubarb diff --git a/src/main.cpp b/src/main.cpp index ea67876..d793272 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -85,12 +85,11 @@ int main(int argc, char *argv[]) { const int columnWidth = 30; std::cerr << std::left; std::cerr << std::setw(columnWidth) << "Analyzing input file"; - unique_ptr audioStream = createAudioStream(inputFileName.getValue()); map phones; { ProgressBar progressBar; phones = detectPhones( - std::move(audioStream), + [&inputFileName]() { return createAudioStream(inputFileName.getValue()); }, [&progressBar](double progress) { progressBar.reportProgress(progress); }); } std::cerr << "Done" << std::endl; diff --git a/src/phoneExtraction.cpp b/src/phoneExtraction.cpp index a8e847d..007c324 100644 --- a/src/phoneExtraction.cpp +++ b/src/phoneExtraction.cpp @@ -6,16 +6,21 @@ #include "audioInput/ChannelDownmixer.h" #include "platformTools.h" #include "tools.h" +#include extern "C" { #include #include +#include +#include +#include } using std::runtime_error; using std::unique_ptr; using std::shared_ptr; using std::string; +using std::vector; using std::map; using boost::filesystem::path; using std::function; @@ -42,17 +47,11 @@ lambda_unique_ptr createConfig(path sphinxModelDirectory) { cmd_ln_init( nullptr, ps_args(), true, // Set acoustic model - "-hmm", (sphinxModelDirectory / "acoustic_model").string().c_str(), - // Set phonetic language model - "-allphone", (sphinxModelDirectory / "en-us-phone.lm.bin").string().c_str(), - "-allphone_ci", "yes", - // The following settings are taken from http://cmusphinx.sourceforge.net/wiki/phonemerecognition - // Set beam width applied to every frame in Viterbi search - "-beam", "1e-20", - // Set beam width applied to phone transitions - "-pbeam", "1e-20", - // Set language model probability weight - "-lw", "2.0", + "-hmm", (sphinxModelDirectory / "acoustic-model").string().c_str(), + // Set language model + "-lm", (sphinxModelDirectory / "en-us.lm.bin").string().c_str(), + // Set pronounciation dictionary + "-dict", (sphinxModelDirectory / "cmudict-en-us.dict").string().c_str(), nullptr), [](cmd_ln_t* config) { cmd_ln_free_r(config); }); if (!config) throw runtime_error("Error creating configuration."); @@ -60,7 +59,7 @@ lambda_unique_ptr createConfig(path sphinxModelDirectory) { return config; } -lambda_unique_ptr createPhoneRecognizer(cmd_ln_t& config) { +lambda_unique_ptr createSpeechRecognizer(cmd_ln_t& config) { lambda_unique_ptr recognizer( ps_init(&config), [](ps_decoder_t* recognizer) { ps_free(recognizer); }); @@ -76,13 +75,9 @@ int16_t floatSampleToInt16(float sample) { return static_cast(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN); } -void processAudioStream(AudioStream& audioStream16kHzMono, ps_decoder_t& recognizer, function reportProgress) { - // Start recognition - int error = ps_start_utt(&recognizer); - if (error) throw runtime_error("Error starting utterance processing."); - +void processAudioStream(AudioStream& audioStream16kHzMono, function&)> processBuffer, function reportProgress) { // Process entire sound file - std::vector buffer; + vector buffer; const int capacity = 1600; // 0.1 second capacity buffer.reserve(capacity); int sampleCount = 0; @@ -106,38 +101,14 @@ void processAudioStream(AudioStream& audioStream16kHzMono, ps_decoder_t& recogni buffer.push_back(sample); } - // Analyze buffer - int searchedFrameCount = ps_process_raw(&recognizer, buffer.data(), buffer.size(), false, false); - if (searchedFrameCount < 0) throw runtime_error("Error analyzing raw audio data."); + // Process buffer + processBuffer(buffer); sampleCount += buffer.size(); reportProgress(static_cast(sampleCount) / audioStream16kHzMono.getFrameCount()); } while (buffer.size()); - error = ps_end_utt(&recognizer); - if (error) throw runtime_error("Error ending utterance processing."); - } -map getPhones(ps_decoder_t& recognizer) { - map result; - ps_seg_t *segmentationIter; - result[centiseconds(0)] = Phone::None; - int32 score; - int endFrame; - for (segmentationIter = ps_seg_iter(&recognizer, &score); segmentationIter; segmentationIter = ps_seg_next(segmentationIter)) { - // Get phone - char const *phone = ps_seg_word(segmentationIter); - - // Get timing - int startFrame; - ps_seg_frames(segmentationIter, &startFrame, &endFrame); - - result[centiseconds(startFrame)] = stringToPhone(phone); - result[centiseconds(endFrame + 1)] = Phone::None; - } - return result; -}; - void sphinxErrorCallback(void* user_data, err_lvl_t errorLevel, const char* format, ...) { if (errorLevel < ERR_WARN) return; @@ -148,7 +119,7 @@ void sphinxErrorCallback(void* user_data, err_lvl_t errorLevel, const char* form // Format message const int initialSize = 256; - std::vector chars(initialSize); + vector chars(initialSize); bool success = false; while (!success) { int charsWritten = vsnprintf(chars.data(), chars.size(), format, args); @@ -166,7 +137,110 @@ void sphinxErrorCallback(void* user_data, err_lvl_t errorLevel, const char* form *errorString += message; } -map detectPhones(unique_ptr audioStream, function reportProgress) { +vector recognizeWords(unique_ptr audioStream, ps_decoder_t& recognizer, function reportProgress) { + // Convert audio stream to the exact format PocketSphinx requires + audioStream = to16kHzMono(std::move(audioStream)); + + // Start recognition + int error = ps_start_utt(&recognizer); + if (error) throw runtime_error("Error starting utterance processing for word recognition."); + + // Process entire sound file + auto processBuffer = [&recognizer](const vector& buffer) { + int searchedFrameCount = ps_process_raw(&recognizer, buffer.data(), buffer.size(), false, false); + if (searchedFrameCount < 0) throw runtime_error("Error analyzing raw audio data for word recognition."); + }; + processAudioStream(*audioStream.get(), processBuffer, reportProgress); + + // End recognition + error = ps_end_utt(&recognizer); + if (error) throw runtime_error("Error ending utterance processing for word recognition."); + + // Collect words + vector result; + int32_t score; + for (ps_seg_t* it = ps_seg_iter(&recognizer, &score); it; it = ps_seg_next(it)) { + // Get word + const char* word = ps_seg_word(it); + s3wid_t wordId = dict_wordid(recognizer.dict, word); + + result.push_back(wordId); + } + + return result; +} + +map getPhoneAlignment(const vector& wordIds, unique_ptr audioStream, ps_decoder_t& recognizer, function reportProgress) { + // Create alignment list + lambda_unique_ptr alignment( + ps_alignment_init(recognizer.d2p), + [](ps_alignment_t* alignment) { ps_alignment_free(alignment); }); + if (!alignment) throw runtime_error("Error creating alignment."); + for (s3wid_t wordId : wordIds) { + // Add word. Initial value for duration is ignored. + ps_alignment_add_word(alignment.get(), wordId, 0); + } + int error = ps_alignment_populate(alignment.get()); + if (error) throw runtime_error("Error populating alignment struct."); + + // Convert audio stream to the exact format PocketSphinx requires + audioStream = to16kHzMono(std::move(audioStream)); + + // Create search structure + acmod_t* acousticModel = recognizer.acmod; + lambda_unique_ptr search( + state_align_search_init("state_align", recognizer.config, acousticModel, alignment.get()), + [](ps_search_t* search) { ps_search_free(search); }); + if (!search) throw runtime_error("Error creating search."); + + // Start recognition + error = acmod_start_utt(acousticModel); + if (error) throw runtime_error("Error starting utterance processing for alignment."); + + // Start search + ps_search_start(search.get()); + + // Process entire sound file + auto processBuffer = [&recognizer, &acousticModel, &search](const vector& buffer) { + const int16* nextSample = buffer.data(); + size_t remainingSamples = buffer.size(); + while (acmod_process_raw(acousticModel, &nextSample, &remainingSamples, false) > 0) { + while (acousticModel->n_feat_frame > 0) { + ps_search_step(search.get(), acousticModel->output_frame); + acmod_advance(acousticModel); + } + } + }; + processAudioStream(*audioStream.get(), processBuffer, reportProgress); + + // End search + ps_search_finish(search.get()); + + // End recognition + acmod_end_utt(acousticModel); + + // Extract phones with timestamps + char** phoneNames = recognizer.dict->mdef->ciname; + map result; + result[centiseconds(0)] = Phone::None; + for (ps_alignment_iter_t* it = ps_alignment_phones(alignment.get()); it; it = ps_alignment_iter_next(it)) { + // Get phone + ps_alignment_entry_t* phoneEntry = ps_alignment_iter_get(it); + s3cipid_t phoneId = phoneEntry->id.pid.cipid; + char* phoneName = phoneNames[phoneId]; + + // Get timing + int startFrame = phoneEntry->start; + int duration = phoneEntry->duration; + + // Add map entries + result[centiseconds(startFrame)] = stringToPhone(phoneName); + result[centiseconds(startFrame + duration)] = Phone::None; + } + return result; +} + +map detectPhones(std::function(void)> createAudioStream, function reportProgress) { // Discard Pocketsphinx output err_set_logfp(nullptr); @@ -179,19 +253,17 @@ map detectPhones(unique_ptr audioStream, funct path sphinxModelDirectory(getBinDirectory() / "res/sphinx"); auto config = createConfig(sphinxModelDirectory); - // Create phone recognizer - auto recognizer = createPhoneRecognizer(*config.get()); + // Create speech recognizer + auto recognizer = createSpeechRecognizer(*config.get()); - // Convert audio stream to the exact format PocketSphinx requires - audioStream = to16kHzMono(std::move(audioStream)); + // Recognize words + vector wordIds = recognizeWords(createAudioStream(), *recognizer.get(), reportProgress); - // Process data - processAudioStream(*audioStream.get(), *recognizer.get(), reportProgress); - - // Collect results into map - return getPhones(*recognizer.get()); + // Align the word's phones with speech + map result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), reportProgress); + return result; } catch (...) { - std::throw_with_nested(runtime_error("Error detecting phones via Pocketsphinx. " + errorMessage)); + std::throw_with_nested(runtime_error("Error performing speech recognition via Pocketsphinx. " + errorMessage)); } } diff --git a/src/phoneExtraction.h b/src/phoneExtraction.h index bfad07e..28969ab 100644 --- a/src/phoneExtraction.h +++ b/src/phoneExtraction.h @@ -2,8 +2,9 @@ #include #include +#include #include "audioInput/AudioStream.h" #include "Phone.h" #include "centiseconds.h" -std::map detectPhones(std::unique_ptr audioStream, std::function reportProgress); +std::map detectPhones(std::function(void)> createAudioStream, std::function reportProgress); diff --git a/tools.cmake b/tools.cmake index 9244f80..239b8b1 100644 --- a/tools.cmake +++ b/tools.cmake @@ -3,19 +3,21 @@ function(copy_and_install sourceGlob relativeTargetDirectory) file(GLOB sourcePaths "${sourceGlob}") foreach(sourcePath ${sourcePaths}) - # Set `fileName` - get_filename_component(fileName "${sourcePath}" NAME) - - # Copy file during build - add_custom_command(TARGET rhubarb POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy "${sourcePath}" "$/${relativeTargetDirectory}/${fileName}" - COMMENT "Creating '${relativeTargetDirectory}/${fileName}'" - ) + if(NOT IS_DIRECTORY ${sourcePath}) + # Set `fileName` + get_filename_component(fileName "${sourcePath}" NAME) - # Install file - install( - FILES "${sourcePath}" - DESTINATION "${relativeTargetDirectory}" - ) + # Copy file during build + add_custom_command(TARGET rhubarb POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy "${sourcePath}" "$/${relativeTargetDirectory}/${fileName}" + COMMENT "Creating '${relativeTargetDirectory}/${fileName}'" + ) + + # Install file + install( + FILES "${sourcePath}" + DESTINATION "${relativeTargetDirectory}" + ) + endif() endforeach() endfunction()