Added (primitive) option to explicitly supply the dialog
This commit is contained in:
parent
ee9e23a6d2
commit
05ef692706
15
src/main.cpp
15
src/main.cpp
|
@ -1,6 +1,7 @@
|
|||
#include <iostream>
|
||||
#include <boost/property_tree/ptree.hpp>
|
||||
#include <boost/property_tree/xml_parser.hpp>
|
||||
#include <boost/optional.hpp>
|
||||
#include <format.h>
|
||||
#include <tclap/CmdLine.h>
|
||||
#include "audioInput/WaveFileReader.h"
|
||||
|
@ -69,6 +70,14 @@ ptree createXmlTree(const path& filePath, const map<centiseconds, Phone>& phones
|
|||
return tree;
|
||||
}
|
||||
|
||||
// Tell TCLAP how to handle boost::optional
|
||||
namespace TCLAP {
|
||||
template<>
|
||||
struct ArgTraits<boost::optional<string>> {
|
||||
typedef TCLAP::StringLike ValueCategory;
|
||||
};
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
// Define command-line parameters
|
||||
const char argumentValueSeparator = ' ';
|
||||
|
@ -76,6 +85,7 @@ int main(int argc, char *argv[]) {
|
|||
cmd.setExceptionHandling(false);
|
||||
cmd.setOutput(new NiceCmdLineOutput());
|
||||
TCLAP::UnlabeledValueArg<string> inputFileName("inputFile", "The input file. Must be a sound file in WAVE format.", true, "", "string", cmd);
|
||||
TCLAP::ValueArg<boost::optional<string>> dialog("d", "dialog", "The text of the dialog.", false, boost::optional<string>(), "string", cmd);
|
||||
|
||||
try {
|
||||
// Parse command line
|
||||
|
@ -88,7 +98,10 @@ int main(int argc, char *argv[]) {
|
|||
map<centiseconds, Phone> phones;
|
||||
{
|
||||
ProgressBar progressBar;
|
||||
phones = detectPhones([&inputFileName]() { return createAudioStream(inputFileName.getValue()); }, progressBar);
|
||||
phones = detectPhones(
|
||||
[&inputFileName]() { return createAudioStream(inputFileName.getValue()); },
|
||||
dialog.getValue(),
|
||||
progressBar);
|
||||
}
|
||||
std::cerr << "Done" << std::endl;
|
||||
|
||||
|
|
|
@ -7,6 +7,8 @@
|
|||
#include "platformTools.h"
|
||||
#include "tools.h"
|
||||
#include <format.h>
|
||||
#include <s3types.h>
|
||||
#include <regex>
|
||||
|
||||
extern "C" {
|
||||
#include <pocketsphinx.h>
|
||||
|
@ -17,6 +19,7 @@ extern "C" {
|
|||
}
|
||||
|
||||
using std::runtime_error;
|
||||
using std::invalid_argument;
|
||||
using std::unique_ptr;
|
||||
using std::shared_ptr;
|
||||
using std::string;
|
||||
|
@ -24,6 +27,8 @@ using std::vector;
|
|||
using std::map;
|
||||
using boost::filesystem::path;
|
||||
using std::function;
|
||||
using std::regex;
|
||||
using std::regex_replace;
|
||||
|
||||
unique_ptr<AudioStream> to16kHzMono(unique_ptr<AudioStream> stream) {
|
||||
// Downmix, if required
|
||||
|
@ -33,7 +38,7 @@ unique_ptr<AudioStream> to16kHzMono(unique_ptr<AudioStream> stream) {
|
|||
|
||||
// Downsample, if required
|
||||
if (stream->getFrameRate() < 16000) {
|
||||
throw runtime_error("Audio sample rate must not be below 16kHz.");
|
||||
throw invalid_argument("Audio sample rate must not be below 16kHz.");
|
||||
}
|
||||
if (stream->getFrameRate() != 16000) {
|
||||
stream.reset(new SampleRateConverter(std::move(stream), 16000));
|
||||
|
@ -133,7 +138,7 @@ void sphinxErrorCallback(void* user_data, err_lvl_t errorLevel, const char* form
|
|||
*errorString += message;
|
||||
}
|
||||
|
||||
vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
|
||||
vector<string> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
|
||||
// Convert audio stream to the exact format PocketSphinx requires
|
||||
audioStream = to16kHzMono(std::move(audioStream));
|
||||
|
||||
|
@ -153,19 +158,55 @@ vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t
|
|||
if (error) throw runtime_error("Error ending utterance processing for word recognition.");
|
||||
|
||||
// Collect words
|
||||
vector<s3wid_t> result;
|
||||
vector<string> result;
|
||||
int32_t score;
|
||||
for (ps_seg_t* it = ps_seg_iter(&recognizer, &score); it; it = ps_seg_next(it)) {
|
||||
// Get word
|
||||
const char* word = ps_seg_word(it);
|
||||
s3wid_t wordId = dict_wordid(recognizer.dict, word);
|
||||
|
||||
result.push_back(wordId);
|
||||
result.push_back(word);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// Splits dialog into words, doing minimal preprocessing.
|
||||
// A robust solution should use TTS logic to cope with numbers, abbreviations, unknown words etc.
|
||||
vector<string> extractDialogWords(string dialog) {
|
||||
// Convert to lower case
|
||||
boost::algorithm::to_lower(dialog);
|
||||
|
||||
// Insert silences where appropriate
|
||||
dialog = regex_replace(dialog, regex("[,;.:!?] "), " <sil> ");
|
||||
|
||||
// Remove all undesired characters
|
||||
dialog = regex_replace(dialog, regex("[^a-z.'\\-0-9<>]"), " ");
|
||||
|
||||
// Collapse whitespace
|
||||
dialog = regex_replace(dialog, regex("\\s+"), " ");
|
||||
|
||||
// Trim
|
||||
boost::algorithm::trim(dialog);
|
||||
|
||||
// Ugly hack: Remove trailing period
|
||||
if (boost::algorithm::ends_with(dialog, ".")) {
|
||||
dialog.pop_back();
|
||||
}
|
||||
|
||||
// Split into words
|
||||
vector<string> result;
|
||||
boost::algorithm::split(result, dialog, boost::is_space());
|
||||
return result;
|
||||
}
|
||||
|
||||
vector<s3wid_t> getWordIds(const vector<string>& words, dict_t& dictionary) {
|
||||
vector<s3wid_t> result;
|
||||
for (const string& word : words) {
|
||||
s3wid_t wordId = dict_wordid(&dictionary, word.c_str());
|
||||
if (wordId == BAD_S3WID) throw invalid_argument(fmt::format("Unknown word '{}'.", word));
|
||||
result.push_back(wordId);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
|
||||
// Create alignment list
|
||||
lambda_unique_ptr<ps_alignment_t> alignment(
|
||||
|
@ -236,7 +277,11 @@ map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, uniqu
|
|||
return result;
|
||||
}
|
||||
|
||||
map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, ProgressSink& progressSink) {
|
||||
map<centiseconds, Phone> detectPhones(
|
||||
std::function<std::unique_ptr<AudioStream>(void)> createAudioStream,
|
||||
boost::optional<std::string> dialog,
|
||||
ProgressSink& progressSink)
|
||||
{
|
||||
// Discard Pocketsphinx output
|
||||
err_set_logfp(nullptr);
|
||||
|
||||
|
@ -253,11 +298,16 @@ map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>
|
|||
auto recognizer = createSpeechRecognizer(*config.get());
|
||||
|
||||
ProgressMerger progressMerger(progressSink);
|
||||
ProgressSink& wordRecognitionProgressSink = progressMerger.addSink(1.0);
|
||||
ProgressSink& wordRecognitionProgressSink = progressMerger.addSink(dialog ? 0.0 : 1.0);
|
||||
ProgressSink& alignmentProgressSink = progressMerger.addSink(0.5);
|
||||
|
||||
// Recognize words
|
||||
vector<s3wid_t> wordIds = recognizeWords(createAudioStream(), *recognizer.get(), wordRecognitionProgressSink);
|
||||
// Get words
|
||||
vector<string> words = dialog
|
||||
? extractDialogWords(*dialog)
|
||||
: recognizeWords(createAudioStream(), *recognizer.get(), wordRecognitionProgressSink);
|
||||
|
||||
// Look up words in dictionary
|
||||
vector<s3wid_t> wordIds = getWordIds(words, *recognizer->dict);
|
||||
|
||||
// Align the word's phones with speech
|
||||
map<centiseconds, Phone> result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), alignmentProgressSink);
|
||||
|
|
|
@ -7,5 +7,9 @@
|
|||
#include "Phone.h"
|
||||
#include "centiseconds.h"
|
||||
#include "progressBar.h"
|
||||
#include <boost/optional/optional.hpp>
|
||||
|
||||
std::map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, ProgressSink& progressSink);
|
||||
std::map<centiseconds, Phone> detectPhones(
|
||||
std::function<std::unique_ptr<AudioStream>(void)> createAudioStream,
|
||||
boost::optional<std::string> dialog,
|
||||
ProgressSink& progressSink);
|
||||
|
|
Loading…
Reference in New Issue