Added (primitive) option to explicitly supply the dialog

This commit is contained in:
Daniel Wolf 2016-02-09 22:08:11 +01:00
parent ee9e23a6d2
commit 05ef692706
3 changed files with 80 additions and 13 deletions

View File

@ -1,6 +1,7 @@
#include <iostream>
#include <boost/property_tree/ptree.hpp>
#include <boost/property_tree/xml_parser.hpp>
#include <boost/optional.hpp>
#include <format.h>
#include <tclap/CmdLine.h>
#include "audioInput/WaveFileReader.h"
@ -69,6 +70,14 @@ ptree createXmlTree(const path& filePath, const map<centiseconds, Phone>& phones
return tree;
}
// Tell TCLAP how to handle boost::optional
namespace TCLAP {
template<>
struct ArgTraits<boost::optional<string>> {
typedef TCLAP::StringLike ValueCategory;
};
}
int main(int argc, char *argv[]) {
// Define command-line parameters
const char argumentValueSeparator = ' ';
@ -76,6 +85,7 @@ int main(int argc, char *argv[]) {
cmd.setExceptionHandling(false);
cmd.setOutput(new NiceCmdLineOutput());
TCLAP::UnlabeledValueArg<string> inputFileName("inputFile", "The input file. Must be a sound file in WAVE format.", true, "", "string", cmd);
TCLAP::ValueArg<boost::optional<string>> dialog("d", "dialog", "The text of the dialog.", false, boost::optional<string>(), "string", cmd);
try {
// Parse command line
@ -88,7 +98,10 @@ int main(int argc, char *argv[]) {
map<centiseconds, Phone> phones;
{
ProgressBar progressBar;
phones = detectPhones([&inputFileName]() { return createAudioStream(inputFileName.getValue()); }, progressBar);
phones = detectPhones(
[&inputFileName]() { return createAudioStream(inputFileName.getValue()); },
dialog.getValue(),
progressBar);
}
std::cerr << "Done" << std::endl;

View File

@ -7,6 +7,8 @@
#include "platformTools.h"
#include "tools.h"
#include <format.h>
#include <s3types.h>
#include <regex>
extern "C" {
#include <pocketsphinx.h>
@ -17,6 +19,7 @@ extern "C" {
}
using std::runtime_error;
using std::invalid_argument;
using std::unique_ptr;
using std::shared_ptr;
using std::string;
@ -24,6 +27,8 @@ using std::vector;
using std::map;
using boost::filesystem::path;
using std::function;
using std::regex;
using std::regex_replace;
unique_ptr<AudioStream> to16kHzMono(unique_ptr<AudioStream> stream) {
// Downmix, if required
@ -33,7 +38,7 @@ unique_ptr<AudioStream> to16kHzMono(unique_ptr<AudioStream> stream) {
// Downsample, if required
if (stream->getFrameRate() < 16000) {
throw runtime_error("Audio sample rate must not be below 16kHz.");
throw invalid_argument("Audio sample rate must not be below 16kHz.");
}
if (stream->getFrameRate() != 16000) {
stream.reset(new SampleRateConverter(std::move(stream), 16000));
@ -133,7 +138,7 @@ void sphinxErrorCallback(void* user_data, err_lvl_t errorLevel, const char* form
*errorString += message;
}
vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
vector<string> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
// Convert audio stream to the exact format PocketSphinx requires
audioStream = to16kHzMono(std::move(audioStream));
@ -153,19 +158,55 @@ vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t
if (error) throw runtime_error("Error ending utterance processing for word recognition.");
// Collect words
vector<s3wid_t> result;
vector<string> result;
int32_t score;
for (ps_seg_t* it = ps_seg_iter(&recognizer, &score); it; it = ps_seg_next(it)) {
// Get word
const char* word = ps_seg_word(it);
s3wid_t wordId = dict_wordid(recognizer.dict, word);
result.push_back(wordId);
result.push_back(word);
}
return result;
}
// Splits dialog into words, doing minimal preprocessing.
// A robust solution should use TTS logic to cope with numbers, abbreviations, unknown words etc.
vector<string> extractDialogWords(string dialog) {
// Convert to lower case
boost::algorithm::to_lower(dialog);
// Insert silences where appropriate
dialog = regex_replace(dialog, regex("[,;.:!?] "), " <sil> ");
// Remove all undesired characters
dialog = regex_replace(dialog, regex("[^a-z.'\\-0-9<>]"), " ");
// Collapse whitespace
dialog = regex_replace(dialog, regex("\\s+"), " ");
// Trim
boost::algorithm::trim(dialog);
// Ugly hack: Remove trailing period
if (boost::algorithm::ends_with(dialog, ".")) {
dialog.pop_back();
}
// Split into words
vector<string> result;
boost::algorithm::split(result, dialog, boost::is_space());
return result;
}
vector<s3wid_t> getWordIds(const vector<string>& words, dict_t& dictionary) {
vector<s3wid_t> result;
for (const string& word : words) {
s3wid_t wordId = dict_wordid(&dictionary, word.c_str());
if (wordId == BAD_S3WID) throw invalid_argument(fmt::format("Unknown word '{}'.", word));
result.push_back(wordId);
}
return result;
}
map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
// Create alignment list
lambda_unique_ptr<ps_alignment_t> alignment(
@ -236,7 +277,11 @@ map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, uniqu
return result;
}
map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, ProgressSink& progressSink) {
map<centiseconds, Phone> detectPhones(
std::function<std::unique_ptr<AudioStream>(void)> createAudioStream,
boost::optional<std::string> dialog,
ProgressSink& progressSink)
{
// Discard Pocketsphinx output
err_set_logfp(nullptr);
@ -253,11 +298,16 @@ map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>
auto recognizer = createSpeechRecognizer(*config.get());
ProgressMerger progressMerger(progressSink);
ProgressSink& wordRecognitionProgressSink = progressMerger.addSink(1.0);
ProgressSink& wordRecognitionProgressSink = progressMerger.addSink(dialog ? 0.0 : 1.0);
ProgressSink& alignmentProgressSink = progressMerger.addSink(0.5);
// Recognize words
vector<s3wid_t> wordIds = recognizeWords(createAudioStream(), *recognizer.get(), wordRecognitionProgressSink);
// Get words
vector<string> words = dialog
? extractDialogWords(*dialog)
: recognizeWords(createAudioStream(), *recognizer.get(), wordRecognitionProgressSink);
// Look up words in dictionary
vector<s3wid_t> wordIds = getWordIds(words, *recognizer->dict);
// Align the word's phones with speech
map<centiseconds, Phone> result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), alignmentProgressSink);

View File

@ -7,5 +7,9 @@
#include "Phone.h"
#include "centiseconds.h"
#include "progressBar.h"
#include <boost/optional/optional.hpp>
std::map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, ProgressSink& progressSink);
std::map<centiseconds, Phone> detectPhones(
std::function<std::unique_ptr<AudioStream>(void)> createAudioStream,
boost::optional<std::string> dialog,
ProgressSink& progressSink);