diff --git a/modules/audio_processing/test/py_quality_assessment/BUILD.gn b/modules/audio_processing/test/py_quality_assessment/BUILD.gn index eee58da2bb..64e3a30bd3 100644 --- a/modules/audio_processing/test/py_quality_assessment/BUILD.gn +++ b/modules/audio_processing/test/py_quality_assessment/BUILD.gn @@ -66,6 +66,7 @@ copy("lib") { "quality_assessment/exceptions.py", "quality_assessment/export.py", "quality_assessment/export_unittest.py", + "quality_assessment/external_vad.py", "quality_assessment/input_mixer.py", "quality_assessment/input_signal_creator.py", "quality_assessment/results.css", @@ -149,6 +150,7 @@ copy("lib_unit_tests") { "quality_assessment/annotations_unittest.py", "quality_assessment/echo_path_simulation_unittest.py", "quality_assessment/eval_scores_unittest.py", + "quality_assessment/fake_external_vad.py", "quality_assessment/input_mixer_unittest.py", "quality_assessment/signal_processing_unittest.py", "quality_assessment/simulation_unittest.py", diff --git a/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py b/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py index 78ff5e93e6..a4cc5f037f 100755 --- a/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py +++ b/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py @@ -27,6 +27,7 @@ import quality_assessment.echo_path_simulation as echo_path_simulation import quality_assessment.eval_scores as eval_scores import quality_assessment.evaluation as evaluation import quality_assessment.eval_scores_factory as eval_scores_factory +import quality_assessment.external_vad as external_vad import quality_assessment.test_data_generation as test_data_generation import quality_assessment.test_data_generation_factory as \ test_data_generation_factory @@ -113,6 +114,14 @@ def _InstanceArgumentsParser(): 'copy of the clean speech input file.'), default=False) + parser.add_argument('--external_vad_paths', nargs='+', required=False, + help=('Paths to external VAD programs. Each must take' + '\'-i -o \' inputs'), default=[]) + + parser.add_argument('--external_vad_names', nargs='+', required=False, + help=('Keys to the vad paths. Must be different and ' + 'as many as the paths.'), default=[]) + return parser @@ -128,6 +137,12 @@ def _ValidateArguments(args, parser): 'also required') sys.exit(1) + if len(args.external_vad_names) != len(args.external_vad_paths): + parser.error('If provided, --external_vad_paths and ' + '--external_vad_names must ' + 'have the same number of arguments.') + sys.exit(1) + def main(): # TODO(alessiob): level = logging.INFO once debugged. @@ -145,7 +160,9 @@ def main(): evaluation_score_factory=eval_scores_factory.EvaluationScoreWorkerFactory( polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME)), ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path), - evaluator=evaluation.ApmModuleEvaluator()) + evaluator=evaluation.ApmModuleEvaluator(), + external_vads=external_vad.ExternalVad.ConstructVadDict( + args.external_vad_paths, args.external_vad_names)) simulator.Run( config_filepaths=args.config_files, capture_input_filepaths=args.capture_input_files, diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py index 2f5daf1f23..a4e9097320 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py @@ -24,6 +24,7 @@ except ImportError: logging.critical('Cannot import the third-party Python package numpy') sys.exit(1) +from . import external_vad from . import exceptions from . import signal_processing @@ -76,7 +77,7 @@ class AudioAnnotationsExtractor(object): _VAD_WEBRTC_APM_PATH = os.path.join( _VAD_WEBRTC_PATH, 'apm_vad') - def __init__(self, vad_type): + def __init__(self, vad_type, external_vads=None): self._signal = None self._level = None self._level_frame_size = None @@ -92,6 +93,19 @@ class AudioAnnotationsExtractor(object): self._vad_type = self.VadType(vad_type) logging.info('VADs used for annotations: ' + str(self._vad_type)) + if external_vads is None: + external_vads = {} + self._external_vads = external_vads + + assert len(self._external_vads) == len(external_vads), ( + 'The external VAD names must be unique.') + for vad in external_vads.values(): + if not isinstance(vad, external_vad.ExternalVad): + raise exceptions.InitializationException( + 'Invalid vad type: ' + str(type(vad))) + logging.info('External VAD used for annotation: ' + + str(vad.name)) + assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \ self._VAD_WEBRTC_COMMON_AUDIO_PATH assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \ @@ -113,9 +127,9 @@ class AudioAnnotationsExtractor(object): def GetVadOutput(self, vad_type): if vad_type == self.VadType.ENERGY_THRESHOLD: - return (self._energy_vad, ) + return self._energy_vad elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO: - return (self._common_audio_vad, ) + return self._common_audio_vad elif vad_type == self.VadType.WEBRTC_APM: return (self._apm_vad_probs, self._apm_vad_rms) else: @@ -132,7 +146,7 @@ class AudioAnnotationsExtractor(object): # Load signal. self._signal = signal_processing.SignalProcessingUtils.LoadWav(filepath) if self._signal.channels != 1: - raise NotImplementedError('multiple-channel annotations not implemented') + raise NotImplementedError('Multiple-channel annotations not implemented') # Level estimation params. self._level_frame_size = int(self._signal.frame_rate / 1000 * ( @@ -160,8 +174,14 @@ class AudioAnnotationsExtractor(object): if self._vad_type.Contains(self.VadType.WEBRTC_APM): # WebRTC modules/audio_processing/ VAD. self._RunWebRtcApmVad(filepath) + for extvad_name in self._external_vads: + self._external_vads[extvad_name].Run(filepath) def Save(self, output_path): + ext_kwargs = {'extvad_conf-' + ext_vad: + self._external_vads[ext_vad].GetVadOutput() + for ext_vad in self._external_vads} + # pylint: disable=star-args np.savez_compressed( file=os.path.join(output_path, self._OUTPUT_FILENAME), level=self._level, @@ -172,7 +192,8 @@ class AudioAnnotationsExtractor(object): vad_frame_size=self._vad_frame_size, vad_frame_size_ms=self._vad_frame_size_ms, vad_probs=self._apm_vad_probs, - vad_rms=self._apm_vad_rms + vad_rms=self._apm_vad_rms, + **ext_kwargs ) def _LevelEstimation(self): diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py index 3f44edfb84..5fe5f5de9e 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py @@ -19,6 +19,7 @@ import unittest import numpy as np from . import annotations +from . import external_vad from . import input_signal_creator from . import signal_processing @@ -29,6 +30,11 @@ class TestAnnotationsExtraction(unittest.TestCase): _CLEAN_TMP_OUTPUT = True _DEBUG_PLOT_VAD = False + _VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType + _ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD | + _VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO | + _VAD_TYPE_CLASS.WEBRTC_APM) + def setUp(self): """Create temporary folder.""" @@ -49,11 +55,7 @@ class TestAnnotationsExtraction(unittest.TestCase): self._tmp_path)) def testFrameSizes(self): - vad_type_class = annotations.AudioAnnotationsExtractor.VadType - vad_type = (vad_type_class.ENERGY_THRESHOLD | - vad_type_class.WEBRTC_COMMON_AUDIO | - vad_type_class.WEBRTC_APM) - e = annotations.AudioAnnotationsExtractor(vad_type=vad_type) + e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES) e.Extract(self._wav_file_path) samples_to_ms = lambda n, sr: 1000 * n // sr self.assertEqual(samples_to_ms(e.GetLevelFrameSize(), self._sample_rate), @@ -62,35 +64,31 @@ class TestAnnotationsExtraction(unittest.TestCase): e.GetVadFrameSizeMs()) def testVoiceActivityDetectors(self): - vad_type_class = annotations.AudioAnnotationsExtractor.VadType - max_vad_type = (vad_type_class.ENERGY_THRESHOLD | - vad_type_class.WEBRTC_COMMON_AUDIO | - vad_type_class.WEBRTC_APM) - for vad_type_value in range(0, max_vad_type+1): - vad_type = vad_type_class(vad_type_value) + for vad_type_value in range(0, self._ALL_VAD_TYPES+1): + vad_type = self._VAD_TYPE_CLASS(vad_type_value) e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value) e.Extract(self._wav_file_path) - if vad_type.Contains(vad_type_class.ENERGY_THRESHOLD): - # pylint: disable=unbalanced-tuple-unpacking - (vad_output, ) = e.GetVadOutput(vad_type_class.ENERGY_THRESHOLD) + if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD): + # pylint: disable=unpacking-non-sequence + vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD) self.assertGreater(len(vad_output), 0) self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output), 0.95) - if vad_type.Contains(vad_type_class.WEBRTC_COMMON_AUDIO): - # pylint: disable=unbalanced-tuple-unpacking - (vad_output,) = e.GetVadOutput(vad_type_class.WEBRTC_COMMON_AUDIO) + if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO): + # pylint: disable=unpacking-non-sequence + vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO) self.assertGreater(len(vad_output), 0) self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output), 0.95) - if vad_type.Contains(vad_type_class.WEBRTC_APM): - # pylint: disable=unbalanced-tuple-unpacking - (vad_probs, vad_rms) = e.GetVadOutput(vad_type_class.WEBRTC_APM) + if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM): + # pylint: disable=unpacking-non-sequence + (vad_probs, vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM) self.assertGreater(len(vad_probs), 0) self.assertGreater(len(vad_rms), 0) self.assertGreaterEqual(float(np.sum(vad_probs)) / len(vad_probs), - 0.95) + 0.5) self.assertGreaterEqual(float(np.sum(vad_rms)) / len(vad_rms), 20000) if self._DEBUG_PLOT_VAD: @@ -111,11 +109,7 @@ class TestAnnotationsExtraction(unittest.TestCase): plt.show() def testSaveLoad(self): - vad_type_class = annotations.AudioAnnotationsExtractor.VadType - vad_type = (vad_type_class.ENERGY_THRESHOLD | - vad_type_class.WEBRTC_COMMON_AUDIO | - vad_type_class.WEBRTC_APM) - e = annotations.AudioAnnotationsExtractor(vad_type) + e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES) e.Extract(self._wav_file_path) e.Save(self._tmp_path) @@ -123,14 +117,37 @@ class TestAnnotationsExtraction(unittest.TestCase): np.testing.assert_array_equal(e.GetLevel(), data['level']) self.assertEqual(np.float32, data['level'].dtype) np.testing.assert_array_equal( - e.GetVadOutput(vad_type_class.ENERGY_THRESHOLD), + e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD), data['vad_energy_output']) np.testing.assert_array_equal( - e.GetVadOutput(vad_type_class.WEBRTC_COMMON_AUDIO), data['vad_output']) + e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO), + data['vad_output']) np.testing.assert_array_equal( - e.GetVadOutput(vad_type_class.WEBRTC_APM)[0], data['vad_probs']) + e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0], data['vad_probs']) np.testing.assert_array_equal( - e.GetVadOutput(vad_type_class.WEBRTC_APM)[1], data['vad_rms']) + e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1], data['vad_rms']) self.assertEqual(np.uint8, data['vad_energy_output'].dtype) self.assertEqual(np.float64, data['vad_probs'].dtype) self.assertEqual(np.float64, data['vad_rms'].dtype) + + def testEmptyExternalShouldNotCrash(self): + for vad_type_value in range(0, self._ALL_VAD_TYPES+1): + annotations.AudioAnnotationsExtractor(vad_type_value, {}) + + def testFakeExternalSaveLoad(self): + def FakeExternalFactory(): + return external_vad.ExternalVad( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), 'fake_external_vad.py'), + 'fake' + ) + for vad_type_value in range(0, self._ALL_VAD_TYPES+1): + e = annotations.AudioAnnotationsExtractor( + vad_type_value, + {'fake': FakeExternalFactory()}) + e.Extract(self._wav_file_path) + e.Save(self._tmp_path) + data = np.load(os.path.join(self._tmp_path, e.GetOutputFileName())) + self.assertEqual(np.float32, data['extvad_conf-fake'].dtype) + np.testing.assert_almost_equal(np.arange(100, dtype=np.float32), + data['extvad_conf-fake']) diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py new file mode 100644 index 0000000000..01418d84fe --- /dev/null +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py @@ -0,0 +1,77 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +from __future__ import division + +import logging +import os +import subprocess +import shutil +import sys +import tempfile + +try: + import numpy as np +except ImportError: + logging.critical('Cannot import the third-party Python package numpy') + sys.exit(1) + +from . import signal_processing + +class ExternalVad(object): + + def __init__(self, path_to_binary, name): + """Args: + path_to_binary: path to binary that accepts '-i ', '-o + '. There must be one float value per + 10ms audio + name: a name to identify the external VAD. Used for saving + the output as extvad_output-. + """ + self._path_to_binary = path_to_binary + self.name = name + assert os.path.exists(self._path_to_binary), ( + self._path_to_binary) + self._vad_output = None + + def Run(self, wav_file_path): + _signal = signal_processing.SignalProcessingUtils.LoadWav(wav_file_path) + if _signal.channels != 1: + raise NotImplementedError('Multiple-channel' + ' annotations not implemented') + if _signal.frame_rate != 48000: + raise NotImplementedError('Frame rates ' + 'other than 48000 not implemented') + + tmp_path = tempfile.mkdtemp() + try: + output_file_path = os.path.join( + tmp_path, self.name + '_vad.tmp') + subprocess.call([ + self._path_to_binary, + '-i', wav_file_path, + '-o', output_file_path + ]) + self._vad_output = np.fromfile(output_file_path, np.float32) + except Exception as e: + logging.error('Error while running the ' + self.name + + ' VAD (' + e.message + ')') + finally: + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + + def GetVadOutput(self): + assert self._vad_output is not None + return self._vad_output + + @classmethod + def ConstructVadDict(cls, vad_paths, vad_names): + external_vads = {} + for path, name in zip(vad_paths, vad_names): + external_vads[name] = ExternalVad(path, name) + return external_vads diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_external_vad.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_external_vad.py new file mode 100755 index 0000000000..7c75e8f5c3 --- /dev/null +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_external_vad.py @@ -0,0 +1,24 @@ +#!/usr/bin/python +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +import argparse +import numpy as np + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-i', required=True) + parser.add_argument('-o', required=True) + + args = parser.parse_args() + + array = np.arange(100, dtype=np.float32) + array.tofile(open(args.o, 'w')) + + +if __name__ == '__main__': + main() diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer.py index 8f9e5422a7..b1afe14454 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer.py @@ -65,8 +65,10 @@ class ApmInputMixer(object): # This ensures that if the internal parameters of the echo path simulator # change, no erroneous cache hit occurs. echo_file_name, _ = os.path.splitext(os.path.split(echo_filepath)[1]) - mix_filepath = os.path.join(output_path, 'mix_capture_{}.wav'.format( - echo_file_name)) + capture_input_file_name, _ = os.path.splitext( + os.path.split(capture_input_filepath)[1]) + mix_filepath = os.path.join(output_path, 'mix_capture_{}_{}.wav'.format( + capture_input_file_name, echo_file_name)) # Create the mix if not done yet. mix = None diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py index 8e672916c5..f791ddda6a 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py @@ -41,7 +41,9 @@ class ApmModuleSimulator(object): _PREFIX_SCORE = 'score-' def __init__(self, test_data_generator_factory, evaluation_score_factory, - ap_wrapper, evaluator): + ap_wrapper, evaluator, external_vads=None): + if external_vads is None: + external_vads = {} self._test_data_generator_factory = test_data_generator_factory self._evaluation_score_factory = evaluation_score_factory self._audioproc_wrapper = ap_wrapper @@ -49,7 +51,9 @@ class ApmModuleSimulator(object): self._annotator = annotations.AudioAnnotationsExtractor( annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD | annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO | - annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM) + annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM, + external_vads + ) # Init. self._test_data_generator_factory.SetOutputDirectoryPrefix( diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py index cf9aac8da9..c7ebcbc87a 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py @@ -26,6 +26,7 @@ import pydub from . import audioproc_wrapper from . import eval_scores_factory from . import evaluation +from . import external_vad from . import signal_processing from . import simulation from . import test_data_generation_factory @@ -75,7 +76,10 @@ class TestApmModuleSimulator(unittest.TestCase): test_data_generator_factory=test_data_generator_factory, evaluation_score_factory=evaluation_score_factory, ap_wrapper=ap_wrapper, - evaluator=evaluator) + evaluator=evaluator, + external_vads={'fake': external_vad.ExternalVad(os.path.join( + os.path.dirname(__file__), 'fake_external_vad.py'), 'fake')} + ) # What to simulate. config_files = ['apm_configs/default.json']