diff --git a/modules/audio_processing/test/py_quality_assessment/BUILD.gn b/modules/audio_processing/test/py_quality_assessment/BUILD.gn index dfae858276..59623e3183 100644 --- a/modules/audio_processing/test/py_quality_assessment/BUILD.gn +++ b/modules/audio_processing/test/py_quality_assessment/BUILD.gn @@ -102,6 +102,7 @@ group("unit_tests") { ":fake_polqa", ":lib_unit_tests", ":scripts_unit_tests", + ":vad", ] } @@ -118,6 +119,17 @@ rtc_executable("fake_polqa") { ] } +rtc_executable("vad") { + sources = [ + "quality_assessment/vad.cc", + ] + deps = [ + "../../../..:webrtc_common", + "../../../../common_audio", + "../../../../rtc_base:rtc_base_approved", + ] +} + copy("lib_unit_tests") { testonly = true sources = [ diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py index 55b3388d40..399beb7c0e 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py @@ -10,9 +10,14 @@ """ from __future__ import division +import enum import logging import os +import shutil +import struct +import subprocess import sys +import tempfile try: import numpy as np @@ -20,6 +25,7 @@ except ImportError: logging.critical('Cannot import the third-party Python package numpy') sys.exit(1) +from . import exceptions from . import signal_processing @@ -27,9 +33,12 @@ class AudioAnnotationsExtractor(object): """Extracts annotations from audio files. """ - _LEVEL_FILENAME = 'level.npy' - _VAD_FILENAME = 'vad.npy' - _SPEECH_LEVEL_FILENAME = 'speech_level.npy' + @enum.unique + class VadType(enum.Enum): + ENERGY_THRESHOLD = 0 # TODO(alessiob): Consider switching to P56 standard. + WEBRTC = 1 + + _OUTPUT_FILENAME = 'annotations.npz' # Level estimation params. _ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0) @@ -41,36 +50,50 @@ class AudioAnnotationsExtractor(object): # VAD params. _VAD_THRESHOLD = 1 + _VAD_WEBRTC_PATH = os.path.join(os.path.dirname( + os.path.abspath(__file__)), os.pardir, os.pardir) + _VAD_WEBRTC_BIN_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad') - def __init__(self): + def __init__(self, vad_type): self._signal = None self._level = None - self._vad = None - self._speech_level = None self._level_frame_size = None + self._vad_output = None + self._vad_frame_size = None + self._vad_frame_size_ms = None self._c_attack = None self._c_decay = None - @classmethod - def GetLevelFileName(cls): - return cls._LEVEL_FILENAME + self._vad_type = vad_type + if self._vad_type not in self.VadType: + raise exceptions.InitializationException( + 'Invalid vad type: ' + self._vad_type) + logging.info('VAD used for annotations: ' + str(self._vad_type)) + + assert os.path.exists(self._VAD_WEBRTC_BIN_PATH), self._VAD_WEBRTC_BIN_PATH @classmethod - def GetVadFileName(cls): - return cls._VAD_FILENAME - - @classmethod - def GetSpeechLevelFileName(cls): - return cls._SPEECH_LEVEL_FILENAME + def GetOutputFileName(cls): + return cls._OUTPUT_FILENAME def GetLevel(self): return self._level - def GetVad(self): - return self._vad + def GetLevelFrameSize(self): + return self._level_frame_size - def GetSpeechLevel(self): - return self._speech_level + @classmethod + def GetLevelFrameSizeMs(cls): + return cls._LEVEL_FRAME_SIZE_MS + + def GetVadOutput(self): + return self._vad_output + + def GetVadFrameSize(self): + return self._vad_frame_size + + def GetVadFrameSizeMs(self): + return self._vad_frame_size_ms def Extract(self, filepath): # Load signal. @@ -78,7 +101,7 @@ class AudioAnnotationsExtractor(object): if self._signal.channels != 1: raise NotImplementedError('multiple-channel annotations not implemented') - # level estimation params. + # Level estimation params. self._level_frame_size = int(self._signal.frame_rate / 1000 * ( self._LEVEL_FRAME_SIZE_MS)) self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else ( @@ -91,26 +114,26 @@ class AudioAnnotationsExtractor(object): # Compute level. self._LevelEstimation() - # Naive VAD based on level thresholding. It assumes ideal clean speech - # with high SNR. - # TODO(alessiob): Maybe replace with a VAD based on stationary-noise - # detection. - vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD) - self._vad = np.uint8(self._level > vad_threshold) - - # Speech level based on VAD output. - self._speech_level = self._level * self._vad - - # Expand to one value per sample. - self._level = np.repeat(self._level, self._level_frame_size) - self._vad = np.repeat(self._vad, self._level_frame_size) - self._speech_level = np.repeat(self._speech_level, self._level_frame_size) + # Ideal VAD output, it requires clean speech with high SNR as input. + if self._vad_type == self.VadType.ENERGY_THRESHOLD: + # Naive VAD based on level thresholding. + vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD) + self._vad_output = np.uint8(self._level > vad_threshold) + self._vad_frame_size = self._level_frame_size + self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS + elif self._vad_type == self.VadType.WEBRTC: + # WebRTC VAD. + self._RunWebRtcVad(filepath, self._signal.frame_rate) def Save(self, output_path): - np.save(os.path.join(output_path, self._LEVEL_FILENAME), self._level) - np.save(os.path.join(output_path, self._VAD_FILENAME), self._vad) - np.save(os.path.join(output_path, self._SPEECH_LEVEL_FILENAME), - self._speech_level) + np.savez_compressed( + file=os.path.join(output_path, self._OUTPUT_FILENAME), + level=self._level, + level_frame_size=self._level_frame_size, + level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS, + vad_output=self._vad_output, + vad_frame_size=self._vad_frame_size, + vad_frame_size_ms=self._vad_frame_size_ms) def _LevelEstimation(self): # Read samples. @@ -132,4 +155,47 @@ class AudioAnnotationsExtractor(object): self._level[i], self._level[i - 1], self._c_attack if ( self._level[i] > self._level[i - 1]) else self._c_decay) - return self._level + def _RunWebRtcVad(self, wav_file_path, sample_rate): + self._vad_output = None + self._vad_frame_size = None + + # Create temporary output path. + tmp_path = tempfile.mkdtemp() + output_file_path = os.path.join( + tmp_path, os.path.split(wav_file_path)[1] + '_vad.tmp') + + # Call WebRTC VAD. + try: + subprocess.call([ + self._VAD_WEBRTC_BIN_PATH, + '-i', wav_file_path, + '-o', output_file_path + ], cwd=self._VAD_WEBRTC_PATH) + + # Read bytes. + with open(output_file_path, 'rb') as f: + raw_data = f.read() + + # Parse side information. + self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0] + self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000 + assert self._vad_frame_size_ms in [10, 20, 30] + extra_bits = struct.unpack('B', raw_data[-1])[0] + assert 0 <= extra_bits <= 8 + + # Init VAD vector. + num_bytes = len(raw_data) + num_frames = 8 * (num_bytes - 2) - extra_bits # 8 frames for each byte. + self._vad_output = np.zeros(num_frames, np.uint8) + + # Read VAD decisions. + for i, byte in enumerate(raw_data[1:-1]): + byte = struct.unpack('B', byte)[0] + for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)): + self._vad_output[i * 8 + j] = int(byte & 1) + byte = byte >> 1 + except Exception as e: + logging.error('Error while running the WebRTC VAD (' + e.message + ')') + finally: + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py index bac3d2174e..8cb0d048b3 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py @@ -9,6 +9,7 @@ """Unit tests for the annotations module. """ +from __future__ import division import logging import os import shutil @@ -27,6 +28,7 @@ class TestAnnotationsExtraction(unittest.TestCase): """ _CLEAN_TMP_OUTPUT = True + _DEBUG_PLOT_VAD = False def setUp(self): """Create temporary folder.""" @@ -36,6 +38,7 @@ class TestAnnotationsExtraction(unittest.TestCase): 'pure_tone', [440, 1000]) signal_processing.SignalProcessingUtils.SaveWav( self._wav_file_path, pure_tone) + self._sample_rate = pure_tone.frame_rate def tearDown(self): """Recursively delete temporary folder.""" @@ -45,27 +48,49 @@ class TestAnnotationsExtraction(unittest.TestCase): logging.warning(self.id() + ' did not clean the temporary path ' + ( self._tmp_path)) - def testExtraction(self): - e = annotations.AudioAnnotationsExtractor() - e.Extract(self._wav_file_path) - vad = e.GetVad() - assert len(vad) > 0 - self.assertGreaterEqual(float(np.sum(vad)) / len(vad), 0.95) + def testFrameSizes(self): + for vad_type in annotations.AudioAnnotationsExtractor.VadType: + e = annotations.AudioAnnotationsExtractor(vad_type=vad_type) + e.Extract(self._wav_file_path) + samples_to_ms = lambda n, sr: 1000 * n // sr + self.assertEqual(samples_to_ms(e.GetLevelFrameSize(), self._sample_rate), + e.GetLevelFrameSizeMs()) + self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate), + e.GetVadFrameSizeMs()) + + def testVoiceActivityDetectors(self): + for vad_type in annotations.AudioAnnotationsExtractor.VadType: + e = annotations.AudioAnnotationsExtractor(vad_type=vad_type) + e.Extract(self._wav_file_path) + vad_output = e.GetVadOutput() + self.assertGreater(len(vad_output), 0) + self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output), 0.95) + + if self._DEBUG_PLOT_VAD: + frame_times_s = lambda num_frames, frame_size_ms: np.arange( + num_frames).astype(np.float32) * frame_size_ms / 1000.0 + level = e.GetLevel() + t_level = frame_times_s( + num_frames=len(level), + frame_size_ms=e.GetLevelFrameSizeMs()) + t_vad = frame_times_s( + num_frames=len(vad_output), + frame_size_ms=e.GetVadFrameSizeMs()) + import matplotlib.pyplot as plt + plt.figure() + plt.hold(True) + plt.plot(t_level, level) + plt.plot(t_vad, vad_output * np.max(level), '.') + plt.show() def testSaveLoad(self): - e = annotations.AudioAnnotationsExtractor() + e = annotations.AudioAnnotationsExtractor( + vad_type=annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD) e.Extract(self._wav_file_path) e.Save(self._tmp_path) - level = np.load(os.path.join(self._tmp_path, e.GetLevelFileName())) - np.testing.assert_array_equal(e.GetLevel(), level) - self.assertEqual(np.float32, level.dtype) - - vad = np.load(os.path.join(self._tmp_path, e.GetVadFileName())) - np.testing.assert_array_equal(e.GetVad(), vad) - self.assertEqual(np.uint8, vad.dtype) - - speech_level = np.load(os.path.join( - self._tmp_path, e.GetSpeechLevelFileName())) - np.testing.assert_array_equal(e.GetSpeechLevel(), speech_level) - self.assertEqual(np.float32, speech_level.dtype) + data = np.load(os.path.join(self._tmp_path, e.GetOutputFileName())) + np.testing.assert_array_equal(e.GetLevel(), data['level']) + self.assertEqual(np.float32, data['level'].dtype) + np.testing.assert_array_equal(e.GetVadOutput(), data['vad_output']) + self.assertEqual(np.uint8, data['vad_output'].dtype) diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py index d62069fc8e..305487a030 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py @@ -46,7 +46,8 @@ class ApmModuleSimulator(object): self._evaluation_score_factory = evaluation_score_factory self._audioproc_wrapper = ap_wrapper self._evaluator = evaluator - self._annotator = annotations.AudioAnnotationsExtractor() + self._annotator = annotations.AudioAnnotationsExtractor( + vad_type=annotations.AudioAnnotationsExtractor.VadType.WEBRTC) # Init. self._test_data_generator_factory.SetOutputDirectoryPrefix( diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/vad.cc b/modules/audio_processing/test/py_quality_assessment/quality_assessment/vad.cc new file mode 100644 index 0000000000..3a2c2849cf --- /dev/null +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/vad.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +// +// Use of this source code is governed by a BSD-style license +// that can be found in the LICENSE file in the root of the source +// tree. An additional intellectual property rights grant can be found +// in the file PATENTS. All contributing project authors may +// be found in the AUTHORS file in the root of the source tree. + +#include +#include +#include + +#include "common_audio/vad/include/vad.h" +#include "common_audio/wav_file.h" +#include "rtc_base/flags.h" +#include "rtc_base/logging.h" + +namespace webrtc { +namespace test { +namespace { + +// The allowed values are 10, 20 or 30 ms. +constexpr uint8_t kAudioFrameLengthMilliseconds = 30; +constexpr int kMaxSampleRate = 48000; +constexpr size_t kMaxFrameLen = + kAudioFrameLengthMilliseconds * kMaxSampleRate / 1000; + +constexpr uint8_t kBitmaskBuffSize = 8; + +DEFINE_string(i, "", "Input wav file"); +DEFINE_string(o, "", "VAD output file"); + +int main(int argc, char* argv[]) { + if (rtc::FlagList::SetFlagsFromCommandLine(&argc, argv, true)) + return 1; + + // Open wav input file and check properties. + WavReader wav_reader(FLAG_i); + if (wav_reader.num_channels() != 1) { + LOG(LS_ERROR) << "Only mono wav files supported"; + return 1; + } + if (wav_reader.sample_rate() > kMaxSampleRate) { + LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate << ")"; + return 1; + } + const size_t kAudioFrameLen = rtc::CheckedDivExact( + kAudioFrameLengthMilliseconds * wav_reader.sample_rate(), 1000); + if (kAudioFrameLen > kMaxFrameLen) { + LOG(LS_ERROR) << "The frame size and/or the sample rate are too large."; + return 2; + } + + // Create output file and write header. + std::ofstream out_file(FLAG_o, std::ofstream::binary); + const char audio_frame_length_ms = kAudioFrameLengthMilliseconds; + out_file.write(&audio_frame_length_ms, 1); // Header. + + // Run VAD and write decisions. + std::unique_ptr vad = CreateVad(Vad::Aggressiveness::kVadNormal); + std::array samples; + char buff = 0; // Buffer to write one bit per frame. + uint8_t next = 0; // Points to the next bit to write in |buff|. + while (true) { + // Process frame. + const auto read_samples = + wav_reader.ReadSamples(kAudioFrameLen, samples.data()); + if (read_samples < kAudioFrameLen) + break; + const auto is_speech = vad->VoiceActivity(samples.data(), kAudioFrameLen, + wav_reader.sample_rate()); + + // Write output. + buff = is_speech ? buff | (1 << next) : buff & ~(1 << next); + if (++next == kBitmaskBuffSize) { + out_file.write(&buff, 1); // Flush. + buff = 0; // Reset. + next = 0; + } + } + + // Finalize. + char extra_bits = 0; + if (next > 0) { + extra_bits = kBitmaskBuffSize - next; + out_file.write(&buff, 1); // Flush. + } + out_file.write(&extra_bits, 1); + out_file.close(); + + return 0; +} + +} // namespace +} // namespace test +} // namespace webrtc + +int main(int argc, char* argv[]) { + return webrtc::test::main(argc, argv); +}