diff --git a/CMakeLists.txt b/CMakeLists.txt index 2f6983d..2ca91b2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -191,6 +191,7 @@ set(SOURCE_FILES src/pairs.h src/Exporter.cpp src/Exporter.h src/ascii.cpp src/ascii.h + src/tokenization.cpp src/tokenization.h ) add_executable(rhubarb ${SOURCE_FILES}) target_link_libraries(rhubarb ${Boost_LIBRARIES} cppFormat sphinxbase pocketSphinx flite) @@ -205,15 +206,17 @@ set(TEST_FILES tests/ContinuousTimelineTests.cpp tests/pairsTests.cpp tests/asciiTests.cpp + tests/tokenizationTests.cpp src/stringTools.cpp src/stringTools.h src/Timeline.h src/TimeRange.cpp src/TimeRange.h src/centiseconds.cpp src/centiseconds.h src/pairs.h src/ascii.cpp src/ascii.h + src/tokenization.cpp src/tokenization.h ) add_executable(runTests ${TEST_FILES}) -target_link_libraries(runTests gtest gmock gmock_main) +target_link_libraries(runTests gtest gmock gmock_main flite) set(CPACK_PACKAGE_NAME ${appName}) string(REPLACE " " "-" CPACK_PACKAGE_NAME "${CPACK_PACKAGE_NAME}") diff --git a/src/tokenization.cpp b/src/tokenization.cpp new file mode 100644 index 0000000..aa1d944 --- /dev/null +++ b/src/tokenization.cpp @@ -0,0 +1,84 @@ +#include "tokenization.h" +#include "tools.h" +#include "ascii.h" +#include + +extern "C" { +#include +#include +#include +} + +using std::runtime_error; +using std::u32string; +using std::string; +using std::vector; +using std::regex; +using std::pair; + +lambda_unique_ptr createDummyVoice() { + lambda_unique_ptr voice(new_voice(), [](cst_voice* voice) { delete_voice(voice); }); + voice->name = "dummy_voice"; + usenglish_init(voice.get()); + cst_lexicon *lexicon = cmu_lex_init(); + feat_set(voice->features, "lexicon", lexicon_val(lexicon)); + return voice; +} + +static const cst_synth_module synth_method_normalize[] = { + { "tokenizer_func", default_tokenization }, // split text into tokens + { "textanalysis_func", default_textanalysis }, // transform tokens into words + { nullptr, nullptr } +}; + +vector tokenizeViaFlite(const string& text) { + // Create utterance object with text + lambda_unique_ptr utterance(new_utterance(), [](cst_utterance* utterance) { delete_utterance(utterance); }); + utt_set_input_text(utterance.get(), text.c_str()); + lambda_unique_ptr voice = createDummyVoice(); + utt_init(utterance.get(), voice.get()); + + // Perform tokenization and text normalization + if (!apply_synth_method(utterance.get(), synth_method_normalize)) { + throw runtime_error("Error normalizing text using Flite."); + } + + vector result; + for (cst_item* item = relation_head(utt_relation(utterance.get(), "Word")); item; item = item_next(item)) { + const char* word = item_feat_string(item, "name"); + result.push_back(word); + } + return result; +} + +vector tokenizeText(const u32string& text) { + vector words = tokenizeViaFlite(toASCII(text)); + + // Join words separated by apostophes + for (int i = words.size() - 1; i > 0; --i) { + if (words[i].size() > 0 && words[i][0] == '\'') { + words[i - 1].append(words[i]); + words.erase(words.begin() + i); + } + } + + // Turn some symbols into words, remove the rest + vector> replacements { + { regex("&"), "and" }, + { regex("\\*"), "times" }, + { regex("\\+"), "plus" }, + { regex("="), "equals" }, + { regex("@"), "at" }, + { regex("[^a-z']"), "" } + }; + for (size_t i = 0; i < words.size(); ++i) { + for (const auto& replacement : replacements) { + words[i] = std::regex_replace(words[i], replacement.first, replacement.second); + } + } + + // Remove empty words + words.erase(std::remove_if(words.begin(), words.end(), [](const string& s) { return s.empty(); }), words.end()); + + return words; +} diff --git a/src/tokenization.h b/src/tokenization.h new file mode 100644 index 0000000..55fcfb9 --- /dev/null +++ b/src/tokenization.h @@ -0,0 +1,5 @@ +#pragma once + +#include + +std::vector tokenizeText(const std::u32string& text); \ No newline at end of file diff --git a/tests/tokenizationTests.cpp b/tests/tokenizationTests.cpp new file mode 100644 index 0000000..14973ad --- /dev/null +++ b/tests/tokenizationTests.cpp @@ -0,0 +1,71 @@ +#include +#include "tokenization.h" +#include + +using namespace testing; +using std::string; +using std::u32string; +using std::vector; +using std::regex; + +TEST(tokenizeText, simpleCases) { + EXPECT_THAT(tokenizeText(U""), IsEmpty()); + EXPECT_THAT(tokenizeText(U" \t\n\r\n "), IsEmpty()); + EXPECT_THAT( + tokenizeText(U"Wit is educated insolence."), + ElementsAre("wit", "is", "educated", "insolence") + ); +} + +TEST(tokenizeText, numbers) { + EXPECT_THAT( + tokenizeText(U"Henry V died at 36."), + ElementsAre("henry", "the", "fifth", "died", "at", "thirty", "six") + ); + EXPECT_THAT( + tokenizeText(U"I spent $4.50 on gum."), + ElementsAre("i", "spent", "four", "dollars", "fifty", "cents", "on", "gum") + ); + EXPECT_THAT( + tokenizeText(U"I was born in 1982."), + ElementsAre("i", "was", "born", "in", "nineteen", "eighty", "two") + ); +} + +TEST(tokenizeText, abbreviations) { + EXPECT_THAT( + tokenizeText(U"I live on Dr. Dolittle Dr."), + ElementsAre("i", "live", "on", "doctor", "dolittle", "drive") + ); +} + +TEST(tokenizeText, apostrophes) { + // HACK: "wouldn't" really should not become "wouldnt"! + EXPECT_THAT( + tokenizeText(U"'Tis said he'd wish'd for a 'bus 'cause he wouldn't walk."), + ElementsAreArray(vector{ "tis", "said", "he'd", "wish'd", "for", "a", "bus", "cause", "he", "wouldnt", "walk" }) + ); +} + +TEST(tokenizeText, math) { + EXPECT_THAT( + tokenizeText(U"'1+2*3=7"), + ElementsAre("one", "plus", "two", "times", "three", "equals", "seven") + ); +} + +// Checks that each word contains only the characters a-z and the apostrophe +TEST(tokenizeText, wordsUseLimitedCharacters) { + // Create string containing lots of undesirable characters + u32string input = U"A naïve man called 晨 was having piña colada and crème brûlée."; + for (char32_t c = 0; c <= 1000; ++c) { + input.append(U" "); + input.append(1, c); + } + + regex legal("^[a-z']+$"); + auto words = tokenizeText(input); + for (const string& word : words) { + EXPECT_TRUE(std::regex_match(word, legal)) << word; + } +}