From 10dd7ed81a2746886e233f9a66ef01d562a4459d Mon Sep 17 00:00:00 2001 From: Alex Loiko Date: Tue, 21 Nov 2017 13:21:28 +0100 Subject: [PATCH] Support for external VAD program in APM-QA There is now an 'ExternalVad' class in the AnnotationsExtractor. The Extractor takes an extra list of these in addition to the other VADs. The external VAD runs an external program to generate the annotations. Annotations are loaded and saved to a compressed Numpy format. Also made a small fix to name a mixed file in a way so that files will not be overwritten. Also did some minor changes to the unittests. TBR=alessiob@webrtc.org Bug: webrtc:7494 Change-Id: I7816b04466be16cd635ac6ceab18cd7aad5325a4 Reviewed-on: https://webrtc-review.googlesource.com/23623 Commit-Queue: Alex Loiko Reviewed-by: Alex Loiko Reviewed-by: Alessio Bazzica Cr-Commit-Position: refs/heads/master@{#20819} --- .../test/py_quality_assessment/BUILD.gn | 2 + .../apm_quality_assessment.py | 19 ++++- .../quality_assessment/annotations.py | 31 ++++++-- .../annotations_unittest.py | 77 +++++++++++-------- .../quality_assessment/external_vad.py | 77 +++++++++++++++++++ .../quality_assessment/fake_external_vad.py | 24 ++++++ .../quality_assessment/input_mixer.py | 6 +- .../quality_assessment/simulation.py | 8 +- .../quality_assessment/simulation_unittest.py | 6 +- 9 files changed, 209 insertions(+), 41 deletions(-) create mode 100644 modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py create mode 100755 modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_external_vad.py diff --git a/modules/audio_processing/test/py_quality_assessment/BUILD.gn b/modules/audio_processing/test/py_quality_assessment/BUILD.gn index eee58da2bb..64e3a30bd3 100644 --- a/modules/audio_processing/test/py_quality_assessment/BUILD.gn +++ b/modules/audio_processing/test/py_quality_assessment/BUILD.gn @@ -66,6 +66,7 @@ copy("lib") { "quality_assessment/exceptions.py", "quality_assessment/export.py", "quality_assessment/export_unittest.py", + "quality_assessment/external_vad.py", "quality_assessment/input_mixer.py", "quality_assessment/input_signal_creator.py", "quality_assessment/results.css", @@ -149,6 +150,7 @@ copy("lib_unit_tests") { "quality_assessment/annotations_unittest.py", "quality_assessment/echo_path_simulation_unittest.py", "quality_assessment/eval_scores_unittest.py", + "quality_assessment/fake_external_vad.py", "quality_assessment/input_mixer_unittest.py", "quality_assessment/signal_processing_unittest.py", "quality_assessment/simulation_unittest.py", diff --git a/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py b/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py index 78ff5e93e6..a4cc5f037f 100755 --- a/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py +++ b/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment.py @@ -27,6 +27,7 @@ import quality_assessment.echo_path_simulation as echo_path_simulation import quality_assessment.eval_scores as eval_scores import quality_assessment.evaluation as evaluation import quality_assessment.eval_scores_factory as eval_scores_factory +import quality_assessment.external_vad as external_vad import quality_assessment.test_data_generation as test_data_generation import quality_assessment.test_data_generation_factory as \ test_data_generation_factory @@ -113,6 +114,14 @@ def _InstanceArgumentsParser(): 'copy of the clean speech input file.'), default=False) + parser.add_argument('--external_vad_paths', nargs='+', required=False, + help=('Paths to external VAD programs. Each must take' + '\'-i -o \' inputs'), default=[]) + + parser.add_argument('--external_vad_names', nargs='+', required=False, + help=('Keys to the vad paths. Must be different and ' + 'as many as the paths.'), default=[]) + return parser @@ -128,6 +137,12 @@ def _ValidateArguments(args, parser): 'also required') sys.exit(1) + if len(args.external_vad_names) != len(args.external_vad_paths): + parser.error('If provided, --external_vad_paths and ' + '--external_vad_names must ' + 'have the same number of arguments.') + sys.exit(1) + def main(): # TODO(alessiob): level = logging.INFO once debugged. @@ -145,7 +160,9 @@ def main(): evaluation_score_factory=eval_scores_factory.EvaluationScoreWorkerFactory( polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME)), ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path), - evaluator=evaluation.ApmModuleEvaluator()) + evaluator=evaluation.ApmModuleEvaluator(), + external_vads=external_vad.ExternalVad.ConstructVadDict( + args.external_vad_paths, args.external_vad_names)) simulator.Run( config_filepaths=args.config_files, capture_input_filepaths=args.capture_input_files, diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py index 2f5daf1f23..a4e9097320 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations.py @@ -24,6 +24,7 @@ except ImportError: logging.critical('Cannot import the third-party Python package numpy') sys.exit(1) +from . import external_vad from . import exceptions from . import signal_processing @@ -76,7 +77,7 @@ class AudioAnnotationsExtractor(object): _VAD_WEBRTC_APM_PATH = os.path.join( _VAD_WEBRTC_PATH, 'apm_vad') - def __init__(self, vad_type): + def __init__(self, vad_type, external_vads=None): self._signal = None self._level = None self._level_frame_size = None @@ -92,6 +93,19 @@ class AudioAnnotationsExtractor(object): self._vad_type = self.VadType(vad_type) logging.info('VADs used for annotations: ' + str(self._vad_type)) + if external_vads is None: + external_vads = {} + self._external_vads = external_vads + + assert len(self._external_vads) == len(external_vads), ( + 'The external VAD names must be unique.') + for vad in external_vads.values(): + if not isinstance(vad, external_vad.ExternalVad): + raise exceptions.InitializationException( + 'Invalid vad type: ' + str(type(vad))) + logging.info('External VAD used for annotation: ' + + str(vad.name)) + assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \ self._VAD_WEBRTC_COMMON_AUDIO_PATH assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \ @@ -113,9 +127,9 @@ class AudioAnnotationsExtractor(object): def GetVadOutput(self, vad_type): if vad_type == self.VadType.ENERGY_THRESHOLD: - return (self._energy_vad, ) + return self._energy_vad elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO: - return (self._common_audio_vad, ) + return self._common_audio_vad elif vad_type == self.VadType.WEBRTC_APM: return (self._apm_vad_probs, self._apm_vad_rms) else: @@ -132,7 +146,7 @@ class AudioAnnotationsExtractor(object): # Load signal. self._signal = signal_processing.SignalProcessingUtils.LoadWav(filepath) if self._signal.channels != 1: - raise NotImplementedError('multiple-channel annotations not implemented') + raise NotImplementedError('Multiple-channel annotations not implemented') # Level estimation params. self._level_frame_size = int(self._signal.frame_rate / 1000 * ( @@ -160,8 +174,14 @@ class AudioAnnotationsExtractor(object): if self._vad_type.Contains(self.VadType.WEBRTC_APM): # WebRTC modules/audio_processing/ VAD. self._RunWebRtcApmVad(filepath) + for extvad_name in self._external_vads: + self._external_vads[extvad_name].Run(filepath) def Save(self, output_path): + ext_kwargs = {'extvad_conf-' + ext_vad: + self._external_vads[ext_vad].GetVadOutput() + for ext_vad in self._external_vads} + # pylint: disable=star-args np.savez_compressed( file=os.path.join(output_path, self._OUTPUT_FILENAME), level=self._level, @@ -172,7 +192,8 @@ class AudioAnnotationsExtractor(object): vad_frame_size=self._vad_frame_size, vad_frame_size_ms=self._vad_frame_size_ms, vad_probs=self._apm_vad_probs, - vad_rms=self._apm_vad_rms + vad_rms=self._apm_vad_rms, + **ext_kwargs ) def _LevelEstimation(self): diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py index 3f44edfb84..5fe5f5de9e 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/annotations_unittest.py @@ -19,6 +19,7 @@ import unittest import numpy as np from . import annotations +from . import external_vad from . import input_signal_creator from . import signal_processing @@ -29,6 +30,11 @@ class TestAnnotationsExtraction(unittest.TestCase): _CLEAN_TMP_OUTPUT = True _DEBUG_PLOT_VAD = False + _VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType + _ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD | + _VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO | + _VAD_TYPE_CLASS.WEBRTC_APM) + def setUp(self): """Create temporary folder.""" @@ -49,11 +55,7 @@ class TestAnnotationsExtraction(unittest.TestCase): self._tmp_path)) def testFrameSizes(self): - vad_type_class = annotations.AudioAnnotationsExtractor.VadType - vad_type = (vad_type_class.ENERGY_THRESHOLD | - vad_type_class.WEBRTC_COMMON_AUDIO | - vad_type_class.WEBRTC_APM) - e = annotations.AudioAnnotationsExtractor(vad_type=vad_type) + e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES) e.Extract(self._wav_file_path) samples_to_ms = lambda n, sr: 1000 * n // sr self.assertEqual(samples_to_ms(e.GetLevelFrameSize(), self._sample_rate), @@ -62,35 +64,31 @@ class TestAnnotationsExtraction(unittest.TestCase): e.GetVadFrameSizeMs()) def testVoiceActivityDetectors(self): - vad_type_class = annotations.AudioAnnotationsExtractor.VadType - max_vad_type = (vad_type_class.ENERGY_THRESHOLD | - vad_type_class.WEBRTC_COMMON_AUDIO | - vad_type_class.WEBRTC_APM) - for vad_type_value in range(0, max_vad_type+1): - vad_type = vad_type_class(vad_type_value) + for vad_type_value in range(0, self._ALL_VAD_TYPES+1): + vad_type = self._VAD_TYPE_CLASS(vad_type_value) e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value) e.Extract(self._wav_file_path) - if vad_type.Contains(vad_type_class.ENERGY_THRESHOLD): - # pylint: disable=unbalanced-tuple-unpacking - (vad_output, ) = e.GetVadOutput(vad_type_class.ENERGY_THRESHOLD) + if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD): + # pylint: disable=unpacking-non-sequence + vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD) self.assertGreater(len(vad_output), 0) self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output), 0.95) - if vad_type.Contains(vad_type_class.WEBRTC_COMMON_AUDIO): - # pylint: disable=unbalanced-tuple-unpacking - (vad_output,) = e.GetVadOutput(vad_type_class.WEBRTC_COMMON_AUDIO) + if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO): + # pylint: disable=unpacking-non-sequence + vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO) self.assertGreater(len(vad_output), 0) self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output), 0.95) - if vad_type.Contains(vad_type_class.WEBRTC_APM): - # pylint: disable=unbalanced-tuple-unpacking - (vad_probs, vad_rms) = e.GetVadOutput(vad_type_class.WEBRTC_APM) + if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM): + # pylint: disable=unpacking-non-sequence + (vad_probs, vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM) self.assertGreater(len(vad_probs), 0) self.assertGreater(len(vad_rms), 0) self.assertGreaterEqual(float(np.sum(vad_probs)) / len(vad_probs), - 0.95) + 0.5) self.assertGreaterEqual(float(np.sum(vad_rms)) / len(vad_rms), 20000) if self._DEBUG_PLOT_VAD: @@ -111,11 +109,7 @@ class TestAnnotationsExtraction(unittest.TestCase): plt.show() def testSaveLoad(self): - vad_type_class = annotations.AudioAnnotationsExtractor.VadType - vad_type = (vad_type_class.ENERGY_THRESHOLD | - vad_type_class.WEBRTC_COMMON_AUDIO | - vad_type_class.WEBRTC_APM) - e = annotations.AudioAnnotationsExtractor(vad_type) + e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES) e.Extract(self._wav_file_path) e.Save(self._tmp_path) @@ -123,14 +117,37 @@ class TestAnnotationsExtraction(unittest.TestCase): np.testing.assert_array_equal(e.GetLevel(), data['level']) self.assertEqual(np.float32, data['level'].dtype) np.testing.assert_array_equal( - e.GetVadOutput(vad_type_class.ENERGY_THRESHOLD), + e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD), data['vad_energy_output']) np.testing.assert_array_equal( - e.GetVadOutput(vad_type_class.WEBRTC_COMMON_AUDIO), data['vad_output']) + e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO), + data['vad_output']) np.testing.assert_array_equal( - e.GetVadOutput(vad_type_class.WEBRTC_APM)[0], data['vad_probs']) + e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0], data['vad_probs']) np.testing.assert_array_equal( - e.GetVadOutput(vad_type_class.WEBRTC_APM)[1], data['vad_rms']) + e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1], data['vad_rms']) self.assertEqual(np.uint8, data['vad_energy_output'].dtype) self.assertEqual(np.float64, data['vad_probs'].dtype) self.assertEqual(np.float64, data['vad_rms'].dtype) + + def testEmptyExternalShouldNotCrash(self): + for vad_type_value in range(0, self._ALL_VAD_TYPES+1): + annotations.AudioAnnotationsExtractor(vad_type_value, {}) + + def testFakeExternalSaveLoad(self): + def FakeExternalFactory(): + return external_vad.ExternalVad( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), 'fake_external_vad.py'), + 'fake' + ) + for vad_type_value in range(0, self._ALL_VAD_TYPES+1): + e = annotations.AudioAnnotationsExtractor( + vad_type_value, + {'fake': FakeExternalFactory()}) + e.Extract(self._wav_file_path) + e.Save(self._tmp_path) + data = np.load(os.path.join(self._tmp_path, e.GetOutputFileName())) + self.assertEqual(np.float32, data['extvad_conf-fake'].dtype) + np.testing.assert_almost_equal(np.arange(100, dtype=np.float32), + data['extvad_conf-fake']) diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py new file mode 100644 index 0000000000..01418d84fe --- /dev/null +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/external_vad.py @@ -0,0 +1,77 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +from __future__ import division + +import logging +import os +import subprocess +import shutil +import sys +import tempfile + +try: + import numpy as np +except ImportError: + logging.critical('Cannot import the third-party Python package numpy') + sys.exit(1) + +from . import signal_processing + +class ExternalVad(object): + + def __init__(self, path_to_binary, name): + """Args: + path_to_binary: path to binary that accepts '-i ', '-o + '. There must be one float value per + 10ms audio + name: a name to identify the external VAD. Used for saving + the output as extvad_output-. + """ + self._path_to_binary = path_to_binary + self.name = name + assert os.path.exists(self._path_to_binary), ( + self._path_to_binary) + self._vad_output = None + + def Run(self, wav_file_path): + _signal = signal_processing.SignalProcessingUtils.LoadWav(wav_file_path) + if _signal.channels != 1: + raise NotImplementedError('Multiple-channel' + ' annotations not implemented') + if _signal.frame_rate != 48000: + raise NotImplementedError('Frame rates ' + 'other than 48000 not implemented') + + tmp_path = tempfile.mkdtemp() + try: + output_file_path = os.path.join( + tmp_path, self.name + '_vad.tmp') + subprocess.call([ + self._path_to_binary, + '-i', wav_file_path, + '-o', output_file_path + ]) + self._vad_output = np.fromfile(output_file_path, np.float32) + except Exception as e: + logging.error('Error while running the ' + self.name + + ' VAD (' + e.message + ')') + finally: + if os.path.exists(tmp_path): + shutil.rmtree(tmp_path) + + def GetVadOutput(self): + assert self._vad_output is not None + return self._vad_output + + @classmethod + def ConstructVadDict(cls, vad_paths, vad_names): + external_vads = {} + for path, name in zip(vad_paths, vad_names): + external_vads[name] = ExternalVad(path, name) + return external_vads diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_external_vad.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_external_vad.py new file mode 100755 index 0000000000..7c75e8f5c3 --- /dev/null +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/fake_external_vad.py @@ -0,0 +1,24 @@ +#!/usr/bin/python +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. +import argparse +import numpy as np + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('-i', required=True) + parser.add_argument('-o', required=True) + + args = parser.parse_args() + + array = np.arange(100, dtype=np.float32) + array.tofile(open(args.o, 'w')) + + +if __name__ == '__main__': + main() diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer.py index 8f9e5422a7..b1afe14454 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_mixer.py @@ -65,8 +65,10 @@ class ApmInputMixer(object): # This ensures that if the internal parameters of the echo path simulator # change, no erroneous cache hit occurs. echo_file_name, _ = os.path.splitext(os.path.split(echo_filepath)[1]) - mix_filepath = os.path.join(output_path, 'mix_capture_{}.wav'.format( - echo_file_name)) + capture_input_file_name, _ = os.path.splitext( + os.path.split(capture_input_filepath)[1]) + mix_filepath = os.path.join(output_path, 'mix_capture_{}_{}.wav'.format( + capture_input_file_name, echo_file_name)) # Create the mix if not done yet. mix = None diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py index 8e672916c5..f791ddda6a 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py @@ -41,7 +41,9 @@ class ApmModuleSimulator(object): _PREFIX_SCORE = 'score-' def __init__(self, test_data_generator_factory, evaluation_score_factory, - ap_wrapper, evaluator): + ap_wrapper, evaluator, external_vads=None): + if external_vads is None: + external_vads = {} self._test_data_generator_factory = test_data_generator_factory self._evaluation_score_factory = evaluation_score_factory self._audioproc_wrapper = ap_wrapper @@ -49,7 +51,9 @@ class ApmModuleSimulator(object): self._annotator = annotations.AudioAnnotationsExtractor( annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD | annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO | - annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM) + annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM, + external_vads + ) # Init. self._test_data_generator_factory.SetOutputDirectoryPrefix( diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py index cf9aac8da9..c7ebcbc87a 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py @@ -26,6 +26,7 @@ import pydub from . import audioproc_wrapper from . import eval_scores_factory from . import evaluation +from . import external_vad from . import signal_processing from . import simulation from . import test_data_generation_factory @@ -75,7 +76,10 @@ class TestApmModuleSimulator(unittest.TestCase): test_data_generator_factory=test_data_generator_factory, evaluation_score_factory=evaluation_score_factory, ap_wrapper=ap_wrapper, - evaluator=evaluator) + evaluator=evaluator, + external_vads={'fake': external_vad.ExternalVad(os.path.join( + os.path.dirname(__file__), 'fake_external_vad.py'), 'fake')} + ) # What to simulate. config_files = ['apm_configs/default.json']