diff --git a/src/main.cpp b/src/main.cpp index 0ffaa95..b4511eb 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -2,16 +2,42 @@ #include "audio_input/WaveFileReader.h" #include "phone_extraction.h" -int main(int argc, char *argv[]) { - // Create audio stream - std::unique_ptr audioStream( - new WaveFileReader(R"(C:\Users\Daniel\Desktop\audio-test\test 16000Hz 1ch 16bit.wav)")); +using std::exception; +using std::string; +using std::unique_ptr; - std::map phones = detectPhones(std::move(audioStream)); +string getMessage(const exception& e) { + string result(e.what()); + try { + std::rethrow_if_nested(e); + } catch(const exception& innerException) { + result += "\n" + getMessage(innerException); + } catch(...) {} - for (auto& pair : phones) { - std::cout << pair.first << ": " << phoneToString(pair.second) << "\n"; + return result; +} + +unique_ptr createAudioStream(string fileName) { + try { + return unique_ptr(new WaveFileReader(fileName)); + } catch (...) { + std::throw_with_nested(std::runtime_error("Could not open sound file.") ); } +} - return 0; +int main(int argc, char *argv[]) { + try { + unique_ptr audioStream = createAudioStream(R"(C:\Users\Daniel\Desktop\audio-test\test 16000Hz 1ch 16bit.wav)"); + + std::map phones = detectPhones(std::move(audioStream)); + + for (auto &pair : phones) { + std::cout << pair.first << ": " << phoneToString(pair.second) << "\n"; + } + + return 0; + } catch (const exception& e) { + std::cout << "An error occurred. " << getMessage(e); + return 1; + } } \ No newline at end of file diff --git a/src/phone_extraction.cpp b/src/phone_extraction.cpp index 272da62..7a95925 100644 --- a/src/phone_extraction.cpp +++ b/src/phone_extraction.cpp @@ -12,6 +12,9 @@ using std::string; using std::map; using boost::filesystem::path; +template +using lambda_unique_ptr = std::unique_ptr>; + unique_ptr to16kHzMono(unique_ptr stream) { // Downmix, if required if (stream->getChannelCount() != 1) { @@ -20,7 +23,7 @@ unique_ptr to16kHzMono(unique_ptr stream) { // Downsample, if required if (stream->getFrameRate() < 16000) { - throw runtime_error("Sample rate must not be below 16kHz."); + throw runtime_error("Audio sample rate must not be below 16kHz."); } if (stream->getFrameRate() != 16000) { stream.reset(new SampleRateConverter(std::move(stream), 16000)); @@ -29,27 +32,14 @@ unique_ptr to16kHzMono(unique_ptr stream) { return stream; } -// Converts a float in the range -1..1 to a signed 16-bit int -int16_t floatSampleToInt16(float sample) { - sample = std::max(sample, -1.0f); - sample = std::min(sample, 1.0f); - return static_cast(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN); -} - -map detectPhones(unique_ptr audioStream) { - // Convert audio stream to the exact format PocketSphinx requires - audioStream = to16kHzMono(std::move(audioStream)); - - // Create PocketSphinx configuration - path binDirectory(getBinDirectory()); - path resDirectory(binDirectory.parent_path() / "res"); - shared_ptr config( +lambda_unique_ptr createConfig(path sphinxModelDirectory) { + lambda_unique_ptr config( cmd_ln_init( nullptr, ps_args(), true, // Set acoustic model - "-hmm", (resDirectory / "sphinx/acoustic_model").string().c_str(), + "-hmm", (sphinxModelDirectory / "acoustic_model").string().c_str(), // Set phonetic language model - "-allphone", (resDirectory / "sphinx/en-us-phone.lm.bin").string().c_str(), + "-allphone", (sphinxModelDirectory / "en-us-phone.lm.bin").string().c_str(), "-allphone_ci", "yes", // The following settings are taken from http://cmusphinx.sourceforge.net/wiki/phonemerecognition // Set beam width applied to every frame in Viterbi search @@ -62,14 +52,28 @@ map detectPhones(unique_ptr audioStream) { [](cmd_ln_t* config) { cmd_ln_free_r(config); }); if (!config) throw runtime_error("Error creating configuration."); - // Create phone recognizer - shared_ptr recognizer( - ps_init(config.get()), + return config; +} + +lambda_unique_ptr createPhoneRecognizer(cmd_ln_t& config) { + lambda_unique_ptr recognizer( + ps_init(&config), [](ps_decoder_t* recognizer) { ps_free(recognizer); }); if (!recognizer) throw runtime_error("Error creating speech recognizer."); + return recognizer; +} + +// Converts a float in the range -1..1 to a signed 16-bit int +int16_t floatSampleToInt16(float sample) { + sample = std::max(sample, -1.0f); + sample = std::min(sample, 1.0f); + return static_cast(((sample + 1) / 2) * (INT16_MAX - INT16_MIN) + INT16_MIN); +} + +void processAudioStream(AudioStream& audioStream16kHzMono, ps_decoder_t& recognizer) { // Start recognition - int error = ps_start_utt(recognizer.get()); + int error = ps_start_utt(&recognizer); if (error) throw runtime_error("Error starting utterance processing."); // Process entire sound file @@ -82,25 +86,27 @@ map detectPhones(unique_ptr audioStream) { buffer.clear(); while (buffer.size() < capacity) { float sample; - if (!audioStream->getNextSample(sample)) break; + if (!audioStream16kHzMono.getNextSample(sample)) break; buffer.push_back(floatSampleToInt16(sample)); } // Analyze buffer - int searchedFrameCount = ps_process_raw(recognizer.get(), buffer.data(), buffer.size(), false, false); - if (searchedFrameCount < 0) throw runtime_error("Error decoding raw audio data."); + int searchedFrameCount = ps_process_raw(&recognizer, buffer.data(), buffer.size(), false, false); + if (searchedFrameCount < 0) throw runtime_error("Error analyzing raw audio data."); sampleCount += buffer.size(); } while (buffer.size()); - error = ps_end_utt(recognizer.get()); + error = ps_end_utt(&recognizer); if (error) throw runtime_error("Error ending utterance processing."); - // Collect results into map +} + +map getPhones(ps_decoder_t& recognizer) { map result; ps_seg_t *segmentationIter; int32 score; int endFrame; - for (segmentationIter = ps_seg_iter(recognizer.get(), &score); segmentationIter; segmentationIter = ps_seg_next(segmentationIter)) { + for (segmentationIter = ps_seg_iter(&recognizer, &score); segmentationIter; segmentationIter = ps_seg_next(segmentationIter)) { // Get phone char const *phone = ps_seg_word(segmentationIter); @@ -113,4 +119,26 @@ map detectPhones(unique_ptr audioStream) { // Add dummy entry past the last phone result[centiseconds(endFrame + 1)] = Phone::None; return result; +}; + +map detectPhones(unique_ptr audioStream) { + try { + // Create PocketSphinx configuration + path sphinxModelDirectory(getBinDirectory().parent_path() / "res/sphinx"); + auto config = createConfig(sphinxModelDirectory); + + // Create phone recognizer + auto recognizer = createPhoneRecognizer(*config.get()); + + // Convert audio stream to the exact format PocketSphinx requires + audioStream = to16kHzMono(std::move(audioStream)); + + // Process data + processAudioStream(*audioStream.get(), *recognizer.get()); + + // Collect results into map + return getPhones(*recognizer.get()); + } catch (...) { + std::throw_with_nested(runtime_error("Error detecting phones via Pocketsphinx.")); + } }