diff --git a/modules/audio_processing/test/py_quality_assessment/README.md b/modules/audio_processing/test/py_quality_assessment/README.md index e19a780236..79e1650f08 100644 --- a/modules/audio_processing/test/py_quality_assessment/README.md +++ b/modules/audio_processing/test/py_quality_assessment/README.md @@ -33,6 +33,12 @@ reference one used for evaluation. - Go to `out/Default/py_quality_assessment` and check that `apm_quality_assessment.py` exists +## Unit tests + + - Compile WebRTC + - Go to `out/Default/py_quality_assessment` + - Run `python -m unittest -p "*_unittest.py" discover` + ## First time setup - Deploy PolqaOem64 and set the `POLQA_PATH` environment variable diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py index 826a0899ab..c488859b89 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/data_access.py @@ -31,8 +31,33 @@ class Metadata(object): def __init__(self): pass + _GENERIC_METADATA_SUFFIX = '.mdata' _AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json' + @classmethod + def LoadFileMetadata(cls, filepath): + """Loads generic metadata linked to a file. + + Args: + filepath: path to the metadata file to read. + + Returns: + A dict. + """ + with open(filepath + cls._GENERIC_METADATA_SUFFIX) as f: + return json.load(f) + + @classmethod + def SaveFileMetadata(cls, filepath, metadata): + """Saves generic metadata linked to a file. + + Args: + filepath: path to the metadata file to write. + metadata: a dict. + """ + with open(filepath + cls._GENERIC_METADATA_SUFFIX, 'w') as f: + json.dump(metadata, f) + @classmethod def LoadAudioTestDataPaths(cls, metadata_path): """Loads the input and the reference audio track paths. diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py index 78d0c18558..420afd2243 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores.py @@ -14,6 +14,13 @@ import logging import os import re import subprocess +import sys + +try: + import numpy as np +except ImportError: + logging.critical('Cannot import the third-party Python package numpy') + sys.exit(1) from . import data_access from . import exceptions @@ -27,6 +34,7 @@ class EvaluationScore(object): def __init__(self, score_filename_prefix): self._score_filename_prefix = score_filename_prefix + self._input_signal_metadata = None self._reference_signal = None self._reference_signal_filepath = None self._tested_signal = None @@ -56,8 +64,16 @@ class EvaluationScore(object): def score(self): return self._score + def SetInputSignalMetadata(self, metadata): + """Sets input signal metadata. + + Args: + metadata: dict instance. + """ + self._input_signal_metadata = metadata + def SetReferenceSignalFilepath(self, filepath): - """ Sets the path to the audio track used as reference signal. + """Sets the path to the audio track used as reference signal. Args: filepath: path to the reference audio track. @@ -65,7 +81,7 @@ class EvaluationScore(object): self._reference_signal_filepath = filepath def SetTestedSignalFilepath(self, filepath): - """ Sets the path to the audio track used as test signal. + """Sets the path to the audio track used as test signal. Args: filepath: path to the test audio track. @@ -242,3 +258,84 @@ class PolqaScore(EvaluationScore): # Build and return a dictionary with field names (header) as keys and the # corresponding field values as values. return {data[0][index]: data[1][index] for index in range(number_of_fields)} + + +@EvaluationScore.RegisterClass +class TotalHarmonicDistorsionScore(EvaluationScore): + """Total harmonic distorsion plus noise score. + + Total harmonic distorsion plus noise score. + See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN". + + Unit: -. + Ideal: 0. + Worst case: +inf + """ + + NAME = 'thd' + + def __init__(self, score_filename_prefix): + EvaluationScore.__init__(self, score_filename_prefix) + self._input_frequency = None + + def _Run(self, output_path): + # TODO(aleloi): Integrate changes made locally. + self._CheckInputSignal() + + self._LoadTestedSignal() + if self._tested_signal.channels != 1: + raise exceptions.EvaluationScoreException( + 'unsupported number of channels') + samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData( + self._tested_signal) + + # Init. + num_samples = len(samples) + duration = len(self._tested_signal) / 1000.0 + scaling = 2.0 / num_samples + max_freq = self._tested_signal.frame_rate / 2 + f0_freq = float(self._input_frequency) + t = np.linspace(0, duration, num_samples) + + # Analyze harmonics. + b_terms = [] + n = 1 + while f0_freq * n < max_freq: + x_n = np.sum(samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling + y_n = np.sum(samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling + b_terms.append(np.sqrt(x_n**2 + y_n**2)) + n += 1 + + output_without_fundamental = samples - b_terms[0] * np.sin( + 2.0 * np.pi * f0_freq * t) + distortion_and_noise = np.sqrt(np.sum( + output_without_fundamental**2) * np.pi * scaling) + + # TODO(alessiob): Fix or remove if not needed. + # thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0] + + # TODO(alessiob): Check the range of |thd_plus_noise| and update the class + # docstring above if accordingly. + thd_plus_noise = distortion_and_noise / b_terms[0] + + self._score = thd_plus_noise + self._SaveScore() + + def _CheckInputSignal(self): + # Check input signal and get properties. + try: + if self._input_signal_metadata['signal'] != 'pure_tone': + raise exceptions.EvaluationScoreException( + 'The THD score requires a pure tone as input signal') + self._input_frequency = self._input_signal_metadata['frequency'] + if self._input_signal_metadata['test_data_gen_name'] != 'identity' or ( + self._input_signal_metadata['test_data_gen_config'] != 'default'): + raise exceptions.EvaluationScoreException( + 'The THD score cannot be used with any test data generator other ' + 'than "identity"') + except TypeError: + raise exceptions.EvaluationScoreException( + 'The THD score requires an input signal with associated metadata') + except KeyError: + raise exceptions.EvaluationScoreException( + 'Invalid input signal metadata to compute the THD score') diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py index b3bd4f9a9e..ce51051a91 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/eval_scores_unittest.py @@ -52,6 +52,9 @@ class TestEvalScores(unittest.TestCase): shutil.rmtree(self._output_path) def testRegisteredClasses(self): + # Evaluation score names to exclude (tested separately). + exceptions = ['thd'] + # Preliminary check. self.assertTrue(os.path.exists(self._output_path)) @@ -69,11 +72,14 @@ class TestEvalScores(unittest.TestCase): # Try each registered evaluation score worker. for eval_score_name in registered_classes: + if eval_score_name in exceptions: + continue + # Instance evaluation score worker. eval_score_worker = eval_score_workers_factory.GetInstance( registered_classes[eval_score_name]) - # Set reference and test, then run. + # Set fake input metadata and reference and test file paths, then run. eval_score_worker.SetReferenceSignalFilepath( self._fake_reference_signal_filepath) eval_score_worker.SetTestedSignalFilepath( @@ -83,3 +89,43 @@ class TestEvalScores(unittest.TestCase): # Check output. score = data_access.ScoreFile.Load(eval_score_worker.output_filepath) self.assertTrue(isinstance(score, float)) + + def testTotalHarmonicDistorsionScore(self): + # Init. + pure_tone_freq = 5000.0 + eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-') + eval_score_worker.SetInputSignalMetadata({ + 'signal': 'pure_tone', + 'frequency': pure_tone_freq, + 'test_data_gen_name': 'identity', + 'test_data_gen_config': 'default', + }) + template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) + + # Create 3 test signals: pure tone, pure tone + white noise, white noise + # only. + pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone( + template, pure_tone_freq) + white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( + template) + noisy_tone = signal_processing.SignalProcessingUtils.MixSignals( + pure_tone, white_noise) + + # Compute scores for increasingly distorted pure tone signals. + scores = [None, None, None] + for index, tested_signal in enumerate([pure_tone, noisy_tone, white_noise]): + # Save signal. + tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav') + signal_processing.SignalProcessingUtils.SaveWav( + tmp_filepath, tested_signal) + + # Compute score. + eval_score_worker.SetTestedSignalFilepath(tmp_filepath) + eval_score_worker.Run(self._output_path) + scores[index] = eval_score_worker.score + + # Remove output file to avoid caching. + os.remove(eval_score_worker.output_filepath) + + # Validate scores (lowest score with a pure tone). + self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)])) diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py index e18f193bb0..09ded4cbd5 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/evaluation.py @@ -20,14 +20,15 @@ class ApmModuleEvaluator(object): pass @classmethod - def Run(cls, evaluation_score_workers, apm_output_filepath, - reference_input_filepath, output_path): + def Run(cls, evaluation_score_workers, apm_input_metadata, + apm_output_filepath, reference_input_filepath, output_path): """Runs the evaluation. Iterates over the given evaluation score workers. Args: evaluation_score_workers: list of EvaluationScore instances. + apm_input_metadata: dictionary with metadata of the APM input. apm_output_filepath: path to the audio track file with the APM output. reference_input_filepath: path to the reference audio track file. output_path: output path. @@ -40,6 +41,7 @@ class ApmModuleEvaluator(object): for evaluation_score_worker in evaluation_score_workers: logging.info(' computing <%s> score', evaluation_score_worker.NAME) + evaluation_score_worker.SetInputSignalMetadata(apm_input_metadata) evaluation_score_worker.SetReferenceSignalFilepath( reference_input_filepath) evaluation_score_worker.SetTestedSignalFilepath( diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py index 0f7716a8c8..b13b35bc89 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/exceptions.py @@ -32,3 +32,9 @@ class InputSignalCreatorException(Exception): """Input signal creator exeception. """ pass + + +class EvaluationScoreException(Exception): + """Evaluation score exeception. + """ + pass diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py index e2a720c796..5d97c3b2fc 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/input_signal_creator.py @@ -18,26 +18,36 @@ class InputSignalCreator(object): """ @classmethod - def Create(cls, name, params): - """Creates a input signal. + def Create(cls, name, raw_params): + """Creates a input signal and its metadata. Args: name: Input signal creator name. - params: Tuple of parameters to pass to the specific signal creator. + raw_params: Tuple of parameters to pass to the specific signal creator. Returns: - AudioSegment instance. + (AudioSegment, dict) tuple. """ try: + signal = {} + params = {} + if name == 'pure_tone': - return cls._CreatePureTone(float(params[0]), int(params[1])) + params['frequency'] = float(raw_params[0]) + params['duration'] = int(raw_params[1]) + signal = cls._CreatePureTone(params['frequency'], params['duration']) + else: + raise exceptions.InputSignalCreatorException( + 'Invalid input signal creator name') + + # Complete metadata. + params['signal'] = name + + return signal, params except (TypeError, AssertionError) as e: raise exceptions.InputSignalCreatorException( 'Invalid signal creator parameters: {}'.format(e)) - raise exceptions.InputSignalCreatorException( - 'Invalid input signal creator name') - @classmethod def _CreatePureTone(cls, frequency, duration): """ diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py index 9a1f27978b..5beb3fb307 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/signal_processing.py @@ -148,6 +148,13 @@ class SignalProcessingUtils(object): duration=len(template), volume=0.0) + @classmethod + def AudioSegmentToRawData(cls, signal): + samples = signal.get_array_of_samples() + if samples.typecode != 'h': + raise exceptions.SignalProcessingException('Unsupported samples type') + return np.array(signal.get_array_of_samples(), np.int16) + @classmethod def DetectHardClipping(cls, signal, threshold=2): """Detects hard clipping. @@ -169,13 +176,7 @@ class SignalProcessingUtils(object): if signal.sample_width != 2: # Note that signal.sample_width is in bytes. raise exceptions.SignalProcessingException( 'hard-clipping detection only supported for 16 bit samples') - - # Get raw samples, check type, cast. - samples = signal.get_array_of_samples() - if samples.typecode != 'h': - raise exceptions.SignalProcessingException( - 'hard-clipping detection only supported for 16 bit samples') - samples = np.array(signal.get_array_of_samples(), np.int16) + samples = cls.AudioSegmentToRawData(signal) # Detect adjacent clipped samples. samples_type_info = np.iinfo(samples.dtype) diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py index 7023b6a8c5..b25694025b 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation.py @@ -17,6 +17,7 @@ from . import echo_path_simulation from . import echo_path_simulation_factory from . import eval_scores from . import eval_scores_factory +from . import exceptions from . import input_mixer from . import test_data_generation from . import test_data_generation_factory @@ -248,9 +249,20 @@ class ApmModuleSimulator(object): test_data_cache_path=test_data_cache_path, base_output_path=output_path) + # Extract metadata linked to the clean input file (if any). + apm_input_metadata = None + try: + apm_input_metadata = data_access.Metadata.LoadFileMetadata( + clean_capture_input_filepath) + except IOError as e: + apm_input_metadata = {} + apm_input_metadata['test_data_gen_name'] = test_data_generators.NAME + apm_input_metadata['test_data_gen_config'] = None + # For each test data pair, simulate a call and evaluate. for config_name in test_data_generators.config_names: logging.info(' - test data generator config: <%s>', config_name) + apm_input_metadata['test_data_gen_config'] = config_name # Paths to the test data generator output. # Note that the reference signal does not depend on the render input @@ -278,23 +290,28 @@ class ApmModuleSimulator(object): render_input_filepath=render_input_filepath, output_path=evaluation_output_path) - # Evaluate. - self._evaluator.Run( - evaluation_score_workers=self._evaluation_score_workers, - apm_output_filepath=self._audioproc_wrapper.output_filepath, - reference_input_filepath=reference_signal_filepath, - output_path=evaluation_output_path) + try: + # Evaluate. + self._evaluator.Run( + evaluation_score_workers=self._evaluation_score_workers, + apm_input_metadata=apm_input_metadata, + apm_output_filepath=self._audioproc_wrapper.output_filepath, + reference_input_filepath=reference_signal_filepath, + output_path=evaluation_output_path) - # Save simulation metadata. - data_access.Metadata.SaveAudioTestDataPaths( - output_path=evaluation_output_path, - clean_capture_input_filepath=clean_capture_input_filepath, - echo_free_capture_filepath=noisy_capture_input_filepath, - echo_filepath=echo_path_filepath, - render_filepath=render_input_filepath, - capture_filepath=apm_input_filepath, - apm_output_filepath=self._audioproc_wrapper.output_filepath, - apm_reference_filepath=reference_signal_filepath) + # Save simulation metadata. + data_access.Metadata.SaveAudioTestDataPaths( + output_path=evaluation_output_path, + clean_capture_input_filepath=clean_capture_input_filepath, + echo_free_capture_filepath=noisy_capture_input_filepath, + echo_filepath=echo_path_filepath, + render_filepath=render_input_filepath, + capture_filepath=apm_input_filepath, + apm_output_filepath=self._audioproc_wrapper.output_filepath, + apm_reference_filepath=reference_signal_filepath) + except exceptions.EvaluationScoreException as e: + logging.warning('the evaluation failed: %s', e.message) + continue def _SetTestInputSignalFilePaths(self, capture_input_filepaths, render_input_filepaths): diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py index 544ad97ffc..33ee92190c 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/simulation_unittest.py @@ -9,6 +9,7 @@ """Unit tests for the simulation module. """ +import logging import os import shutil import sys @@ -33,8 +34,9 @@ class TestApmModuleSimulator(unittest.TestCase): """ def setUp(self): - """Create temporary folder and fake audio track.""" + """Create temporary folders and fake audio track.""" self._output_path = tempfile.mkdtemp() + self._tmp_path = tempfile.mkdtemp() silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000) fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise( @@ -46,6 +48,7 @@ class TestApmModuleSimulator(unittest.TestCase): def tearDown(self): """Recursively delete temporary folders.""" shutil.rmtree(self._output_path) + shutil.rmtree(self._tmp_path) def testSimulation(self): # Instance dependencies to inject and mock. @@ -87,3 +90,39 @@ class TestApmModuleSimulator(unittest.TestCase): min_number_of_simulations) self.assertGreaterEqual(len(evaluator.Run.call_args_list), min_number_of_simulations) + + def testPureToneGenerationWithTotalHarmonicDistorsion(self): + logging.warning = mock.MagicMock(name='warning') + + # Instance simulator. + simulator = simulation.ApmModuleSimulator( + aechen_ir_database_path='', + polqa_tool_bin_path=os.path.join( + os.path.dirname(__file__), 'fake_polqa'), + ap_wrapper=audioproc_wrapper.AudioProcWrapper(), + evaluator=evaluation.ApmModuleEvaluator()) + + # What to simulate. + config_files = ['apm_configs/default.json'] + input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')] + eval_scores = ['thd'] + + # Should work. + simulator.Run( + config_filepaths=config_files, + capture_input_filepaths=input_files, + test_data_generator_names=['identity'], + eval_score_names=eval_scores, + output_dir=self._output_path) + self.assertFalse(logging.warning.called) + + # Warning expected. + simulator.Run( + config_filepaths=config_files, + capture_input_filepaths=input_files, + test_data_generator_names=['white_noise'], # Not allowed with THD. + eval_score_names=eval_scores, + output_dir=self._output_path) + logging.warning.assert_called_with('the evaluation failed: %s', ( + 'The THD score cannot be used with any test data generator other than ' + '"identity"')) diff --git a/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py b/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py index 3d54da5fc2..4153f738ab 100644 --- a/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py +++ b/modules/audio_processing/test/py_quality_assessment/quality_assessment/test_data_generation.py @@ -147,11 +147,12 @@ class TestDataGenerator(object): raise exceptions.InputSignalCreatorException( 'Cannot parse input signal file name') - signal = input_signal_creator.InputSignalCreator.Create( + signal, metadata = input_signal_creator.InputSignalCreator.Create( filename_parts[0], filename_parts[1].split('_')) signal_processing.SignalProcessingUtils.SaveWav( input_signal_filepath, signal) + data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata) def _Generate( self, input_signal_filepath, test_data_cache_path, base_output_path):