diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/BUILD.gn b/webrtc/modules/audio_processing/test/py_quality_assessment/BUILD.gn index 72ec187b9c..5b75153ffe 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/BUILD.gn +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/BUILD.gn @@ -62,6 +62,7 @@ copy("lib") { "quality_assessment/exceptions.py", "quality_assessment/export.py", "quality_assessment/input_mixer.py", + "quality_assessment/input_signal_creator.py", "quality_assessment/results.css", "quality_assessment/results.js", "quality_assessment/signal_processing.py", diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_unittest.py b/webrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_unittest.py index dd0eb07383..f5240f8696 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_unittest.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/apm_quality_assessment_unittest.py @@ -9,8 +9,16 @@ """Unit tests for the apm_quality_assessment module. """ +import os +import sys import unittest +SRC = os.path.abspath(os.path.join( + os.path.dirname((__file__)), os.pardir, os.pardir, os.pardir)) +sys.path.append(os.path.join(SRC, 'third_party', 'pymock')) + +import mock + import apm_quality_assessment class TestSimulationScript(unittest.TestCase): @@ -19,6 +27,7 @@ class TestSimulationScript(unittest.TestCase): def testMain(self): # Exit with error code if no arguments are passed. - with self.assertRaises(SystemExit) as cm: + with self.assertRaises(SystemExit) as cm, mock.patch.object( + sys, 'argv', ['apm_quality_assessment.py']): apm_quality_assessment.main() self.assertGreater(cm.exception.code, 0) diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py index 943f2143b8..0f7716a8c8 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py @@ -26,3 +26,9 @@ class InputMixerException(Exception): """Input mixer exeception. """ pass + + +class InputSignalCreatorException(Exception): + """Input signal creator exeception. + """ + pass diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py new file mode 100644 index 0000000000..e2a720c796 --- /dev/null +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py @@ -0,0 +1,57 @@ +# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved. +# +# Use of this source code is governed by a BSD-style license +# that can be found in the LICENSE file in the root of the source +# tree. An additional intellectual property rights grant can be found +# in the file PATENTS. All contributing project authors may +# be found in the AUTHORS file in the root of the source tree. + +"""Input signal creator module. +""" + +from . import exceptions +from . import signal_processing + + +class InputSignalCreator(object): + """Input signal creator class. + """ + + @classmethod + def Create(cls, name, params): + """Creates a input signal. + + Args: + name: Input signal creator name. + params: Tuple of parameters to pass to the specific signal creator. + + Returns: + AudioSegment instance. + """ + try: + if name == 'pure_tone': + return cls._CreatePureTone(float(params[0]), int(params[1])) + except (TypeError, AssertionError) as e: + raise exceptions.InputSignalCreatorException( + 'Invalid signal creator parameters: {}'.format(e)) + + raise exceptions.InputSignalCreatorException( + 'Invalid input signal creator name') + + @classmethod + def _CreatePureTone(cls, frequency, duration): + """ + Generates a pure tone at 48000 Hz. + + Args: + frequency: Float in (0-24000] (Hz). + duration: Integer (milliseconds). + + Returns: + AudioSegment instance. + """ + assert 0 < frequency <= 24000 + assert 0 < duration + template = signal_processing.SignalProcessingUtils.GenerateSilence(duration) + return signal_processing.SignalProcessingUtils.GeneratePureTone( + template, frequency) diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py index 1d4789d6e9..544ad97ffc 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py @@ -66,7 +66,7 @@ class TestApmModuleSimulator(unittest.TestCase): config_files = ['apm_configs/default.json'] input_files = [self._fake_audio_track_path] test_data_generators = ['identity', 'white_noise'] - eval_scores = ['audio_level', 'polqa'] + eval_scores = ['audio_level_mean', 'polqa'] # Run all simulations. simulator.Run( diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py index 2fa49da6e3..3d54da5fc2 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py @@ -33,6 +33,7 @@ except ImportError: from . import data_access from . import exceptions +from . import input_signal_creator from . import signal_processing @@ -109,6 +110,12 @@ class TestDataGenerator(object): base_output_path: base path where output is written. """ self.Clear() + + # If the input signal file does not exist, try to create using the + # available input signal creators. + if not os.path.exists(input_signal_filepath): + self._CreateInputSignal(input_signal_filepath) + self._Generate( input_signal_filepath, test_data_cache_path, base_output_path) @@ -119,6 +126,33 @@ class TestDataGenerator(object): self._apm_output_paths = {} self._reference_signal_filepaths = {} + @classmethod + def _CreateInputSignal(cls, input_signal_filepath): + """Creates a missing input signal file. + + The file name is parsed to extract input signal creator and params. If a + creator is matched and the parameters are valid, a new signal is generated + and written in |input_signal_filepath|. + + Args: + input_signal_filepath: Path to the input signal audio file to write. + + Raises: + InputSignalCreatorException + """ + filename = os.path.splitext(os.path.split(input_signal_filepath)[-1])[0] + filename_parts = filename.split('-') + + if len(filename_parts) < 2: + raise exceptions.InputSignalCreatorException( + 'Cannot parse input signal file name') + + signal = input_signal_creator.InputSignalCreator.Create( + filename_parts[0], filename_parts[1].split('_')) + + signal_processing.SignalProcessingUtils.SaveWav( + input_signal_filepath, signal) + def _Generate( self, input_signal_filepath, test_data_cache_path, base_output_path): """Abstract method to be implemented in each concrete class. diff --git a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py index 0bf7e1ab98..6239d516b0 100644 --- a/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py +++ b/webrtc/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation_unittest.py @@ -52,6 +52,25 @@ class TestTestDataGenerators(unittest.TestCase): shutil.rmtree(self._test_data_cache_path) shutil.rmtree(self._fake_air_db_path) + def testInputSignalCreation(self): + # Init. + generator = test_data_generation.IdentityTestDataGenerator('tmp') + input_signal_filepath = os.path.join( + self._test_data_cache_path, 'pure_tone-440_1000.wav') + + # Check that the input signal is generated. + self.assertFalse(os.path.exists(input_signal_filepath)) + generator.Generate( + input_signal_filepath=input_signal_filepath, + test_data_cache_path=self._test_data_cache_path, + base_output_path=self._base_output_path) + self.assertTrue(os.path.exists(input_signal_filepath)) + + # Check input signal properties. + input_signal = signal_processing.SignalProcessingUtils.LoadWav( + input_signal_filepath) + self.assertEqual(1000, len(input_signal)) + def testTestDataGenerators(self): # Preliminary check. self.assertTrue(os.path.exists(self._base_output_path))