From 0d488e8de23658724e9e12814d16ee43922bd3fa Mon Sep 17 00:00:00 2001 From: Daniel Wolf Date: Fri, 3 Jun 2016 21:07:49 +0200 Subject: [PATCH] Restored dialog option, this time based on language model This approach should be more robust and error-tolerant. --- CMakeLists.txt | 4 +- src/languageModels.cpp | 183 ++++++++++++++++++++++++++++++++++++++++ src/languageModels.h | 6 ++ src/main.cpp | 24 ++++++ src/phoneExtraction.cpp | 140 ++++++++++++++++-------------- src/phoneExtraction.h | 1 + src/platformTools.cpp | 11 +++ src/platformTools.h | 1 + src/stringTools.cpp | 10 +++ src/stringTools.h | 5 +- src/tupleHash.h | 40 +++++++++ 11 files changed, 357 insertions(+), 68 deletions(-) create mode 100644 src/languageModels.cpp create mode 100644 src/languageModels.h create mode 100644 src/tupleHash.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 5dfbe13..d072222 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -64,7 +64,7 @@ target_compile_options(cppFormat PRIVATE ${disableWarningsFlags}) set_target_properties(cppFormat PROPERTIES FOLDER lib) # ... sphinxbase -include_directories(SYSTEM "lib/sphinxbase-5prealpha-2015-08-05/include") +include_directories(SYSTEM "lib/sphinxbase-5prealpha-2015-08-05/include" "lib/sphinxbase-5prealpha-2015-08-05/src") FILE(GLOB_RECURSE sphinxbaseFiles "lib/sphinxbase-5prealpha-2015-08-05/src/libsphinxbase/*.c") add_library(sphinxbase ${sphinxbaseFiles}) target_compile_options(sphinxbase PRIVATE ${disableWarningsFlags}) @@ -192,6 +192,8 @@ set(SOURCE_FILES src/Exporter.cpp src/Exporter.h src/tokenization.cpp src/tokenization.h src/g2p.cpp src/g2p.h + src/languageModels.cpp src/languageModels.h + src/tupleHash.h ) add_executable(rhubarb ${SOURCE_FILES}) target_link_libraries(rhubarb ${Boost_LIBRARIES} cppFormat sphinxbase pocketSphinx flite) diff --git a/src/languageModels.cpp b/src/languageModels.cpp new file mode 100644 index 0000000..1574a64 --- /dev/null +++ b/src/languageModels.cpp @@ -0,0 +1,183 @@ +#include "languageModels.h" +#include +#include +#include +#include +#include +#include "platformTools.h" +#include +#include "appInfo.h" +#include +#include + +using std::string; +using std::u32string; +using std::vector; +using std::regex; +using std::map; +using std::tuple; +using std::make_tuple; +using std::get; +using std::endl; +using boost::filesystem::path; + +using unigram_t = string; +using bigram_t = tuple; +using trigram_t = tuple; + +map getUnigramCounts(const vector& words) { + map unigramCounts; + for (const unigram_t& unigram : words) { + ++unigramCounts[unigram]; + } + return unigramCounts; +} + +map getBigramCounts(const vector& words) { + map bigramCounts; + for (auto it = words.begin(); it < words.end() - 1; ++it) { + ++bigramCounts[bigram_t(*it, *(it + 1))]; + } + return bigramCounts; +} + +map getTrigramCounts(const vector& words) { + map trigramCounts; + if (words.size() >= 3) { + for (auto it = words.begin(); it < words.end() - 2; ++it) { + ++trigramCounts[trigram_t(*it, *(it + 1), *(it + 2))]; + } + } + return trigramCounts; +} + +map getUnigramProbabilities(const vector& words, const map& unigramCounts, const double deflator) { + map unigramProbabilities; + for (const auto& pair : unigramCounts) { + unigram_t unigram = get<0>(pair); + int unigramCount = get<1>(pair); + unigramProbabilities[unigram] = double(unigramCount) / words.size() * deflator; + } + return unigramProbabilities; +} + +map getBigramProbabilities(const map& unigramCounts, const map& bigramCounts, const double deflator) { + map bigramProbabilities; + for (const auto& pair : bigramCounts) { + bigram_t bigram = get<0>(pair); + int bigramCount = get<1>(pair); + int unigramPrefixCount = unigramCounts.at(get<0>(bigram)); + bigramProbabilities[bigram] = double(bigramCount) / unigramPrefixCount * deflator; + } + return bigramProbabilities; +} + +map getTrigramProbabilities(const map& bigramCounts, const map& trigramCounts, const double deflator) { + map trigramProbabilities; + for (const auto& pair : trigramCounts) { + trigram_t trigram = get<0>(pair); + int trigramCount = get<1>(pair); + int bigramPrefixCount = bigramCounts.at(bigram_t(get<0>(trigram), get<1>(trigram))); + trigramProbabilities[trigram] = double(trigramCount) / bigramPrefixCount * deflator; + } + return trigramProbabilities; +} + +map getUnigramBackoffWeights( + const map& unigramCounts, + const map& unigramProbabilities, + const map& bigramCounts, + const double discountMass) +{ + map unigramBackoffWeights; + for (const unigram_t& unigram : unigramCounts | boost::adaptors::map_keys) { + double denominator = 1; + for (const bigram_t& bigram : bigramCounts | boost::adaptors::map_keys) { + if (get<0>(bigram) == unigram) { + denominator -= unigramProbabilities.at(get<1>(bigram)); + } + } + unigramBackoffWeights[unigram] = discountMass / denominator; + } + return unigramBackoffWeights; +} + +map getBigramBackoffWeights( + const map& bigramCounts, + const map& bigramProbabilities, + const map& trigramCounts, + const double discountMass) +{ + map bigramBackoffWeights; + for (const bigram_t& bigram : bigramCounts | boost::adaptors::map_keys) { + double denominator = 1; + for (const trigram_t& trigram : trigramCounts | boost::adaptors::map_keys) { + if (bigram_t(get<0>(trigram), get<1>(trigram)) == bigram) { + denominator -= bigramProbabilities.at(bigram_t(get<1>(trigram), get<2>(trigram))); + } + } + bigramBackoffWeights[bigram] = discountMass / denominator; + } + return bigramBackoffWeights; +} + +void createLanguageModelFile(const vector& words, path filePath) { + const double discountMass = 0.5; + const double deflator = 1.0 - discountMass; + + map unigramCounts = getUnigramCounts(words); + map bigramCounts = getBigramCounts(words); + map trigramCounts = getTrigramCounts(words); + + map unigramProbabilities = getUnigramProbabilities(words, unigramCounts, deflator); + map bigramProbabilities = getBigramProbabilities(unigramCounts, bigramCounts, deflator); + map trigramProbabilities = getTrigramProbabilities(bigramCounts, trigramCounts, deflator); + + map unigramBackoffWeights = getUnigramBackoffWeights(unigramCounts, unigramProbabilities, bigramCounts, discountMass); + map bigramBackoffWeights = getBigramBackoffWeights(bigramCounts, bigramProbabilities, trigramCounts, discountMass); + + boost::filesystem::ofstream file(filePath); + file << "Generated by " << appName << " " << appVersion << endl << endl; + + file << "\\data\\" << endl; + file << "ngram 1=" << unigramCounts.size() << endl; + file << "ngram 2=" << bigramCounts.size() << endl; + file << "ngram 3=" << trigramCounts.size() << endl << endl; + + file.setf(std::ios::fixed, std::ios::floatfield); + file.precision(4); + file << "\\1-grams:" << endl; + for (const unigram_t& unigram : unigramCounts | boost::adaptors::map_keys) { + file << log10(unigramProbabilities.at(unigram)) + << " " << unigram + << " " << log10(unigramBackoffWeights.at(unigram)) << endl; + } + file << endl; + + file << "\\2-grams:" << endl; + for (const bigram_t& bigram : bigramCounts | boost::adaptors::map_keys) { + file << log10(bigramProbabilities.at(bigram)) + << " " << get<0>(bigram) << " " << get<1>(bigram) + << " " << log10(bigramBackoffWeights.at(bigram)) << endl; + } + file << endl; + + file << "\\3-grams:" << endl; + for (const trigram_t& trigram : trigramCounts | boost::adaptors::map_keys) { + file << log10(trigramProbabilities.at(trigram)) + << " " << get<0>(trigram) << " " << get<1>(trigram) << " " << get<2>(trigram) << endl; + } + file << endl; + + file << "\\end\\" << endl; +} + +lambda_unique_ptr createLanguageModel(const vector& words, logmath_t& logMath) { + path tempFilePath = getTempFilePath(); + createLanguageModelFile(words, tempFilePath); + auto deleteTempFile = gsl::finally([&]() { boost::filesystem::remove(tempFilePath); }); + + return lambda_unique_ptr( + ngram_model_read(nullptr, tempFilePath.string().c_str(), NGRAM_ARPA, &logMath), + [](ngram_model_t* lm) { ngram_model_free(lm); }); +} diff --git a/src/languageModels.h b/src/languageModels.h new file mode 100644 index 0000000..f67d57f --- /dev/null +++ b/src/languageModels.h @@ -0,0 +1,6 @@ +#pragma once +#include +#include +#include "tools.h" + +lambda_unique_ptr createLanguageModel(const std::vector& words, logmath_t& logMath); diff --git a/src/main.cpp b/src/main.cpp index 9621f45..131cbdf 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -12,9 +12,12 @@ #include #include "Exporter.h" #include "ContinuousTimeline.h" +#include +#include "stringTools.h" using std::exception; using std::string; +using std::u32string; using std::vector; using std::unique_ptr; using std::make_unique; @@ -75,6 +78,25 @@ void addFileSink(path path, logging::Level minLevel) { logging::addSink(levelFilter); } +u32string readTextFile(path filePath) { + if (!exists(filePath)) { + throw std::invalid_argument(fmt::format("File {} does not exist.", filePath)); + } + try { + boost::filesystem::ifstream file; + file.exceptions(std::ifstream::failbit | std::ifstream::badbit); + file.open(filePath); + string utf8Text((std::istreambuf_iterator(file)), std::istreambuf_iterator()); + try { + return utf8ToUtf32(utf8Text); + } catch (...) { + std::throw_with_nested(std::runtime_error(fmt::format("File encoding is not ASCII or UTF-8.", filePath))); + } + } catch (...) { + std::throw_with_nested(std::runtime_error(fmt::format("Error reading file {0}.", filePath))); + } +} + int main(int argc, char *argv[]) { auto pausableStderrSink = addPausableStdErrSink(logging::Level::Warn); pausableStderrSink->pause(); @@ -88,6 +110,7 @@ int main(int argc, char *argv[]) { tclap::ValuesConstraint logLevelConstraint(logLevels); tclap::ValueArg logLevel("", "logLevel", "The minimum log level to log", false, logging::Level::Debug, &logLevelConstraint, cmd); tclap::ValueArg logFileName("", "logFile", "The log file path.", false, string(), "string", cmd); + tclap::ValueArg dialogFile("d", "dialogFile", "A file containing the text of the dialog.", false, string(), "string", cmd); auto exportFormats = vector(ExportFormatConverter::get().getValues()); tclap::ValuesConstraint exportFormatConstraint(exportFormats); tclap::ValueArg exportFormat("f", "exportFormat", "The export format.", false, ExportFormat::TSV, &exportFormatConstraint, cmd); @@ -117,6 +140,7 @@ int main(int argc, char *argv[]) { ProgressBar progressBar; phones = detectPhones( createAudioStream(inputFileName.getValue()), + dialogFile.isSet() ? readTextFile(path(dialogFile.getValue())) : boost::optional(), progressBar); } std::cerr << "Done" << std::endl; diff --git a/src/phoneExtraction.cpp b/src/phoneExtraction.cpp index caa2168..8201899 100644 --- a/src/phoneExtraction.cpp +++ b/src/phoneExtraction.cpp @@ -1,6 +1,5 @@ #include #include -#include #include "phoneExtraction.h" #include "audio/SampleRateConverter.h" #include "platformTools.h" @@ -14,6 +13,9 @@ #include #include