diff --git a/rhubarb/src/main/kotlin/vad.kt b/rhubarb/src/main/kotlin/vad.kt index bc0d839..60eef60 100644 --- a/rhubarb/src/main/kotlin/vad.kt +++ b/rhubarb/src/main/kotlin/vad.kt @@ -1,32 +1,27 @@ import org.apache.commons.lang3.mutable.MutableInt import kotlin.math.absoluteValue -/////////////////////////////////////////////////////////////////////////////////////////////////////// -// webrtc/common_audio/signal_processing/include/signal_processing_library.h - -// Macros specific for the fixed point implementation -val WEBRTC_SPL_WORD16_MAX = 32767 - /////////////////////////////////////////////////////////////////////////////////////////////////////// // webrtc/common_audio/signal_processing/include/spl_inl.h // webrtc/common_audio/signal_processing/spl_inl.c -// Table used by CountLeadingZeros32_NotBuiltin. For each UInt n -// that's a sequence of 0 bits followed by a sequence of 1 bits, the entry at -// index (n * 0x8c0b2891) shr 26 in this table gives the number of zero bits in -// n. -val kCountLeadingZeros32_Table = intArrayOf( +/** + * Table used by getLeadingZeroCount. + * For each UInt n that's a sequence of 0 bits followed by a sequence of 1 bits, the entry at index + * (n * 0x8c0b2891) shr 26 in this table gives the number of zero bits in n. + */ +val leadingZerosTable = intArrayOf( 32, 8, 17, -1, -1, 14, -1, -1, -1, 20, -1, -1, -1, 28, -1, 18, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 26, 25, 24, 4, 11, 23, 31, 3, 7, 10, 16, 22, 30, -1, -1, 2, 6, 13, 9, -1, 15, -1, 21, -1, 29, 19, -1, -1, -1, -1, -1, 1, 27, 5, 12 ).apply { assert(size == 64) } -// Returns the number of leading zero bits in the argument. -fun CountLeadingZeros32(n: UInt): Int { - // Normalize n by rounding up to the nearest number that is a sequence of 0 - // bits followed by a sequence of 1 bits. This number has the same number of - // leading zeros as the original n. There are exactly 33 such values. +/** Returns the number of leading zero bits in the argument. */ +fun getLeadingZeroCount(n: UInt): Int { + // Normalize n by rounding up to the nearest number that is a sequence of 0 bits followed by a + // sequence of 1 bits. This number has the same number of leading zeros as the original n. + // There are exactly 33 such values. var normalized = n normalized = normalized or (normalized shr 1) normalized = normalized or (normalized shr 2) @@ -34,50 +29,37 @@ fun CountLeadingZeros32(n: UInt): Int { normalized = normalized or (normalized shr 8) normalized = normalized or (normalized shr 16) - // Multiply the modified n with a constant selected (by exhaustive search) - // such that each of the 33 possible values of n give a product whose 6 most - // significant bits are unique. Then look up the answer in the table. - return kCountLeadingZeros32_Table[((normalized * 0x8c0b2891u) shr 26).toInt()] + // Multiply the modified n with a constant selected (by exhaustive search) such that each of the + // 33 possible values of n give a product whose 6 most significant bits are unique. + // Then look up the answer in the table. + return leadingZerosTable[((normalized * 0x8c0b2891u) shr 26).toInt()] } -// Return the number of steps a signed int can be left-shifted without overflow, -// or 0 if a == 0. -inline fun NormW32(a: Int): Int { - return if (a == 0) +/** + * Returns the number of bits by which a signed int can be left-shifted without overflow, or 0 if + * a == 0. + */ +fun normSigned(a: Int): Int = + if (a == 0) 0 else - CountLeadingZeros32((if (a < 0) a.inv() else a).toUInt()) - 1 -} + getLeadingZeroCount((if (a < 0) a.inv() else a).toUInt()) - 1 -// Return the number of steps an unsigned int can be left-shifted without overflow, -// or 0 if a == 0. -inline fun NormU32(a: UInt): Int { - return if (a == 0u) 0 else CountLeadingZeros32(a) -} +/** + * Returns the number of bits by which an unsigned int can be left-shifted without overflow, or 0 if + * a == 0. + */ +fun normUnsigned(a: UInt): Int = if (a == 0u) 0 else getLeadingZeroCount(a) -inline fun GetSizeInBits(n: UInt): Int { - return 32 - CountLeadingZeros32(n) -} +/** Returns the number of bits needed to represent the specified value. */ +fun getBitCount(n: UInt): Int = 32 - getLeadingZeroCount(n) -/////////////////////////////////////////////////////////////////////////////////////////////////////// -// webrtc/common_audio/signal_processing/get_scaling_square.c - -// -// GetScalingSquare(...) -// -// Returns the # of bits required to scale the samples specified in the -// [in_vector] parameter so that, if the squares of the samples are added the -// # of times specified by the [times] parameter, the 32-bit addition will not -// overflow (result in Int). -// -// Input: -// - in_vector : Input vector to check scaling on -// - in_vector_length : Samples in [in_vector] -// - times : Number of additions to be performed -// -// Return value : Number of right bit shifts needed to avoid -// overflow in the addition calculation -fun GetScalingSquare(buffer: AudioBuffer, times: Int): Int { +/** + * Returns the number of right bit shifts that must be applied to each of the given samples so that, + * if the squares of the samples are added [times] times, the signed 32-bit addition will not + * overflow. +*/ +fun getScalingSquare(buffer: AudioBuffer, times: Int): Int { var maxAbsSample = -1 for (i in 0 until buffer.size) { val absSample = buffer[i].toInt().absoluteValue @@ -90,39 +72,25 @@ fun GetScalingSquare(buffer: AudioBuffer, times: Int): Int { return 0 // Since norm(0) returns 0 } - val t = NormW32(maxAbsSample * maxAbsSample) - val bitCount = GetSizeInBits(times.toUInt()) + val t = normSigned(maxAbsSample * maxAbsSample) + val bitCount = getBitCount(times.toUInt()) return if (t > bitCount) 0 else bitCount - t } -/////////////////////////////////////////////////////////////////////////////////////////////////////// -// webrtc/common_audio/signal_processing/energy.c - data class EnergyResult( - // Number of left bit shifts needed to get the physical energy value, i.e, to get the Q0 value + /** + * The number of left bit shifts needed to get the physical energy value, i.e, to get the Q0 + * value + */ val rightShifts: Int, - // Energy value in Q(-[scale_factor]) + /** The energy value in Q(-[scale_factor]) */ val energy: Int ) -// -// Energy(...) -// -// Calculates the energy of a vector -// -// Input: -// - vector : Vector which the energy should be calculated on -// - vector_length : Number of samples in vector -// -// Output: -// - scale_factor : Number of left bit shifts needed to get the physical -// energy value, i.e, to get the Q0 value -// -// Return value : Energy value in Q(-[scale_factor]) -// -fun Energy(buffer: AudioBuffer): EnergyResult { - val scaling = GetScalingSquare(buffer, buffer.size) +/** Calculates the energy of an audio buffer. */ +fun getEnergy(buffer: AudioBuffer): EnergyResult { + val scaling = getScalingSquare(buffer, buffer.size) var energy = 0 for (i in 0 until buffer.size) { @@ -132,118 +100,80 @@ fun Energy(buffer: AudioBuffer): EnergyResult { return EnergyResult(scaling, energy) } -/////////////////////////////////////////////////////////////////////////////////////////////////////// -// webrtc/common_audio/signal_processing/division_operations.c - -// -// DivW32W16(...) -// -// Divides a Int [num] by a Int [den]. -// -// If [den]==0, (Int)0x7FFFFFFF is returned. -// -// Input: -// - num : Numerator -// - den : Denominator -// -// Return value : Result of the division (as a Int), i.e., the -// integer part of num/den. -// -fun DivW32W16(num: Int, den: Int) = - if (den != 0) num / den else Int.MAX_VALUE - -/////////////////////////////////////////////////////////////////////////////////////////////////////// -// webrtc/common_audio/vad/vad_gmm.c +/** Performs a safe integer division, returning [Int.MAX_VALUE] if [denominator] = 0. */ +infix fun Int.safeDiv(denominator: Int) = if (denominator != 0) this / denominator else Int.MAX_VALUE data class GaussianProbabilityResult( - // (probability for [input]) = 1 / [std] * exp(-([input] - [mean])^2 / (2 * [std]^2)) + /** (probability for input) = 1 / std * exp(-(input - mean)^2 / (2 * std^2)) */ val probability: Int, - // Input used when updating the model, Q11. - // [delta] = ([input] - [mean]) / [std]^2. + + /** + * Input used when updating the model, Q11. + * delta = (input - mean) / std^2. + */ val delta: Int ) -val kCompVar = 22005 -val kLog2Exp = 5909 // log2(exp(1)) in Q12. - -// Calculates the probability for [input], given that [input] comes from a -// normal distribution with mean and standard deviation ([mean], [std]). -// -// Inputs: -// - input : input sample in Q4. -// - mean : mean input in the statistical model, Q7. -// - std : standard deviation, Q7. -// -// Output: -// -// - delta : input used when updating the model, Q11. -// [delta] = ([input] - [mean]) / [std]^2. -// -// Return: -// (probability for [input]) = -// 1 / [std] * exp(-([input] - [mean])^2 / (2 * [std]^2)); -//--------------------------------------------------------------------------------- -// For a normal distribution, the probability of [input] is calculated and -// returned (in Q20). The formula for normal distributed probability is -// -// 1 / s * exp(-(x - m)^2 / (2 * s^2)) -// -// where the parameters are given in the following Q domains: -// m = [mean] (Q7) -// s = [std] (Q7) -// x = [input] (Q4) -// in addition to the probability we output [delta] (in Q11) used when updating -// the noise/speech model. -fun GaussianProbability(input: Int, mean: Int, std: Int): GaussianProbabilityResult { +/** + * Calculates the probability for [input], given that [input] comes from a normal distribution with + * mean [mean] and standard deviation [std]. + * + * @param [input] Input sample in Q4. + * @param [mean] Mean input in the statistical model, Q7. + * @param [std] Standard deviation, Q7. +*/ +fun getGaussianProbability(input: Int, mean: Int, std: Int): GaussianProbabilityResult { var tmp16 = 0 - var inv_std = 0 - var inv_std2 = 0 - var exp_value = 0 + var invStd = 0 + var invStd2 = 0 + var expValue = 0 var tmp32 = 0 - // Calculate [inv_std] = 1 / s, in Q10. - // 131072 = 1 in Q17, and ([std] shr 1) is for rounding instead of truncation. - // Q-domain: Q17 / Q7 = Q10. + // Calculate invStd = 1 / s, in Q10. + // 131072 = 1 in Q17, and (std shr 1) is for rounding instead of truncation. + // Q-domain: Q17 / Q7 = Q10 tmp32 = 131072 + (std shr 1) - inv_std = DivW32W16(tmp32, std) + invStd = tmp32 safeDiv std - // Calculate [inv_std2] = 1 / s^2, in Q14. - tmp16 = inv_std shr 2 // Q10 -> Q8. - // Q-domain: (Q8 * Q8) shr 2 = Q14. - inv_std2 = (tmp16 * tmp16) shr 2 + // Calculate inv_std2 = 1 / s^2, in Q14 + tmp16 = invStd shr 2 // Q10 -> Q8. + // Q-domain: (Q8 * Q8) shr 2 = Q14 + invStd2 = (tmp16 * tmp16) shr 2 tmp16 = input shl 3 // Q4 -> Q7 tmp16 -= mean // Q7 - Q7 = Q7 // To be used later, when updating noise/speech model. - // [delta] = (x - m) / s^2, in Q11. - // Q-domain: (Q14 * Q7) shr 10 = Q11. - val delta = (inv_std2 * tmp16) shr 10 + // delta = (x - m) / s^2, in Q11. + // Q-domain: (Q14 * Q7) shr 10 = Q11 + val delta = (invStd2 * tmp16) shr 10 - // Calculate the exponent [tmp32] = (x - m)^2 / (2 * s^2), in Q10. Replacing - // division by two with one shift. + // Calculate the exponent [tmp32] = (x - m)^2 / (2 * s^2), in Q10. + // Replacing division by two with one shift. // Q-domain: (Q11 * Q7) shr 8 = Q10. tmp32 = (delta * tmp16) shr 9 - // If the exponent is small enough to give a non-zero probability we calculate - // [exp_value] ~= exp(-(x - m)^2 / (2 * s^2)) - // ~= exp2(-log2(exp(1)) * [tmp32]). + // If the exponent is small enough to give a non-zero probability, we calculate + // exp_value ~= exp(-(x - m)^2 / (2 * s^2)) + // ~= exp2(-log2(exp(1)) * tmp32) + val kCompVar = 22005 if (tmp32 < kCompVar) { // Calculate [tmp16] = log2(exp(1)) * [tmp32], in Q10. // Q-domain: (Q12 * Q10) shr 12 = Q10. + val kLog2Exp = 5909 // log2(exp(1)) in Q12. tmp16 = (kLog2Exp * tmp32) shr 12 tmp16 = -tmp16 - exp_value = 0x0400 or (tmp16 and 0x03FF) + expValue = 0x0400 or (tmp16 and 0x03FF) tmp16 = tmp16 xor 0xFFFF tmp16 = tmp16 shr 10 tmp16 += 1 // Get [exp_value] = exp(-[tmp32]) in Q10. - exp_value = exp_value shr tmp16 + expValue = expValue shr tmp16 } // Calculate and return (1 / s) * exp(-(x - m)^2 / (2 * s^2)), in Q20. // Q-domain: Q10 * Q10 = Q20. - val probability = inv_std * exp_value + val probability = invStd * expValue return GaussianProbabilityResult(probability, delta) } @@ -435,14 +365,14 @@ fun GmmProbability(self: VadInstT, features: List, total_power: Int, frame_ // Probability under H0, that is, probability of frame being noise. // Value given in Q27 = Q7 * Q20. - val pNoise = GaussianProbability(features[channel], self.noise_means[gaussian], self.noise_stds[gaussian]) + val pNoise = getGaussianProbability(features[channel], self.noise_means[gaussian], self.noise_stds[gaussian]) deltaN[gaussian] = pNoise.delta noise_probability[k] = kNoiseDataWeights[gaussian] * pNoise.probability h0_test += noise_probability[k] // Q27 // Probability under H1, that is, probability of frame being speech. // Value given in Q27 = Q7 * Q20. - val pSpeech = GaussianProbability(features[channel], self.speech_means[gaussian], self.speech_stds[gaussian]) + val pSpeech = getGaussianProbability(features[channel], self.speech_means[gaussian], self.speech_stds[gaussian]) speech_probability[k] = kSpeechDataWeights[gaussian] * pSpeech.probability deltaS[gaussian] = pSpeech.delta h1_test += speech_probability[k] // Q27 @@ -460,8 +390,8 @@ fun GmmProbability(self: VadInstT, features: List, total_power: Int, frame_ // // Note that b0 and b1 are values less than 1, hence, 0 <= log2(1+b0) < 1. // Further, b0 and b1 are independent and on the average the two terms cancel. - val shifts_h0 = if (h0_test != 0) NormW32(h0_test) else 31 - val shifts_h1 = if (h1_test != 0) NormW32(h1_test) else 31 + val shifts_h0 = if (h0_test != 0) normSigned(h0_test) else 31 + val shifts_h1 = if (h1_test != 0) normSigned(h1_test) else 31 val log_likelihood_ratio = shifts_h0 - shifts_h1 // Update [sum_log_likelihood_ratios] with spectrum weighting. This is @@ -481,7 +411,7 @@ fun GmmProbability(self: VadInstT, features: List, total_power: Int, frame_ // High probability of noise. Assign conditional probabilities for each // Gaussian in the GMM. val tmp = (noise_probability[0] and 0xFFFFF000u.toInt()) shl 2 // Q29 - ngprvec[channel] = DivW32W16(tmp, h0) // Q14 + ngprvec[channel] = tmp safeDiv h0 // Q14 ngprvec[channel + kNumChannels] = 16384 - ngprvec[channel] } else { // Low noise probability. Assign conditional probability 1 to the first @@ -495,7 +425,7 @@ fun GmmProbability(self: VadInstT, features: List, total_power: Int, frame_ // High probability of speech. Assign conditional probabilities for each // Gaussian in the GMM. Otherwise use the initialized values, i.e., 0. val tmp = (speech_probability[0] and 0xFFFFF000u.toInt()) shl 2 // Q29 - sgprvec[channel] = DivW32W16(tmp, h1) // Q14 + sgprvec[channel] = tmp safeDiv h1 // Q14 sgprvec[channel + kNumChannels] = 16384 - sgprvec[channel] } } @@ -589,9 +519,9 @@ fun GmmProbability(self: VadInstT, features: List, total_power: Int, frame_ // 0.1 * Q20 / Q7 = Q13. if (tmp2_s32 > 0) { - tmp_s16 = DivW32W16(tmp2_s32, ssk * 10) + tmp_s16 = tmp2_s32 safeDiv (ssk * 10) } else { - tmp_s16 = DivW32W16(-tmp2_s32, ssk * 10) + tmp_s16 = -tmp2_s32 safeDiv (ssk * 10) tmp_s16 = -tmp_s16 } // Divide by 4 giving an update factor of 0.025 (= 0.1 / 4). @@ -621,9 +551,9 @@ fun GmmProbability(self: VadInstT, features: List, total_power: Int, frame_ // Q20 / Q7 = Q13. if (tmp1_s32 > 0) { - tmp_s16 = DivW32W16(tmp1_s32, nsk) + tmp_s16 = tmp1_s32 safeDiv nsk } else { - tmp_s16 = DivW32W16(-tmp1_s32, nsk) + tmp_s16 = -tmp1_s32 safeDiv nsk tmp_s16 = -tmp_s16 } tmp_s16 += 32 // Rounding @@ -847,7 +777,7 @@ fun FindMinimum(self: VadInstT, feature_value: Int, channel: Int): Int { } } tmp32 = (alpha + 1) * self.mean_value[channel] - tmp32 += (WEBRTC_SPL_WORD16_MAX - alpha) * current_median + tmp32 += (Short.MAX_VALUE - alpha) * current_median tmp32 += 16384 self.mean_value[channel] = tmp32 shr 15 @@ -992,7 +922,7 @@ fun SplitFilter(input: AudioBuffer, upper_state: MutableInt, lower_state: Mutabl fun LogOfEnergy(input: AudioBuffer, offset: Int, total_energy: MutableInt): Int { assert(input.size > 0) - val energyResult = Energy(input) + val energyResult = getEnergy(input) // [tot_rshifts] accumulates the number of right shifts performed on [energy]. var tot_rshifts = energyResult.rightShifts // The [energy] will be normalized to 15 bits. We use unsigned integer because @@ -1005,7 +935,7 @@ fun LogOfEnergy(input: AudioBuffer, offset: Int, total_energy: MutableInt): Int // By construction, normalizing to 15 bits is equivalent with 17 leading // zeros of an unsigned 32 bit value. - val normalizing_rshifts = 17 - NormU32(energy) + val normalizing_rshifts = 17 - normUnsigned(energy) // In a 15 bit representation the leading bit is 2^14. log2(2^14) in Q10 is // (14 shl 10), which is what we initialize [log2_energy] with. For a more // detailed derivations, see below.