Added (primitive) option to explicitly supply the dialog
This commit is contained in:
parent
ee9e23a6d2
commit
05ef692706
15
src/main.cpp
15
src/main.cpp
|
@ -1,6 +1,7 @@
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <boost/property_tree/ptree.hpp>
|
#include <boost/property_tree/ptree.hpp>
|
||||||
#include <boost/property_tree/xml_parser.hpp>
|
#include <boost/property_tree/xml_parser.hpp>
|
||||||
|
#include <boost/optional.hpp>
|
||||||
#include <format.h>
|
#include <format.h>
|
||||||
#include <tclap/CmdLine.h>
|
#include <tclap/CmdLine.h>
|
||||||
#include "audioInput/WaveFileReader.h"
|
#include "audioInput/WaveFileReader.h"
|
||||||
|
@ -69,6 +70,14 @@ ptree createXmlTree(const path& filePath, const map<centiseconds, Phone>& phones
|
||||||
return tree;
|
return tree;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tell TCLAP how to handle boost::optional
|
||||||
|
namespace TCLAP {
|
||||||
|
template<>
|
||||||
|
struct ArgTraits<boost::optional<string>> {
|
||||||
|
typedef TCLAP::StringLike ValueCategory;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char *argv[]) {
|
int main(int argc, char *argv[]) {
|
||||||
// Define command-line parameters
|
// Define command-line parameters
|
||||||
const char argumentValueSeparator = ' ';
|
const char argumentValueSeparator = ' ';
|
||||||
|
@ -76,6 +85,7 @@ int main(int argc, char *argv[]) {
|
||||||
cmd.setExceptionHandling(false);
|
cmd.setExceptionHandling(false);
|
||||||
cmd.setOutput(new NiceCmdLineOutput());
|
cmd.setOutput(new NiceCmdLineOutput());
|
||||||
TCLAP::UnlabeledValueArg<string> inputFileName("inputFile", "The input file. Must be a sound file in WAVE format.", true, "", "string", cmd);
|
TCLAP::UnlabeledValueArg<string> inputFileName("inputFile", "The input file. Must be a sound file in WAVE format.", true, "", "string", cmd);
|
||||||
|
TCLAP::ValueArg<boost::optional<string>> dialog("d", "dialog", "The text of the dialog.", false, boost::optional<string>(), "string", cmd);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// Parse command line
|
// Parse command line
|
||||||
|
@ -88,7 +98,10 @@ int main(int argc, char *argv[]) {
|
||||||
map<centiseconds, Phone> phones;
|
map<centiseconds, Phone> phones;
|
||||||
{
|
{
|
||||||
ProgressBar progressBar;
|
ProgressBar progressBar;
|
||||||
phones = detectPhones([&inputFileName]() { return createAudioStream(inputFileName.getValue()); }, progressBar);
|
phones = detectPhones(
|
||||||
|
[&inputFileName]() { return createAudioStream(inputFileName.getValue()); },
|
||||||
|
dialog.getValue(),
|
||||||
|
progressBar);
|
||||||
}
|
}
|
||||||
std::cerr << "Done" << std::endl;
|
std::cerr << "Done" << std::endl;
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,8 @@
|
||||||
#include "platformTools.h"
|
#include "platformTools.h"
|
||||||
#include "tools.h"
|
#include "tools.h"
|
||||||
#include <format.h>
|
#include <format.h>
|
||||||
|
#include <s3types.h>
|
||||||
|
#include <regex>
|
||||||
|
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#include <pocketsphinx.h>
|
#include <pocketsphinx.h>
|
||||||
|
@ -17,6 +19,7 @@ extern "C" {
|
||||||
}
|
}
|
||||||
|
|
||||||
using std::runtime_error;
|
using std::runtime_error;
|
||||||
|
using std::invalid_argument;
|
||||||
using std::unique_ptr;
|
using std::unique_ptr;
|
||||||
using std::shared_ptr;
|
using std::shared_ptr;
|
||||||
using std::string;
|
using std::string;
|
||||||
|
@ -24,6 +27,8 @@ using std::vector;
|
||||||
using std::map;
|
using std::map;
|
||||||
using boost::filesystem::path;
|
using boost::filesystem::path;
|
||||||
using std::function;
|
using std::function;
|
||||||
|
using std::regex;
|
||||||
|
using std::regex_replace;
|
||||||
|
|
||||||
unique_ptr<AudioStream> to16kHzMono(unique_ptr<AudioStream> stream) {
|
unique_ptr<AudioStream> to16kHzMono(unique_ptr<AudioStream> stream) {
|
||||||
// Downmix, if required
|
// Downmix, if required
|
||||||
|
@ -33,7 +38,7 @@ unique_ptr<AudioStream> to16kHzMono(unique_ptr<AudioStream> stream) {
|
||||||
|
|
||||||
// Downsample, if required
|
// Downsample, if required
|
||||||
if (stream->getFrameRate() < 16000) {
|
if (stream->getFrameRate() < 16000) {
|
||||||
throw runtime_error("Audio sample rate must not be below 16kHz.");
|
throw invalid_argument("Audio sample rate must not be below 16kHz.");
|
||||||
}
|
}
|
||||||
if (stream->getFrameRate() != 16000) {
|
if (stream->getFrameRate() != 16000) {
|
||||||
stream.reset(new SampleRateConverter(std::move(stream), 16000));
|
stream.reset(new SampleRateConverter(std::move(stream), 16000));
|
||||||
|
@ -133,7 +138,7 @@ void sphinxErrorCallback(void* user_data, err_lvl_t errorLevel, const char* form
|
||||||
*errorString += message;
|
*errorString += message;
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
|
vector<string> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
|
||||||
// Convert audio stream to the exact format PocketSphinx requires
|
// Convert audio stream to the exact format PocketSphinx requires
|
||||||
audioStream = to16kHzMono(std::move(audioStream));
|
audioStream = to16kHzMono(std::move(audioStream));
|
||||||
|
|
||||||
|
@ -153,19 +158,55 @@ vector<s3wid_t> recognizeWords(unique_ptr<AudioStream> audioStream, ps_decoder_t
|
||||||
if (error) throw runtime_error("Error ending utterance processing for word recognition.");
|
if (error) throw runtime_error("Error ending utterance processing for word recognition.");
|
||||||
|
|
||||||
// Collect words
|
// Collect words
|
||||||
vector<s3wid_t> result;
|
vector<string> result;
|
||||||
int32_t score;
|
int32_t score;
|
||||||
for (ps_seg_t* it = ps_seg_iter(&recognizer, &score); it; it = ps_seg_next(it)) {
|
for (ps_seg_t* it = ps_seg_iter(&recognizer, &score); it; it = ps_seg_next(it)) {
|
||||||
// Get word
|
|
||||||
const char* word = ps_seg_word(it);
|
const char* word = ps_seg_word(it);
|
||||||
s3wid_t wordId = dict_wordid(recognizer.dict, word);
|
result.push_back(word);
|
||||||
|
|
||||||
result.push_back(wordId);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Splits dialog into words, doing minimal preprocessing.
|
||||||
|
// A robust solution should use TTS logic to cope with numbers, abbreviations, unknown words etc.
|
||||||
|
vector<string> extractDialogWords(string dialog) {
|
||||||
|
// Convert to lower case
|
||||||
|
boost::algorithm::to_lower(dialog);
|
||||||
|
|
||||||
|
// Insert silences where appropriate
|
||||||
|
dialog = regex_replace(dialog, regex("[,;.:!?] "), " <sil> ");
|
||||||
|
|
||||||
|
// Remove all undesired characters
|
||||||
|
dialog = regex_replace(dialog, regex("[^a-z.'\\-0-9<>]"), " ");
|
||||||
|
|
||||||
|
// Collapse whitespace
|
||||||
|
dialog = regex_replace(dialog, regex("\\s+"), " ");
|
||||||
|
|
||||||
|
// Trim
|
||||||
|
boost::algorithm::trim(dialog);
|
||||||
|
|
||||||
|
// Ugly hack: Remove trailing period
|
||||||
|
if (boost::algorithm::ends_with(dialog, ".")) {
|
||||||
|
dialog.pop_back();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split into words
|
||||||
|
vector<string> result;
|
||||||
|
boost::algorithm::split(result, dialog, boost::is_space());
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
vector<s3wid_t> getWordIds(const vector<string>& words, dict_t& dictionary) {
|
||||||
|
vector<s3wid_t> result;
|
||||||
|
for (const string& word : words) {
|
||||||
|
s3wid_t wordId = dict_wordid(&dictionary, word.c_str());
|
||||||
|
if (wordId == BAD_S3WID) throw invalid_argument(fmt::format("Unknown word '{}'.", word));
|
||||||
|
result.push_back(wordId);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
|
map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, unique_ptr<AudioStream> audioStream, ps_decoder_t& recognizer, ProgressSink& progressSink) {
|
||||||
// Create alignment list
|
// Create alignment list
|
||||||
lambda_unique_ptr<ps_alignment_t> alignment(
|
lambda_unique_ptr<ps_alignment_t> alignment(
|
||||||
|
@ -236,7 +277,11 @@ map<centiseconds, Phone> getPhoneAlignment(const vector<s3wid_t>& wordIds, uniqu
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, ProgressSink& progressSink) {
|
map<centiseconds, Phone> detectPhones(
|
||||||
|
std::function<std::unique_ptr<AudioStream>(void)> createAudioStream,
|
||||||
|
boost::optional<std::string> dialog,
|
||||||
|
ProgressSink& progressSink)
|
||||||
|
{
|
||||||
// Discard Pocketsphinx output
|
// Discard Pocketsphinx output
|
||||||
err_set_logfp(nullptr);
|
err_set_logfp(nullptr);
|
||||||
|
|
||||||
|
@ -253,11 +298,16 @@ map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>
|
||||||
auto recognizer = createSpeechRecognizer(*config.get());
|
auto recognizer = createSpeechRecognizer(*config.get());
|
||||||
|
|
||||||
ProgressMerger progressMerger(progressSink);
|
ProgressMerger progressMerger(progressSink);
|
||||||
ProgressSink& wordRecognitionProgressSink = progressMerger.addSink(1.0);
|
ProgressSink& wordRecognitionProgressSink = progressMerger.addSink(dialog ? 0.0 : 1.0);
|
||||||
ProgressSink& alignmentProgressSink = progressMerger.addSink(0.5);
|
ProgressSink& alignmentProgressSink = progressMerger.addSink(0.5);
|
||||||
|
|
||||||
// Recognize words
|
// Get words
|
||||||
vector<s3wid_t> wordIds = recognizeWords(createAudioStream(), *recognizer.get(), wordRecognitionProgressSink);
|
vector<string> words = dialog
|
||||||
|
? extractDialogWords(*dialog)
|
||||||
|
: recognizeWords(createAudioStream(), *recognizer.get(), wordRecognitionProgressSink);
|
||||||
|
|
||||||
|
// Look up words in dictionary
|
||||||
|
vector<s3wid_t> wordIds = getWordIds(words, *recognizer->dict);
|
||||||
|
|
||||||
// Align the word's phones with speech
|
// Align the word's phones with speech
|
||||||
map<centiseconds, Phone> result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), alignmentProgressSink);
|
map<centiseconds, Phone> result = getPhoneAlignment(wordIds, createAudioStream(), *recognizer.get(), alignmentProgressSink);
|
||||||
|
|
|
@ -7,5 +7,9 @@
|
||||||
#include "Phone.h"
|
#include "Phone.h"
|
||||||
#include "centiseconds.h"
|
#include "centiseconds.h"
|
||||||
#include "progressBar.h"
|
#include "progressBar.h"
|
||||||
|
#include <boost/optional/optional.hpp>
|
||||||
|
|
||||||
std::map<centiseconds, Phone> detectPhones(std::function<std::unique_ptr<AudioStream>(void)> createAudioStream, ProgressSink& progressSink);
|
std::map<centiseconds, Phone> detectPhones(
|
||||||
|
std::function<std::unique_ptr<AudioStream>(void)> createAudioStream,
|
||||||
|
boost::optional<std::string> dialog,
|
||||||
|
ProgressSink& progressSink);
|
||||||
|
|
Loading…
Reference in New Issue