Remove py_quality_assessment and old TODOs in conversational_speech
Bug: webrtc:379542219 Change-Id: I7a6c087ce42f854d9b440da018248323b2435b55 Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/368500 Reviewed-by: Sam Zackrisson <saza@webrtc.org> Commit-Queue: Alessio Bazzica <alessiob@webrtc.org> Cr-Commit-Position: refs/heads/main@{#43418}
This commit is contained in:
parent
7f775bc94c
commit
331ca30635
@ -310,7 +310,6 @@ if (rtc_include_tests) {
|
||||
":audioproc_unittest_proto",
|
||||
"aec_dump:aec_dump_unittests",
|
||||
"test/conversational_speech",
|
||||
"test/py_quality_assessment",
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@ -70,6 +70,7 @@ rtc_library("unittest") {
|
||||
"../../../../api:array_view",
|
||||
"../../../../common_audio",
|
||||
"../../../../rtc_base:logging",
|
||||
"../../../../rtc_base:safe_conversions",
|
||||
"../../../../test:fileutils",
|
||||
"../../../../test:test_support",
|
||||
"//testing/gtest",
|
||||
|
||||
@ -53,6 +53,7 @@
|
||||
#include "modules/audio_processing/test/conversational_speech/timing.h"
|
||||
#include "modules/audio_processing/test/conversational_speech/wavreader_factory.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
#include "test/gmock.h"
|
||||
#include "test/gtest.h"
|
||||
#include "test/testsupport/file_utils.h"
|
||||
@ -101,17 +102,15 @@ std::unique_ptr<MockWavReaderFactory> CreateMockWavReaderFactory() {
|
||||
|
||||
void CreateSineWavFile(absl::string_view filepath,
|
||||
const MockWavReaderFactory::Params& params,
|
||||
float frequency = 440.0f) {
|
||||
// Create samples.
|
||||
constexpr double two_pi = 2.0 * M_PI;
|
||||
float frequency_hz = 440.0f) {
|
||||
const double phase_step = 2 * M_PI * frequency_hz / params.sample_rate;
|
||||
double phase = 0.0;
|
||||
std::vector<int16_t> samples(params.num_samples);
|
||||
for (std::size_t i = 0; i < params.num_samples; ++i) {
|
||||
// TODO(alessiob): the produced tone is not pure, improve.
|
||||
samples[i] = std::lround(
|
||||
32767.0f * std::sin(two_pi * i * frequency / params.sample_rate));
|
||||
for (size_t i = 0; i < params.num_samples; ++i) {
|
||||
samples[i] = rtc::saturated_cast<int16_t>(32767.0f * std::sin(phase));
|
||||
phase += phase_step;
|
||||
}
|
||||
|
||||
// Write samples.
|
||||
WavWriter wav_writer(filepath, params.sample_rate, params.num_channels);
|
||||
wav_writer.WriteSamples(samples.data(), params.num_samples);
|
||||
}
|
||||
|
||||
@ -27,7 +27,6 @@ class MockWavReader : public WavReaderInterface {
|
||||
MockWavReader(int sample_rate, size_t num_channels, size_t num_samples);
|
||||
~MockWavReader();
|
||||
|
||||
// TODO(alessiob): use ON_CALL to return random samples if needed.
|
||||
MOCK_METHOD(size_t, ReadFloatSamples, (rtc::ArrayView<float>), (override));
|
||||
MOCK_METHOD(size_t, ReadInt16Samples, (rtc::ArrayView<int16_t>), (override));
|
||||
|
||||
|
||||
@ -187,13 +187,6 @@ std::unique_ptr<std::map<std::string, SpeakerOutputFilePaths>> Simulate(
|
||||
const auto& audiotrack_readers = multiend_call.audiotrack_readers();
|
||||
auto audiotracks = PreloadAudioTracks(audiotrack_readers);
|
||||
|
||||
// TODO(alessiob): When speaker_names.size() == 2, near-end and far-end
|
||||
// across the 2 speakers are symmetric; hence, the code below could be
|
||||
// replaced by only creating the near-end or the far-end. However, this would
|
||||
// require to split the unit tests and document the behavior in README.md.
|
||||
// In practice, it should not be an issue since the files are not expected to
|
||||
// be signinificant.
|
||||
|
||||
// Write near-end and far-end output tracks.
|
||||
for (const auto& speaking_turn : multiend_call.speaking_turns()) {
|
||||
const std::string& active_speaker_name = speaking_turn.speaker_name;
|
||||
|
||||
@ -1,170 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
import("../../../../webrtc.gni")
|
||||
|
||||
if (!build_with_chromium) {
|
||||
group("py_quality_assessment") {
|
||||
testonly = true
|
||||
deps = [
|
||||
":scripts",
|
||||
":unit_tests",
|
||||
]
|
||||
}
|
||||
|
||||
copy("scripts") {
|
||||
testonly = true
|
||||
sources = [
|
||||
"README.md",
|
||||
"apm_quality_assessment.py",
|
||||
"apm_quality_assessment.sh",
|
||||
"apm_quality_assessment_boxplot.py",
|
||||
"apm_quality_assessment_export.py",
|
||||
"apm_quality_assessment_gencfgs.py",
|
||||
"apm_quality_assessment_optimize.py",
|
||||
]
|
||||
outputs = [ "$root_build_dir/py_quality_assessment/{{source_file_part}}" ]
|
||||
deps = [
|
||||
":apm_configs",
|
||||
":lib",
|
||||
":output",
|
||||
"../../../../resources/audio_processing/test/py_quality_assessment:probing_signals",
|
||||
"../../../../rtc_tools:audioproc_f",
|
||||
]
|
||||
}
|
||||
|
||||
copy("apm_configs") {
|
||||
testonly = true
|
||||
sources = [ "apm_configs/default.json" ]
|
||||
visibility = [ ":*" ] # Only targets in this file can depend on this.
|
||||
outputs = [
|
||||
"$root_build_dir/py_quality_assessment/apm_configs/{{source_file_part}}",
|
||||
]
|
||||
} # apm_configs
|
||||
|
||||
copy("lib") {
|
||||
testonly = true
|
||||
sources = [
|
||||
"quality_assessment/__init__.py",
|
||||
"quality_assessment/annotations.py",
|
||||
"quality_assessment/audioproc_wrapper.py",
|
||||
"quality_assessment/collect_data.py",
|
||||
"quality_assessment/data_access.py",
|
||||
"quality_assessment/echo_path_simulation.py",
|
||||
"quality_assessment/echo_path_simulation_factory.py",
|
||||
"quality_assessment/eval_scores.py",
|
||||
"quality_assessment/eval_scores_factory.py",
|
||||
"quality_assessment/evaluation.py",
|
||||
"quality_assessment/exceptions.py",
|
||||
"quality_assessment/export.py",
|
||||
"quality_assessment/export_unittest.py",
|
||||
"quality_assessment/external_vad.py",
|
||||
"quality_assessment/input_mixer.py",
|
||||
"quality_assessment/input_signal_creator.py",
|
||||
"quality_assessment/results.css",
|
||||
"quality_assessment/results.js",
|
||||
"quality_assessment/signal_processing.py",
|
||||
"quality_assessment/simulation.py",
|
||||
"quality_assessment/test_data_generation.py",
|
||||
"quality_assessment/test_data_generation_factory.py",
|
||||
]
|
||||
visibility = [ ":*" ] # Only targets in this file can depend on this.
|
||||
outputs = [ "$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}" ]
|
||||
deps = [ "../../../../resources/audio_processing/test/py_quality_assessment:noise_tracks" ]
|
||||
}
|
||||
|
||||
copy("output") {
|
||||
testonly = true
|
||||
sources = [ "output/README.md" ]
|
||||
visibility = [ ":*" ] # Only targets in this file can depend on this.
|
||||
outputs =
|
||||
[ "$root_build_dir/py_quality_assessment/output/{{source_file_part}}" ]
|
||||
}
|
||||
|
||||
group("unit_tests") {
|
||||
testonly = true
|
||||
visibility = [ ":*" ] # Only targets in this file can depend on this.
|
||||
deps = [
|
||||
":apm_vad",
|
||||
":fake_polqa",
|
||||
":lib_unit_tests",
|
||||
":scripts_unit_tests",
|
||||
":vad",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_executable("fake_polqa") {
|
||||
testonly = true
|
||||
sources = [ "quality_assessment/fake_polqa.cc" ]
|
||||
visibility = [ ":*" ] # Only targets in this file can depend on this.
|
||||
output_dir = "${root_out_dir}/py_quality_assessment/quality_assessment"
|
||||
deps = [
|
||||
"../../../../rtc_base:checks",
|
||||
"//third_party/abseil-cpp/absl/strings",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_executable("vad") {
|
||||
testonly = true
|
||||
sources = [ "quality_assessment/vad.cc" ]
|
||||
deps = [
|
||||
"../../../../common_audio",
|
||||
"../../../../rtc_base:logging",
|
||||
"//third_party/abseil-cpp/absl/flags:flag",
|
||||
"//third_party/abseil-cpp/absl/flags:parse",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_executable("apm_vad") {
|
||||
testonly = true
|
||||
sources = [ "quality_assessment/apm_vad.cc" ]
|
||||
deps = [
|
||||
"../..",
|
||||
"../../../../common_audio",
|
||||
"../../../../rtc_base:logging",
|
||||
"../../vad",
|
||||
"//third_party/abseil-cpp/absl/flags:flag",
|
||||
"//third_party/abseil-cpp/absl/flags:parse",
|
||||
]
|
||||
}
|
||||
|
||||
rtc_executable("sound_level") {
|
||||
testonly = true
|
||||
sources = [ "quality_assessment/sound_level.cc" ]
|
||||
deps = [
|
||||
"../..",
|
||||
"../../../../common_audio",
|
||||
"../../../../rtc_base:logging",
|
||||
"//third_party/abseil-cpp/absl/flags:flag",
|
||||
"//third_party/abseil-cpp/absl/flags:parse",
|
||||
]
|
||||
}
|
||||
|
||||
copy("lib_unit_tests") {
|
||||
testonly = true
|
||||
sources = [
|
||||
"quality_assessment/annotations_unittest.py",
|
||||
"quality_assessment/echo_path_simulation_unittest.py",
|
||||
"quality_assessment/eval_scores_unittest.py",
|
||||
"quality_assessment/fake_external_vad.py",
|
||||
"quality_assessment/input_mixer_unittest.py",
|
||||
"quality_assessment/signal_processing_unittest.py",
|
||||
"quality_assessment/simulation_unittest.py",
|
||||
"quality_assessment/test_data_generation_unittest.py",
|
||||
]
|
||||
visibility = [ ":*" ] # Only targets in this file can depend on this.
|
||||
outputs = [ "$root_build_dir/py_quality_assessment/quality_assessment/{{source_file_part}}" ]
|
||||
}
|
||||
|
||||
copy("scripts_unit_tests") {
|
||||
testonly = true
|
||||
sources = [ "apm_quality_assessment_unittest.py" ]
|
||||
visibility = [ ":*" ] # Only targets in this file can depend on this.
|
||||
outputs = [ "$root_build_dir/py_quality_assessment/{{source_file_part}}" ]
|
||||
}
|
||||
}
|
||||
@ -1,5 +0,0 @@
|
||||
aleloi@webrtc.org
|
||||
alessiob@webrtc.org
|
||||
henrik.lundin@webrtc.org
|
||||
ivoc@webrtc.org
|
||||
peah@webrtc.org
|
||||
@ -1,125 +0,0 @@
|
||||
# APM Quality Assessment tool
|
||||
|
||||
Python wrapper of APM simulators (e.g., `audioproc_f`) with which quality
|
||||
assessment can be automatized. The tool allows to simulate different noise
|
||||
conditions, input signals, APM configurations and it computes different scores.
|
||||
Once the scores are computed, the results can be easily exported to an HTML page
|
||||
which allows to listen to the APM input and output signals and also the
|
||||
reference one used for evaluation.
|
||||
|
||||
## Dependencies
|
||||
- OS: Linux
|
||||
- Python 2.7
|
||||
- Python libraries: enum34, numpy, scipy, pydub (0.17.0+), pandas (0.20.1+),
|
||||
pyquery (1.2+), jsmin (2.2+), csscompressor (0.9.4)
|
||||
- It is recommended that a dedicated Python environment is used
|
||||
- install `virtualenv`
|
||||
- `$ sudo apt-get install python-virtualenv`
|
||||
- setup a new Python environment (e.g., `my_env`)
|
||||
- `$ cd ~ && virtualenv my_env`
|
||||
- activate the new Python environment
|
||||
- `$ source ~/my_env/bin/activate`
|
||||
- add dependcies via `pip`
|
||||
- `(my_env)$ pip install enum34 numpy pydub scipy pandas pyquery jsmin \`
|
||||
`csscompressor`
|
||||
- PolqaOem64 (see http://www.polqa.info/)
|
||||
- Tested with POLQA Library v1.180 / P863 v2.400
|
||||
- Aachen Impulse Response (AIR) Database
|
||||
- Download https://www2.iks.rwth-aachen.de/air/air_database_release_1_4.zip
|
||||
- Input probing signals and noise tracks (you can make your own dataset - *1)
|
||||
|
||||
## Build
|
||||
- Compile WebRTC
|
||||
- Go to `out/Default/py_quality_assessment` and check that
|
||||
`apm_quality_assessment.py` exists
|
||||
|
||||
## Unit tests
|
||||
- Compile WebRTC
|
||||
- Go to `out/Default/py_quality_assessment`
|
||||
- Run `python -m unittest discover -p "*_unittest.py"`
|
||||
|
||||
## First time setup
|
||||
- Deploy PolqaOem64 and set the `POLQA_PATH` environment variable
|
||||
- e.g., `$ export POLQA_PATH=/var/opt/PolqaOem64`
|
||||
- Deploy the AIR Database and set the `AECHEN_IR_DATABASE_PATH` environment
|
||||
variable
|
||||
- e.g., `$ export AECHEN_IR_DATABASE_PATH=/var/opt/AIR_1_4`
|
||||
- Deploy probing signal tracks into
|
||||
- `out/Default/py_quality_assessment/probing_signals` (*1)
|
||||
- Deploy noise tracks into
|
||||
- `out/Default/py_quality_assessment/noise_tracks` (*1, *2)
|
||||
|
||||
(*1) You can use custom files as long as they are mono tracks sampled at 48kHz
|
||||
encoded in the 16 bit signed format (it is recommended that the tracks are
|
||||
converted and exported with Audacity).
|
||||
|
||||
## Usage (scores computation)
|
||||
- Go to `out/Default/py_quality_assessment`
|
||||
- Check the `apm_quality_assessment.sh` as an example script to parallelize the
|
||||
experiments
|
||||
- Adjust the script according to your preferences (e.g., output path)
|
||||
- Run `apm_quality_assessment.sh`
|
||||
- The script will end by opening the browser and showing ALL the computed
|
||||
scores
|
||||
|
||||
## Usage (export reports)
|
||||
Showing all the results at once can be confusing. You therefore may want to
|
||||
export separate reports. In this case, you can use the
|
||||
`apm_quality_assessment_export.py` script as follows:
|
||||
|
||||
- Set `--output_dir, -o` to the same value used in `apm_quality_assessment.sh`
|
||||
- Use regular expressions to select/filter out scores by
|
||||
- APM configurations: `--config_names, -c`
|
||||
- capture signals: `--capture_names, -i`
|
||||
- render signals: `--render_names, -r`
|
||||
- echo simulator: `--echo_simulator_names, -e`
|
||||
- test data generators: `--test_data_generators, -t`
|
||||
- scores: `--eval_scores, -s`
|
||||
- Assign a suffix to the report name using `-f <suffix>`
|
||||
|
||||
For instance:
|
||||
|
||||
```
|
||||
$ ./apm_quality_assessment_export.py \
|
||||
-o output/ \
|
||||
-c "(^default$)|(.*AE.*)" \
|
||||
-t \(white_noise\) \
|
||||
-s \(polqa\) \
|
||||
-f echo
|
||||
```
|
||||
|
||||
## Usage (boxplot)
|
||||
After generating stats, it can help to visualize how a score depends on a
|
||||
certain APM simulator parameter. The `apm_quality_assessment_boxplot.py` script
|
||||
helps with that, producing plots similar to [this
|
||||
one](https://matplotlib.org/mpl_examples/pylab_examples/boxplot_demo_06.png).
|
||||
|
||||
Suppose some scores come from running the APM simulator `audioproc_f` with
|
||||
or without the level controller: `--lc=1` or `--lc=0`. Then two boxplots
|
||||
side by side can be generated with
|
||||
|
||||
```
|
||||
$ ./apm_quality_assessment_boxplot.py \
|
||||
-o /path/to/output
|
||||
-v <score_name>
|
||||
-n /path/to/dir/with/apm_configs
|
||||
-z lc
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
The input wav file must be:
|
||||
- sampled at a sample rate that is a multiple of 100 (required by POLQA)
|
||||
- in the 16 bit format (required by `audioproc_f`)
|
||||
- encoded in the Microsoft WAV signed 16 bit PCM format (Audacity default
|
||||
when exporting)
|
||||
|
||||
Depending on the license, the POLQA tool may take “breaks” as a way to limit the
|
||||
throughput. When this happens, the APM Quality Assessment tool is slowed down.
|
||||
For more details about this limitation, check Section 10.9.1 in the POLQA manual
|
||||
v.1.18.
|
||||
|
||||
In case of issues with the POLQA score computation, check
|
||||
`py_quality_assessment/eval_scores.py` and adapt
|
||||
`PolqaScore._parse_output_file()`.
|
||||
The code can be also fixed directly into the build directory (namely,
|
||||
`out/Default/py_quality_assessment/eval_scores.py`).
|
||||
@ -1 +0,0 @@
|
||||
{"-all_default": null}
|
||||
@ -1,217 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Perform APM module quality assessment on one or more input files using one or
|
||||
more APM simulator configuration files and one or more test data generators.
|
||||
|
||||
Usage: apm_quality_assessment.py -i audio1.wav [audio2.wav ...]
|
||||
-c cfg1.json [cfg2.json ...]
|
||||
-n white [echo ...]
|
||||
-e audio_level [polqa ...]
|
||||
-o /path/to/output
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import quality_assessment.audioproc_wrapper as audioproc_wrapper
|
||||
import quality_assessment.echo_path_simulation as echo_path_simulation
|
||||
import quality_assessment.eval_scores as eval_scores
|
||||
import quality_assessment.evaluation as evaluation
|
||||
import quality_assessment.eval_scores_factory as eval_scores_factory
|
||||
import quality_assessment.external_vad as external_vad
|
||||
import quality_assessment.test_data_generation as test_data_generation
|
||||
import quality_assessment.test_data_generation_factory as \
|
||||
test_data_generation_factory
|
||||
import quality_assessment.simulation as simulation
|
||||
|
||||
_ECHO_PATH_SIMULATOR_NAMES = (
|
||||
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES)
|
||||
_TEST_DATA_GENERATOR_CLASSES = (
|
||||
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
|
||||
_TEST_DATA_GENERATORS_NAMES = _TEST_DATA_GENERATOR_CLASSES.keys()
|
||||
_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES
|
||||
_EVAL_SCORE_WORKER_NAMES = _EVAL_SCORE_WORKER_CLASSES.keys()
|
||||
|
||||
_DEFAULT_CONFIG_FILE = 'apm_configs/default.json'
|
||||
|
||||
_POLQA_BIN_NAME = 'PolqaOem64'
|
||||
|
||||
|
||||
def _InstanceArgumentsParser():
|
||||
"""Arguments parser factory.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description=(
|
||||
'Perform APM module quality assessment on one or more input files using '
|
||||
'one or more APM simulator configuration files and one or more '
|
||||
'test data generators.'))
|
||||
|
||||
parser.add_argument('-c',
|
||||
'--config_files',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help=('path to the configuration files defining the '
|
||||
'arguments with which the APM simulator tool is '
|
||||
'called'),
|
||||
default=[_DEFAULT_CONFIG_FILE])
|
||||
|
||||
parser.add_argument(
|
||||
'-i',
|
||||
'--capture_input_files',
|
||||
nargs='+',
|
||||
required=True,
|
||||
help='path to the capture input wav files (one or more)')
|
||||
|
||||
parser.add_argument('-r',
|
||||
'--render_input_files',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help=('path to the render input wav files; either '
|
||||
'omitted or one file for each file in '
|
||||
'--capture_input_files (files will be paired by '
|
||||
'index)'),
|
||||
default=None)
|
||||
|
||||
parser.add_argument('-p',
|
||||
'--echo_path_simulator',
|
||||
required=False,
|
||||
help=('custom echo path simulator name; required if '
|
||||
'--render_input_files is specified'),
|
||||
choices=_ECHO_PATH_SIMULATOR_NAMES,
|
||||
default=echo_path_simulation.NoEchoPathSimulator.NAME)
|
||||
|
||||
parser.add_argument('-t',
|
||||
'--test_data_generators',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help='custom list of test data generators to use',
|
||||
choices=_TEST_DATA_GENERATORS_NAMES,
|
||||
default=_TEST_DATA_GENERATORS_NAMES)
|
||||
|
||||
parser.add_argument('--additive_noise_tracks_path', required=False,
|
||||
help='path to the wav files for the additive',
|
||||
default=test_data_generation. \
|
||||
AdditiveNoiseTestDataGenerator. \
|
||||
DEFAULT_NOISE_TRACKS_PATH)
|
||||
|
||||
parser.add_argument('-e',
|
||||
'--eval_scores',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help='custom list of evaluation scores to use',
|
||||
choices=_EVAL_SCORE_WORKER_NAMES,
|
||||
default=_EVAL_SCORE_WORKER_NAMES)
|
||||
|
||||
parser.add_argument('-o',
|
||||
'--output_dir',
|
||||
required=False,
|
||||
help=('base path to the output directory in which the '
|
||||
'output wav files and the evaluation outcomes '
|
||||
'are saved'),
|
||||
default='output')
|
||||
|
||||
parser.add_argument('--polqa_path',
|
||||
required=True,
|
||||
help='path to the POLQA tool')
|
||||
|
||||
parser.add_argument('--air_db_path',
|
||||
required=True,
|
||||
help='path to the Aechen IR database')
|
||||
|
||||
parser.add_argument('--apm_sim_path', required=False,
|
||||
help='path to the APM simulator tool',
|
||||
default=audioproc_wrapper. \
|
||||
AudioProcWrapper. \
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH)
|
||||
|
||||
parser.add_argument('--echo_metric_tool_bin_path',
|
||||
required=False,
|
||||
help=('path to the echo metric binary '
|
||||
'(required for the echo eval score)'),
|
||||
default=None)
|
||||
|
||||
parser.add_argument(
|
||||
'--copy_with_identity_generator',
|
||||
required=False,
|
||||
help=('If true, the identity test data generator makes a '
|
||||
'copy of the clean speech input file.'),
|
||||
default=False)
|
||||
|
||||
parser.add_argument('--external_vad_paths',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help=('Paths to external VAD programs. Each must take'
|
||||
'\'-i <wav file> -o <output>\' inputs'),
|
||||
default=[])
|
||||
|
||||
parser.add_argument('--external_vad_names',
|
||||
nargs='+',
|
||||
required=False,
|
||||
help=('Keys to the vad paths. Must be different and '
|
||||
'as many as the paths.'),
|
||||
default=[])
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _ValidateArguments(args, parser):
|
||||
if args.capture_input_files and args.render_input_files and (len(
|
||||
args.capture_input_files) != len(args.render_input_files)):
|
||||
parser.error(
|
||||
'--render_input_files and --capture_input_files must be lists '
|
||||
'having the same length')
|
||||
sys.exit(1)
|
||||
|
||||
if args.render_input_files and not args.echo_path_simulator:
|
||||
parser.error(
|
||||
'when --render_input_files is set, --echo_path_simulator is '
|
||||
'also required')
|
||||
sys.exit(1)
|
||||
|
||||
if len(args.external_vad_names) != len(args.external_vad_paths):
|
||||
parser.error('If provided, --external_vad_paths and '
|
||||
'--external_vad_names must '
|
||||
'have the same number of arguments.')
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def main():
|
||||
# TODO(alessiob): level = logging.INFO once debugged.
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = _InstanceArgumentsParser()
|
||||
args = parser.parse_args()
|
||||
_ValidateArguments(args, parser)
|
||||
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path=args.air_db_path,
|
||||
noise_tracks_path=args.additive_noise_tracks_path,
|
||||
copy_with_identity=args.copy_with_identity_generator)),
|
||||
evaluation_score_factory=eval_scores_factory.
|
||||
EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME),
|
||||
echo_metric_tool_bin_path=args.echo_metric_tool_bin_path),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path),
|
||||
evaluator=evaluation.ApmModuleEvaluator(),
|
||||
external_vads=external_vad.ExternalVad.ConstructVadDict(
|
||||
args.external_vad_paths, args.external_vad_names))
|
||||
simulator.Run(config_filepaths=args.config_files,
|
||||
capture_input_filepaths=args.capture_input_files,
|
||||
render_input_filepaths=args.render_input_files,
|
||||
echo_path_simulator_name=args.echo_path_simulator,
|
||||
test_data_generator_names=args.test_data_generators,
|
||||
eval_score_names=args.eval_scores,
|
||||
output_dir=args.output_dir)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -1,91 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
# Path to the POLQA tool.
|
||||
if [ -z ${POLQA_PATH} ]; then # Check if defined.
|
||||
# Default location.
|
||||
export POLQA_PATH='/var/opt/PolqaOem64'
|
||||
fi
|
||||
if [ -d "${POLQA_PATH}" ]; then
|
||||
echo "POLQA found in ${POLQA_PATH}"
|
||||
else
|
||||
echo "POLQA not found in ${POLQA_PATH}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Path to the Aechen IR database.
|
||||
if [ -z ${AECHEN_IR_DATABASE_PATH} ]; then # Check if defined.
|
||||
# Default location.
|
||||
export AECHEN_IR_DATABASE_PATH='/var/opt/AIR_1_4'
|
||||
fi
|
||||
if [ -d "${AECHEN_IR_DATABASE_PATH}" ]; then
|
||||
echo "AIR database found in ${AECHEN_IR_DATABASE_PATH}"
|
||||
else
|
||||
echo "AIR database not found in ${AECHEN_IR_DATABASE_PATH}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Customize probing signals, test data generators and scores if needed.
|
||||
CAPTURE_SIGNALS=(probing_signals/*.wav)
|
||||
TEST_DATA_GENERATORS=( \
|
||||
"identity" \
|
||||
"white_noise" \
|
||||
# "environmental_noise" \
|
||||
# "reverberation" \
|
||||
)
|
||||
SCORES=( \
|
||||
# "polqa" \
|
||||
"audio_level_peak" \
|
||||
"audio_level_mean" \
|
||||
)
|
||||
OUTPUT_PATH=output
|
||||
|
||||
# Generate standard APM config files.
|
||||
chmod +x apm_quality_assessment_gencfgs.py
|
||||
./apm_quality_assessment_gencfgs.py
|
||||
|
||||
# Customize APM configurations if needed.
|
||||
APM_CONFIGS=(apm_configs/*.json)
|
||||
|
||||
# Add output path if missing.
|
||||
if [ ! -d ${OUTPUT_PATH} ]; then
|
||||
mkdir ${OUTPUT_PATH}
|
||||
fi
|
||||
|
||||
# Start one process for each "probing signal"-"test data source" pair.
|
||||
chmod +x apm_quality_assessment.py
|
||||
for capture_signal_filepath in "${CAPTURE_SIGNALS[@]}" ; do
|
||||
probing_signal_name="$(basename $capture_signal_filepath)"
|
||||
probing_signal_name="${probing_signal_name%.*}"
|
||||
for test_data_gen_name in "${TEST_DATA_GENERATORS[@]}" ; do
|
||||
LOG_FILE="${OUTPUT_PATH}/apm_qa-${probing_signal_name}-"`
|
||||
`"${test_data_gen_name}.log"
|
||||
echo "Starting ${probing_signal_name} ${test_data_gen_name} "`
|
||||
`"(see ${LOG_FILE})"
|
||||
./apm_quality_assessment.py \
|
||||
--polqa_path ${POLQA_PATH}\
|
||||
--air_db_path ${AECHEN_IR_DATABASE_PATH}\
|
||||
-i ${capture_signal_filepath} \
|
||||
-o ${OUTPUT_PATH} \
|
||||
-t ${test_data_gen_name} \
|
||||
-c "${APM_CONFIGS[@]}" \
|
||||
-e "${SCORES[@]}" > $LOG_FILE 2>&1 &
|
||||
done
|
||||
done
|
||||
|
||||
# Join Python processes running apm_quality_assessment.py.
|
||||
wait
|
||||
|
||||
# Export results.
|
||||
chmod +x ./apm_quality_assessment_export.py
|
||||
./apm_quality_assessment_export.py -o ${OUTPUT_PATH}
|
||||
|
||||
# Show results in the browser.
|
||||
RESULTS_FILE="$(realpath ${OUTPUT_PATH}/results.html)"
|
||||
sensible-browser "file://${RESULTS_FILE}" > /dev/null 2>&1 &
|
||||
@ -1,154 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Shows boxplots of given score for different values of selected
|
||||
parameters. Can be used to compare scores by audioproc_f flag.
|
||||
|
||||
Usage: apm_quality_assessment_boxplot.py -o /path/to/output
|
||||
-v polqa
|
||||
-n /path/to/dir/with/apm_configs
|
||||
-z audioproc_f_arg1 [arg2 ...]
|
||||
|
||||
Arguments --config_names, --render_names, --echo_simulator_names,
|
||||
--test_data_generators, --eval_scores can be used to filter the data
|
||||
used for plotting.
|
||||
"""
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
|
||||
import quality_assessment.data_access as data_access
|
||||
import quality_assessment.collect_data as collect_data
|
||||
|
||||
|
||||
def InstanceArgumentsParser():
|
||||
"""Arguments parser factory.
|
||||
"""
|
||||
parser = collect_data.InstanceArgumentsParser()
|
||||
parser.description = (
|
||||
'Shows boxplot of given score for different values of selected'
|
||||
'parameters. Can be used to compare scores by audioproc_f flag')
|
||||
|
||||
parser.add_argument('-v',
|
||||
'--eval_score',
|
||||
required=True,
|
||||
help=('Score name for constructing boxplots'))
|
||||
|
||||
parser.add_argument(
|
||||
'-n',
|
||||
'--config_dir',
|
||||
required=False,
|
||||
help=('path to the folder with the configuration files'),
|
||||
default='apm_configs')
|
||||
|
||||
parser.add_argument('-z',
|
||||
'--params_to_plot',
|
||||
required=True,
|
||||
nargs='+',
|
||||
help=('audioproc_f parameter values'
|
||||
'by which to group scores (no leading dash)'))
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def FilterScoresByParams(data_frame, filter_params, score_name, config_dir):
|
||||
"""Filters data on the values of one or more parameters.
|
||||
|
||||
Args:
|
||||
data_frame: pandas.DataFrame of all used input data.
|
||||
|
||||
filter_params: each config of the input data is assumed to have
|
||||
exactly one parameter from `filter_params` defined. Every value
|
||||
of the parameters in `filter_params` is a key in the returned
|
||||
dict; the associated value is all cells of the data with that
|
||||
value of the parameter.
|
||||
|
||||
score_name: Name of score which value is boxplotted. Currently cannot do
|
||||
more than one value.
|
||||
|
||||
config_dir: path to dir with APM configs.
|
||||
|
||||
Returns: dictionary, key is a param value, result is all scores for
|
||||
that param value (see `filter_params` for explanation).
|
||||
"""
|
||||
results = collections.defaultdict(dict)
|
||||
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
|
||||
|
||||
for config_name in config_names:
|
||||
config_json = data_access.AudioProcConfigFile.Load(
|
||||
os.path.join(config_dir, config_name + '.json'))
|
||||
data_with_config = data_frame[data_frame.apm_config == config_name]
|
||||
data_cell_scores = data_with_config[data_with_config.eval_score_name ==
|
||||
score_name]
|
||||
|
||||
# Exactly one of `params_to_plot` must match:
|
||||
(matching_param, ) = [
|
||||
x for x in filter_params if '-' + x in config_json
|
||||
]
|
||||
|
||||
# Add scores for every track to the result.
|
||||
for capture_name in data_cell_scores.capture:
|
||||
result_score = float(data_cell_scores[data_cell_scores.capture ==
|
||||
capture_name].score)
|
||||
config_dict = results[config_json['-' + matching_param]]
|
||||
if capture_name not in config_dict:
|
||||
config_dict[capture_name] = {}
|
||||
|
||||
config_dict[capture_name][matching_param] = result_score
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def _FlattenToScoresList(config_param_score_dict):
|
||||
"""Extracts a list of scores from input data structure.
|
||||
|
||||
Args:
|
||||
config_param_score_dict: of the form {'capture_name':
|
||||
{'param_name' : score_value,.. } ..}
|
||||
|
||||
Returns: Plain list of all score value present in input data
|
||||
structure
|
||||
"""
|
||||
result = []
|
||||
for capture_name in config_param_score_dict:
|
||||
result += list(config_param_score_dict[capture_name].values())
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
# Init.
|
||||
# TODO(alessiob): INFO once debugged.
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = InstanceArgumentsParser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get the scores.
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
logging.debug(src_path)
|
||||
scores_data_frame = collect_data.FindScores(src_path, args)
|
||||
|
||||
# Filter the data by `args.params_to_plot`
|
||||
scores_filtered = FilterScoresByParams(scores_data_frame,
|
||||
args.params_to_plot,
|
||||
args.eval_score, args.config_dir)
|
||||
|
||||
data_list = sorted(scores_filtered.items())
|
||||
data_values = [_FlattenToScoresList(x) for (_, x) in data_list]
|
||||
data_labels = [x for (x, _) in data_list]
|
||||
|
||||
_, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
|
||||
axes.boxplot(data_values, labels=data_labels)
|
||||
axes.set_ylabel(args.eval_score)
|
||||
axes.set_xlabel('/'.join(args.params_to_plot))
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,63 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Export the scores computed by the apm_quality_assessment.py script into an
|
||||
HTML file.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import quality_assessment.collect_data as collect_data
|
||||
import quality_assessment.export as export
|
||||
|
||||
|
||||
def _BuildOutputFilename(filename_suffix):
|
||||
"""Builds the filename for the exported file.
|
||||
|
||||
Args:
|
||||
filename_suffix: suffix for the output file name.
|
||||
|
||||
Returns:
|
||||
A string.
|
||||
"""
|
||||
if filename_suffix is None:
|
||||
return 'results.html'
|
||||
return 'results-{}.html'.format(filename_suffix)
|
||||
|
||||
|
||||
def main():
|
||||
# Init.
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG) # TODO(alessio): INFO once debugged.
|
||||
parser = collect_data.InstanceArgumentsParser()
|
||||
parser.add_argument('-f',
|
||||
'--filename_suffix',
|
||||
help=('suffix of the exported file'))
|
||||
parser.description = ('Exports pre-computed APM module quality assessment '
|
||||
'results into HTML tables')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get the scores.
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
logging.debug(src_path)
|
||||
scores_data_frame = collect_data.FindScores(src_path, args)
|
||||
|
||||
# Export.
|
||||
output_filepath = os.path.join(args.output_dir,
|
||||
_BuildOutputFilename(args.filename_suffix))
|
||||
exporter = export.HtmlExport(output_filepath)
|
||||
exporter.Export(scores_data_frame)
|
||||
|
||||
logging.info('output file successfully written in %s', output_filepath)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -1,128 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Generate .json files with which the APM module can be tested using the
|
||||
apm_quality_assessment.py script and audioproc_f as APM simulator.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import quality_assessment.data_access as data_access
|
||||
|
||||
OUTPUT_PATH = os.path.abspath('apm_configs')
|
||||
|
||||
|
||||
def _GenerateDefaultOverridden(config_override):
|
||||
"""Generates one or more APM overriden configurations.
|
||||
|
||||
For each item in config_override, it overrides the default configuration and
|
||||
writes a new APM configuration file.
|
||||
|
||||
The default settings are loaded via "-all_default".
|
||||
Check "src/modules/audio_processing/test/audioproc_float.cc" and search
|
||||
for "if (FLAG_all_default) {".
|
||||
|
||||
For instance, in 55eb6d621489730084927868fed195d3645a9ec9 the default is this:
|
||||
settings.use_aec = rtc::Optional<bool>(true);
|
||||
settings.use_aecm = rtc::Optional<bool>(false);
|
||||
settings.use_agc = rtc::Optional<bool>(true);
|
||||
settings.use_bf = rtc::Optional<bool>(false);
|
||||
settings.use_ed = rtc::Optional<bool>(false);
|
||||
settings.use_hpf = rtc::Optional<bool>(true);
|
||||
settings.use_le = rtc::Optional<bool>(true);
|
||||
settings.use_ns = rtc::Optional<bool>(true);
|
||||
settings.use_ts = rtc::Optional<bool>(true);
|
||||
settings.use_vad = rtc::Optional<bool>(true);
|
||||
|
||||
Args:
|
||||
config_override: dict of APM configuration file names as keys; the values
|
||||
are dict instances encoding the audioproc_f flags.
|
||||
"""
|
||||
for config_filename in config_override:
|
||||
config = config_override[config_filename]
|
||||
config['-all_default'] = None
|
||||
|
||||
config_filepath = os.path.join(
|
||||
OUTPUT_PATH, 'default-{}.json'.format(config_filename))
|
||||
logging.debug('config file <%s> | %s', config_filepath, config)
|
||||
|
||||
data_access.AudioProcConfigFile.Save(config_filepath, config)
|
||||
logging.info('config file created: <%s>', config_filepath)
|
||||
|
||||
|
||||
def _GenerateAllDefaultButOne():
|
||||
"""Disables the flags enabled by default one-by-one.
|
||||
"""
|
||||
config_sets = {
|
||||
'no_AEC': {
|
||||
'-aec': 0,
|
||||
},
|
||||
'no_AGC': {
|
||||
'-agc': 0,
|
||||
},
|
||||
'no_HP_filter': {
|
||||
'-hpf': 0,
|
||||
},
|
||||
'no_level_estimator': {
|
||||
'-le': 0,
|
||||
},
|
||||
'no_noise_suppressor': {
|
||||
'-ns': 0,
|
||||
},
|
||||
'no_transient_suppressor': {
|
||||
'-ts': 0,
|
||||
},
|
||||
'no_vad': {
|
||||
'-vad': 0,
|
||||
},
|
||||
}
|
||||
_GenerateDefaultOverridden(config_sets)
|
||||
|
||||
|
||||
def _GenerateAllDefaultPlusOne():
|
||||
"""Enables the flags disabled by default one-by-one.
|
||||
"""
|
||||
config_sets = {
|
||||
'with_AECM': {
|
||||
'-aec': 0,
|
||||
'-aecm': 1,
|
||||
}, # AEC and AECM are exclusive.
|
||||
'with_AGC_limiter': {
|
||||
'-agc_limiter': 1,
|
||||
},
|
||||
'with_AEC_delay_agnostic': {
|
||||
'-delay_agnostic': 1,
|
||||
},
|
||||
'with_drift_compensation': {
|
||||
'-drift_compensation': 1,
|
||||
},
|
||||
'with_residual_echo_detector': {
|
||||
'-ed': 1,
|
||||
},
|
||||
'with_AEC_extended_filter': {
|
||||
'-extended_filter': 1,
|
||||
},
|
||||
'with_LC': {
|
||||
'-lc': 1,
|
||||
},
|
||||
'with_refined_adaptive_filter': {
|
||||
'-refined_adaptive_filter': 1,
|
||||
},
|
||||
}
|
||||
_GenerateDefaultOverridden(config_sets)
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
_GenerateAllDefaultPlusOne()
|
||||
_GenerateAllDefaultButOne()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -1,189 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Finds the APM configuration that maximizes a provided metric by
|
||||
parsing the output generated apm_quality_assessment.py.
|
||||
"""
|
||||
|
||||
from __future__ import division
|
||||
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
|
||||
import quality_assessment.data_access as data_access
|
||||
import quality_assessment.collect_data as collect_data
|
||||
|
||||
|
||||
def _InstanceArgumentsParser():
|
||||
"""Arguments parser factory. Extends the arguments from 'collect_data'
|
||||
with a few extra for selecting what parameters to optimize for.
|
||||
"""
|
||||
parser = collect_data.InstanceArgumentsParser()
|
||||
parser.description = (
|
||||
'Rudimentary optimization of a function over different parameter'
|
||||
'combinations.')
|
||||
|
||||
parser.add_argument(
|
||||
'-n',
|
||||
'--config_dir',
|
||||
required=False,
|
||||
help=('path to the folder with the configuration files'),
|
||||
default='apm_configs')
|
||||
|
||||
parser.add_argument('-p',
|
||||
'--params',
|
||||
required=True,
|
||||
nargs='+',
|
||||
help=('parameters to parse from the config files in'
|
||||
'config_dir'))
|
||||
|
||||
parser.add_argument(
|
||||
'-z',
|
||||
'--params_not_to_optimize',
|
||||
required=False,
|
||||
nargs='+',
|
||||
default=[],
|
||||
help=('parameters from `params` not to be optimized for'))
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _ConfigurationAndScores(data_frame, params, params_not_to_optimize,
|
||||
config_dir):
|
||||
"""Returns a list of all configurations and scores.
|
||||
|
||||
Args:
|
||||
data_frame: A pandas data frame with the scores and config name
|
||||
returned by _FindScores.
|
||||
params: The parameter names to parse from configs the config
|
||||
directory
|
||||
|
||||
params_not_to_optimize: The parameter names which shouldn't affect
|
||||
the optimal parameter
|
||||
selection. E.g., fixed settings and not
|
||||
tunable parameters.
|
||||
|
||||
config_dir: Path to folder with config files.
|
||||
|
||||
Returns:
|
||||
Dictionary of the form
|
||||
{param_combination: [{params: {param1: value1, ...},
|
||||
scores: {score1: value1, ...}}]}.
|
||||
|
||||
The key `param_combination` runs over all parameter combinations
|
||||
of the parameters in `params` and not in
|
||||
`params_not_to_optimize`. A corresponding value is a list of all
|
||||
param combinations for params in `params_not_to_optimize` and
|
||||
their scores.
|
||||
"""
|
||||
results = collections.defaultdict(list)
|
||||
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
|
||||
score_names = data_frame['eval_score_name'].drop_duplicates(
|
||||
).values.tolist()
|
||||
|
||||
# Normalize the scores
|
||||
normalization_constants = {}
|
||||
for score_name in score_names:
|
||||
scores = data_frame[data_frame.eval_score_name == score_name].score
|
||||
normalization_constants[score_name] = max(scores)
|
||||
|
||||
params_to_optimize = [p for p in params if p not in params_not_to_optimize]
|
||||
param_combination = collections.namedtuple("ParamCombination",
|
||||
params_to_optimize)
|
||||
|
||||
for config_name in config_names:
|
||||
config_json = data_access.AudioProcConfigFile.Load(
|
||||
os.path.join(config_dir, config_name + ".json"))
|
||||
scores = {}
|
||||
data_cell = data_frame[data_frame.apm_config == config_name]
|
||||
for score_name in score_names:
|
||||
data_cell_scores = data_cell[data_cell.eval_score_name ==
|
||||
score_name].score
|
||||
scores[score_name] = sum(data_cell_scores) / len(data_cell_scores)
|
||||
scores[score_name] /= normalization_constants[score_name]
|
||||
|
||||
result = {'scores': scores, 'params': {}}
|
||||
config_optimize_params = {}
|
||||
for param in params:
|
||||
if param in params_to_optimize:
|
||||
config_optimize_params[param] = config_json['-' + param]
|
||||
else:
|
||||
result['params'][param] = config_json['-' + param]
|
||||
|
||||
current_param_combination = param_combination(**config_optimize_params)
|
||||
results[current_param_combination].append(result)
|
||||
return results
|
||||
|
||||
|
||||
def _FindOptimalParameter(configs_and_scores, score_weighting):
|
||||
"""Finds the config producing the maximal score.
|
||||
|
||||
Args:
|
||||
configs_and_scores: structure of the form returned by
|
||||
_ConfigurationAndScores
|
||||
|
||||
score_weighting: a function to weight together all score values of
|
||||
the form [{params: {param1: value1, ...}, scores:
|
||||
{score1: value1, ...}}] into a numeric
|
||||
value
|
||||
Returns:
|
||||
the config that has the largest values of `score_weighting` applied
|
||||
to its scores.
|
||||
"""
|
||||
|
||||
min_score = float('+inf')
|
||||
best_params = None
|
||||
for config in configs_and_scores:
|
||||
scores_and_params = configs_and_scores[config]
|
||||
current_score = score_weighting(scores_and_params)
|
||||
if current_score < min_score:
|
||||
min_score = current_score
|
||||
best_params = config
|
||||
logging.debug("Score: %f", current_score)
|
||||
logging.debug("Config: %s", str(config))
|
||||
return best_params
|
||||
|
||||
|
||||
def _ExampleWeighting(scores_and_configs):
|
||||
"""Example argument to `_FindOptimalParameter`
|
||||
Args:
|
||||
scores_and_configs: a list of configs and scores, in the form
|
||||
described in _FindOptimalParameter
|
||||
Returns:
|
||||
numeric value, the sum of all scores
|
||||
"""
|
||||
res = 0
|
||||
for score_config in scores_and_configs:
|
||||
res += sum(score_config['scores'].values())
|
||||
return res
|
||||
|
||||
|
||||
def main():
|
||||
# Init.
|
||||
# TODO(alessiob): INFO once debugged.
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
parser = _InstanceArgumentsParser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get the scores.
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
logging.debug('Src path <%s>', src_path)
|
||||
scores_data_frame = collect_data.FindScores(src_path, args)
|
||||
all_scores = _ConfigurationAndScores(scores_data_frame, args.params,
|
||||
args.params_not_to_optimize,
|
||||
args.config_dir)
|
||||
|
||||
opt_param = _FindOptimalParameter(all_scores, _ExampleWeighting)
|
||||
|
||||
logging.info('Optimal parameter combination: <%s>', opt_param)
|
||||
logging.info('It\'s score values: <%s>', all_scores[opt_param])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,28 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Unit tests for the apm_quality_assessment module.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import mock
|
||||
|
||||
import apm_quality_assessment
|
||||
|
||||
|
||||
class TestSimulationScript(unittest.TestCase):
|
||||
"""Unit tests for the apm_quality_assessment module.
|
||||
"""
|
||||
|
||||
def testMain(self):
|
||||
# Exit with error code if no arguments are passed.
|
||||
with self.assertRaises(SystemExit) as cm, mock.patch.object(
|
||||
sys, 'argv', ['apm_quality_assessment.py']):
|
||||
apm_quality_assessment.main()
|
||||
self.assertGreater(cm.exception.code, 0)
|
||||
@ -1 +0,0 @@
|
||||
You can use this folder for the output generated by the apm_quality_assessment scripts.
|
||||
@ -1,7 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
@ -1,296 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Extraction of annotations from audio files.
|
||||
"""
|
||||
|
||||
from __future__ import division
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import struct
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import external_vad
|
||||
from . import exceptions
|
||||
from . import signal_processing
|
||||
|
||||
|
||||
class AudioAnnotationsExtractor(object):
|
||||
"""Extracts annotations from audio files.
|
||||
"""
|
||||
|
||||
class VadType(object):
|
||||
ENERGY_THRESHOLD = 1 # TODO(alessiob): Consider switching to P56 standard.
|
||||
WEBRTC_COMMON_AUDIO = 2 # common_audio/vad/include/vad.h
|
||||
WEBRTC_APM = 4 # modules/audio_processing/vad/vad.h
|
||||
|
||||
def __init__(self, value):
|
||||
if (not isinstance(value, int)) or not 0 <= value <= 7:
|
||||
raise exceptions.InitializationException('Invalid vad type: ' +
|
||||
value)
|
||||
self._value = value
|
||||
|
||||
def Contains(self, vad_type):
|
||||
return self._value | vad_type == self._value
|
||||
|
||||
def __str__(self):
|
||||
vads = []
|
||||
if self.Contains(self.ENERGY_THRESHOLD):
|
||||
vads.append("energy")
|
||||
if self.Contains(self.WEBRTC_COMMON_AUDIO):
|
||||
vads.append("common_audio")
|
||||
if self.Contains(self.WEBRTC_APM):
|
||||
vads.append("apm")
|
||||
return "VadType({})".format(", ".join(vads))
|
||||
|
||||
_OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz'
|
||||
|
||||
# Level estimation params.
|
||||
_ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0)
|
||||
_LEVEL_FRAME_SIZE_MS = 1.0
|
||||
# The time constants in ms indicate the time it takes for the level estimate
|
||||
# to go down/up by 1 db if the signal is zero.
|
||||
_LEVEL_ATTACK_MS = 5.0
|
||||
_LEVEL_DECAY_MS = 20.0
|
||||
|
||||
# VAD params.
|
||||
_VAD_THRESHOLD = 1
|
||||
_VAD_WEBRTC_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
os.pardir, os.pardir)
|
||||
_VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
|
||||
|
||||
_VAD_WEBRTC_APM_PATH = os.path.join(_VAD_WEBRTC_PATH, 'apm_vad')
|
||||
|
||||
def __init__(self, vad_type, external_vads=None):
|
||||
self._signal = None
|
||||
self._level = None
|
||||
self._level_frame_size = None
|
||||
self._common_audio_vad = None
|
||||
self._energy_vad = None
|
||||
self._apm_vad_probs = None
|
||||
self._apm_vad_rms = None
|
||||
self._vad_frame_size = None
|
||||
self._vad_frame_size_ms = None
|
||||
self._c_attack = None
|
||||
self._c_decay = None
|
||||
|
||||
self._vad_type = self.VadType(vad_type)
|
||||
logging.info('VADs used for annotations: ' + str(self._vad_type))
|
||||
|
||||
if external_vads is None:
|
||||
external_vads = {}
|
||||
self._external_vads = external_vads
|
||||
|
||||
assert len(self._external_vads) == len(external_vads), (
|
||||
'The external VAD names must be unique.')
|
||||
for vad in external_vads.values():
|
||||
if not isinstance(vad, external_vad.ExternalVad):
|
||||
raise exceptions.InitializationException('Invalid vad type: ' +
|
||||
str(type(vad)))
|
||||
logging.info('External VAD used for annotation: ' + str(vad.name))
|
||||
|
||||
assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \
|
||||
self._VAD_WEBRTC_COMMON_AUDIO_PATH
|
||||
assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \
|
||||
self._VAD_WEBRTC_APM_PATH
|
||||
|
||||
@classmethod
|
||||
def GetOutputFileNameTemplate(cls):
|
||||
return cls._OUTPUT_FILENAME_TEMPLATE
|
||||
|
||||
def GetLevel(self):
|
||||
return self._level
|
||||
|
||||
def GetLevelFrameSize(self):
|
||||
return self._level_frame_size
|
||||
|
||||
@classmethod
|
||||
def GetLevelFrameSizeMs(cls):
|
||||
return cls._LEVEL_FRAME_SIZE_MS
|
||||
|
||||
def GetVadOutput(self, vad_type):
|
||||
if vad_type == self.VadType.ENERGY_THRESHOLD:
|
||||
return self._energy_vad
|
||||
elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO:
|
||||
return self._common_audio_vad
|
||||
elif vad_type == self.VadType.WEBRTC_APM:
|
||||
return (self._apm_vad_probs, self._apm_vad_rms)
|
||||
else:
|
||||
raise exceptions.InitializationException('Invalid vad type: ' +
|
||||
vad_type)
|
||||
|
||||
def GetVadFrameSize(self):
|
||||
return self._vad_frame_size
|
||||
|
||||
def GetVadFrameSizeMs(self):
|
||||
return self._vad_frame_size_ms
|
||||
|
||||
def Extract(self, filepath):
|
||||
# Load signal.
|
||||
self._signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
filepath)
|
||||
if self._signal.channels != 1:
|
||||
raise NotImplementedError(
|
||||
'Multiple-channel annotations not implemented')
|
||||
|
||||
# Level estimation params.
|
||||
self._level_frame_size = int(self._signal.frame_rate / 1000 *
|
||||
(self._LEVEL_FRAME_SIZE_MS))
|
||||
self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else (
|
||||
self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS /
|
||||
self._LEVEL_ATTACK_MS))
|
||||
self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else (
|
||||
self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS /
|
||||
self._LEVEL_DECAY_MS))
|
||||
|
||||
# Compute level.
|
||||
self._LevelEstimation()
|
||||
|
||||
# Ideal VAD output, it requires clean speech with high SNR as input.
|
||||
if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD):
|
||||
# Naive VAD based on level thresholding.
|
||||
vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
|
||||
self._energy_vad = np.uint8(self._level > vad_threshold)
|
||||
self._vad_frame_size = self._level_frame_size
|
||||
self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS
|
||||
if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO):
|
||||
# WebRTC common_audio/ VAD.
|
||||
self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate)
|
||||
if self._vad_type.Contains(self.VadType.WEBRTC_APM):
|
||||
# WebRTC modules/audio_processing/ VAD.
|
||||
self._RunWebRtcApmVad(filepath)
|
||||
for extvad_name in self._external_vads:
|
||||
self._external_vads[extvad_name].Run(filepath)
|
||||
|
||||
def Save(self, output_path, annotation_name=""):
|
||||
ext_kwargs = {
|
||||
'extvad_conf-' + ext_vad:
|
||||
self._external_vads[ext_vad].GetVadOutput()
|
||||
for ext_vad in self._external_vads
|
||||
}
|
||||
np.savez_compressed(file=os.path.join(
|
||||
output_path,
|
||||
self.GetOutputFileNameTemplate().format(annotation_name)),
|
||||
level=self._level,
|
||||
level_frame_size=self._level_frame_size,
|
||||
level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS,
|
||||
vad_output=self._common_audio_vad,
|
||||
vad_energy_output=self._energy_vad,
|
||||
vad_frame_size=self._vad_frame_size,
|
||||
vad_frame_size_ms=self._vad_frame_size_ms,
|
||||
vad_probs=self._apm_vad_probs,
|
||||
vad_rms=self._apm_vad_rms,
|
||||
**ext_kwargs)
|
||||
|
||||
def _LevelEstimation(self):
|
||||
# Read samples.
|
||||
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||
self._signal).astype(np.float32) / 32768.0
|
||||
num_frames = len(samples) // self._level_frame_size
|
||||
num_samples = num_frames * self._level_frame_size
|
||||
|
||||
# Envelope.
|
||||
self._level = np.max(np.reshape(np.abs(samples[:num_samples]),
|
||||
(num_frames, self._level_frame_size)),
|
||||
axis=1)
|
||||
assert len(self._level) == num_frames
|
||||
|
||||
# Envelope smoothing.
|
||||
smooth = lambda curr, prev, k: (1 - k) * curr + k * prev
|
||||
self._level[0] = smooth(self._level[0], 0.0, self._c_attack)
|
||||
for i in range(1, num_frames):
|
||||
self._level[i] = smooth(
|
||||
self._level[i], self._level[i - 1], self._c_attack if
|
||||
(self._level[i] > self._level[i - 1]) else self._c_decay)
|
||||
|
||||
def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate):
|
||||
self._common_audio_vad = None
|
||||
self._vad_frame_size = None
|
||||
|
||||
# Create temporary output path.
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
output_file_path = os.path.join(
|
||||
tmp_path,
|
||||
os.path.split(wav_file_path)[1] + '_vad.tmp')
|
||||
|
||||
# Call WebRTC VAD.
|
||||
try:
|
||||
subprocess.call([
|
||||
self._VAD_WEBRTC_COMMON_AUDIO_PATH, '-i', wav_file_path, '-o',
|
||||
output_file_path
|
||||
],
|
||||
cwd=self._VAD_WEBRTC_PATH)
|
||||
|
||||
# Read bytes.
|
||||
with open(output_file_path, 'rb') as f:
|
||||
raw_data = f.read()
|
||||
|
||||
# Parse side information.
|
||||
self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0]
|
||||
self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000
|
||||
assert self._vad_frame_size_ms in [10, 20, 30]
|
||||
extra_bits = struct.unpack('B', raw_data[-1])[0]
|
||||
assert 0 <= extra_bits <= 8
|
||||
|
||||
# Init VAD vector.
|
||||
num_bytes = len(raw_data)
|
||||
num_frames = 8 * (num_bytes -
|
||||
2) - extra_bits # 8 frames for each byte.
|
||||
self._common_audio_vad = np.zeros(num_frames, np.uint8)
|
||||
|
||||
# Read VAD decisions.
|
||||
for i, byte in enumerate(raw_data[1:-1]):
|
||||
byte = struct.unpack('B', byte)[0]
|
||||
for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)):
|
||||
self._common_audio_vad[i * 8 + j] = int(byte & 1)
|
||||
byte = byte >> 1
|
||||
except Exception as e:
|
||||
logging.error('Error while running the WebRTC VAD (' + e.message +
|
||||
')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
|
||||
def _RunWebRtcApmVad(self, wav_file_path):
|
||||
# Create temporary output path.
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
output_file_path_probs = os.path.join(
|
||||
tmp_path,
|
||||
os.path.split(wav_file_path)[1] + '_vad_probs.tmp')
|
||||
output_file_path_rms = os.path.join(
|
||||
tmp_path,
|
||||
os.path.split(wav_file_path)[1] + '_vad_rms.tmp')
|
||||
|
||||
# Call WebRTC VAD.
|
||||
try:
|
||||
subprocess.call([
|
||||
self._VAD_WEBRTC_APM_PATH, '-i', wav_file_path, '-o_probs',
|
||||
output_file_path_probs, '-o_rms', output_file_path_rms
|
||||
],
|
||||
cwd=self._VAD_WEBRTC_PATH)
|
||||
|
||||
# Parse annotations.
|
||||
self._apm_vad_probs = np.fromfile(output_file_path_probs,
|
||||
np.double)
|
||||
self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double)
|
||||
assert len(self._apm_vad_rms) == len(self._apm_vad_probs)
|
||||
|
||||
except Exception as e:
|
||||
logging.error('Error while running the WebRTC APM VAD (' +
|
||||
e.message + ')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
@ -1,160 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Unit tests for the annotations module.
|
||||
"""
|
||||
|
||||
from __future__ import division
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import annotations
|
||||
from . import external_vad
|
||||
from . import input_signal_creator
|
||||
from . import signal_processing
|
||||
|
||||
|
||||
class TestAnnotationsExtraction(unittest.TestCase):
|
||||
"""Unit tests for the annotations module.
|
||||
"""
|
||||
|
||||
_CLEAN_TMP_OUTPUT = True
|
||||
_DEBUG_PLOT_VAD = False
|
||||
_VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType
|
||||
_ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD
|
||||
| _VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO
|
||||
| _VAD_TYPE_CLASS.WEBRTC_APM)
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary folder."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav')
|
||||
pure_tone, _ = input_signal_creator.InputSignalCreator.Create(
|
||||
'pure_tone', [440, 1000])
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._wav_file_path, pure_tone)
|
||||
self._sample_rate = pure_tone.frame_rate
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folder."""
|
||||
if self._CLEAN_TMP_OUTPUT:
|
||||
shutil.rmtree(self._tmp_path)
|
||||
else:
|
||||
logging.warning(self.id() + ' did not clean the temporary path ' +
|
||||
(self._tmp_path))
|
||||
|
||||
def testFrameSizes(self):
|
||||
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
|
||||
e.Extract(self._wav_file_path)
|
||||
samples_to_ms = lambda n, sr: 1000 * n // sr
|
||||
self.assertEqual(
|
||||
samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
|
||||
e.GetLevelFrameSizeMs())
|
||||
self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
|
||||
e.GetVadFrameSizeMs())
|
||||
|
||||
def testVoiceActivityDetectors(self):
|
||||
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
|
||||
vad_type = self._VAD_TYPE_CLASS(vad_type_value)
|
||||
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value)
|
||||
e.Extract(self._wav_file_path)
|
||||
if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
vad_output = e.GetVadOutput(
|
||||
self._VAD_TYPE_CLASS.ENERGY_THRESHOLD)
|
||||
self.assertGreater(len(vad_output), 0)
|
||||
self.assertGreaterEqual(
|
||||
float(np.sum(vad_output)) / len(vad_output), 0.95)
|
||||
|
||||
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
vad_output = e.GetVadOutput(
|
||||
self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO)
|
||||
self.assertGreater(len(vad_output), 0)
|
||||
self.assertGreaterEqual(
|
||||
float(np.sum(vad_output)) / len(vad_output), 0.95)
|
||||
|
||||
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM):
|
||||
# pylint: disable=unpacking-non-sequence
|
||||
(vad_probs,
|
||||
vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)
|
||||
self.assertGreater(len(vad_probs), 0)
|
||||
self.assertGreater(len(vad_rms), 0)
|
||||
self.assertGreaterEqual(
|
||||
float(np.sum(vad_probs)) / len(vad_probs), 0.5)
|
||||
self.assertGreaterEqual(
|
||||
float(np.sum(vad_rms)) / len(vad_rms), 20000)
|
||||
|
||||
if self._DEBUG_PLOT_VAD:
|
||||
frame_times_s = lambda num_frames, frame_size_ms: np.arange(
|
||||
num_frames).astype(np.float32) * frame_size_ms / 1000.0
|
||||
level = e.GetLevel()
|
||||
t_level = frame_times_s(num_frames=len(level),
|
||||
frame_size_ms=e.GetLevelFrameSizeMs())
|
||||
t_vad = frame_times_s(num_frames=len(vad_output),
|
||||
frame_size_ms=e.GetVadFrameSizeMs())
|
||||
import matplotlib.pyplot as plt
|
||||
plt.figure()
|
||||
plt.hold(True)
|
||||
plt.plot(t_level, level)
|
||||
plt.plot(t_vad, vad_output * np.max(level), '.')
|
||||
plt.show()
|
||||
|
||||
def testSaveLoad(self):
|
||||
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
|
||||
e.Extract(self._wav_file_path)
|
||||
e.Save(self._tmp_path, "fake-annotation")
|
||||
|
||||
data = np.load(
|
||||
os.path.join(
|
||||
self._tmp_path,
|
||||
e.GetOutputFileNameTemplate().format("fake-annotation")))
|
||||
np.testing.assert_array_equal(e.GetLevel(), data['level'])
|
||||
self.assertEqual(np.float32, data['level'].dtype)
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD),
|
||||
data['vad_energy_output'])
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO),
|
||||
data['vad_output'])
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0],
|
||||
data['vad_probs'])
|
||||
np.testing.assert_array_equal(
|
||||
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1],
|
||||
data['vad_rms'])
|
||||
self.assertEqual(np.uint8, data['vad_energy_output'].dtype)
|
||||
self.assertEqual(np.float64, data['vad_probs'].dtype)
|
||||
self.assertEqual(np.float64, data['vad_rms'].dtype)
|
||||
|
||||
def testEmptyExternalShouldNotCrash(self):
|
||||
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
|
||||
annotations.AudioAnnotationsExtractor(vad_type_value, {})
|
||||
|
||||
def testFakeExternalSaveLoad(self):
|
||||
def FakeExternalFactory():
|
||||
return external_vad.ExternalVad(
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)),
|
||||
'fake_external_vad.py'), 'fake')
|
||||
|
||||
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
|
||||
e = annotations.AudioAnnotationsExtractor(
|
||||
vad_type_value, {'fake': FakeExternalFactory()})
|
||||
e.Extract(self._wav_file_path)
|
||||
e.Save(self._tmp_path, annotation_name="fake-annotation")
|
||||
data = np.load(
|
||||
os.path.join(
|
||||
self._tmp_path,
|
||||
e.GetOutputFileNameTemplate().format("fake-annotation")))
|
||||
self.assertEqual(np.float32, data['extvad_conf-fake'].dtype)
|
||||
np.testing.assert_almost_equal(np.arange(100, dtype=np.float32),
|
||||
data['extvad_conf-fake'])
|
||||
@ -1 +0,0 @@
|
||||
{"-all_default": null}
|
||||
@ -1,96 +0,0 @@
|
||||
// Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
//
|
||||
// Use of this source code is governed by a BSD-style license
|
||||
// that can be found in the LICENSE file in the root of the source
|
||||
// tree. An additional intellectual property rights grant can be found
|
||||
// in the file PATENTS. All contributing project authors may
|
||||
// be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
#include <array>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/flags/parse.h"
|
||||
#include "common_audio/wav_file.h"
|
||||
#include "modules/audio_processing/vad/voice_activity_detector.h"
|
||||
#include "rtc_base/logging.h"
|
||||
|
||||
ABSL_FLAG(std::string, i, "", "Input wav file");
|
||||
ABSL_FLAG(std::string, o_probs, "", "VAD probabilities output file");
|
||||
ABSL_FLAG(std::string, o_rms, "", "VAD output file");
|
||||
|
||||
namespace webrtc {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
constexpr uint8_t kAudioFrameLengthMilliseconds = 10;
|
||||
constexpr int kMaxSampleRate = 48000;
|
||||
constexpr size_t kMaxFrameLen =
|
||||
kAudioFrameLengthMilliseconds * kMaxSampleRate / 1000;
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
absl::ParseCommandLine(argc, argv);
|
||||
const std::string input_file = absl::GetFlag(FLAGS_i);
|
||||
const std::string output_probs_file = absl::GetFlag(FLAGS_o_probs);
|
||||
const std::string output_file = absl::GetFlag(FLAGS_o_rms);
|
||||
// Open wav input file and check properties.
|
||||
WavReader wav_reader(input_file);
|
||||
if (wav_reader.num_channels() != 1) {
|
||||
RTC_LOG(LS_ERROR) << "Only mono wav files supported";
|
||||
return 1;
|
||||
}
|
||||
if (wav_reader.sample_rate() > kMaxSampleRate) {
|
||||
RTC_LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate
|
||||
<< ")";
|
||||
return 1;
|
||||
}
|
||||
const size_t audio_frame_len = rtc::CheckedDivExact(
|
||||
kAudioFrameLengthMilliseconds * wav_reader.sample_rate(), 1000);
|
||||
if (audio_frame_len > kMaxFrameLen) {
|
||||
RTC_LOG(LS_ERROR) << "The frame size and/or the sample rate are too large.";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Create output file and write header.
|
||||
std::ofstream out_probs_file(output_probs_file, std::ofstream::binary);
|
||||
std::ofstream out_rms_file(output_file, std::ofstream::binary);
|
||||
|
||||
// Run VAD and write decisions.
|
||||
VoiceActivityDetector vad;
|
||||
std::array<int16_t, kMaxFrameLen> samples;
|
||||
|
||||
while (true) {
|
||||
// Process frame.
|
||||
const auto read_samples =
|
||||
wav_reader.ReadSamples(audio_frame_len, samples.data());
|
||||
if (read_samples < audio_frame_len) {
|
||||
break;
|
||||
}
|
||||
vad.ProcessChunk(samples.data(), audio_frame_len, wav_reader.sample_rate());
|
||||
// Write output.
|
||||
auto probs = vad.chunkwise_voice_probabilities();
|
||||
auto rms = vad.chunkwise_rms();
|
||||
RTC_CHECK_EQ(probs.size(), rms.size());
|
||||
RTC_CHECK_EQ(sizeof(double), 8);
|
||||
|
||||
for (const auto& p : probs) {
|
||||
out_probs_file.write(reinterpret_cast<const char*>(&p), 8);
|
||||
}
|
||||
for (const auto& r : rms) {
|
||||
out_rms_file.write(reinterpret_cast<const char*>(&r), 8);
|
||||
}
|
||||
}
|
||||
|
||||
out_probs_file.close();
|
||||
out_rms_file.close();
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
return webrtc::test::main(argc, argv);
|
||||
}
|
||||
@ -1,100 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Class implementing a wrapper for APM simulators.
|
||||
"""
|
||||
|
||||
import cProfile
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
from . import data_access
|
||||
from . import exceptions
|
||||
|
||||
|
||||
class AudioProcWrapper(object):
|
||||
"""Wrapper for APM simulators.
|
||||
"""
|
||||
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH = os.path.abspath(
|
||||
os.path.join(os.pardir, 'audioproc_f'))
|
||||
OUTPUT_FILENAME = 'output.wav'
|
||||
|
||||
def __init__(self, simulator_bin_path):
|
||||
"""Ctor.
|
||||
|
||||
Args:
|
||||
simulator_bin_path: path to the APM simulator binary.
|
||||
"""
|
||||
self._simulator_bin_path = simulator_bin_path
|
||||
self._config = None
|
||||
self._output_signal_filepath = None
|
||||
|
||||
# Profiler instance to measure running time.
|
||||
self._profiler = cProfile.Profile()
|
||||
|
||||
@property
|
||||
def output_filepath(self):
|
||||
return self._output_signal_filepath
|
||||
|
||||
def Run(self,
|
||||
config_filepath,
|
||||
capture_input_filepath,
|
||||
output_path,
|
||||
render_input_filepath=None):
|
||||
"""Runs APM simulator.
|
||||
|
||||
Args:
|
||||
config_filepath: path to the configuration file specifying the arguments
|
||||
for the APM simulator.
|
||||
capture_input_filepath: path to the capture audio track input file (aka
|
||||
forward or near-end).
|
||||
output_path: path of the audio track output file.
|
||||
render_input_filepath: path to the render audio track input file (aka
|
||||
reverse or far-end).
|
||||
"""
|
||||
# Init.
|
||||
self._output_signal_filepath = os.path.join(output_path,
|
||||
self.OUTPUT_FILENAME)
|
||||
profiling_stats_filepath = os.path.join(output_path, 'profiling.stats')
|
||||
|
||||
# Skip if the output has already been generated.
|
||||
if os.path.exists(self._output_signal_filepath) and os.path.exists(
|
||||
profiling_stats_filepath):
|
||||
return
|
||||
|
||||
# Load configuration.
|
||||
self._config = data_access.AudioProcConfigFile.Load(config_filepath)
|
||||
|
||||
# Set remaining parameters.
|
||||
if not os.path.exists(capture_input_filepath):
|
||||
raise exceptions.FileNotFoundError(
|
||||
'cannot find capture input file')
|
||||
self._config['-i'] = capture_input_filepath
|
||||
self._config['-o'] = self._output_signal_filepath
|
||||
if render_input_filepath is not None:
|
||||
if not os.path.exists(render_input_filepath):
|
||||
raise exceptions.FileNotFoundError(
|
||||
'cannot find render input file')
|
||||
self._config['-ri'] = render_input_filepath
|
||||
|
||||
# Build arguments list.
|
||||
args = [self._simulator_bin_path]
|
||||
for param_name in self._config:
|
||||
args.append(param_name)
|
||||
if self._config[param_name] is not None:
|
||||
args.append(str(self._config[param_name]))
|
||||
logging.debug(' '.join(args))
|
||||
|
||||
# Run.
|
||||
self._profiler.enable()
|
||||
subprocess.call(args)
|
||||
self._profiler.disable()
|
||||
|
||||
# Save profiling stats.
|
||||
self._profiler.dump_stats(profiling_stats_filepath)
|
||||
@ -1,243 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Imports a filtered subset of the scores and configurations computed
|
||||
by apm_quality_assessment.py into a pandas data frame.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package pandas')
|
||||
sys.exit(1)
|
||||
|
||||
from . import data_access as data_access
|
||||
from . import simulation as sim
|
||||
|
||||
# Compiled regular expressions used to extract score descriptors.
|
||||
RE_CONFIG_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixApmConfig() +
|
||||
r'(.+)')
|
||||
RE_CAPTURE_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixCapture() +
|
||||
r'(.+)')
|
||||
RE_RENDER_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixRender() + r'(.+)')
|
||||
RE_ECHO_SIM_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixEchoSimulator() +
|
||||
r'(.+)')
|
||||
RE_TEST_DATA_GEN_NAME = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + r'(.+)')
|
||||
RE_TEST_DATA_GEN_PARAMS = re.compile(
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + r'(.+)')
|
||||
RE_SCORE_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixScore() +
|
||||
r'(.+)(\..+)')
|
||||
|
||||
|
||||
def InstanceArgumentsParser():
|
||||
"""Arguments parser factory.
|
||||
"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description=('Override this description in a user script by changing'
|
||||
' `parser.description` of the returned parser.'))
|
||||
|
||||
parser.add_argument('-o',
|
||||
'--output_dir',
|
||||
required=True,
|
||||
help=('the same base path used with the '
|
||||
'apm_quality_assessment tool'))
|
||||
|
||||
parser.add_argument(
|
||||
'-c',
|
||||
'--config_names',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the APM configuration'
|
||||
' names'))
|
||||
|
||||
parser.add_argument(
|
||||
'-i',
|
||||
'--capture_names',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the capture signal '
|
||||
'names'))
|
||||
|
||||
parser.add_argument('-r',
|
||||
'--render_names',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the render signal '
|
||||
'names'))
|
||||
|
||||
parser.add_argument(
|
||||
'-e',
|
||||
'--echo_simulator_names',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the echo simulator '
|
||||
'names'))
|
||||
|
||||
parser.add_argument('-t',
|
||||
'--test_data_generators',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the test data '
|
||||
'generator names'))
|
||||
|
||||
parser.add_argument(
|
||||
'-s',
|
||||
'--eval_scores',
|
||||
type=re.compile,
|
||||
help=('regular expression to filter the evaluation score '
|
||||
'names'))
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _GetScoreDescriptors(score_filepath):
|
||||
"""Extracts a score descriptor from the given score file path.
|
||||
|
||||
Args:
|
||||
score_filepath: path to the score file.
|
||||
|
||||
Returns:
|
||||
A tuple of strings (APM configuration name, capture audio track name,
|
||||
render audio track name, echo simulator name, test data generator name,
|
||||
test data generator parameters as string, evaluation score name).
|
||||
"""
|
||||
fields = score_filepath.split(os.sep)[-7:]
|
||||
extract_name = lambda index, reg_expr: (reg_expr.match(fields[index]).
|
||||
groups(0)[0])
|
||||
return (
|
||||
extract_name(0, RE_CONFIG_NAME),
|
||||
extract_name(1, RE_CAPTURE_NAME),
|
||||
extract_name(2, RE_RENDER_NAME),
|
||||
extract_name(3, RE_ECHO_SIM_NAME),
|
||||
extract_name(4, RE_TEST_DATA_GEN_NAME),
|
||||
extract_name(5, RE_TEST_DATA_GEN_PARAMS),
|
||||
extract_name(6, RE_SCORE_NAME),
|
||||
)
|
||||
|
||||
|
||||
def _ExcludeScore(config_name, capture_name, render_name, echo_simulator_name,
|
||||
test_data_gen_name, score_name, args):
|
||||
"""Decides whether excluding a score.
|
||||
|
||||
A set of optional regular expressions in args is used to determine if the
|
||||
score should be excluded (depending on its |*_name| descriptors).
|
||||
|
||||
Args:
|
||||
config_name: APM configuration name.
|
||||
capture_name: capture audio track name.
|
||||
render_name: render audio track name.
|
||||
echo_simulator_name: echo simulator name.
|
||||
test_data_gen_name: test data generator name.
|
||||
score_name: evaluation score name.
|
||||
args: parsed arguments.
|
||||
|
||||
Returns:
|
||||
A boolean.
|
||||
"""
|
||||
value_regexpr_pairs = [
|
||||
(config_name, args.config_names),
|
||||
(capture_name, args.capture_names),
|
||||
(render_name, args.render_names),
|
||||
(echo_simulator_name, args.echo_simulator_names),
|
||||
(test_data_gen_name, args.test_data_generators),
|
||||
(score_name, args.eval_scores),
|
||||
]
|
||||
|
||||
# Score accepted if each value matches the corresponding regular expression.
|
||||
for value, regexpr in value_regexpr_pairs:
|
||||
if regexpr is None:
|
||||
continue
|
||||
if not regexpr.match(value):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def FindScores(src_path, args):
|
||||
"""Given a search path, find scores and return a DataFrame object.
|
||||
|
||||
Args:
|
||||
src_path: Search path pattern.
|
||||
args: parsed arguments.
|
||||
|
||||
Returns:
|
||||
A DataFrame object.
|
||||
"""
|
||||
# Get scores.
|
||||
scores = []
|
||||
for score_filepath in glob.iglob(src_path):
|
||||
# Extract score descriptor fields from the path.
|
||||
(config_name, capture_name, render_name, echo_simulator_name,
|
||||
test_data_gen_name, test_data_gen_params,
|
||||
score_name) = _GetScoreDescriptors(score_filepath)
|
||||
|
||||
# Ignore the score if required.
|
||||
if _ExcludeScore(config_name, capture_name, render_name,
|
||||
echo_simulator_name, test_data_gen_name, score_name,
|
||||
args):
|
||||
logging.info('ignored score: %s %s %s %s %s %s', config_name,
|
||||
capture_name, render_name, echo_simulator_name,
|
||||
test_data_gen_name, score_name)
|
||||
continue
|
||||
|
||||
# Read metadata and score.
|
||||
metadata = data_access.Metadata.LoadAudioTestDataPaths(
|
||||
os.path.split(score_filepath)[0])
|
||||
score = data_access.ScoreFile.Load(score_filepath)
|
||||
|
||||
# Add a score with its descriptor fields.
|
||||
scores.append((
|
||||
metadata['clean_capture_input_filepath'],
|
||||
metadata['echo_free_capture_filepath'],
|
||||
metadata['echo_filepath'],
|
||||
metadata['render_filepath'],
|
||||
metadata['capture_filepath'],
|
||||
metadata['apm_output_filepath'],
|
||||
metadata['apm_reference_filepath'],
|
||||
config_name,
|
||||
capture_name,
|
||||
render_name,
|
||||
echo_simulator_name,
|
||||
test_data_gen_name,
|
||||
test_data_gen_params,
|
||||
score_name,
|
||||
score,
|
||||
))
|
||||
|
||||
return pd.DataFrame(data=scores,
|
||||
columns=(
|
||||
'clean_capture_input_filepath',
|
||||
'echo_free_capture_filepath',
|
||||
'echo_filepath',
|
||||
'render_filepath',
|
||||
'capture_filepath',
|
||||
'apm_output_filepath',
|
||||
'apm_reference_filepath',
|
||||
'apm_config',
|
||||
'capture',
|
||||
'render',
|
||||
'echo_simulator',
|
||||
'test_data_gen',
|
||||
'test_data_gen_params',
|
||||
'eval_score_name',
|
||||
'score',
|
||||
))
|
||||
|
||||
|
||||
def ConstructSrcPath(args):
|
||||
return os.path.join(
|
||||
args.output_dir,
|
||||
sim.ApmModuleSimulator.GetPrefixApmConfig() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixCapture() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixRender() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixEchoSimulator() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + '*',
|
||||
sim.ApmModuleSimulator.GetPrefixScore() + '*')
|
||||
@ -1,154 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Data access utility functions and classes.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
|
||||
def MakeDirectory(path):
|
||||
"""Makes a directory recursively without rising exceptions if existing.
|
||||
|
||||
Args:
|
||||
path: path to the directory to be created.
|
||||
"""
|
||||
if os.path.exists(path):
|
||||
return
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
class Metadata(object):
|
||||
"""Data access class to save and load metadata.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
_GENERIC_METADATA_SUFFIX = '.mdata'
|
||||
_AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json'
|
||||
|
||||
@classmethod
|
||||
def LoadFileMetadata(cls, filepath):
|
||||
"""Loads generic metadata linked to a file.
|
||||
|
||||
Args:
|
||||
filepath: path to the metadata file to read.
|
||||
|
||||
Returns:
|
||||
A dict.
|
||||
"""
|
||||
with open(filepath + cls._GENERIC_METADATA_SUFFIX) as f:
|
||||
return json.load(f)
|
||||
|
||||
@classmethod
|
||||
def SaveFileMetadata(cls, filepath, metadata):
|
||||
"""Saves generic metadata linked to a file.
|
||||
|
||||
Args:
|
||||
filepath: path to the metadata file to write.
|
||||
metadata: a dict.
|
||||
"""
|
||||
with open(filepath + cls._GENERIC_METADATA_SUFFIX, 'w') as f:
|
||||
json.dump(metadata, f)
|
||||
|
||||
@classmethod
|
||||
def LoadAudioTestDataPaths(cls, metadata_path):
|
||||
"""Loads the input and the reference audio track paths.
|
||||
|
||||
Args:
|
||||
metadata_path: path to the directory containing the metadata file.
|
||||
|
||||
Returns:
|
||||
Tuple with the paths to the input and output audio tracks.
|
||||
"""
|
||||
metadata_filepath = os.path.join(metadata_path,
|
||||
cls._AUDIO_TEST_DATA_FILENAME)
|
||||
with open(metadata_filepath) as f:
|
||||
return json.load(f)
|
||||
|
||||
@classmethod
|
||||
def SaveAudioTestDataPaths(cls, output_path, **filepaths):
|
||||
"""Saves the input and the reference audio track paths.
|
||||
|
||||
Args:
|
||||
output_path: path to the directory containing the metadata file.
|
||||
|
||||
Keyword Args:
|
||||
filepaths: collection of audio track file paths to save.
|
||||
"""
|
||||
output_filepath = os.path.join(output_path,
|
||||
cls._AUDIO_TEST_DATA_FILENAME)
|
||||
with open(output_filepath, 'w') as f:
|
||||
json.dump(filepaths, f)
|
||||
|
||||
|
||||
class AudioProcConfigFile(object):
|
||||
"""Data access to load/save APM simulator argument lists.
|
||||
|
||||
The arguments stored in the config files are used to control the APM flags.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def Load(cls, filepath):
|
||||
"""Loads a configuration file for an APM simulator.
|
||||
|
||||
Args:
|
||||
filepath: path to the configuration file.
|
||||
|
||||
Returns:
|
||||
A dict containing the configuration.
|
||||
"""
|
||||
with open(filepath) as f:
|
||||
return json.load(f)
|
||||
|
||||
@classmethod
|
||||
def Save(cls, filepath, config):
|
||||
"""Saves a configuration file for an APM simulator.
|
||||
|
||||
Args:
|
||||
filepath: path to the configuration file.
|
||||
config: a dict containing the configuration.
|
||||
"""
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(config, f)
|
||||
|
||||
|
||||
class ScoreFile(object):
|
||||
"""Data access class to save and load float scalar scores.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def Load(cls, filepath):
|
||||
"""Loads a score from file.
|
||||
|
||||
Args:
|
||||
filepath: path to the score file.
|
||||
|
||||
Returns:
|
||||
A float encoding the score.
|
||||
"""
|
||||
with open(filepath) as f:
|
||||
return float(f.readline().strip())
|
||||
|
||||
@classmethod
|
||||
def Save(cls, filepath, score):
|
||||
"""Saves a score into a file.
|
||||
|
||||
Args:
|
||||
filepath: path to the score file.
|
||||
score: float encoding the score.
|
||||
"""
|
||||
with open(filepath, 'w') as f:
|
||||
f.write('{0:f}\n'.format(score))
|
||||
@ -1,136 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Echo path simulation module.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
from . import signal_processing
|
||||
|
||||
|
||||
class EchoPathSimulator(object):
|
||||
"""Abstract class for the echo path simulators.
|
||||
|
||||
In general, an echo path simulator is a function of the render signal and
|
||||
simulates the propagation of the latter into the microphone (e.g., due to
|
||||
mechanical or electrical paths).
|
||||
"""
|
||||
|
||||
NAME = None
|
||||
REGISTERED_CLASSES = {}
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def Simulate(self, output_path):
|
||||
"""Creates the echo signal and stores it in an audio file (abstract method).
|
||||
|
||||
Args:
|
||||
output_path: Path in which any output can be saved.
|
||||
|
||||
Returns:
|
||||
Path to the generated audio track file or None if no echo is present.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@classmethod
|
||||
def RegisterClass(cls, class_to_register):
|
||||
"""Registers an EchoPathSimulator implementation.
|
||||
|
||||
Decorator to automatically register the classes that extend
|
||||
EchoPathSimulator.
|
||||
Example usage:
|
||||
|
||||
@EchoPathSimulator.RegisterClass
|
||||
class NoEchoPathSimulator(EchoPathSimulator):
|
||||
pass
|
||||
"""
|
||||
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
|
||||
return class_to_register
|
||||
|
||||
|
||||
@EchoPathSimulator.RegisterClass
|
||||
class NoEchoPathSimulator(EchoPathSimulator):
|
||||
"""Simulates absence of echo."""
|
||||
|
||||
NAME = 'noecho'
|
||||
|
||||
def __init__(self):
|
||||
EchoPathSimulator.__init__(self)
|
||||
|
||||
def Simulate(self, output_path):
|
||||
return None
|
||||
|
||||
|
||||
@EchoPathSimulator.RegisterClass
|
||||
class LinearEchoPathSimulator(EchoPathSimulator):
|
||||
"""Simulates linear echo path.
|
||||
|
||||
This class applies a given impulse response to the render input and then it
|
||||
sums the signal to the capture input signal.
|
||||
"""
|
||||
|
||||
NAME = 'linear'
|
||||
|
||||
def __init__(self, render_input_filepath, impulse_response):
|
||||
"""
|
||||
Args:
|
||||
render_input_filepath: Render audio track file.
|
||||
impulse_response: list or numpy vector of float values.
|
||||
"""
|
||||
EchoPathSimulator.__init__(self)
|
||||
self._render_input_filepath = render_input_filepath
|
||||
self._impulse_response = impulse_response
|
||||
|
||||
def Simulate(self, output_path):
|
||||
"""Simulates linear echo path."""
|
||||
# Form the file name with a hash of the impulse response.
|
||||
impulse_response_hash = hashlib.sha256(
|
||||
str(self._impulse_response).encode('utf-8', 'ignore')).hexdigest()
|
||||
echo_filepath = os.path.join(
|
||||
output_path, 'linear_echo_{}.wav'.format(impulse_response_hash))
|
||||
|
||||
# If the simulated echo audio track file does not exists, create it.
|
||||
if not os.path.exists(echo_filepath):
|
||||
render = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
self._render_input_filepath)
|
||||
echo = signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
|
||||
render, self._impulse_response)
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
echo_filepath, echo)
|
||||
|
||||
return echo_filepath
|
||||
|
||||
|
||||
@EchoPathSimulator.RegisterClass
|
||||
class RecordedEchoPathSimulator(EchoPathSimulator):
|
||||
"""Uses recorded echo.
|
||||
|
||||
This class uses the clean capture input file name to build the file name of
|
||||
the corresponding recording containing echo (a predefined suffix is used).
|
||||
Such a file is expected to be already existing.
|
||||
"""
|
||||
|
||||
NAME = 'recorded'
|
||||
|
||||
_FILE_NAME_SUFFIX = '_echo'
|
||||
|
||||
def __init__(self, render_input_filepath):
|
||||
EchoPathSimulator.__init__(self)
|
||||
self._render_input_filepath = render_input_filepath
|
||||
|
||||
def Simulate(self, output_path):
|
||||
"""Uses recorded echo path."""
|
||||
path, file_name_ext = os.path.split(self._render_input_filepath)
|
||||
file_name, file_ext = os.path.splitext(file_name_ext)
|
||||
echo_filepath = os.path.join(
|
||||
path, '{}{}{}'.format(file_name, self._FILE_NAME_SUFFIX, file_ext))
|
||||
assert os.path.exists(echo_filepath), (
|
||||
'cannot find the echo audio track file {}'.format(echo_filepath))
|
||||
return echo_filepath
|
||||
@ -1,48 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Echo path simulation factory module.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
from . import echo_path_simulation
|
||||
|
||||
|
||||
class EchoPathSimulatorFactory(object):
|
||||
|
||||
# TODO(alessiob): Replace 20 ms delay (at 48 kHz sample rate) with a more
|
||||
# realistic impulse response.
|
||||
_LINEAR_ECHO_IMPULSE_RESPONSE = np.array([0.0] * (20 * 48) + [0.15])
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def GetInstance(cls, echo_path_simulator_class, render_input_filepath):
|
||||
"""Creates an EchoPathSimulator instance given a class object.
|
||||
|
||||
Args:
|
||||
echo_path_simulator_class: EchoPathSimulator class object (not an
|
||||
instance).
|
||||
render_input_filepath: Path to the render audio track file.
|
||||
|
||||
Returns:
|
||||
An EchoPathSimulator instance.
|
||||
"""
|
||||
assert render_input_filepath is not None or (
|
||||
echo_path_simulator_class ==
|
||||
echo_path_simulation.NoEchoPathSimulator)
|
||||
|
||||
if echo_path_simulator_class == echo_path_simulation.NoEchoPathSimulator:
|
||||
return echo_path_simulation.NoEchoPathSimulator()
|
||||
elif echo_path_simulator_class == (
|
||||
echo_path_simulation.LinearEchoPathSimulator):
|
||||
return echo_path_simulation.LinearEchoPathSimulator(
|
||||
render_input_filepath, cls._LINEAR_ECHO_IMPULSE_RESPONSE)
|
||||
else:
|
||||
return echo_path_simulator_class(render_input_filepath)
|
||||
@ -1,82 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Unit tests for the echo path simulation module.
|
||||
"""
|
||||
|
||||
import shutil
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pydub
|
||||
|
||||
from . import echo_path_simulation
|
||||
from . import echo_path_simulation_factory
|
||||
from . import signal_processing
|
||||
|
||||
|
||||
class TestEchoPathSimulators(unittest.TestCase):
|
||||
"""Unit tests for the eval_scores module.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Creates temporary data."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
|
||||
# Create and save white noise.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
self._audio_track_num_samples = (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(white_noise))
|
||||
self._audio_track_filepath = os.path.join(self._tmp_path,
|
||||
'white_noise.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._audio_track_filepath, white_noise)
|
||||
|
||||
# Make a copy the white noise audio track file; it will be used by
|
||||
# echo_path_simulation.RecordedEchoPathSimulator.
|
||||
shutil.copy(self._audio_track_filepath,
|
||||
os.path.join(self._tmp_path, 'white_noise_echo.wav'))
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively deletes temporary folders."""
|
||||
shutil.rmtree(self._tmp_path)
|
||||
|
||||
def testRegisteredClasses(self):
|
||||
# Check that there is at least one registered echo path simulator.
|
||||
registered_classes = (
|
||||
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES)
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
|
||||
# Instance factory.
|
||||
factory = echo_path_simulation_factory.EchoPathSimulatorFactory()
|
||||
|
||||
# Try each registered echo path simulator.
|
||||
for echo_path_simulator_name in registered_classes:
|
||||
simulator = factory.GetInstance(
|
||||
echo_path_simulator_class=registered_classes[
|
||||
echo_path_simulator_name],
|
||||
render_input_filepath=self._audio_track_filepath)
|
||||
|
||||
echo_filepath = simulator.Simulate(self._tmp_path)
|
||||
if echo_filepath is None:
|
||||
self.assertEqual(echo_path_simulation.NoEchoPathSimulator.NAME,
|
||||
echo_path_simulator_name)
|
||||
# No other tests in this case.
|
||||
continue
|
||||
|
||||
# Check that the echo audio track file exists and its length is greater or
|
||||
# equal to that of the render audio track.
|
||||
self.assertTrue(os.path.exists(echo_filepath))
|
||||
echo = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
echo_filepath)
|
||||
self.assertGreaterEqual(
|
||||
signal_processing.SignalProcessingUtils.CountSamples(echo),
|
||||
self._audio_track_num_samples)
|
||||
@ -1,427 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Evaluation score abstract class and implementations.
|
||||
"""
|
||||
|
||||
from __future__ import division
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import data_access
|
||||
from . import exceptions
|
||||
from . import signal_processing
|
||||
|
||||
|
||||
class EvaluationScore(object):
|
||||
|
||||
NAME = None
|
||||
REGISTERED_CLASSES = {}
|
||||
|
||||
def __init__(self, score_filename_prefix):
|
||||
self._score_filename_prefix = score_filename_prefix
|
||||
self._input_signal_metadata = None
|
||||
self._reference_signal = None
|
||||
self._reference_signal_filepath = None
|
||||
self._tested_signal = None
|
||||
self._tested_signal_filepath = None
|
||||
self._output_filepath = None
|
||||
self._score = None
|
||||
self._render_signal_filepath = None
|
||||
|
||||
@classmethod
|
||||
def RegisterClass(cls, class_to_register):
|
||||
"""Registers an EvaluationScore implementation.
|
||||
|
||||
Decorator to automatically register the classes that extend EvaluationScore.
|
||||
Example usage:
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class AudioLevelScore(EvaluationScore):
|
||||
pass
|
||||
"""
|
||||
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
|
||||
return class_to_register
|
||||
|
||||
@property
|
||||
def output_filepath(self):
|
||||
return self._output_filepath
|
||||
|
||||
@property
|
||||
def score(self):
|
||||
return self._score
|
||||
|
||||
def SetInputSignalMetadata(self, metadata):
|
||||
"""Sets input signal metadata.
|
||||
|
||||
Args:
|
||||
metadata: dict instance.
|
||||
"""
|
||||
self._input_signal_metadata = metadata
|
||||
|
||||
def SetReferenceSignalFilepath(self, filepath):
|
||||
"""Sets the path to the audio track used as reference signal.
|
||||
|
||||
Args:
|
||||
filepath: path to the reference audio track.
|
||||
"""
|
||||
self._reference_signal_filepath = filepath
|
||||
|
||||
def SetTestedSignalFilepath(self, filepath):
|
||||
"""Sets the path to the audio track used as test signal.
|
||||
|
||||
Args:
|
||||
filepath: path to the test audio track.
|
||||
"""
|
||||
self._tested_signal_filepath = filepath
|
||||
|
||||
def SetRenderSignalFilepath(self, filepath):
|
||||
"""Sets the path to the audio track used as render signal.
|
||||
|
||||
Args:
|
||||
filepath: path to the test audio track.
|
||||
"""
|
||||
self._render_signal_filepath = filepath
|
||||
|
||||
def Run(self, output_path):
|
||||
"""Extracts the score for the set test data pair.
|
||||
|
||||
Args:
|
||||
output_path: path to the directory where the output is written.
|
||||
"""
|
||||
self._output_filepath = os.path.join(
|
||||
output_path, self._score_filename_prefix + self.NAME + '.txt')
|
||||
try:
|
||||
# If the score has already been computed, load.
|
||||
self._LoadScore()
|
||||
logging.debug('score found and loaded')
|
||||
except IOError:
|
||||
# Compute the score.
|
||||
logging.debug('score not found, compute')
|
||||
self._Run(output_path)
|
||||
|
||||
def _Run(self, output_path):
|
||||
# Abstract method.
|
||||
raise NotImplementedError()
|
||||
|
||||
def _LoadReferenceSignal(self):
|
||||
assert self._reference_signal_filepath is not None
|
||||
self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
self._reference_signal_filepath)
|
||||
|
||||
def _LoadTestedSignal(self):
|
||||
assert self._tested_signal_filepath is not None
|
||||
self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
self._tested_signal_filepath)
|
||||
|
||||
def _LoadScore(self):
|
||||
return data_access.ScoreFile.Load(self._output_filepath)
|
||||
|
||||
def _SaveScore(self):
|
||||
return data_access.ScoreFile.Save(self._output_filepath, self._score)
|
||||
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class AudioLevelPeakScore(EvaluationScore):
|
||||
"""Peak audio level score.
|
||||
|
||||
Defined as the difference between the peak audio level of the tested and
|
||||
the reference signals.
|
||||
|
||||
Unit: dB
|
||||
Ideal: 0 dB
|
||||
Worst case: +/-inf dB
|
||||
"""
|
||||
|
||||
NAME = 'audio_level_peak'
|
||||
|
||||
def __init__(self, score_filename_prefix):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
|
||||
def _Run(self, output_path):
|
||||
self._LoadReferenceSignal()
|
||||
self._LoadTestedSignal()
|
||||
self._score = self._tested_signal.dBFS - self._reference_signal.dBFS
|
||||
self._SaveScore()
|
||||
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class MeanAudioLevelScore(EvaluationScore):
|
||||
"""Mean audio level score.
|
||||
|
||||
Defined as the difference between the mean audio level of the tested and
|
||||
the reference signals.
|
||||
|
||||
Unit: dB
|
||||
Ideal: 0 dB
|
||||
Worst case: +/-inf dB
|
||||
"""
|
||||
|
||||
NAME = 'audio_level_mean'
|
||||
|
||||
def __init__(self, score_filename_prefix):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
|
||||
def _Run(self, output_path):
|
||||
self._LoadReferenceSignal()
|
||||
self._LoadTestedSignal()
|
||||
|
||||
dbfs_diffs_sum = 0.0
|
||||
seconds = min(len(self._tested_signal), len(
|
||||
self._reference_signal)) // 1000
|
||||
for t in range(seconds):
|
||||
t0 = t * seconds
|
||||
t1 = t0 + seconds
|
||||
dbfs_diffs_sum += (self._tested_signal[t0:t1].dBFS -
|
||||
self._reference_signal[t0:t1].dBFS)
|
||||
self._score = dbfs_diffs_sum / float(seconds)
|
||||
self._SaveScore()
|
||||
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class EchoMetric(EvaluationScore):
|
||||
"""Echo score.
|
||||
|
||||
Proportion of detected echo.
|
||||
|
||||
Unit: ratio
|
||||
Ideal: 0
|
||||
Worst case: 1
|
||||
"""
|
||||
|
||||
NAME = 'echo_metric'
|
||||
|
||||
def __init__(self, score_filename_prefix, echo_detector_bin_filepath):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
|
||||
# POLQA binary file path.
|
||||
self._echo_detector_bin_filepath = echo_detector_bin_filepath
|
||||
if not os.path.exists(self._echo_detector_bin_filepath):
|
||||
logging.error('cannot find EchoMetric tool binary file')
|
||||
raise exceptions.FileNotFoundError()
|
||||
|
||||
self._echo_detector_bin_path, _ = os.path.split(
|
||||
self._echo_detector_bin_filepath)
|
||||
|
||||
def _Run(self, output_path):
|
||||
echo_detector_out_filepath = os.path.join(output_path,
|
||||
'echo_detector.out')
|
||||
if os.path.exists(echo_detector_out_filepath):
|
||||
os.unlink(echo_detector_out_filepath)
|
||||
|
||||
logging.debug("Render signal filepath: %s",
|
||||
self._render_signal_filepath)
|
||||
if not os.path.exists(self._render_signal_filepath):
|
||||
logging.error(
|
||||
"Render input required for evaluating the echo metric.")
|
||||
|
||||
args = [
|
||||
self._echo_detector_bin_filepath, '--output_file',
|
||||
echo_detector_out_filepath, '--', '-i',
|
||||
self._tested_signal_filepath, '-ri', self._render_signal_filepath
|
||||
]
|
||||
logging.debug(' '.join(args))
|
||||
subprocess.call(args, cwd=self._echo_detector_bin_path)
|
||||
|
||||
# Parse Echo detector tool output and extract the score.
|
||||
self._score = self._ParseOutputFile(echo_detector_out_filepath)
|
||||
self._SaveScore()
|
||||
|
||||
@classmethod
|
||||
def _ParseOutputFile(cls, echo_metric_file_path):
|
||||
"""
|
||||
Parses the POLQA tool output formatted as a table ('-t' option).
|
||||
|
||||
Args:
|
||||
polqa_out_filepath: path to the POLQA tool output file.
|
||||
|
||||
Returns:
|
||||
The score as a number in [0, 1].
|
||||
"""
|
||||
with open(echo_metric_file_path) as f:
|
||||
return float(f.read())
|
||||
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class PolqaScore(EvaluationScore):
|
||||
"""POLQA score.
|
||||
|
||||
See http://www.polqa.info/.
|
||||
|
||||
Unit: MOS
|
||||
Ideal: 4.5
|
||||
Worst case: 1.0
|
||||
"""
|
||||
|
||||
NAME = 'polqa'
|
||||
|
||||
def __init__(self, score_filename_prefix, polqa_bin_filepath):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
|
||||
# POLQA binary file path.
|
||||
self._polqa_bin_filepath = polqa_bin_filepath
|
||||
if not os.path.exists(self._polqa_bin_filepath):
|
||||
logging.error('cannot find POLQA tool binary file')
|
||||
raise exceptions.FileNotFoundError()
|
||||
|
||||
# Path to the POLQA directory with binary and license files.
|
||||
self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath)
|
||||
|
||||
def _Run(self, output_path):
|
||||
polqa_out_filepath = os.path.join(output_path, 'polqa.out')
|
||||
if os.path.exists(polqa_out_filepath):
|
||||
os.unlink(polqa_out_filepath)
|
||||
|
||||
args = [
|
||||
self._polqa_bin_filepath,
|
||||
'-t',
|
||||
'-q',
|
||||
'-Overwrite',
|
||||
'-Ref',
|
||||
self._reference_signal_filepath,
|
||||
'-Test',
|
||||
self._tested_signal_filepath,
|
||||
'-LC',
|
||||
'NB',
|
||||
'-Out',
|
||||
polqa_out_filepath,
|
||||
]
|
||||
logging.debug(' '.join(args))
|
||||
subprocess.call(args, cwd=self._polqa_tool_path)
|
||||
|
||||
# Parse POLQA tool output and extract the score.
|
||||
polqa_output = self._ParseOutputFile(polqa_out_filepath)
|
||||
self._score = float(polqa_output['PolqaScore'])
|
||||
|
||||
self._SaveScore()
|
||||
|
||||
@classmethod
|
||||
def _ParseOutputFile(cls, polqa_out_filepath):
|
||||
"""
|
||||
Parses the POLQA tool output formatted as a table ('-t' option).
|
||||
|
||||
Args:
|
||||
polqa_out_filepath: path to the POLQA tool output file.
|
||||
|
||||
Returns:
|
||||
A dict.
|
||||
"""
|
||||
data = []
|
||||
with open(polqa_out_filepath) as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if len(line) == 0 or line.startswith('*'):
|
||||
# Ignore comments.
|
||||
continue
|
||||
# Read fields.
|
||||
data.append(re.split(r'\t+', line))
|
||||
|
||||
# Two rows expected (header and values).
|
||||
assert len(data) == 2, 'Cannot parse POLQA output'
|
||||
number_of_fields = len(data[0])
|
||||
assert number_of_fields == len(data[1])
|
||||
|
||||
# Build and return a dictionary with field names (header) as keys and the
|
||||
# corresponding field values as values.
|
||||
return {
|
||||
data[0][index]: data[1][index]
|
||||
for index in range(number_of_fields)
|
||||
}
|
||||
|
||||
|
||||
@EvaluationScore.RegisterClass
|
||||
class TotalHarmonicDistorsionScore(EvaluationScore):
|
||||
"""Total harmonic distorsion plus noise score.
|
||||
|
||||
Total harmonic distorsion plus noise score.
|
||||
See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN".
|
||||
|
||||
Unit: -.
|
||||
Ideal: 0.
|
||||
Worst case: +inf
|
||||
"""
|
||||
|
||||
NAME = 'thd'
|
||||
|
||||
def __init__(self, score_filename_prefix):
|
||||
EvaluationScore.__init__(self, score_filename_prefix)
|
||||
self._input_frequency = None
|
||||
|
||||
def _Run(self, output_path):
|
||||
self._CheckInputSignal()
|
||||
|
||||
self._LoadTestedSignal()
|
||||
if self._tested_signal.channels != 1:
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'unsupported number of channels')
|
||||
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||
self._tested_signal)
|
||||
|
||||
# Init.
|
||||
num_samples = len(samples)
|
||||
duration = len(self._tested_signal) / 1000.0
|
||||
scaling = 2.0 / num_samples
|
||||
max_freq = self._tested_signal.frame_rate / 2
|
||||
f0_freq = float(self._input_frequency)
|
||||
t = np.linspace(0, duration, num_samples)
|
||||
|
||||
# Analyze harmonics.
|
||||
b_terms = []
|
||||
n = 1
|
||||
while f0_freq * n < max_freq:
|
||||
x_n = np.sum(
|
||||
samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
|
||||
y_n = np.sum(
|
||||
samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
|
||||
b_terms.append(np.sqrt(x_n**2 + y_n**2))
|
||||
n += 1
|
||||
|
||||
output_without_fundamental = samples - b_terms[0] * np.sin(
|
||||
2.0 * np.pi * f0_freq * t)
|
||||
distortion_and_noise = np.sqrt(
|
||||
np.sum(output_without_fundamental**2) * np.pi * scaling)
|
||||
|
||||
# TODO(alessiob): Fix or remove if not needed.
|
||||
# thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
|
||||
|
||||
# TODO(alessiob): Check the range of `thd_plus_noise` and update the class
|
||||
# docstring above if accordingly.
|
||||
thd_plus_noise = distortion_and_noise / b_terms[0]
|
||||
|
||||
self._score = thd_plus_noise
|
||||
self._SaveScore()
|
||||
|
||||
def _CheckInputSignal(self):
|
||||
# Check input signal and get properties.
|
||||
try:
|
||||
if self._input_signal_metadata['signal'] != 'pure_tone':
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'The THD score requires a pure tone as input signal')
|
||||
self._input_frequency = self._input_signal_metadata['frequency']
|
||||
if self._input_signal_metadata[
|
||||
'test_data_gen_name'] != 'identity' or (
|
||||
self._input_signal_metadata['test_data_gen_config'] !=
|
||||
'default'):
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'The THD score cannot be used with any test data generator other '
|
||||
'than "identity"')
|
||||
except TypeError:
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'The THD score requires an input signal with associated metadata'
|
||||
)
|
||||
except KeyError:
|
||||
raise exceptions.EvaluationScoreException(
|
||||
'Invalid input signal metadata to compute the THD score')
|
||||
@ -1,55 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""EvaluationScore factory class.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from . import exceptions
|
||||
from . import eval_scores
|
||||
|
||||
|
||||
class EvaluationScoreWorkerFactory(object):
|
||||
"""Factory class used to instantiate evaluation score workers.
|
||||
|
||||
The ctor gets the parametrs that are used to instatiate the evaluation score
|
||||
workers.
|
||||
"""
|
||||
|
||||
def __init__(self, polqa_tool_bin_path, echo_metric_tool_bin_path):
|
||||
self._score_filename_prefix = None
|
||||
self._polqa_tool_bin_path = polqa_tool_bin_path
|
||||
self._echo_metric_tool_bin_path = echo_metric_tool_bin_path
|
||||
|
||||
def SetScoreFilenamePrefix(self, prefix):
|
||||
self._score_filename_prefix = prefix
|
||||
|
||||
def GetInstance(self, evaluation_score_class):
|
||||
"""Creates an EvaluationScore instance given a class object.
|
||||
|
||||
Args:
|
||||
evaluation_score_class: EvaluationScore class object (not an instance).
|
||||
|
||||
Returns:
|
||||
An EvaluationScore instance.
|
||||
"""
|
||||
if self._score_filename_prefix is None:
|
||||
raise exceptions.InitializationException(
|
||||
'The score file name prefix for evaluation score workers is not set'
|
||||
)
|
||||
logging.debug('factory producing a %s evaluation score',
|
||||
evaluation_score_class)
|
||||
|
||||
if evaluation_score_class == eval_scores.PolqaScore:
|
||||
return eval_scores.PolqaScore(self._score_filename_prefix,
|
||||
self._polqa_tool_bin_path)
|
||||
elif evaluation_score_class == eval_scores.EchoMetric:
|
||||
return eval_scores.EchoMetric(self._score_filename_prefix,
|
||||
self._echo_metric_tool_bin_path)
|
||||
else:
|
||||
return evaluation_score_class(self._score_filename_prefix)
|
||||
@ -1,137 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Unit tests for the eval_scores module.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pydub
|
||||
|
||||
from . import data_access
|
||||
from . import eval_scores
|
||||
from . import eval_scores_factory
|
||||
from . import signal_processing
|
||||
|
||||
|
||||
class TestEvalScores(unittest.TestCase):
|
||||
"""Unit tests for the eval_scores module.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary output folder and two audio track files."""
|
||||
self._output_path = tempfile.mkdtemp()
|
||||
|
||||
# Create fake reference and tested (i.e., APM output) audio track files.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
fake_reference_signal = (signal_processing.SignalProcessingUtils.
|
||||
GenerateWhiteNoise(silence))
|
||||
fake_tested_signal = (signal_processing.SignalProcessingUtils.
|
||||
GenerateWhiteNoise(silence))
|
||||
|
||||
# Save fake audio tracks.
|
||||
self._fake_reference_signal_filepath = os.path.join(
|
||||
self._output_path, 'fake_ref.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_reference_signal_filepath, fake_reference_signal)
|
||||
self._fake_tested_signal_filepath = os.path.join(
|
||||
self._output_path, 'fake_test.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_tested_signal_filepath, fake_tested_signal)
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folder."""
|
||||
shutil.rmtree(self._output_path)
|
||||
|
||||
def testRegisteredClasses(self):
|
||||
# Evaluation score names to exclude (tested separately).
|
||||
exceptions = ['thd', 'echo_metric']
|
||||
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._output_path))
|
||||
|
||||
# Check that there is at least one registered evaluation score worker.
|
||||
registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
|
||||
# Instance evaluation score workers factory with fake dependencies.
|
||||
eval_score_workers_factory = (
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None))
|
||||
eval_score_workers_factory.SetScoreFilenamePrefix('scores-')
|
||||
|
||||
# Try each registered evaluation score worker.
|
||||
for eval_score_name in registered_classes:
|
||||
if eval_score_name in exceptions:
|
||||
continue
|
||||
|
||||
# Instance evaluation score worker.
|
||||
eval_score_worker = eval_score_workers_factory.GetInstance(
|
||||
registered_classes[eval_score_name])
|
||||
|
||||
# Set fake input metadata and reference and test file paths, then run.
|
||||
eval_score_worker.SetReferenceSignalFilepath(
|
||||
self._fake_reference_signal_filepath)
|
||||
eval_score_worker.SetTestedSignalFilepath(
|
||||
self._fake_tested_signal_filepath)
|
||||
eval_score_worker.Run(self._output_path)
|
||||
|
||||
# Check output.
|
||||
score = data_access.ScoreFile.Load(
|
||||
eval_score_worker.output_filepath)
|
||||
self.assertTrue(isinstance(score, float))
|
||||
|
||||
def testTotalHarmonicDistorsionScore(self):
|
||||
# Init.
|
||||
pure_tone_freq = 5000.0
|
||||
eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
|
||||
eval_score_worker.SetInputSignalMetadata({
|
||||
'signal':
|
||||
'pure_tone',
|
||||
'frequency':
|
||||
pure_tone_freq,
|
||||
'test_data_gen_name':
|
||||
'identity',
|
||||
'test_data_gen_config':
|
||||
'default',
|
||||
})
|
||||
template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
|
||||
# Create 3 test signals: pure tone, pure tone + white noise, white noise
|
||||
# only.
|
||||
pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
template, pure_tone_freq)
|
||||
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
template)
|
||||
noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
pure_tone, white_noise)
|
||||
|
||||
# Compute scores for increasingly distorted pure tone signals.
|
||||
scores = [None, None, None]
|
||||
for index, tested_signal in enumerate(
|
||||
[pure_tone, noisy_tone, white_noise]):
|
||||
# Save signal.
|
||||
tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
tmp_filepath, tested_signal)
|
||||
|
||||
# Compute score.
|
||||
eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
|
||||
eval_score_worker.Run(self._output_path)
|
||||
scores[index] = eval_score_worker.score
|
||||
|
||||
# Remove output file to avoid caching.
|
||||
os.remove(eval_score_worker.output_filepath)
|
||||
|
||||
# Validate scores (lowest score with a pure tone).
|
||||
self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))
|
||||
@ -1,57 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Evaluator of the APM module.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class ApmModuleEvaluator(object):
|
||||
"""APM evaluator class.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def Run(cls, evaluation_score_workers, apm_input_metadata,
|
||||
apm_output_filepath, reference_input_filepath,
|
||||
render_input_filepath, output_path):
|
||||
"""Runs the evaluation.
|
||||
|
||||
Iterates over the given evaluation score workers.
|
||||
|
||||
Args:
|
||||
evaluation_score_workers: list of EvaluationScore instances.
|
||||
apm_input_metadata: dictionary with metadata of the APM input.
|
||||
apm_output_filepath: path to the audio track file with the APM output.
|
||||
reference_input_filepath: path to the reference audio track file.
|
||||
output_path: output path.
|
||||
|
||||
Returns:
|
||||
A dict of evaluation score name and score pairs.
|
||||
"""
|
||||
# Init.
|
||||
scores = {}
|
||||
|
||||
for evaluation_score_worker in evaluation_score_workers:
|
||||
logging.info(' computing <%s> score',
|
||||
evaluation_score_worker.NAME)
|
||||
evaluation_score_worker.SetInputSignalMetadata(apm_input_metadata)
|
||||
evaluation_score_worker.SetReferenceSignalFilepath(
|
||||
reference_input_filepath)
|
||||
evaluation_score_worker.SetTestedSignalFilepath(
|
||||
apm_output_filepath)
|
||||
evaluation_score_worker.SetRenderSignalFilepath(
|
||||
render_input_filepath)
|
||||
|
||||
evaluation_score_worker.Run(output_path)
|
||||
scores[
|
||||
evaluation_score_worker.NAME] = evaluation_score_worker.score
|
||||
|
||||
return scores
|
||||
@ -1,45 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Exception classes.
|
||||
"""
|
||||
|
||||
|
||||
class FileNotFoundError(Exception):
|
||||
"""File not found exception.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class SignalProcessingException(Exception):
|
||||
"""Signal processing exception.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class InputMixerException(Exception):
|
||||
"""Input mixer exception.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class InputSignalCreatorException(Exception):
|
||||
"""Input signal creator exception.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class EvaluationScoreException(Exception):
|
||||
"""Evaluation score exception.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class InitializationException(Exception):
|
||||
"""Initialization exception.
|
||||
"""
|
||||
pass
|
||||
@ -1,426 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
import functools
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
try:
|
||||
import csscompressor
|
||||
except ImportError:
|
||||
logging.critical(
|
||||
'Cannot import the third-party Python package csscompressor')
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import jsmin
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package jsmin')
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class HtmlExport(object):
|
||||
"""HTML exporter class for APM quality scores."""
|
||||
|
||||
_NEW_LINE = '\n'
|
||||
|
||||
# CSS and JS file paths.
|
||||
_PATH = os.path.dirname(os.path.realpath(__file__))
|
||||
_CSS_FILEPATH = os.path.join(_PATH, 'results.css')
|
||||
_CSS_MINIFIED = True
|
||||
_JS_FILEPATH = os.path.join(_PATH, 'results.js')
|
||||
_JS_MINIFIED = True
|
||||
|
||||
def __init__(self, output_filepath):
|
||||
self._scores_data_frame = None
|
||||
self._output_filepath = output_filepath
|
||||
|
||||
def Export(self, scores_data_frame):
|
||||
"""Exports scores into an HTML file.
|
||||
|
||||
Args:
|
||||
scores_data_frame: DataFrame instance.
|
||||
"""
|
||||
self._scores_data_frame = scores_data_frame
|
||||
html = [
|
||||
'<html>',
|
||||
self._BuildHeader(),
|
||||
('<script type="text/javascript">'
|
||||
'(function () {'
|
||||
'window.addEventListener(\'load\', function () {'
|
||||
'var inspector = new AudioInspector();'
|
||||
'});'
|
||||
'})();'
|
||||
'</script>'), '<body>',
|
||||
self._BuildBody(), '</body>', '</html>'
|
||||
]
|
||||
self._Save(self._output_filepath, self._NEW_LINE.join(html))
|
||||
|
||||
def _BuildHeader(self):
|
||||
"""Builds the <head> section of the HTML file.
|
||||
|
||||
The header contains the page title and either embedded or linked CSS and JS
|
||||
files.
|
||||
|
||||
Returns:
|
||||
A string with <head>...</head> HTML.
|
||||
"""
|
||||
html = ['<head>', '<title>Results</title>']
|
||||
|
||||
# Add Material Design hosted libs.
|
||||
html.append('<link rel="stylesheet" href="http://fonts.googleapis.com/'
|
||||
'css?family=Roboto:300,400,500,700" type="text/css">')
|
||||
html.append(
|
||||
'<link rel="stylesheet" href="https://fonts.googleapis.com/'
|
||||
'icon?family=Material+Icons">')
|
||||
html.append(
|
||||
'<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/'
|
||||
'material.indigo-pink.min.css">')
|
||||
html.append('<script defer src="https://code.getmdl.io/1.3.0/'
|
||||
'material.min.js"></script>')
|
||||
|
||||
# Embed custom JavaScript and CSS files.
|
||||
html.append('<script>')
|
||||
with open(self._JS_FILEPATH) as f:
|
||||
html.append(
|
||||
jsmin.jsmin(f.read()) if self._JS_MINIFIED else (
|
||||
f.read().rstrip()))
|
||||
html.append('</script>')
|
||||
html.append('<style>')
|
||||
with open(self._CSS_FILEPATH) as f:
|
||||
html.append(
|
||||
csscompressor.compress(f.read()) if self._CSS_MINIFIED else (
|
||||
f.read().rstrip()))
|
||||
html.append('</style>')
|
||||
|
||||
html.append('</head>')
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildBody(self):
|
||||
"""Builds the content of the <body> section."""
|
||||
score_names = self._scores_data_frame[
|
||||
'eval_score_name'].drop_duplicates().values.tolist()
|
||||
|
||||
html = [
|
||||
('<div class="mdl-layout mdl-js-layout mdl-layout--fixed-header '
|
||||
'mdl-layout--fixed-tabs">'),
|
||||
'<header class="mdl-layout__header">',
|
||||
'<div class="mdl-layout__header-row">',
|
||||
'<span class="mdl-layout-title">APM QA results ({})</span>'.format(
|
||||
self._output_filepath),
|
||||
'</div>',
|
||||
]
|
||||
|
||||
# Tab selectors.
|
||||
html.append('<div class="mdl-layout__tab-bar mdl-js-ripple-effect">')
|
||||
for tab_index, score_name in enumerate(score_names):
|
||||
is_active = tab_index == 0
|
||||
html.append('<a href="#score-tab-{}" class="mdl-layout__tab{}">'
|
||||
'{}</a>'.format(tab_index,
|
||||
' is-active' if is_active else '',
|
||||
self._FormatName(score_name)))
|
||||
html.append('</div>')
|
||||
|
||||
html.append('</header>')
|
||||
html.append(
|
||||
'<main class="mdl-layout__content" style="overflow-x: auto;">')
|
||||
|
||||
# Tabs content.
|
||||
for tab_index, score_name in enumerate(score_names):
|
||||
html.append('<section class="mdl-layout__tab-panel{}" '
|
||||
'id="score-tab-{}">'.format(
|
||||
' is-active' if is_active else '', tab_index))
|
||||
html.append('<div class="page-content">')
|
||||
html.append(
|
||||
self._BuildScoreTab(score_name, ('s{}'.format(tab_index), )))
|
||||
html.append('</div>')
|
||||
html.append('</section>')
|
||||
|
||||
html.append('</main>')
|
||||
html.append('</div>')
|
||||
|
||||
# Add snackbar for notifications.
|
||||
html.append(
|
||||
'<div id="snackbar" aria-live="assertive" aria-atomic="true"'
|
||||
' aria-relevant="text" class="mdl-snackbar mdl-js-snackbar">'
|
||||
'<div class="mdl-snackbar__text"></div>'
|
||||
'<button type="button" class="mdl-snackbar__action"></button>'
|
||||
'</div>')
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildScoreTab(self, score_name, anchor_data):
|
||||
"""Builds the content of a tab."""
|
||||
# Find unique values.
|
||||
scores = self._scores_data_frame[
|
||||
self._scores_data_frame.eval_score_name == score_name]
|
||||
apm_configs = sorted(self._FindUniqueTuples(scores, ['apm_config']))
|
||||
test_data_gen_configs = sorted(
|
||||
self._FindUniqueTuples(scores,
|
||||
['test_data_gen', 'test_data_gen_params']))
|
||||
|
||||
html = [
|
||||
'<div class="mdl-grid">',
|
||||
'<div class="mdl-layout-spacer"></div>',
|
||||
'<div class="mdl-cell mdl-cell--10-col">',
|
||||
('<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp" '
|
||||
'style="width: 100%;">'),
|
||||
]
|
||||
|
||||
# Header.
|
||||
html.append('<thead><tr><th>APM config / Test data generator</th>')
|
||||
for test_data_gen_info in test_data_gen_configs:
|
||||
html.append('<th>{} {}</th>'.format(
|
||||
self._FormatName(test_data_gen_info[0]),
|
||||
test_data_gen_info[1]))
|
||||
html.append('</tr></thead>')
|
||||
|
||||
# Body.
|
||||
html.append('<tbody>')
|
||||
for apm_config in apm_configs:
|
||||
html.append('<tr><td>' + self._FormatName(apm_config[0]) + '</td>')
|
||||
for test_data_gen_info in test_data_gen_configs:
|
||||
dialog_id = self._ScoreStatsInspectorDialogId(
|
||||
score_name, apm_config[0], test_data_gen_info[0],
|
||||
test_data_gen_info[1])
|
||||
html.append(
|
||||
'<td onclick="openScoreStatsInspector(\'{}\')">{}</td>'.
|
||||
format(
|
||||
dialog_id,
|
||||
self._BuildScoreTableCell(score_name,
|
||||
test_data_gen_info[0],
|
||||
test_data_gen_info[1],
|
||||
apm_config[0])))
|
||||
html.append('</tr>')
|
||||
html.append('</tbody>')
|
||||
|
||||
html.append(
|
||||
'</table></div><div class="mdl-layout-spacer"></div></div>')
|
||||
|
||||
html.append(
|
||||
self._BuildScoreStatsInspectorDialogs(score_name, apm_configs,
|
||||
test_data_gen_configs,
|
||||
anchor_data))
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildScoreTableCell(self, score_name, test_data_gen,
|
||||
test_data_gen_params, apm_config):
|
||||
"""Builds the content of a table cell for a score table."""
|
||||
scores = self._SliceDataForScoreTableCell(score_name, apm_config,
|
||||
test_data_gen,
|
||||
test_data_gen_params)
|
||||
stats = self._ComputeScoreStats(scores)
|
||||
|
||||
html = []
|
||||
items_id_prefix = (score_name + test_data_gen + test_data_gen_params +
|
||||
apm_config)
|
||||
if stats['count'] == 1:
|
||||
# Show the only available score.
|
||||
item_id = hashlib.md5(items_id_prefix.encode('utf-8')).hexdigest()
|
||||
html.append('<div id="single-value-{0}">{1:f}</div>'.format(
|
||||
item_id, scores['score'].mean()))
|
||||
html.append(
|
||||
'<div class="mdl-tooltip" data-mdl-for="single-value-{}">{}'
|
||||
'</div>'.format(item_id, 'single value'))
|
||||
else:
|
||||
# Show stats.
|
||||
for stat_name in ['min', 'max', 'mean', 'std dev']:
|
||||
item_id = hashlib.md5(
|
||||
(items_id_prefix + stat_name).encode('utf-8')).hexdigest()
|
||||
html.append('<div id="stats-{0}">{1:f}</div>'.format(
|
||||
item_id, stats[stat_name]))
|
||||
html.append(
|
||||
'<div class="mdl-tooltip" data-mdl-for="stats-{}">{}'
|
||||
'</div>'.format(item_id, stat_name))
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildScoreStatsInspectorDialogs(self, score_name, apm_configs,
|
||||
test_data_gen_configs, anchor_data):
|
||||
"""Builds a set of score stats inspector dialogs."""
|
||||
html = []
|
||||
for apm_config in apm_configs:
|
||||
for test_data_gen_info in test_data_gen_configs:
|
||||
dialog_id = self._ScoreStatsInspectorDialogId(
|
||||
score_name, apm_config[0], test_data_gen_info[0],
|
||||
test_data_gen_info[1])
|
||||
|
||||
html.append('<dialog class="mdl-dialog" id="{}" '
|
||||
'style="width: 40%;">'.format(dialog_id))
|
||||
|
||||
# Content.
|
||||
html.append('<div class="mdl-dialog__content">')
|
||||
html.append(
|
||||
'<h6><strong>APM config preset</strong>: {}<br/>'
|
||||
'<strong>Test data generator</strong>: {} ({})</h6>'.
|
||||
format(self._FormatName(apm_config[0]),
|
||||
self._FormatName(test_data_gen_info[0]),
|
||||
test_data_gen_info[1]))
|
||||
html.append(
|
||||
self._BuildScoreStatsInspectorDialog(
|
||||
score_name, apm_config[0], test_data_gen_info[0],
|
||||
test_data_gen_info[1], anchor_data + (dialog_id, )))
|
||||
html.append('</div>')
|
||||
|
||||
# Actions.
|
||||
html.append('<div class="mdl-dialog__actions">')
|
||||
html.append('<button type="button" class="mdl-button" '
|
||||
'onclick="closeScoreStatsInspector()">'
|
||||
'Close</button>')
|
||||
html.append('</div>')
|
||||
|
||||
html.append('</dialog>')
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildScoreStatsInspectorDialog(self, score_name, apm_config,
|
||||
test_data_gen, test_data_gen_params,
|
||||
anchor_data):
|
||||
"""Builds one score stats inspector dialog."""
|
||||
scores = self._SliceDataForScoreTableCell(score_name, apm_config,
|
||||
test_data_gen,
|
||||
test_data_gen_params)
|
||||
|
||||
capture_render_pairs = sorted(
|
||||
self._FindUniqueTuples(scores, ['capture', 'render']))
|
||||
echo_simulators = sorted(
|
||||
self._FindUniqueTuples(scores, ['echo_simulator']))
|
||||
|
||||
html = [
|
||||
'<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp">'
|
||||
]
|
||||
|
||||
# Header.
|
||||
html.append('<thead><tr><th>Capture-Render / Echo simulator</th>')
|
||||
for echo_simulator in echo_simulators:
|
||||
html.append('<th>' + self._FormatName(echo_simulator[0]) + '</th>')
|
||||
html.append('</tr></thead>')
|
||||
|
||||
# Body.
|
||||
html.append('<tbody>')
|
||||
for row, (capture, render) in enumerate(capture_render_pairs):
|
||||
html.append('<tr><td><div>{}</div><div>{}</div></td>'.format(
|
||||
capture, render))
|
||||
for col, echo_simulator in enumerate(echo_simulators):
|
||||
score_tuple = self._SliceDataForScoreStatsTableCell(
|
||||
scores, capture, render, echo_simulator[0])
|
||||
cell_class = 'r{}c{}'.format(row, col)
|
||||
html.append('<td class="single-score-cell {}">{}</td>'.format(
|
||||
cell_class,
|
||||
self._BuildScoreStatsInspectorTableCell(
|
||||
score_tuple, anchor_data + (cell_class, ))))
|
||||
html.append('</tr>')
|
||||
html.append('</tbody>')
|
||||
|
||||
html.append('</table>')
|
||||
|
||||
# Placeholder for the audio inspector.
|
||||
html.append('<div class="audio-inspector-placeholder"></div>')
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _BuildScoreStatsInspectorTableCell(self, score_tuple, anchor_data):
|
||||
"""Builds the content of a cell of a score stats inspector."""
|
||||
anchor = '&'.join(anchor_data)
|
||||
html = [('<div class="v">{}</div>'
|
||||
'<button class="mdl-button mdl-js-button mdl-button--icon"'
|
||||
' data-anchor="{}">'
|
||||
'<i class="material-icons mdl-color-text--blue-grey">link</i>'
|
||||
'</button>').format(score_tuple.score, anchor)]
|
||||
|
||||
# Add all the available file paths as hidden data.
|
||||
for field_name in score_tuple.keys():
|
||||
if field_name.endswith('_filepath'):
|
||||
html.append(
|
||||
'<input type="hidden" name="{}" value="{}">'.format(
|
||||
field_name, score_tuple[field_name]))
|
||||
|
||||
return self._NEW_LINE.join(html)
|
||||
|
||||
def _SliceDataForScoreTableCell(self, score_name, apm_config,
|
||||
test_data_gen, test_data_gen_params):
|
||||
"""Slices `self._scores_data_frame` to extract the data for a tab."""
|
||||
masks = []
|
||||
masks.append(self._scores_data_frame.eval_score_name == score_name)
|
||||
masks.append(self._scores_data_frame.apm_config == apm_config)
|
||||
masks.append(self._scores_data_frame.test_data_gen == test_data_gen)
|
||||
masks.append(self._scores_data_frame.test_data_gen_params ==
|
||||
test_data_gen_params)
|
||||
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
|
||||
del masks
|
||||
return self._scores_data_frame[mask]
|
||||
|
||||
@classmethod
|
||||
def _SliceDataForScoreStatsTableCell(cls, scores, capture, render,
|
||||
echo_simulator):
|
||||
"""Slices `scores` to extract the data for a tab."""
|
||||
masks = []
|
||||
|
||||
masks.append(scores.capture == capture)
|
||||
masks.append(scores.render == render)
|
||||
masks.append(scores.echo_simulator == echo_simulator)
|
||||
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
|
||||
del masks
|
||||
|
||||
sliced_data = scores[mask]
|
||||
assert len(sliced_data) == 1, 'single score is expected'
|
||||
return sliced_data.iloc[0]
|
||||
|
||||
@classmethod
|
||||
def _FindUniqueTuples(cls, data_frame, fields):
|
||||
"""Slices `data_frame` to a list of fields and finds unique tuples."""
|
||||
return data_frame[fields].drop_duplicates().values.tolist()
|
||||
|
||||
@classmethod
|
||||
def _ComputeScoreStats(cls, data_frame):
|
||||
"""Computes score stats."""
|
||||
scores = data_frame['score']
|
||||
return {
|
||||
'count': scores.count(),
|
||||
'min': scores.min(),
|
||||
'max': scores.max(),
|
||||
'mean': scores.mean(),
|
||||
'std dev': scores.std(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _ScoreStatsInspectorDialogId(cls, score_name, apm_config,
|
||||
test_data_gen, test_data_gen_params):
|
||||
"""Assigns a unique name to a dialog."""
|
||||
return 'score-stats-dialog-' + hashlib.md5(
|
||||
'score-stats-inspector-{}-{}-{}-{}'.format(
|
||||
score_name, apm_config, test_data_gen,
|
||||
test_data_gen_params).encode('utf-8')).hexdigest()
|
||||
|
||||
@classmethod
|
||||
def _Save(cls, output_filepath, html):
|
||||
"""Writes the HTML file.
|
||||
|
||||
Args:
|
||||
output_filepath: output file path.
|
||||
html: string with the HTML content.
|
||||
"""
|
||||
with open(output_filepath, 'w') as f:
|
||||
f.write(html)
|
||||
|
||||
@classmethod
|
||||
def _FormatName(cls, name):
|
||||
"""Formats a name.
|
||||
|
||||
Args:
|
||||
name: a string.
|
||||
|
||||
Returns:
|
||||
A copy of name in which underscores and dashes are replaced with a space.
|
||||
"""
|
||||
return re.sub(r'[_\-]', ' ', name)
|
||||
@ -1,86 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Unit tests for the export module.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pyquery as pq
|
||||
|
||||
from . import audioproc_wrapper
|
||||
from . import collect_data
|
||||
from . import eval_scores_factory
|
||||
from . import evaluation
|
||||
from . import export
|
||||
from . import simulation
|
||||
from . import test_data_generation_factory
|
||||
|
||||
|
||||
class TestExport(unittest.TestCase):
|
||||
"""Unit tests for the export module.
|
||||
"""
|
||||
|
||||
_CLEAN_TMP_OUTPUT = True
|
||||
|
||||
def setUp(self):
|
||||
"""Creates temporary data to export."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
|
||||
# Run a fake experiment to produce data to export.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False)),
|
||||
evaluation_score_factory=(
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)),
|
||||
'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None)),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
simulator.Run(
|
||||
config_filepaths=['apm_configs/default.json'],
|
||||
capture_input_filepaths=[
|
||||
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
|
||||
os.path.join(self._tmp_path, 'pure_tone-880_1000.wav'),
|
||||
],
|
||||
test_data_generator_names=['identity', 'white_noise'],
|
||||
eval_score_names=['audio_level_peak', 'audio_level_mean'],
|
||||
output_dir=self._tmp_path)
|
||||
|
||||
# Export results.
|
||||
p = collect_data.InstanceArgumentsParser()
|
||||
args = p.parse_args(['--output_dir', self._tmp_path])
|
||||
src_path = collect_data.ConstructSrcPath(args)
|
||||
self._data_to_export = collect_data.FindScores(src_path, args)
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively deletes temporary folders."""
|
||||
if self._CLEAN_TMP_OUTPUT:
|
||||
shutil.rmtree(self._tmp_path)
|
||||
else:
|
||||
logging.warning(self.id() + ' did not clean the temporary path ' +
|
||||
(self._tmp_path))
|
||||
|
||||
def testCreateHtmlReport(self):
|
||||
fn_out = os.path.join(self._tmp_path, 'results.html')
|
||||
exporter = export.HtmlExport(fn_out)
|
||||
exporter.Export(self._data_to_export)
|
||||
|
||||
document = pq.PyQuery(filename=fn_out)
|
||||
self.assertIsInstance(document, pq.PyQuery)
|
||||
# TODO(alessiob): Use PyQuery API to check the HTML file.
|
||||
@ -1,75 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
from __future__ import division
|
||||
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import signal_processing
|
||||
|
||||
|
||||
class ExternalVad(object):
|
||||
def __init__(self, path_to_binary, name):
|
||||
"""Args:
|
||||
path_to_binary: path to binary that accepts '-i <wav>', '-o
|
||||
<float probabilities>'. There must be one float value per
|
||||
10ms audio
|
||||
name: a name to identify the external VAD. Used for saving
|
||||
the output as extvad_output-<name>.
|
||||
"""
|
||||
self._path_to_binary = path_to_binary
|
||||
self.name = name
|
||||
assert os.path.exists(self._path_to_binary), (self._path_to_binary)
|
||||
self._vad_output = None
|
||||
|
||||
def Run(self, wav_file_path):
|
||||
_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
wav_file_path)
|
||||
if _signal.channels != 1:
|
||||
raise NotImplementedError('Multiple-channel'
|
||||
' annotations not implemented')
|
||||
if _signal.frame_rate != 48000:
|
||||
raise NotImplementedError('Frame rates '
|
||||
'other than 48000 not implemented')
|
||||
|
||||
tmp_path = tempfile.mkdtemp()
|
||||
try:
|
||||
output_file_path = os.path.join(tmp_path, self.name + '_vad.tmp')
|
||||
subprocess.call([
|
||||
self._path_to_binary, '-i', wav_file_path, '-o',
|
||||
output_file_path
|
||||
])
|
||||
self._vad_output = np.fromfile(output_file_path, np.float32)
|
||||
except Exception as e:
|
||||
logging.error('Error while running the ' + self.name + ' VAD (' +
|
||||
e.message + ')')
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
shutil.rmtree(tmp_path)
|
||||
|
||||
def GetVadOutput(self):
|
||||
assert self._vad_output is not None
|
||||
return self._vad_output
|
||||
|
||||
@classmethod
|
||||
def ConstructVadDict(cls, vad_paths, vad_names):
|
||||
external_vads = {}
|
||||
for path, name in zip(vad_paths, vad_names):
|
||||
external_vads[name] = ExternalVad(path, name)
|
||||
return external_vads
|
||||
@ -1,25 +0,0 @@
|
||||
#!/usr/bin/python
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-i', required=True)
|
||||
parser.add_argument('-o', required=True)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
array = np.arange(100, dtype=np.float32)
|
||||
array.tofile(open(args.o, 'w'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -1,56 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
const char* const kErrorMessage = "-Out /path/to/output/file is mandatory";
|
||||
|
||||
// Writes fake output intended to be parsed by
|
||||
// quality_assessment.eval_scores.PolqaScore.
|
||||
void WriteOutputFile(absl::string_view output_file_path) {
|
||||
RTC_CHECK_NE(output_file_path, "");
|
||||
std::ofstream out(std::string{output_file_path});
|
||||
RTC_CHECK(!out.bad());
|
||||
out << "* Fake Polqa output" << std::endl;
|
||||
out << "FakeField1\tPolqaScore\tFakeField2" << std::endl;
|
||||
out << "FakeValue1\t3.25\tFakeValue2" << std::endl;
|
||||
out.close();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
// Find "-Out" and use its next argument as output file path.
|
||||
RTC_CHECK_GE(argc, 3) << kErrorMessage;
|
||||
const std::string kSoughtFlagName = "-Out";
|
||||
for (int i = 1; i < argc - 1; ++i) {
|
||||
if (kSoughtFlagName.compare(argv[i]) == 0) {
|
||||
WriteOutputFile(argv[i + 1]);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
RTC_FATAL() << kErrorMessage;
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
return webrtc::test::main(argc, argv);
|
||||
}
|
||||
@ -1,97 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Input mixer module.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from . import exceptions
|
||||
from . import signal_processing
|
||||
|
||||
|
||||
class ApmInputMixer(object):
|
||||
"""Class to mix a set of audio segments down to the APM input."""
|
||||
|
||||
_HARD_CLIPPING_LOG_MSG = 'hard clipping detected in the mixed signal'
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def HardClippingLogMessage(cls):
|
||||
"""Returns the log message used when hard clipping is detected in the mix.
|
||||
|
||||
This method is mainly intended to be used by the unit tests.
|
||||
"""
|
||||
return cls._HARD_CLIPPING_LOG_MSG
|
||||
|
||||
@classmethod
|
||||
def Mix(cls, output_path, capture_input_filepath, echo_filepath):
|
||||
"""Mixes capture and echo.
|
||||
|
||||
Creates the overall capture input for APM by mixing the "echo-free" capture
|
||||
signal with the echo signal (e.g., echo simulated via the
|
||||
echo_path_simulation module).
|
||||
|
||||
The echo signal cannot be shorter than the capture signal and the generated
|
||||
mix will have the same duration of the capture signal. The latter property
|
||||
is enforced in order to let the input of APM and the reference signal
|
||||
created by TestDataGenerator have the same length (required for the
|
||||
evaluation step).
|
||||
|
||||
Hard-clipping may occur in the mix; a warning is raised when this happens.
|
||||
|
||||
If `echo_filepath` is None, nothing is done and `capture_input_filepath` is
|
||||
returned.
|
||||
|
||||
Args:
|
||||
speech: AudioSegment instance.
|
||||
echo_path: AudioSegment instance or None.
|
||||
|
||||
Returns:
|
||||
Path to the mix audio track file.
|
||||
"""
|
||||
if echo_filepath is None:
|
||||
return capture_input_filepath
|
||||
|
||||
# Build the mix output file name as a function of the echo file name.
|
||||
# This ensures that if the internal parameters of the echo path simulator
|
||||
# change, no erroneous cache hit occurs.
|
||||
echo_file_name, _ = os.path.splitext(os.path.split(echo_filepath)[1])
|
||||
capture_input_file_name, _ = os.path.splitext(
|
||||
os.path.split(capture_input_filepath)[1])
|
||||
mix_filepath = os.path.join(
|
||||
output_path,
|
||||
'mix_capture_{}_{}.wav'.format(capture_input_file_name,
|
||||
echo_file_name))
|
||||
|
||||
# Create the mix if not done yet.
|
||||
mix = None
|
||||
if not os.path.exists(mix_filepath):
|
||||
echo_free_capture = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
capture_input_filepath)
|
||||
echo = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
echo_filepath)
|
||||
|
||||
if signal_processing.SignalProcessingUtils.CountSamples(echo) < (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(
|
||||
echo_free_capture)):
|
||||
raise exceptions.InputMixerException(
|
||||
'echo cannot be shorter than capture')
|
||||
|
||||
mix = echo_free_capture.overlay(echo)
|
||||
signal_processing.SignalProcessingUtils.SaveWav(mix_filepath, mix)
|
||||
|
||||
# Check if hard clipping occurs.
|
||||
if mix is None:
|
||||
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
|
||||
if signal_processing.SignalProcessingUtils.DetectHardClipping(mix):
|
||||
logging.warning(cls._HARD_CLIPPING_LOG_MSG)
|
||||
|
||||
return mix_filepath
|
||||
@ -1,140 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Unit tests for the input mixer module.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import mock
|
||||
|
||||
from . import exceptions
|
||||
from . import input_mixer
|
||||
from . import signal_processing
|
||||
|
||||
|
||||
class TestApmInputMixer(unittest.TestCase):
|
||||
"""Unit tests for the ApmInputMixer class.
|
||||
"""
|
||||
|
||||
# Audio track file names created in setUp().
|
||||
_FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer']
|
||||
|
||||
# Target peak power level (dBFS) of each audio track file created in setUp().
|
||||
# These values are hand-crafted in order to make saturation happen when
|
||||
# capture and echo_2 are mixed and the contrary for capture and echo_1.
|
||||
# None means that the power is not changed.
|
||||
_MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None]
|
||||
|
||||
# Audio track file durations in milliseconds.
|
||||
_DURATIONS = [1000, 1000, 1000, 800, 1200]
|
||||
|
||||
_SAMPLE_RATE = 48000
|
||||
|
||||
def setUp(self):
|
||||
"""Creates temporary data."""
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
|
||||
# Create audio track files.
|
||||
self._audio_tracks = {}
|
||||
for filename, peak_power, duration in zip(self._FILENAMES,
|
||||
self._MAX_PEAK_POWER_LEVELS,
|
||||
self._DURATIONS):
|
||||
audio_track_filepath = os.path.join(self._tmp_path,
|
||||
'{}.wav'.format(filename))
|
||||
|
||||
# Create a pure tone with the target peak power level.
|
||||
template = signal_processing.SignalProcessingUtils.GenerateSilence(
|
||||
duration=duration, sample_rate=self._SAMPLE_RATE)
|
||||
signal = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
template)
|
||||
if peak_power is not None:
|
||||
signal = signal.apply_gain(-signal.max_dBFS + peak_power)
|
||||
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
audio_track_filepath, signal)
|
||||
self._audio_tracks[filename] = {
|
||||
'filepath':
|
||||
audio_track_filepath,
|
||||
'num_samples':
|
||||
signal_processing.SignalProcessingUtils.CountSamples(signal)
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively deletes temporary folders."""
|
||||
shutil.rmtree(self._tmp_path)
|
||||
|
||||
def testCheckMixSameDuration(self):
|
||||
"""Checks the duration when mixing capture and echo with same duration."""
|
||||
mix_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_1']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix_filepath))
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
|
||||
self.assertEqual(
|
||||
self._audio_tracks['capture']['num_samples'],
|
||||
signal_processing.SignalProcessingUtils.CountSamples(mix))
|
||||
|
||||
def testRejectShorterEcho(self):
|
||||
"""Rejects echo signals that are shorter than the capture signal."""
|
||||
try:
|
||||
_ = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['shorter']['filepath'])
|
||||
self.fail('no exception raised')
|
||||
except exceptions.InputMixerException:
|
||||
pass
|
||||
|
||||
def testCheckMixDurationWithLongerEcho(self):
|
||||
"""Checks the duration when mixing an echo longer than the capture."""
|
||||
mix_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['longer']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix_filepath))
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
|
||||
self.assertEqual(
|
||||
self._audio_tracks['capture']['num_samples'],
|
||||
signal_processing.SignalProcessingUtils.CountSamples(mix))
|
||||
|
||||
def testCheckOutputFileNamesConflict(self):
|
||||
"""Checks that different echo files lead to different output file names."""
|
||||
mix1_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_1']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix1_filepath))
|
||||
|
||||
mix2_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_2']['filepath'])
|
||||
self.assertTrue(os.path.exists(mix2_filepath))
|
||||
|
||||
self.assertNotEqual(mix1_filepath, mix2_filepath)
|
||||
|
||||
def testHardClippingLogExpected(self):
|
||||
"""Checks that hard clipping warning is raised when occurring."""
|
||||
logging.warning = mock.MagicMock(name='warning')
|
||||
_ = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_2']['filepath'])
|
||||
logging.warning.assert_called_once_with(
|
||||
input_mixer.ApmInputMixer.HardClippingLogMessage())
|
||||
|
||||
def testHardClippingLogNotExpected(self):
|
||||
"""Checks that hard clipping warning is not raised when not occurring."""
|
||||
logging.warning = mock.MagicMock(name='warning')
|
||||
_ = input_mixer.ApmInputMixer.Mix(
|
||||
self._tmp_path, self._audio_tracks['capture']['filepath'],
|
||||
self._audio_tracks['echo_1']['filepath'])
|
||||
self.assertNotIn(
|
||||
mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()),
|
||||
logging.warning.call_args_list)
|
||||
@ -1,68 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Input signal creator module.
|
||||
"""
|
||||
|
||||
from . import exceptions
|
||||
from . import signal_processing
|
||||
|
||||
|
||||
class InputSignalCreator(object):
|
||||
"""Input signal creator class.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def Create(cls, name, raw_params):
|
||||
"""Creates a input signal and its metadata.
|
||||
|
||||
Args:
|
||||
name: Input signal creator name.
|
||||
raw_params: Tuple of parameters to pass to the specific signal creator.
|
||||
|
||||
Returns:
|
||||
(AudioSegment, dict) tuple.
|
||||
"""
|
||||
try:
|
||||
signal = {}
|
||||
params = {}
|
||||
|
||||
if name == 'pure_tone':
|
||||
params['frequency'] = float(raw_params[0])
|
||||
params['duration'] = int(raw_params[1])
|
||||
signal = cls._CreatePureTone(params['frequency'],
|
||||
params['duration'])
|
||||
else:
|
||||
raise exceptions.InputSignalCreatorException(
|
||||
'Invalid input signal creator name')
|
||||
|
||||
# Complete metadata.
|
||||
params['signal'] = name
|
||||
|
||||
return signal, params
|
||||
except (TypeError, AssertionError) as e:
|
||||
raise exceptions.InputSignalCreatorException(
|
||||
'Invalid signal creator parameters: {}'.format(e))
|
||||
|
||||
@classmethod
|
||||
def _CreatePureTone(cls, frequency, duration):
|
||||
"""
|
||||
Generates a pure tone at 48000 Hz.
|
||||
|
||||
Args:
|
||||
frequency: Float in (0-24000] (Hz).
|
||||
duration: Integer (milliseconds).
|
||||
|
||||
Returns:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
assert 0 < frequency <= 24000
|
||||
assert duration > 0
|
||||
template = signal_processing.SignalProcessingUtils.GenerateSilence(
|
||||
duration)
|
||||
return signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
template, frequency)
|
||||
@ -1,32 +0,0 @@
|
||||
/* Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
td.selected-score {
|
||||
background-color: #DDD;
|
||||
}
|
||||
|
||||
td.single-score-cell{
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.audio-inspector {
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.audio-inspector div{
|
||||
margin-bottom: 0;
|
||||
padding-bottom: 0;
|
||||
padding-top: 0;
|
||||
}
|
||||
|
||||
.audio-inspector div div{
|
||||
margin-bottom: 0;
|
||||
padding-bottom: 0;
|
||||
padding-top: 0;
|
||||
}
|
||||
@ -1,376 +0,0 @@
|
||||
// Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
//
|
||||
// Use of this source code is governed by a BSD-style license
|
||||
// that can be found in the LICENSE file in the root of the source
|
||||
// tree. An additional intellectual property rights grant can be found
|
||||
// in the file PATENTS. All contributing project authors may
|
||||
// be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
/**
|
||||
* Opens the score stats inspector dialog.
|
||||
* @param {String} dialogId: identifier of the dialog to show.
|
||||
* @return {DOMElement} The dialog element that has been opened.
|
||||
*/
|
||||
function openScoreStatsInspector(dialogId) {
|
||||
var dialog = document.getElementById(dialogId);
|
||||
dialog.showModal();
|
||||
return dialog;
|
||||
}
|
||||
|
||||
/**
|
||||
* Closes the score stats inspector dialog.
|
||||
*/
|
||||
function closeScoreStatsInspector() {
|
||||
var dialog = document.querySelector('dialog[open]');
|
||||
if (dialog == null)
|
||||
return;
|
||||
dialog.close();
|
||||
}
|
||||
|
||||
/**
|
||||
* Audio inspector class.
|
||||
* @constructor
|
||||
*/
|
||||
function AudioInspector() {
|
||||
console.debug('Creating an AudioInspector instance.');
|
||||
this.audioPlayer_ = new Audio();
|
||||
this.metadata_ = {};
|
||||
this.currentScore_ = null;
|
||||
this.audioInspector_ = null;
|
||||
this.snackbarContainer_ = document.querySelector('#snackbar');
|
||||
|
||||
// Get base URL without anchors.
|
||||
this.baseUrl_ = window.location.href;
|
||||
var index = this.baseUrl_.indexOf('#');
|
||||
if (index > 0)
|
||||
this.baseUrl_ = this.baseUrl_.substr(0, index)
|
||||
console.info('Base URL set to "' + window.location.href + '".');
|
||||
|
||||
window.event.stopPropagation();
|
||||
this.createTextAreasForCopy_();
|
||||
this.createAudioInspector_();
|
||||
this.initializeEventHandlers_();
|
||||
|
||||
// When MDL is ready, parse the anchor (if any) to show the requested
|
||||
// experiment.
|
||||
var self = this;
|
||||
document.querySelectorAll('header a')[0].addEventListener(
|
||||
'mdl-componentupgraded', function() {
|
||||
if (!self.parseWindowAnchor()) {
|
||||
// If not experiment is requested, open the first section.
|
||||
console.info('No anchor parsing, opening the first section.');
|
||||
document.querySelectorAll('header a > span')[0].click();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse the anchor in the window URL.
|
||||
* @return {bool} True if the parsing succeeded.
|
||||
*/
|
||||
AudioInspector.prototype.parseWindowAnchor = function() {
|
||||
var index = location.href.indexOf('#');
|
||||
if (index == -1) {
|
||||
console.debug('No # found in the URL.');
|
||||
return false;
|
||||
}
|
||||
|
||||
var anchor = location.href.substr(index - location.href.length + 1);
|
||||
console.info('Anchor changed: "' + anchor + '".');
|
||||
|
||||
var parts = anchor.split('&');
|
||||
if (parts.length != 3) {
|
||||
console.info('Ignoring anchor with invalid number of fields.');
|
||||
return false;
|
||||
}
|
||||
|
||||
var openDialog = document.querySelector('dialog[open]');
|
||||
try {
|
||||
// Open the requested dialog if not already open.
|
||||
if (!openDialog || openDialog.id != parts[1]) {
|
||||
!openDialog || openDialog.close();
|
||||
document.querySelectorAll('header a > span')[
|
||||
parseInt(parts[0].substr(1))].click();
|
||||
openDialog = openScoreStatsInspector(parts[1]);
|
||||
}
|
||||
|
||||
// Trigger click on cell.
|
||||
var cell = openDialog.querySelector('td.' + parts[2]);
|
||||
cell.focus();
|
||||
cell.click();
|
||||
|
||||
this.showNotification_('Experiment selected.');
|
||||
return true;
|
||||
} catch (e) {
|
||||
this.showNotification_('Cannot select experiment :(');
|
||||
console.error('Exception caught while selecting experiment: "' + e + '".');
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set up the inspector for a new score.
|
||||
* @param {DOMElement} element: Element linked to the selected score.
|
||||
*/
|
||||
AudioInspector.prototype.selectedScoreChange = function(element) {
|
||||
if (this.currentScore_ == element) { return; }
|
||||
if (this.currentScore_ != null) {
|
||||
this.currentScore_.classList.remove('selected-score');
|
||||
}
|
||||
this.currentScore_ = element;
|
||||
this.currentScore_.classList.add('selected-score');
|
||||
this.stopAudio();
|
||||
|
||||
// Read metadata.
|
||||
var matches = element.querySelectorAll('input[type=hidden]');
|
||||
this.metadata_ = {};
|
||||
for (var index = 0; index < matches.length; ++index) {
|
||||
this.metadata_[matches[index].name] = matches[index].value;
|
||||
}
|
||||
|
||||
// Show the audio inspector interface.
|
||||
var container = element.parentNode.parentNode.parentNode.parentNode;
|
||||
var audioInspectorPlaceholder = container.querySelector(
|
||||
'.audio-inspector-placeholder');
|
||||
this.moveInspector_(audioInspectorPlaceholder);
|
||||
};
|
||||
|
||||
/**
|
||||
* Stop playing audio.
|
||||
*/
|
||||
AudioInspector.prototype.stopAudio = function() {
|
||||
console.info('Pausing audio play out.');
|
||||
this.audioPlayer_.pause();
|
||||
};
|
||||
|
||||
/**
|
||||
* Show a text message using the snackbar.
|
||||
*/
|
||||
AudioInspector.prototype.showNotification_ = function(text) {
|
||||
try {
|
||||
this.snackbarContainer_.MaterialSnackbar.showSnackbar({
|
||||
message: text, timeout: 2000});
|
||||
} catch (e) {
|
||||
// Fallback to an alert.
|
||||
alert(text);
|
||||
console.warn('Cannot use snackbar: "' + e + '"');
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Move the audio inspector DOM node into the given parent.
|
||||
* @param {DOMElement} newParentNode: New parent for the inspector.
|
||||
*/
|
||||
AudioInspector.prototype.moveInspector_ = function(newParentNode) {
|
||||
newParentNode.appendChild(this.audioInspector_);
|
||||
};
|
||||
|
||||
/**
|
||||
* Play audio file from url.
|
||||
* @param {string} metadataFieldName: Metadata field name.
|
||||
*/
|
||||
AudioInspector.prototype.playAudio = function(metadataFieldName) {
|
||||
if (this.metadata_[metadataFieldName] == undefined) { return; }
|
||||
if (this.metadata_[metadataFieldName] == 'None') {
|
||||
alert('The selected stream was not used during the experiment.');
|
||||
return;
|
||||
}
|
||||
this.stopAudio();
|
||||
this.audioPlayer_.src = this.metadata_[metadataFieldName];
|
||||
console.debug('Audio source URL: "' + this.audioPlayer_.src + '"');
|
||||
this.audioPlayer_.play();
|
||||
console.info('Playing out audio.');
|
||||
};
|
||||
|
||||
/**
|
||||
* Create hidden text areas to copy URLs.
|
||||
*
|
||||
* For each dialog, one text area is created since it is not possible to select
|
||||
* text on a text area outside of the active dialog.
|
||||
*/
|
||||
AudioInspector.prototype.createTextAreasForCopy_ = function() {
|
||||
var self = this;
|
||||
document.querySelectorAll('dialog.mdl-dialog').forEach(function(element) {
|
||||
var textArea = document.createElement("textarea");
|
||||
textArea.classList.add('url-copy');
|
||||
textArea.style.position = 'fixed';
|
||||
textArea.style.bottom = 0;
|
||||
textArea.style.left = 0;
|
||||
textArea.style.width = '2em';
|
||||
textArea.style.height = '2em';
|
||||
textArea.style.border = 'none';
|
||||
textArea.style.outline = 'none';
|
||||
textArea.style.boxShadow = 'none';
|
||||
textArea.style.background = 'transparent';
|
||||
textArea.style.fontSize = '6px';
|
||||
element.appendChild(textArea);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Create audio inspector.
|
||||
*/
|
||||
AudioInspector.prototype.createAudioInspector_ = function() {
|
||||
var buttonIndex = 0;
|
||||
function getButtonHtml(icon, toolTipText, caption, metadataFieldName) {
|
||||
var buttonId = 'audioInspectorButton' + buttonIndex++;
|
||||
html = caption == null ? '' : caption;
|
||||
html += '<button class="mdl-button mdl-js-button mdl-button--icon ' +
|
||||
'mdl-js-ripple-effect" id="' + buttonId + '">' +
|
||||
'<i class="material-icons">' + icon + '</i>' +
|
||||
'<div class="mdl-tooltip" data-mdl-for="' + buttonId + '">' +
|
||||
toolTipText +
|
||||
'</div>';
|
||||
if (metadataFieldName != null) {
|
||||
html += '<input type="hidden" value="' + metadataFieldName + '">'
|
||||
}
|
||||
html += '</button>'
|
||||
|
||||
return html;
|
||||
}
|
||||
|
||||
// TODO(alessiob): Add timeline and highlight current track by changing icon
|
||||
// color.
|
||||
|
||||
this.audioInspector_ = document.createElement('div');
|
||||
this.audioInspector_.classList.add('audio-inspector');
|
||||
this.audioInspector_.innerHTML =
|
||||
'<div class="mdl-grid">' +
|
||||
'<div class="mdl-layout-spacer"></div>' +
|
||||
'<div class="mdl-cell mdl-cell--2-col">' +
|
||||
getButtonHtml('play_arrow', 'Simulated echo', 'E<sub>in</sub>',
|
||||
'echo_filepath') +
|
||||
'</div>' +
|
||||
'<div class="mdl-cell mdl-cell--2-col">' +
|
||||
getButtonHtml('stop', 'Stop playing [S]', null, '__stop__') +
|
||||
'</div>' +
|
||||
'<div class="mdl-cell mdl-cell--2-col">' +
|
||||
getButtonHtml('play_arrow', 'Render stream', 'R<sub>in</sub>',
|
||||
'render_filepath') +
|
||||
'</div>' +
|
||||
'<div class="mdl-layout-spacer"></div>' +
|
||||
'</div>' +
|
||||
'<div class="mdl-grid">' +
|
||||
'<div class="mdl-layout-spacer"></div>' +
|
||||
'<div class="mdl-cell mdl-cell--2-col">' +
|
||||
getButtonHtml('play_arrow', 'Capture stream (APM input) [1]',
|
||||
'Y\'<sub>in</sub>', 'capture_filepath') +
|
||||
'</div>' +
|
||||
'<div class="mdl-cell mdl-cell--2-col"><strong>APM</strong></div>' +
|
||||
'<div class="mdl-cell mdl-cell--2-col">' +
|
||||
getButtonHtml('play_arrow', 'APM output [2]', 'Y<sub>out</sub>',
|
||||
'apm_output_filepath') +
|
||||
'</div>' +
|
||||
'<div class="mdl-layout-spacer"></div>' +
|
||||
'</div>' +
|
||||
'<div class="mdl-grid">' +
|
||||
'<div class="mdl-layout-spacer"></div>' +
|
||||
'<div class="mdl-cell mdl-cell--2-col">' +
|
||||
getButtonHtml('play_arrow', 'Echo-free capture stream',
|
||||
'Y<sub>in</sub>', 'echo_free_capture_filepath') +
|
||||
'</div>' +
|
||||
'<div class="mdl-cell mdl-cell--2-col">' +
|
||||
getButtonHtml('play_arrow', 'Clean capture stream',
|
||||
'Y<sub>clean</sub>', 'clean_capture_input_filepath') +
|
||||
'</div>' +
|
||||
'<div class="mdl-cell mdl-cell--2-col">' +
|
||||
getButtonHtml('play_arrow', 'APM reference [3]', 'Y<sub>ref</sub>',
|
||||
'apm_reference_filepath') +
|
||||
'</div>' +
|
||||
'<div class="mdl-layout-spacer"></div>' +
|
||||
'</div>';
|
||||
|
||||
// Add an invisible node as initial container for the audio inspector.
|
||||
var parent = document.createElement('div');
|
||||
parent.style.display = 'none';
|
||||
this.moveInspector_(parent);
|
||||
document.body.appendChild(parent);
|
||||
};
|
||||
|
||||
/**
|
||||
* Initialize event handlers.
|
||||
*/
|
||||
AudioInspector.prototype.initializeEventHandlers_ = function() {
|
||||
var self = this;
|
||||
|
||||
// Score cells.
|
||||
document.querySelectorAll('td.single-score-cell').forEach(function(element) {
|
||||
element.onclick = function() {
|
||||
self.selectedScoreChange(this);
|
||||
}
|
||||
});
|
||||
|
||||
// Copy anchor URLs icons.
|
||||
if (document.queryCommandSupported('copy')) {
|
||||
document.querySelectorAll('td.single-score-cell button').forEach(
|
||||
function(element) {
|
||||
element.onclick = function() {
|
||||
// Find the text area in the dialog.
|
||||
var textArea = element.closest('dialog').querySelector(
|
||||
'textarea.url-copy');
|
||||
|
||||
// Copy.
|
||||
textArea.value = self.baseUrl_ + '#' + element.getAttribute(
|
||||
'data-anchor');
|
||||
textArea.select();
|
||||
try {
|
||||
if (!document.execCommand('copy'))
|
||||
throw 'Copy returned false';
|
||||
self.showNotification_('Experiment URL copied.');
|
||||
} catch (e) {
|
||||
self.showNotification_('Cannot copy experiment URL :(');
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
});
|
||||
} else {
|
||||
self.showNotification_(
|
||||
'The copy command is disabled. URL copy is not enabled.');
|
||||
}
|
||||
|
||||
// Audio inspector buttons.
|
||||
this.audioInspector_.querySelectorAll('button').forEach(function(element) {
|
||||
var target = element.querySelector('input[type=hidden]');
|
||||
if (target == null) { return; }
|
||||
element.onclick = function() {
|
||||
if (target.value == '__stop__') {
|
||||
self.stopAudio();
|
||||
} else {
|
||||
self.playAudio(target.value);
|
||||
}
|
||||
};
|
||||
});
|
||||
|
||||
// Dialog close handlers.
|
||||
var dialogs = document.querySelectorAll('dialog').forEach(function(element) {
|
||||
element.onclose = function() {
|
||||
self.stopAudio();
|
||||
}
|
||||
});
|
||||
|
||||
// Keyboard shortcuts.
|
||||
window.onkeyup = function(e) {
|
||||
var key = e.keyCode ? e.keyCode : e.which;
|
||||
switch (key) {
|
||||
case 49: // 1.
|
||||
self.playAudio('capture_filepath');
|
||||
break;
|
||||
case 50: // 2.
|
||||
self.playAudio('apm_output_filepath');
|
||||
break;
|
||||
case 51: // 3.
|
||||
self.playAudio('apm_reference_filepath');
|
||||
break;
|
||||
case 83: // S.
|
||||
case 115: // s.
|
||||
self.stopAudio();
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
// Hash change.
|
||||
window.onhashchange = function(e) {
|
||||
self.parseWindowAnchor();
|
||||
}
|
||||
};
|
||||
@ -1,359 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Signal processing utility module.
|
||||
"""
|
||||
|
||||
import array
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import enum
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package numpy')
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import pydub
|
||||
import pydub.generators
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package pydub')
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
import scipy.signal
|
||||
import scipy.fftpack
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package scipy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import exceptions
|
||||
|
||||
|
||||
class SignalProcessingUtils(object):
|
||||
"""Collection of signal processing utilities.
|
||||
"""
|
||||
|
||||
@enum.unique
|
||||
class MixPadding(enum.Enum):
|
||||
NO_PADDING = 0
|
||||
ZERO_PADDING = 1
|
||||
LOOP = 2
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def LoadWav(cls, filepath, channels=1):
|
||||
"""Loads wav file.
|
||||
|
||||
Args:
|
||||
filepath: path to the wav audio track file to load.
|
||||
channels: number of channels (downmixing to mono by default).
|
||||
|
||||
Returns:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
if not os.path.exists(filepath):
|
||||
logging.error('cannot find the <%s> audio track file', filepath)
|
||||
raise exceptions.FileNotFoundError()
|
||||
return pydub.AudioSegment.from_file(filepath,
|
||||
format='wav',
|
||||
channels=channels)
|
||||
|
||||
@classmethod
|
||||
def SaveWav(cls, output_filepath, signal):
|
||||
"""Saves wav file.
|
||||
|
||||
Args:
|
||||
output_filepath: path to the wav audio track file to save.
|
||||
signal: AudioSegment instance.
|
||||
"""
|
||||
return signal.export(output_filepath, format='wav')
|
||||
|
||||
@classmethod
|
||||
def CountSamples(cls, signal):
|
||||
"""Number of samples per channel.
|
||||
|
||||
Args:
|
||||
signal: AudioSegment instance.
|
||||
|
||||
Returns:
|
||||
An integer.
|
||||
"""
|
||||
number_of_samples = len(signal.get_array_of_samples())
|
||||
assert signal.channels > 0
|
||||
assert number_of_samples % signal.channels == 0
|
||||
return number_of_samples / signal.channels
|
||||
|
||||
@classmethod
|
||||
def GenerateSilence(cls, duration=1000, sample_rate=48000):
|
||||
"""Generates silence.
|
||||
|
||||
This method can also be used to create a template AudioSegment instance.
|
||||
A template can then be used with other Generate*() methods accepting an
|
||||
AudioSegment instance as argument.
|
||||
|
||||
Args:
|
||||
duration: duration in ms.
|
||||
sample_rate: sample rate.
|
||||
|
||||
Returns:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
return pydub.AudioSegment.silent(duration, sample_rate)
|
||||
|
||||
@classmethod
|
||||
def GeneratePureTone(cls, template, frequency=440.0):
|
||||
"""Generates a pure tone.
|
||||
|
||||
The pure tone is generated with the same duration and in the same format of
|
||||
the given template signal.
|
||||
|
||||
Args:
|
||||
template: AudioSegment instance.
|
||||
frequency: Frequency of the pure tone in Hz.
|
||||
|
||||
Return:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
if frequency > template.frame_rate >> 1:
|
||||
raise exceptions.SignalProcessingException('Invalid frequency')
|
||||
|
||||
generator = pydub.generators.Sine(sample_rate=template.frame_rate,
|
||||
bit_depth=template.sample_width * 8,
|
||||
freq=frequency)
|
||||
|
||||
return generator.to_audio_segment(duration=len(template), volume=0.0)
|
||||
|
||||
@classmethod
|
||||
def GenerateWhiteNoise(cls, template):
|
||||
"""Generates white noise.
|
||||
|
||||
The white noise is generated with the same duration and in the same format
|
||||
of the given template signal.
|
||||
|
||||
Args:
|
||||
template: AudioSegment instance.
|
||||
|
||||
Return:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
generator = pydub.generators.WhiteNoise(
|
||||
sample_rate=template.frame_rate,
|
||||
bit_depth=template.sample_width * 8)
|
||||
return generator.to_audio_segment(duration=len(template), volume=0.0)
|
||||
|
||||
@classmethod
|
||||
def AudioSegmentToRawData(cls, signal):
|
||||
samples = signal.get_array_of_samples()
|
||||
if samples.typecode != 'h':
|
||||
raise exceptions.SignalProcessingException(
|
||||
'Unsupported samples type')
|
||||
return np.array(signal.get_array_of_samples(), np.int16)
|
||||
|
||||
@classmethod
|
||||
def Fft(cls, signal, normalize=True):
|
||||
if signal.channels != 1:
|
||||
raise NotImplementedError('multiple-channel FFT not implemented')
|
||||
x = cls.AudioSegmentToRawData(signal).astype(np.float32)
|
||||
if normalize:
|
||||
x /= max(abs(np.max(x)), 1.0)
|
||||
y = scipy.fftpack.fft(x)
|
||||
return y[:len(y) / 2]
|
||||
|
||||
@classmethod
|
||||
def DetectHardClipping(cls, signal, threshold=2):
|
||||
"""Detects hard clipping.
|
||||
|
||||
Hard clipping is simply detected by counting samples that touch either the
|
||||
lower or upper bound too many times in a row (according to `threshold`).
|
||||
The presence of a single sequence of samples meeting such property is enough
|
||||
to label the signal as hard clipped.
|
||||
|
||||
Args:
|
||||
signal: AudioSegment instance.
|
||||
threshold: minimum number of samples at full-scale in a row.
|
||||
|
||||
Returns:
|
||||
True if hard clipping is detect, False otherwise.
|
||||
"""
|
||||
if signal.channels != 1:
|
||||
raise NotImplementedError(
|
||||
'multiple-channel clipping not implemented')
|
||||
if signal.sample_width != 2: # Note that signal.sample_width is in bytes.
|
||||
raise exceptions.SignalProcessingException(
|
||||
'hard-clipping detection only supported for 16 bit samples')
|
||||
samples = cls.AudioSegmentToRawData(signal)
|
||||
|
||||
# Detect adjacent clipped samples.
|
||||
samples_type_info = np.iinfo(samples.dtype)
|
||||
mask_min = samples == samples_type_info.min
|
||||
mask_max = samples == samples_type_info.max
|
||||
|
||||
def HasLongSequence(vector, min_legth=threshold):
|
||||
"""Returns True if there are one or more long sequences of True flags."""
|
||||
seq_length = 0
|
||||
for b in vector:
|
||||
seq_length = seq_length + 1 if b else 0
|
||||
if seq_length >= min_legth:
|
||||
return True
|
||||
return False
|
||||
|
||||
return HasLongSequence(mask_min) or HasLongSequence(mask_max)
|
||||
|
||||
@classmethod
|
||||
def ApplyImpulseResponse(cls, signal, impulse_response):
|
||||
"""Applies an impulse response to a signal.
|
||||
|
||||
Args:
|
||||
signal: AudioSegment instance.
|
||||
impulse_response: list or numpy vector of float values.
|
||||
|
||||
Returns:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
# Get samples.
|
||||
assert signal.channels == 1, (
|
||||
'multiple-channel recordings not supported')
|
||||
samples = signal.get_array_of_samples()
|
||||
|
||||
# Convolve.
|
||||
logging.info(
|
||||
'applying %d order impulse response to a signal lasting %d ms',
|
||||
len(impulse_response), len(signal))
|
||||
convolved_samples = scipy.signal.fftconvolve(in1=samples,
|
||||
in2=impulse_response,
|
||||
mode='full').astype(
|
||||
np.int16)
|
||||
logging.info('convolution computed')
|
||||
|
||||
# Cast.
|
||||
convolved_samples = array.array(signal.array_type, convolved_samples)
|
||||
|
||||
# Verify.
|
||||
logging.debug('signal length: %d samples', len(samples))
|
||||
logging.debug('convolved signal length: %d samples',
|
||||
len(convolved_samples))
|
||||
assert len(convolved_samples) > len(samples)
|
||||
|
||||
# Generate convolved signal AudioSegment instance.
|
||||
convolved_signal = pydub.AudioSegment(data=convolved_samples,
|
||||
metadata={
|
||||
'sample_width':
|
||||
signal.sample_width,
|
||||
'frame_rate':
|
||||
signal.frame_rate,
|
||||
'frame_width':
|
||||
signal.frame_width,
|
||||
'channels': signal.channels,
|
||||
})
|
||||
assert len(convolved_signal) > len(signal)
|
||||
|
||||
return convolved_signal
|
||||
|
||||
@classmethod
|
||||
def Normalize(cls, signal):
|
||||
"""Normalizes a signal.
|
||||
|
||||
Args:
|
||||
signal: AudioSegment instance.
|
||||
|
||||
Returns:
|
||||
An AudioSegment instance.
|
||||
"""
|
||||
return signal.apply_gain(-signal.max_dBFS)
|
||||
|
||||
@classmethod
|
||||
def Copy(cls, signal):
|
||||
"""Makes a copy os a signal.
|
||||
|
||||
Args:
|
||||
signal: AudioSegment instance.
|
||||
|
||||
Returns:
|
||||
An AudioSegment instance.
|
||||
"""
|
||||
return pydub.AudioSegment(data=signal.get_array_of_samples(),
|
||||
metadata={
|
||||
'sample_width': signal.sample_width,
|
||||
'frame_rate': signal.frame_rate,
|
||||
'frame_width': signal.frame_width,
|
||||
'channels': signal.channels,
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def MixSignals(cls,
|
||||
signal,
|
||||
noise,
|
||||
target_snr=0.0,
|
||||
pad_noise=MixPadding.NO_PADDING):
|
||||
"""Mixes `signal` and `noise` with a target SNR.
|
||||
|
||||
Mix `signal` and `noise` with a desired SNR by scaling `noise`.
|
||||
If the target SNR is +/- infinite, a copy of signal/noise is returned.
|
||||
If `signal` is shorter than `noise`, the length of the mix equals that of
|
||||
`signal`. Otherwise, the mix length depends on whether padding is applied.
|
||||
When padding is not applied, that is `pad_noise` is set to NO_PADDING
|
||||
(default), the mix length equals that of `noise` - i.e., `signal` is
|
||||
truncated. Otherwise, `noise` is extended and the resulting mix has the same
|
||||
length of `signal`.
|
||||
|
||||
Args:
|
||||
signal: AudioSegment instance (signal).
|
||||
noise: AudioSegment instance (noise).
|
||||
target_snr: float, numpy.Inf or -numpy.Inf (dB).
|
||||
pad_noise: SignalProcessingUtils.MixPadding, default: NO_PADDING.
|
||||
|
||||
Returns:
|
||||
An AudioSegment instance.
|
||||
"""
|
||||
# Handle infinite target SNR.
|
||||
if target_snr == -np.inf:
|
||||
# Return a copy of noise.
|
||||
logging.warning('SNR = -Inf, returning noise')
|
||||
return cls.Copy(noise)
|
||||
elif target_snr == np.inf:
|
||||
# Return a copy of signal.
|
||||
logging.warning('SNR = +Inf, returning signal')
|
||||
return cls.Copy(signal)
|
||||
|
||||
# Check signal and noise power.
|
||||
signal_power = float(signal.dBFS)
|
||||
noise_power = float(noise.dBFS)
|
||||
if signal_power == -np.inf:
|
||||
logging.error('signal has -Inf power, cannot mix')
|
||||
raise exceptions.SignalProcessingException(
|
||||
'cannot mix a signal with -Inf power')
|
||||
if noise_power == -np.inf:
|
||||
logging.error('noise has -Inf power, cannot mix')
|
||||
raise exceptions.SignalProcessingException(
|
||||
'cannot mix a signal with -Inf power')
|
||||
|
||||
# Mix.
|
||||
gain_db = signal_power - noise_power - target_snr
|
||||
signal_duration = len(signal)
|
||||
noise_duration = len(noise)
|
||||
if signal_duration <= noise_duration:
|
||||
# Ignore `pad_noise`, `noise` is truncated if longer that `signal`, the
|
||||
# mix will have the same length of `signal`.
|
||||
return signal.overlay(noise.apply_gain(gain_db))
|
||||
elif pad_noise == cls.MixPadding.NO_PADDING:
|
||||
# `signal` is longer than `noise`, but no padding is applied to `noise`.
|
||||
# Truncate `signal`.
|
||||
return noise.overlay(signal, gain_during_overlay=gain_db)
|
||||
elif pad_noise == cls.MixPadding.ZERO_PADDING:
|
||||
# TODO(alessiob): Check that this works as expected.
|
||||
return signal.overlay(noise.apply_gain(gain_db))
|
||||
elif pad_noise == cls.MixPadding.LOOP:
|
||||
# `signal` is longer than `noise`, extend `noise` by looping.
|
||||
return signal.overlay(noise.apply_gain(gain_db), loop=True)
|
||||
else:
|
||||
raise exceptions.SignalProcessingException('invalid padding type')
|
||||
@ -1,183 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Unit tests for the signal_processing module.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pydub
|
||||
|
||||
from . import exceptions
|
||||
from . import signal_processing
|
||||
|
||||
|
||||
class TestSignalProcessing(unittest.TestCase):
|
||||
"""Unit tests for the signal_processing module.
|
||||
"""
|
||||
|
||||
def testMixSignals(self):
|
||||
# Generate a template signal with which white noise can be generated.
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
|
||||
# Generate two distinct AudioSegment instances with 1 second of white noise.
|
||||
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
|
||||
# Extract samples.
|
||||
signal_samples = signal.get_array_of_samples()
|
||||
noise_samples = noise.get_array_of_samples()
|
||||
|
||||
# Test target SNR -inf (noise expected).
|
||||
mix_neg_inf = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, noise, -np.inf)
|
||||
self.assertTrue(len(noise), len(mix_neg_inf)) # Check duration.
|
||||
mix_neg_inf_samples = mix_neg_inf.get_array_of_samples()
|
||||
self.assertTrue( # Check samples.
|
||||
all([x == y for x, y in zip(noise_samples, mix_neg_inf_samples)]))
|
||||
|
||||
# Test target SNR 0.0 (different data expected).
|
||||
mix_0 = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, noise, 0.0)
|
||||
self.assertTrue(len(signal), len(mix_0)) # Check duration.
|
||||
self.assertTrue(len(noise), len(mix_0))
|
||||
mix_0_samples = mix_0.get_array_of_samples()
|
||||
self.assertTrue(
|
||||
any([x != y for x, y in zip(signal_samples, mix_0_samples)]))
|
||||
self.assertTrue(
|
||||
any([x != y for x, y in zip(noise_samples, mix_0_samples)]))
|
||||
|
||||
# Test target SNR +inf (signal expected).
|
||||
mix_pos_inf = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, noise, np.inf)
|
||||
self.assertTrue(len(signal), len(mix_pos_inf)) # Check duration.
|
||||
mix_pos_inf_samples = mix_pos_inf.get_array_of_samples()
|
||||
self.assertTrue( # Check samples.
|
||||
all([x == y for x, y in zip(signal_samples, mix_pos_inf_samples)]))
|
||||
|
||||
def testMixSignalsMinInfPower(self):
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
|
||||
with self.assertRaises(exceptions.SignalProcessingException):
|
||||
_ = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal, silence, 0.0)
|
||||
|
||||
with self.assertRaises(exceptions.SignalProcessingException):
|
||||
_ = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
silence, signal, 0.0)
|
||||
|
||||
def testMixSignalNoiseDifferentLengths(self):
|
||||
# Test signals.
|
||||
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
|
||||
longer = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
pydub.AudioSegment.silent(duration=2000, frame_rate=8000))
|
||||
|
||||
# When the signal is shorter than the noise, the mix length always equals
|
||||
# that of the signal regardless of whether padding is applied.
|
||||
# No noise padding, length of signal less than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=shorter,
|
||||
noise=longer,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
|
||||
NO_PADDING)
|
||||
self.assertEqual(len(shorter), len(mix))
|
||||
# With noise padding, length of signal less than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=shorter,
|
||||
noise=longer,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
|
||||
ZERO_PADDING)
|
||||
self.assertEqual(len(shorter), len(mix))
|
||||
|
||||
# When the signal is longer than the noise, the mix length depends on
|
||||
# whether padding is applied.
|
||||
# No noise padding, length of signal greater than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
|
||||
NO_PADDING)
|
||||
self.assertEqual(len(shorter), len(mix))
|
||||
# With noise padding, length of signal greater than that of noise.
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
|
||||
ZERO_PADDING)
|
||||
self.assertEqual(len(longer), len(mix))
|
||||
|
||||
def testMixSignalNoisePaddingTypes(self):
|
||||
# Test signals.
|
||||
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
|
||||
longer = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
pydub.AudioSegment.silent(duration=2000, frame_rate=8000), 440.0)
|
||||
|
||||
# Zero padding: expect pure tone only in 1-2s.
|
||||
mix_zero_pad = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
target_snr=-6,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
|
||||
ZERO_PADDING)
|
||||
|
||||
# Loop: expect pure tone plus noise in 1-2s.
|
||||
mix_loop = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=longer,
|
||||
noise=shorter,
|
||||
target_snr=-6,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.LOOP)
|
||||
|
||||
def Energy(signal):
|
||||
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
|
||||
signal).astype(np.float32)
|
||||
return np.sum(samples * samples)
|
||||
|
||||
e_mix_zero_pad = Energy(mix_zero_pad[-1000:])
|
||||
e_mix_loop = Energy(mix_loop[-1000:])
|
||||
self.assertLess(0, e_mix_zero_pad)
|
||||
self.assertLess(e_mix_zero_pad, e_mix_loop)
|
||||
|
||||
def testMixSignalSnr(self):
|
||||
# Test signals.
|
||||
tone_low = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 250.0)
|
||||
tone_high = signal_processing.SignalProcessingUtils.GeneratePureTone(
|
||||
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 3000.0)
|
||||
|
||||
def ToneAmplitudes(mix):
|
||||
"""Returns the amplitude of the coefficients #16 and #192, which
|
||||
correspond to the tones at 250 and 3k Hz respectively."""
|
||||
mix_fft = np.absolute(
|
||||
signal_processing.SignalProcessingUtils.Fft(mix))
|
||||
return mix_fft[16], mix_fft[192]
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_low, noise=tone_high, target_snr=-6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_low, ampl_high)
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_high, noise=tone_low, target_snr=-6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_high, ampl_low)
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_low, noise=tone_high, target_snr=6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_high, ampl_low)
|
||||
|
||||
mix = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
signal=tone_high, noise=tone_low, target_snr=6)
|
||||
ampl_low, ampl_high = ToneAmplitudes(mix)
|
||||
self.assertLess(ampl_low, ampl_high)
|
||||
@ -1,446 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""APM module simulator.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from . import annotations
|
||||
from . import data_access
|
||||
from . import echo_path_simulation
|
||||
from . import echo_path_simulation_factory
|
||||
from . import eval_scores
|
||||
from . import exceptions
|
||||
from . import input_mixer
|
||||
from . import input_signal_creator
|
||||
from . import signal_processing
|
||||
from . import test_data_generation
|
||||
|
||||
|
||||
class ApmModuleSimulator(object):
|
||||
"""Audio processing module (APM) simulator class.
|
||||
"""
|
||||
|
||||
_TEST_DATA_GENERATOR_CLASSES = (
|
||||
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
|
||||
_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES
|
||||
|
||||
_PREFIX_APM_CONFIG = 'apmcfg-'
|
||||
_PREFIX_CAPTURE = 'capture-'
|
||||
_PREFIX_RENDER = 'render-'
|
||||
_PREFIX_ECHO_SIMULATOR = 'echosim-'
|
||||
_PREFIX_TEST_DATA_GEN = 'datagen-'
|
||||
_PREFIX_TEST_DATA_GEN_PARAMS = 'datagen_params-'
|
||||
_PREFIX_SCORE = 'score-'
|
||||
|
||||
def __init__(self,
|
||||
test_data_generator_factory,
|
||||
evaluation_score_factory,
|
||||
ap_wrapper,
|
||||
evaluator,
|
||||
external_vads=None):
|
||||
if external_vads is None:
|
||||
external_vads = {}
|
||||
self._test_data_generator_factory = test_data_generator_factory
|
||||
self._evaluation_score_factory = evaluation_score_factory
|
||||
self._audioproc_wrapper = ap_wrapper
|
||||
self._evaluator = evaluator
|
||||
self._annotator = annotations.AudioAnnotationsExtractor(
|
||||
annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD
|
||||
| annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO
|
||||
| annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM,
|
||||
external_vads)
|
||||
|
||||
# Init.
|
||||
self._test_data_generator_factory.SetOutputDirectoryPrefix(
|
||||
self._PREFIX_TEST_DATA_GEN_PARAMS)
|
||||
self._evaluation_score_factory.SetScoreFilenamePrefix(
|
||||
self._PREFIX_SCORE)
|
||||
|
||||
# Properties for each run.
|
||||
self._base_output_path = None
|
||||
self._output_cache_path = None
|
||||
self._test_data_generators = None
|
||||
self._evaluation_score_workers = None
|
||||
self._config_filepaths = None
|
||||
self._capture_input_filepaths = None
|
||||
self._render_input_filepaths = None
|
||||
self._echo_path_simulator_class = None
|
||||
|
||||
@classmethod
|
||||
def GetPrefixApmConfig(cls):
|
||||
return cls._PREFIX_APM_CONFIG
|
||||
|
||||
@classmethod
|
||||
def GetPrefixCapture(cls):
|
||||
return cls._PREFIX_CAPTURE
|
||||
|
||||
@classmethod
|
||||
def GetPrefixRender(cls):
|
||||
return cls._PREFIX_RENDER
|
||||
|
||||
@classmethod
|
||||
def GetPrefixEchoSimulator(cls):
|
||||
return cls._PREFIX_ECHO_SIMULATOR
|
||||
|
||||
@classmethod
|
||||
def GetPrefixTestDataGenerator(cls):
|
||||
return cls._PREFIX_TEST_DATA_GEN
|
||||
|
||||
@classmethod
|
||||
def GetPrefixTestDataGeneratorParameters(cls):
|
||||
return cls._PREFIX_TEST_DATA_GEN_PARAMS
|
||||
|
||||
@classmethod
|
||||
def GetPrefixScore(cls):
|
||||
return cls._PREFIX_SCORE
|
||||
|
||||
def Run(self,
|
||||
config_filepaths,
|
||||
capture_input_filepaths,
|
||||
test_data_generator_names,
|
||||
eval_score_names,
|
||||
output_dir,
|
||||
render_input_filepaths=None,
|
||||
echo_path_simulator_name=(
|
||||
echo_path_simulation.NoEchoPathSimulator.NAME)):
|
||||
"""Runs the APM simulation.
|
||||
|
||||
Initializes paths and required instances, then runs all the simulations.
|
||||
The render input can be optionally added. If added, the number of capture
|
||||
input audio tracks and the number of render input audio tracks have to be
|
||||
equal. The two lists are used to form pairs of capture and render input.
|
||||
|
||||
Args:
|
||||
config_filepaths: set of APM configuration files to test.
|
||||
capture_input_filepaths: set of capture input audio track files to test.
|
||||
test_data_generator_names: set of test data generator names to test.
|
||||
eval_score_names: set of evaluation score names to test.
|
||||
output_dir: base path to the output directory for wav files and outcomes.
|
||||
render_input_filepaths: set of render input audio track files to test.
|
||||
echo_path_simulator_name: name of the echo path simulator to use when
|
||||
render input is provided.
|
||||
"""
|
||||
assert render_input_filepaths is None or (
|
||||
len(capture_input_filepaths) == len(render_input_filepaths)), (
|
||||
'render input set size not matching input set size')
|
||||
assert render_input_filepaths is None or echo_path_simulator_name in (
|
||||
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES), (
|
||||
'invalid echo path simulator')
|
||||
self._base_output_path = os.path.abspath(output_dir)
|
||||
|
||||
# Output path used to cache the data shared across simulations.
|
||||
self._output_cache_path = os.path.join(self._base_output_path,
|
||||
'_cache')
|
||||
|
||||
# Instance test data generators.
|
||||
self._test_data_generators = [
|
||||
self._test_data_generator_factory.GetInstance(
|
||||
test_data_generators_class=(
|
||||
self._TEST_DATA_GENERATOR_CLASSES[name]))
|
||||
for name in (test_data_generator_names)
|
||||
]
|
||||
|
||||
# Instance evaluation score workers.
|
||||
self._evaluation_score_workers = [
|
||||
self._evaluation_score_factory.GetInstance(
|
||||
evaluation_score_class=self._EVAL_SCORE_WORKER_CLASSES[name])
|
||||
for (name) in eval_score_names
|
||||
]
|
||||
|
||||
# Set APM configuration file paths.
|
||||
self._config_filepaths = self._CreatePathsCollection(config_filepaths)
|
||||
|
||||
# Set probing signal file paths.
|
||||
if render_input_filepaths is None:
|
||||
# Capture input only.
|
||||
self._capture_input_filepaths = self._CreatePathsCollection(
|
||||
capture_input_filepaths)
|
||||
self._render_input_filepaths = None
|
||||
else:
|
||||
# Set both capture and render input signals.
|
||||
self._SetTestInputSignalFilePaths(capture_input_filepaths,
|
||||
render_input_filepaths)
|
||||
|
||||
# Set the echo path simulator class.
|
||||
self._echo_path_simulator_class = (
|
||||
echo_path_simulation.EchoPathSimulator.
|
||||
REGISTERED_CLASSES[echo_path_simulator_name])
|
||||
|
||||
self._SimulateAll()
|
||||
|
||||
def _SimulateAll(self):
|
||||
"""Runs all the simulations.
|
||||
|
||||
Iterates over the combinations of APM configurations, probing signals, and
|
||||
test data generators. This method is mainly responsible for the creation of
|
||||
the cache and output directories required in order to call _Simulate().
|
||||
"""
|
||||
without_render_input = self._render_input_filepaths is None
|
||||
|
||||
# Try different APM config files.
|
||||
for config_name in self._config_filepaths:
|
||||
config_filepath = self._config_filepaths[config_name]
|
||||
|
||||
# Try different capture-render pairs.
|
||||
for capture_input_name in self._capture_input_filepaths:
|
||||
# Output path for the capture signal annotations.
|
||||
capture_annotations_cache_path = os.path.join(
|
||||
self._output_cache_path,
|
||||
self._PREFIX_CAPTURE + capture_input_name)
|
||||
data_access.MakeDirectory(capture_annotations_cache_path)
|
||||
|
||||
# Capture.
|
||||
capture_input_filepath = self._capture_input_filepaths[
|
||||
capture_input_name]
|
||||
if not os.path.exists(capture_input_filepath):
|
||||
# If the input signal file does not exist, try to create using the
|
||||
# available input signal creators.
|
||||
self._CreateInputSignal(capture_input_filepath)
|
||||
assert os.path.exists(capture_input_filepath)
|
||||
self._ExtractCaptureAnnotations(
|
||||
capture_input_filepath, capture_annotations_cache_path)
|
||||
|
||||
# Render and simulated echo path (optional).
|
||||
render_input_filepath = None if without_render_input else (
|
||||
self._render_input_filepaths[capture_input_name])
|
||||
render_input_name = '(none)' if without_render_input else (
|
||||
self._ExtractFileName(render_input_filepath))
|
||||
echo_path_simulator = (echo_path_simulation_factory.
|
||||
EchoPathSimulatorFactory.GetInstance(
|
||||
self._echo_path_simulator_class,
|
||||
render_input_filepath))
|
||||
|
||||
# Try different test data generators.
|
||||
for test_data_generators in self._test_data_generators:
|
||||
logging.info(
|
||||
'APM config preset: <%s>, capture: <%s>, render: <%s>,'
|
||||
'test data generator: <%s>, echo simulator: <%s>',
|
||||
config_name, capture_input_name, render_input_name,
|
||||
test_data_generators.NAME, echo_path_simulator.NAME)
|
||||
|
||||
# Output path for the generated test data.
|
||||
test_data_cache_path = os.path.join(
|
||||
capture_annotations_cache_path,
|
||||
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
|
||||
data_access.MakeDirectory(test_data_cache_path)
|
||||
logging.debug('test data cache path: <%s>',
|
||||
test_data_cache_path)
|
||||
|
||||
# Output path for the echo simulator and APM input mixer output.
|
||||
echo_test_data_cache_path = os.path.join(
|
||||
test_data_cache_path,
|
||||
'echosim-{}'.format(echo_path_simulator.NAME))
|
||||
data_access.MakeDirectory(echo_test_data_cache_path)
|
||||
logging.debug('echo test data cache path: <%s>',
|
||||
echo_test_data_cache_path)
|
||||
|
||||
# Full output path.
|
||||
output_path = os.path.join(
|
||||
self._base_output_path,
|
||||
self._PREFIX_APM_CONFIG + config_name,
|
||||
self._PREFIX_CAPTURE + capture_input_name,
|
||||
self._PREFIX_RENDER + render_input_name,
|
||||
self._PREFIX_ECHO_SIMULATOR + echo_path_simulator.NAME,
|
||||
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
|
||||
data_access.MakeDirectory(output_path)
|
||||
logging.debug('output path: <%s>', output_path)
|
||||
|
||||
self._Simulate(test_data_generators,
|
||||
capture_input_filepath,
|
||||
render_input_filepath, test_data_cache_path,
|
||||
echo_test_data_cache_path, output_path,
|
||||
config_filepath, echo_path_simulator)
|
||||
|
||||
@staticmethod
|
||||
def _CreateInputSignal(input_signal_filepath):
|
||||
"""Creates a missing input signal file.
|
||||
|
||||
The file name is parsed to extract input signal creator and params. If a
|
||||
creator is matched and the parameters are valid, a new signal is generated
|
||||
and written in `input_signal_filepath`.
|
||||
|
||||
Args:
|
||||
input_signal_filepath: Path to the input signal audio file to write.
|
||||
|
||||
Raises:
|
||||
InputSignalCreatorException
|
||||
"""
|
||||
filename = os.path.splitext(
|
||||
os.path.split(input_signal_filepath)[-1])[0]
|
||||
filename_parts = filename.split('-')
|
||||
|
||||
if len(filename_parts) < 2:
|
||||
raise exceptions.InputSignalCreatorException(
|
||||
'Cannot parse input signal file name')
|
||||
|
||||
signal, metadata = input_signal_creator.InputSignalCreator.Create(
|
||||
filename_parts[0], filename_parts[1].split('_'))
|
||||
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
input_signal_filepath, signal)
|
||||
data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata)
|
||||
|
||||
def _ExtractCaptureAnnotations(self,
|
||||
input_filepath,
|
||||
output_path,
|
||||
annotation_name=""):
|
||||
self._annotator.Extract(input_filepath)
|
||||
self._annotator.Save(output_path, annotation_name)
|
||||
|
||||
def _Simulate(self, test_data_generators, clean_capture_input_filepath,
|
||||
render_input_filepath, test_data_cache_path,
|
||||
echo_test_data_cache_path, output_path, config_filepath,
|
||||
echo_path_simulator):
|
||||
"""Runs a single set of simulation.
|
||||
|
||||
Simulates a given combination of APM configuration, probing signal, and
|
||||
test data generator. It iterates over the test data generator
|
||||
internal configurations.
|
||||
|
||||
Args:
|
||||
test_data_generators: TestDataGenerator instance.
|
||||
clean_capture_input_filepath: capture input audio track file to be
|
||||
processed by a test data generator and
|
||||
not affected by echo.
|
||||
render_input_filepath: render input audio track file to test.
|
||||
test_data_cache_path: path for the generated test audio track files.
|
||||
echo_test_data_cache_path: path for the echo simulator.
|
||||
output_path: base output path for the test data generator.
|
||||
config_filepath: APM configuration file to test.
|
||||
echo_path_simulator: EchoPathSimulator instance.
|
||||
"""
|
||||
# Generate pairs of noisy input and reference signal files.
|
||||
test_data_generators.Generate(
|
||||
input_signal_filepath=clean_capture_input_filepath,
|
||||
test_data_cache_path=test_data_cache_path,
|
||||
base_output_path=output_path)
|
||||
|
||||
# Extract metadata linked to the clean input file (if any).
|
||||
apm_input_metadata = None
|
||||
try:
|
||||
apm_input_metadata = data_access.Metadata.LoadFileMetadata(
|
||||
clean_capture_input_filepath)
|
||||
except IOError as e:
|
||||
apm_input_metadata = {}
|
||||
apm_input_metadata['test_data_gen_name'] = test_data_generators.NAME
|
||||
apm_input_metadata['test_data_gen_config'] = None
|
||||
|
||||
# For each test data pair, simulate a call and evaluate.
|
||||
for config_name in test_data_generators.config_names:
|
||||
logging.info(' - test data generator config: <%s>', config_name)
|
||||
apm_input_metadata['test_data_gen_config'] = config_name
|
||||
|
||||
# Paths to the test data generator output.
|
||||
# Note that the reference signal does not depend on the render input
|
||||
# which is optional.
|
||||
noisy_capture_input_filepath = (
|
||||
test_data_generators.noisy_signal_filepaths[config_name])
|
||||
reference_signal_filepath = (
|
||||
test_data_generators.reference_signal_filepaths[config_name])
|
||||
|
||||
# Output path for the evaluation (e.g., APM output file).
|
||||
evaluation_output_path = test_data_generators.apm_output_paths[
|
||||
config_name]
|
||||
|
||||
# Paths to the APM input signals.
|
||||
echo_path_filepath = echo_path_simulator.Simulate(
|
||||
echo_test_data_cache_path)
|
||||
apm_input_filepath = input_mixer.ApmInputMixer.Mix(
|
||||
echo_test_data_cache_path, noisy_capture_input_filepath,
|
||||
echo_path_filepath)
|
||||
|
||||
# Extract annotations for the APM input mix.
|
||||
apm_input_basepath, apm_input_filename = os.path.split(
|
||||
apm_input_filepath)
|
||||
self._ExtractCaptureAnnotations(
|
||||
apm_input_filepath, apm_input_basepath,
|
||||
os.path.splitext(apm_input_filename)[0] + '-')
|
||||
|
||||
# Simulate a call using APM.
|
||||
self._audioproc_wrapper.Run(
|
||||
config_filepath=config_filepath,
|
||||
capture_input_filepath=apm_input_filepath,
|
||||
render_input_filepath=render_input_filepath,
|
||||
output_path=evaluation_output_path)
|
||||
|
||||
try:
|
||||
# Evaluate.
|
||||
self._evaluator.Run(
|
||||
evaluation_score_workers=self._evaluation_score_workers,
|
||||
apm_input_metadata=apm_input_metadata,
|
||||
apm_output_filepath=self._audioproc_wrapper.
|
||||
output_filepath,
|
||||
reference_input_filepath=reference_signal_filepath,
|
||||
render_input_filepath=render_input_filepath,
|
||||
output_path=evaluation_output_path,
|
||||
)
|
||||
|
||||
# Save simulation metadata.
|
||||
data_access.Metadata.SaveAudioTestDataPaths(
|
||||
output_path=evaluation_output_path,
|
||||
clean_capture_input_filepath=clean_capture_input_filepath,
|
||||
echo_free_capture_filepath=noisy_capture_input_filepath,
|
||||
echo_filepath=echo_path_filepath,
|
||||
render_filepath=render_input_filepath,
|
||||
capture_filepath=apm_input_filepath,
|
||||
apm_output_filepath=self._audioproc_wrapper.
|
||||
output_filepath,
|
||||
apm_reference_filepath=reference_signal_filepath,
|
||||
apm_config_filepath=config_filepath,
|
||||
)
|
||||
except exceptions.EvaluationScoreException as e:
|
||||
logging.warning('the evaluation failed: %s', e.message)
|
||||
continue
|
||||
|
||||
def _SetTestInputSignalFilePaths(self, capture_input_filepaths,
|
||||
render_input_filepaths):
|
||||
"""Sets input and render input file paths collections.
|
||||
|
||||
Pairs the input and render input files by storing the file paths into two
|
||||
collections. The key is the file name of the input file.
|
||||
|
||||
Args:
|
||||
capture_input_filepaths: list of file paths.
|
||||
render_input_filepaths: list of file paths.
|
||||
"""
|
||||
self._capture_input_filepaths = {}
|
||||
self._render_input_filepaths = {}
|
||||
assert len(capture_input_filepaths) == len(render_input_filepaths)
|
||||
for capture_input_filepath, render_input_filepath in zip(
|
||||
capture_input_filepaths, render_input_filepaths):
|
||||
name = self._ExtractFileName(capture_input_filepath)
|
||||
self._capture_input_filepaths[name] = os.path.abspath(
|
||||
capture_input_filepath)
|
||||
self._render_input_filepaths[name] = os.path.abspath(
|
||||
render_input_filepath)
|
||||
|
||||
@classmethod
|
||||
def _CreatePathsCollection(cls, filepaths):
|
||||
"""Creates a collection of file paths.
|
||||
|
||||
Given a list of file paths, makes a collection with one item for each file
|
||||
path. The value is absolute path, the key is the file name without
|
||||
extenstion.
|
||||
|
||||
Args:
|
||||
filepaths: list of file paths.
|
||||
|
||||
Returns:
|
||||
A dict.
|
||||
"""
|
||||
filepaths_collection = {}
|
||||
for filepath in filepaths:
|
||||
name = cls._ExtractFileName(filepath)
|
||||
filepaths_collection[name] = os.path.abspath(filepath)
|
||||
return filepaths_collection
|
||||
|
||||
@classmethod
|
||||
def _ExtractFileName(cls, filepath):
|
||||
return os.path.splitext(os.path.split(filepath)[-1])[0]
|
||||
@ -1,203 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Unit tests for the simulation module.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import mock
|
||||
import pydub
|
||||
|
||||
from . import audioproc_wrapper
|
||||
from . import eval_scores_factory
|
||||
from . import evaluation
|
||||
from . import external_vad
|
||||
from . import signal_processing
|
||||
from . import simulation
|
||||
from . import test_data_generation_factory
|
||||
|
||||
|
||||
class TestApmModuleSimulator(unittest.TestCase):
|
||||
"""Unit tests for the ApmModuleSimulator class.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary folders and fake audio track."""
|
||||
self._output_path = tempfile.mkdtemp()
|
||||
self._tmp_path = tempfile.mkdtemp()
|
||||
|
||||
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
|
||||
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
silence)
|
||||
self._fake_audio_track_path = os.path.join(self._output_path,
|
||||
'fake.wav')
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
self._fake_audio_track_path, fake_signal)
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folders."""
|
||||
shutil.rmtree(self._output_path)
|
||||
shutil.rmtree(self._tmp_path)
|
||||
|
||||
def testSimulation(self):
|
||||
# Instance dependencies to mock and inject.
|
||||
ap_wrapper = audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
|
||||
evaluator = evaluation.ApmModuleEvaluator()
|
||||
ap_wrapper.Run = mock.MagicMock(name='Run')
|
||||
evaluator.Run = mock.MagicMock(name='Run')
|
||||
|
||||
# Instance non-mocked dependencies.
|
||||
test_data_generator_factory = (
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False))
|
||||
evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
|
||||
'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None)
|
||||
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=test_data_generator_factory,
|
||||
evaluation_score_factory=evaluation_score_factory,
|
||||
ap_wrapper=ap_wrapper,
|
||||
evaluator=evaluator,
|
||||
external_vads={
|
||||
'fake':
|
||||
external_vad.ExternalVad(
|
||||
os.path.join(os.path.dirname(__file__),
|
||||
'fake_external_vad.py'), 'fake')
|
||||
})
|
||||
|
||||
# What to simulate.
|
||||
config_files = ['apm_configs/default.json']
|
||||
input_files = [self._fake_audio_track_path]
|
||||
test_data_generators = ['identity', 'white_noise']
|
||||
eval_scores = ['audio_level_mean', 'polqa']
|
||||
|
||||
# Run all simulations.
|
||||
simulator.Run(config_filepaths=config_files,
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=test_data_generators,
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
|
||||
# Check.
|
||||
# TODO(alessiob): Once the TestDataGenerator classes can be configured by
|
||||
# the client code (e.g., number of SNR pairs for the white noise test data
|
||||
# generator), the exact number of calls to ap_wrapper.Run and evaluator.Run
|
||||
# is known; use that with assertEqual.
|
||||
min_number_of_simulations = len(config_files) * len(input_files) * len(
|
||||
test_data_generators)
|
||||
self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
|
||||
min_number_of_simulations)
|
||||
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
|
||||
min_number_of_simulations)
|
||||
|
||||
def testInputSignalCreation(self):
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False)),
|
||||
evaluation_score_factory=(
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
|
||||
'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None)),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
|
||||
# Inexistent input files to be silently created.
|
||||
input_files = [
|
||||
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
|
||||
os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
|
||||
]
|
||||
self.assertFalse(
|
||||
any([os.path.exists(input_file) for input_file in (input_files)]))
|
||||
|
||||
# The input files are created during the simulation.
|
||||
simulator.Run(config_filepaths=['apm_configs/default.json'],
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=['identity'],
|
||||
eval_score_names=['audio_level_peak'],
|
||||
output_dir=self._output_path)
|
||||
self.assertTrue(
|
||||
all([os.path.exists(input_file) for input_file in (input_files)]))
|
||||
|
||||
def testPureToneGenerationWithTotalHarmonicDistorsion(self):
|
||||
logging.warning = mock.MagicMock(name='warning')
|
||||
|
||||
# Instance simulator.
|
||||
simulator = simulation.ApmModuleSimulator(
|
||||
test_data_generator_factory=(
|
||||
test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=False)),
|
||||
evaluation_score_factory=(
|
||||
eval_scores_factory.EvaluationScoreWorkerFactory(
|
||||
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
|
||||
'fake_polqa'),
|
||||
echo_metric_tool_bin_path=None)),
|
||||
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
|
||||
audioproc_wrapper.AudioProcWrapper.
|
||||
DEFAULT_APM_SIMULATOR_BIN_PATH),
|
||||
evaluator=evaluation.ApmModuleEvaluator())
|
||||
|
||||
# What to simulate.
|
||||
config_files = ['apm_configs/default.json']
|
||||
input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
|
||||
eval_scores = ['thd']
|
||||
|
||||
# Should work.
|
||||
simulator.Run(config_filepaths=config_files,
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=['identity'],
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
self.assertFalse(logging.warning.called)
|
||||
|
||||
# Warning expected.
|
||||
simulator.Run(
|
||||
config_filepaths=config_files,
|
||||
capture_input_filepaths=input_files,
|
||||
test_data_generator_names=['white_noise'], # Not allowed with THD.
|
||||
eval_score_names=eval_scores,
|
||||
output_dir=self._output_path)
|
||||
logging.warning.assert_called_with('the evaluation failed: %s', (
|
||||
'The THD score cannot be used with any test data generator other than '
|
||||
'"identity"'))
|
||||
|
||||
# # Init.
|
||||
# generator = test_data_generation.IdentityTestDataGenerator('tmp')
|
||||
# input_signal_filepath = os.path.join(
|
||||
# self._test_data_cache_path, 'pure_tone-440_1000.wav')
|
||||
|
||||
# # Check that the input signal is generated.
|
||||
# self.assertFalse(os.path.exists(input_signal_filepath))
|
||||
# generator.Generate(
|
||||
# input_signal_filepath=input_signal_filepath,
|
||||
# test_data_cache_path=self._test_data_cache_path,
|
||||
# base_output_path=self._base_output_path)
|
||||
# self.assertTrue(os.path.exists(input_signal_filepath))
|
||||
|
||||
# # Check input signal properties.
|
||||
# input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
# input_signal_filepath)
|
||||
# self.assertEqual(1000, len(input_signal))
|
||||
@ -1,127 +0,0 @@
|
||||
// Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
//
|
||||
// Use of this source code is governed by a BSD-style license
|
||||
// that can be found in the LICENSE file in the root of the source
|
||||
// tree. An additional intellectual property rights grant can be found
|
||||
// in the file PATENTS. All contributing project authors may
|
||||
// be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/flags/parse.h"
|
||||
#include "common_audio/include/audio_util.h"
|
||||
#include "common_audio/wav_file.h"
|
||||
#include "rtc_base/logging.h"
|
||||
|
||||
ABSL_FLAG(std::string, i, "", "Input wav file");
|
||||
ABSL_FLAG(std::string, oc, "", "Config output file");
|
||||
ABSL_FLAG(std::string, ol, "", "Levels output file");
|
||||
ABSL_FLAG(float, a, 5.f, "Attack (ms)");
|
||||
ABSL_FLAG(float, d, 20.f, "Decay (ms)");
|
||||
ABSL_FLAG(int, f, 10, "Frame length (ms)");
|
||||
|
||||
namespace webrtc {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
constexpr int kMaxSampleRate = 48000;
|
||||
constexpr uint8_t kMaxFrameLenMs = 30;
|
||||
constexpr size_t kMaxFrameLen = kMaxFrameLenMs * kMaxSampleRate / 1000;
|
||||
|
||||
const double kOneDbReduction = DbToRatio(-1.0);
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
absl::ParseCommandLine(argc, argv);
|
||||
// Check parameters.
|
||||
if (absl::GetFlag(FLAGS_f) < 1 || absl::GetFlag(FLAGS_f) > kMaxFrameLenMs) {
|
||||
RTC_LOG(LS_ERROR) << "Invalid frame length (min: 1, max: " << kMaxFrameLenMs
|
||||
<< ")";
|
||||
return 1;
|
||||
}
|
||||
if (absl::GetFlag(FLAGS_a) < 0 || absl::GetFlag(FLAGS_d) < 0) {
|
||||
RTC_LOG(LS_ERROR) << "Attack and decay must be non-negative";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Open wav input file and check properties.
|
||||
const std::string input_file = absl::GetFlag(FLAGS_i);
|
||||
const std::string config_output_file = absl::GetFlag(FLAGS_oc);
|
||||
const std::string levels_output_file = absl::GetFlag(FLAGS_ol);
|
||||
WavReader wav_reader(input_file);
|
||||
if (wav_reader.num_channels() != 1) {
|
||||
RTC_LOG(LS_ERROR) << "Only mono wav files supported";
|
||||
return 1;
|
||||
}
|
||||
if (wav_reader.sample_rate() > kMaxSampleRate) {
|
||||
RTC_LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate
|
||||
<< ")";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Map from milliseconds to samples.
|
||||
const size_t audio_frame_length = rtc::CheckedDivExact(
|
||||
absl::GetFlag(FLAGS_f) * wav_reader.sample_rate(), 1000);
|
||||
auto time_const = [](double c) {
|
||||
return std::pow(kOneDbReduction, absl::GetFlag(FLAGS_f) / c);
|
||||
};
|
||||
const float attack =
|
||||
absl::GetFlag(FLAGS_a) == 0.0 ? 0.0 : time_const(absl::GetFlag(FLAGS_a));
|
||||
const float decay =
|
||||
absl::GetFlag(FLAGS_d) == 0.0 ? 0.0 : time_const(absl::GetFlag(FLAGS_d));
|
||||
|
||||
// Write config to file.
|
||||
std::ofstream out_config(config_output_file);
|
||||
out_config << "{"
|
||||
"'frame_len_ms': "
|
||||
<< absl::GetFlag(FLAGS_f)
|
||||
<< ", "
|
||||
"'attack_ms': "
|
||||
<< absl::GetFlag(FLAGS_a)
|
||||
<< ", "
|
||||
"'decay_ms': "
|
||||
<< absl::GetFlag(FLAGS_d) << "}\n";
|
||||
out_config.close();
|
||||
|
||||
// Measure level frame-by-frame.
|
||||
std::ofstream out_levels(levels_output_file, std::ofstream::binary);
|
||||
std::array<int16_t, kMaxFrameLen> samples;
|
||||
float level_prev = 0.f;
|
||||
while (true) {
|
||||
// Process frame.
|
||||
const auto read_samples =
|
||||
wav_reader.ReadSamples(audio_frame_length, samples.data());
|
||||
if (read_samples < audio_frame_length)
|
||||
break; // EOF.
|
||||
|
||||
// Frame peak level.
|
||||
std::transform(samples.begin(), samples.begin() + audio_frame_length,
|
||||
samples.begin(), [](int16_t s) { return std::abs(s); });
|
||||
const int16_t peak_level = *std::max_element(
|
||||
samples.cbegin(), samples.cbegin() + audio_frame_length);
|
||||
const float level_curr = static_cast<float>(peak_level) / 32768.f;
|
||||
|
||||
// Temporal smoothing.
|
||||
auto smooth = [&level_prev, &level_curr](float c) {
|
||||
return (1.0 - c) * level_curr + c * level_prev;
|
||||
};
|
||||
level_prev = smooth(level_curr > level_prev ? attack : decay);
|
||||
|
||||
// Write output.
|
||||
out_levels.write(reinterpret_cast<const char*>(&level_prev), sizeof(float));
|
||||
}
|
||||
out_levels.close();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
return webrtc::test::main(argc, argv);
|
||||
}
|
||||
@ -1,526 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Test data generators producing signals pairs intended to be used to
|
||||
test the APM module. Each pair consists of a noisy input and a reference signal.
|
||||
The former is used as APM input and it is generated by adding noise to a
|
||||
clean audio track. The reference is the expected APM output.
|
||||
|
||||
Throughout this file, the following naming convention is used:
|
||||
- input signal: the clean signal (e.g., speech),
|
||||
- noise signal: the noise to be summed up to the input signal (e.g., white
|
||||
noise, Gaussian noise),
|
||||
- noisy signal: input + noise.
|
||||
The noise signal may or may not be a function of the clean signal. For
|
||||
instance, white noise is independently generated, whereas reverberation is
|
||||
obtained by convolving the input signal with an impulse response.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
try:
|
||||
import scipy.io
|
||||
except ImportError:
|
||||
logging.critical('Cannot import the third-party Python package scipy')
|
||||
sys.exit(1)
|
||||
|
||||
from . import data_access
|
||||
from . import exceptions
|
||||
from . import signal_processing
|
||||
|
||||
|
||||
class TestDataGenerator(object):
|
||||
"""Abstract class responsible for the generation of noisy signals.
|
||||
|
||||
Given a clean signal, it generates two streams named noisy signal and
|
||||
reference. The former is the clean signal deteriorated by the noise source,
|
||||
the latter goes through the same deterioration process, but more "gently".
|
||||
Noisy signal and reference are produced so that the reference is the signal
|
||||
expected at the output of the APM module when the latter is fed with the noisy
|
||||
signal.
|
||||
|
||||
An test data generator generates one or more pairs.
|
||||
"""
|
||||
|
||||
NAME = None
|
||||
REGISTERED_CLASSES = {}
|
||||
|
||||
def __init__(self, output_directory_prefix):
|
||||
self._output_directory_prefix = output_directory_prefix
|
||||
# Init dictionaries with one entry for each test data generator
|
||||
# configuration (e.g., different SNRs).
|
||||
# Noisy audio track files (stored separately in a cache folder).
|
||||
self._noisy_signal_filepaths = None
|
||||
# Path to be used for the APM simulation output files.
|
||||
self._apm_output_paths = None
|
||||
# Reference audio track files (stored separately in a cache folder).
|
||||
self._reference_signal_filepaths = None
|
||||
self.Clear()
|
||||
|
||||
@classmethod
|
||||
def RegisterClass(cls, class_to_register):
|
||||
"""Registers a TestDataGenerator implementation.
|
||||
|
||||
Decorator to automatically register the classes that extend
|
||||
TestDataGenerator.
|
||||
Example usage:
|
||||
|
||||
@TestDataGenerator.RegisterClass
|
||||
class IdentityGenerator(TestDataGenerator):
|
||||
pass
|
||||
"""
|
||||
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
|
||||
return class_to_register
|
||||
|
||||
@property
|
||||
def config_names(self):
|
||||
return self._noisy_signal_filepaths.keys()
|
||||
|
||||
@property
|
||||
def noisy_signal_filepaths(self):
|
||||
return self._noisy_signal_filepaths
|
||||
|
||||
@property
|
||||
def apm_output_paths(self):
|
||||
return self._apm_output_paths
|
||||
|
||||
@property
|
||||
def reference_signal_filepaths(self):
|
||||
return self._reference_signal_filepaths
|
||||
|
||||
def Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
"""Generates a set of noisy input and reference audiotrack file pairs.
|
||||
|
||||
This method initializes an empty set of pairs and calls the _Generate()
|
||||
method implemented in a concrete class.
|
||||
|
||||
Args:
|
||||
input_signal_filepath: path to the clean input audio track file.
|
||||
test_data_cache_path: path to the cache of the generated audio track
|
||||
files.
|
||||
base_output_path: base path where output is written.
|
||||
"""
|
||||
self.Clear()
|
||||
self._Generate(input_signal_filepath, test_data_cache_path,
|
||||
base_output_path)
|
||||
|
||||
def Clear(self):
|
||||
"""Clears the generated output path dictionaries.
|
||||
"""
|
||||
self._noisy_signal_filepaths = {}
|
||||
self._apm_output_paths = {}
|
||||
self._reference_signal_filepaths = {}
|
||||
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
"""Abstract method to be implemented in each concrete class.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _AddNoiseSnrPairs(self, base_output_path, noisy_mix_filepaths,
|
||||
snr_value_pairs):
|
||||
"""Adds noisy-reference signal pairs.
|
||||
|
||||
Args:
|
||||
base_output_path: noisy tracks base output path.
|
||||
noisy_mix_filepaths: nested dictionary of noisy signal paths organized
|
||||
by noisy track name and SNR level.
|
||||
snr_value_pairs: list of SNR pairs.
|
||||
"""
|
||||
for noise_track_name in noisy_mix_filepaths:
|
||||
for snr_noisy, snr_refence in snr_value_pairs:
|
||||
config_name = '{0}_{1:d}_{2:d}_SNR'.format(
|
||||
noise_track_name, snr_noisy, snr_refence)
|
||||
output_path = self._MakeDir(base_output_path, config_name)
|
||||
self._AddNoiseReferenceFilesPair(
|
||||
config_name=config_name,
|
||||
noisy_signal_filepath=noisy_mix_filepaths[noise_track_name]
|
||||
[snr_noisy],
|
||||
reference_signal_filepath=noisy_mix_filepaths[
|
||||
noise_track_name][snr_refence],
|
||||
output_path=output_path)
|
||||
|
||||
def _AddNoiseReferenceFilesPair(self, config_name, noisy_signal_filepath,
|
||||
reference_signal_filepath, output_path):
|
||||
"""Adds one noisy-reference signal pair.
|
||||
|
||||
Args:
|
||||
config_name: name of the APM configuration.
|
||||
noisy_signal_filepath: path to noisy audio track file.
|
||||
reference_signal_filepath: path to reference audio track file.
|
||||
output_path: APM output path.
|
||||
"""
|
||||
assert config_name not in self._noisy_signal_filepaths
|
||||
self._noisy_signal_filepaths[config_name] = os.path.abspath(
|
||||
noisy_signal_filepath)
|
||||
self._apm_output_paths[config_name] = os.path.abspath(output_path)
|
||||
self._reference_signal_filepaths[config_name] = os.path.abspath(
|
||||
reference_signal_filepath)
|
||||
|
||||
def _MakeDir(self, base_output_path, test_data_generator_config_name):
|
||||
output_path = os.path.join(
|
||||
base_output_path,
|
||||
self._output_directory_prefix + test_data_generator_config_name)
|
||||
data_access.MakeDirectory(output_path)
|
||||
return output_path
|
||||
|
||||
|
||||
@TestDataGenerator.RegisterClass
|
||||
class IdentityTestDataGenerator(TestDataGenerator):
|
||||
"""Generator that adds no noise.
|
||||
|
||||
Both the noisy and the reference signals are the input signal.
|
||||
"""
|
||||
|
||||
NAME = 'identity'
|
||||
|
||||
def __init__(self, output_directory_prefix, copy_with_identity):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
self._copy_with_identity = copy_with_identity
|
||||
|
||||
@property
|
||||
def copy_with_identity(self):
|
||||
return self._copy_with_identity
|
||||
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
config_name = 'default'
|
||||
output_path = self._MakeDir(base_output_path, config_name)
|
||||
|
||||
if self._copy_with_identity:
|
||||
input_signal_filepath_new = os.path.join(
|
||||
test_data_cache_path,
|
||||
os.path.split(input_signal_filepath)[1])
|
||||
logging.info('copying ' + input_signal_filepath + ' to ' +
|
||||
(input_signal_filepath_new))
|
||||
shutil.copy(input_signal_filepath, input_signal_filepath_new)
|
||||
input_signal_filepath = input_signal_filepath_new
|
||||
|
||||
self._AddNoiseReferenceFilesPair(
|
||||
config_name=config_name,
|
||||
noisy_signal_filepath=input_signal_filepath,
|
||||
reference_signal_filepath=input_signal_filepath,
|
||||
output_path=output_path)
|
||||
|
||||
|
||||
@TestDataGenerator.RegisterClass
|
||||
class WhiteNoiseTestDataGenerator(TestDataGenerator):
|
||||
"""Generator that adds white noise.
|
||||
"""
|
||||
|
||||
NAME = 'white_noise'
|
||||
|
||||
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
|
||||
# The reference (second value of each pair) always has a lower amount of noise
|
||||
# - i.e., the SNR is 10 dB higher.
|
||||
_SNR_VALUE_PAIRS = [
|
||||
[20, 30], # Smallest noise.
|
||||
[10, 20],
|
||||
[5, 15],
|
||||
[0, 10], # Largest noise.
|
||||
]
|
||||
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = 'noise_{0:d}_SNR.wav'
|
||||
|
||||
def __init__(self, output_directory_prefix):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
# Load the input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
|
||||
# Create the noise track.
|
||||
noise_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
|
||||
input_signal)
|
||||
|
||||
# Create the noisy mixes (once for each unique SNR value).
|
||||
noisy_mix_filepaths = {}
|
||||
snr_values = set(
|
||||
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
|
||||
for snr in snr_values:
|
||||
noisy_signal_filepath = os.path.join(
|
||||
test_data_cache_path,
|
||||
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(snr))
|
||||
|
||||
# Create and save if not done.
|
||||
if not os.path.exists(noisy_signal_filepath):
|
||||
# Create noisy signal.
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
input_signal, noise_signal, snr)
|
||||
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noisy_signal_filepath, noisy_signal)
|
||||
|
||||
# Add file to the collection of mixes.
|
||||
noisy_mix_filepaths[snr] = noisy_signal_filepath
|
||||
|
||||
# Add all the noisy-reference signal pairs.
|
||||
for snr_noisy, snr_refence in self._SNR_VALUE_PAIRS:
|
||||
config_name = '{0:d}_{1:d}_SNR'.format(snr_noisy, snr_refence)
|
||||
output_path = self._MakeDir(base_output_path, config_name)
|
||||
self._AddNoiseReferenceFilesPair(
|
||||
config_name=config_name,
|
||||
noisy_signal_filepath=noisy_mix_filepaths[snr_noisy],
|
||||
reference_signal_filepath=noisy_mix_filepaths[snr_refence],
|
||||
output_path=output_path)
|
||||
|
||||
|
||||
# TODO(alessiob): remove comment when class implemented.
|
||||
# @TestDataGenerator.RegisterClass
|
||||
class NarrowBandNoiseTestDataGenerator(TestDataGenerator):
|
||||
"""Generator that adds narrow-band noise.
|
||||
"""
|
||||
|
||||
NAME = 'narrow_band_noise'
|
||||
|
||||
def __init__(self, output_directory_prefix):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
# TODO(alessiob): implement.
|
||||
pass
|
||||
|
||||
|
||||
@TestDataGenerator.RegisterClass
|
||||
class AdditiveNoiseTestDataGenerator(TestDataGenerator):
|
||||
"""Generator that adds noise loops.
|
||||
|
||||
This generator uses all the wav files in a given path (default: noise_tracks/)
|
||||
and mixes them to the clean speech with different target SNRs (hard-coded).
|
||||
"""
|
||||
|
||||
NAME = 'additive_noise'
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
|
||||
|
||||
DEFAULT_NOISE_TRACKS_PATH = os.path.join(os.path.dirname(__file__),
|
||||
os.pardir, 'noise_tracks')
|
||||
|
||||
# TODO(alessiob): Make the list of SNR pairs customizable.
|
||||
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
|
||||
# The reference (second value of each pair) always has a lower amount of noise
|
||||
# - i.e., the SNR is 10 dB higher.
|
||||
_SNR_VALUE_PAIRS = [
|
||||
[20, 30], # Smallest noise.
|
||||
[10, 20],
|
||||
[5, 15],
|
||||
[0, 10], # Largest noise.
|
||||
]
|
||||
|
||||
def __init__(self, output_directory_prefix, noise_tracks_path):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
self._noise_tracks_path = noise_tracks_path
|
||||
self._noise_tracks_file_names = [
|
||||
n for n in os.listdir(self._noise_tracks_path)
|
||||
if n.lower().endswith('.wav')
|
||||
]
|
||||
if len(self._noise_tracks_file_names) == 0:
|
||||
raise exceptions.InitializationException(
|
||||
'No wav files found in the noise tracks path %s' %
|
||||
(self._noise_tracks_path))
|
||||
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
"""Generates test data pairs using environmental noise.
|
||||
|
||||
For each noise track and pair of SNR values, the following two audio tracks
|
||||
are created: the noisy signal and the reference signal. The former is
|
||||
obtained by mixing the (clean) input signal to the corresponding noise
|
||||
track enforcing the target SNR.
|
||||
"""
|
||||
# Init.
|
||||
snr_values = set(
|
||||
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
|
||||
|
||||
# Load the input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
|
||||
noisy_mix_filepaths = {}
|
||||
for noise_track_filename in self._noise_tracks_file_names:
|
||||
# Load the noise track.
|
||||
noise_track_name, _ = os.path.splitext(noise_track_filename)
|
||||
noise_track_filepath = os.path.join(self._noise_tracks_path,
|
||||
noise_track_filename)
|
||||
if not os.path.exists(noise_track_filepath):
|
||||
logging.error('cannot find the <%s> noise track',
|
||||
noise_track_filename)
|
||||
raise exceptions.FileNotFoundError()
|
||||
|
||||
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noise_track_filepath)
|
||||
|
||||
# Create the noisy mixes (once for each unique SNR value).
|
||||
noisy_mix_filepaths[noise_track_name] = {}
|
||||
for snr in snr_values:
|
||||
noisy_signal_filepath = os.path.join(
|
||||
test_data_cache_path,
|
||||
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
|
||||
noise_track_name, snr))
|
||||
|
||||
# Create and save if not done.
|
||||
if not os.path.exists(noisy_signal_filepath):
|
||||
# Create noisy signal.
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
input_signal,
|
||||
noise_signal,
|
||||
snr,
|
||||
pad_noise=signal_processing.SignalProcessingUtils.
|
||||
MixPadding.LOOP)
|
||||
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noisy_signal_filepath, noisy_signal)
|
||||
|
||||
# Add file to the collection of mixes.
|
||||
noisy_mix_filepaths[noise_track_name][
|
||||
snr] = noisy_signal_filepath
|
||||
|
||||
# Add all the noise-SNR pairs.
|
||||
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
|
||||
self._SNR_VALUE_PAIRS)
|
||||
|
||||
|
||||
@TestDataGenerator.RegisterClass
|
||||
class ReverberationTestDataGenerator(TestDataGenerator):
|
||||
"""Generator that adds reverberation noise.
|
||||
|
||||
TODO(alessiob): Make this class more generic since the impulse response can be
|
||||
anything (not just reverberation); call it e.g.,
|
||||
ConvolutionalNoiseTestDataGenerator.
|
||||
"""
|
||||
|
||||
NAME = 'reverberation'
|
||||
|
||||
_IMPULSE_RESPONSES = {
|
||||
'lecture': 'air_binaural_lecture_0_0_1.mat', # Long echo.
|
||||
'booth': 'air_binaural_booth_0_0_1.mat', # Short echo.
|
||||
}
|
||||
_MAX_IMPULSE_RESPONSE_LENGTH = None
|
||||
|
||||
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
|
||||
# The reference (second value of each pair) always has a lower amount of noise
|
||||
# - i.e., the SNR is 5 dB higher.
|
||||
_SNR_VALUE_PAIRS = [
|
||||
[3, 8], # Smallest noise.
|
||||
[-3, 2], # Largest noise.
|
||||
]
|
||||
|
||||
_NOISE_TRACK_FILENAME_TEMPLATE = '{0}.wav'
|
||||
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
|
||||
|
||||
def __init__(self, output_directory_prefix, aechen_ir_database_path):
|
||||
TestDataGenerator.__init__(self, output_directory_prefix)
|
||||
self._aechen_ir_database_path = aechen_ir_database_path
|
||||
|
||||
def _Generate(self, input_signal_filepath, test_data_cache_path,
|
||||
base_output_path):
|
||||
"""Generates test data pairs using reverberation noise.
|
||||
|
||||
For each impulse response, one noise track is created. For each impulse
|
||||
response and pair of SNR values, the following 2 audio tracks are
|
||||
created: the noisy signal and the reference signal. The former is
|
||||
obtained by mixing the (clean) input signal to the corresponding noise
|
||||
track enforcing the target SNR.
|
||||
"""
|
||||
# Init.
|
||||
snr_values = set(
|
||||
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
|
||||
|
||||
# Load the input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
|
||||
noisy_mix_filepaths = {}
|
||||
for impulse_response_name in self._IMPULSE_RESPONSES:
|
||||
noise_track_filename = self._NOISE_TRACK_FILENAME_TEMPLATE.format(
|
||||
impulse_response_name)
|
||||
noise_track_filepath = os.path.join(test_data_cache_path,
|
||||
noise_track_filename)
|
||||
noise_signal = None
|
||||
try:
|
||||
# Load noise track.
|
||||
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noise_track_filepath)
|
||||
except exceptions.FileNotFoundError:
|
||||
# Generate noise track by applying the impulse response.
|
||||
impulse_response_filepath = os.path.join(
|
||||
self._aechen_ir_database_path,
|
||||
self._IMPULSE_RESPONSES[impulse_response_name])
|
||||
noise_signal = self._GenerateNoiseTrack(
|
||||
noise_track_filepath, input_signal,
|
||||
impulse_response_filepath)
|
||||
assert noise_signal is not None
|
||||
|
||||
# Create the noisy mixes (once for each unique SNR value).
|
||||
noisy_mix_filepaths[impulse_response_name] = {}
|
||||
for snr in snr_values:
|
||||
noisy_signal_filepath = os.path.join(
|
||||
test_data_cache_path,
|
||||
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
|
||||
impulse_response_name, snr))
|
||||
|
||||
# Create and save if not done.
|
||||
if not os.path.exists(noisy_signal_filepath):
|
||||
# Create noisy signal.
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
|
||||
input_signal, noise_signal, snr)
|
||||
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noisy_signal_filepath, noisy_signal)
|
||||
|
||||
# Add file to the collection of mixes.
|
||||
noisy_mix_filepaths[impulse_response_name][
|
||||
snr] = noisy_signal_filepath
|
||||
|
||||
# Add all the noise-SNR pairs.
|
||||
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
|
||||
self._SNR_VALUE_PAIRS)
|
||||
|
||||
def _GenerateNoiseTrack(self, noise_track_filepath, input_signal,
|
||||
impulse_response_filepath):
|
||||
"""Generates noise track.
|
||||
|
||||
Generate a signal by convolving input_signal with the impulse response in
|
||||
impulse_response_filepath; then save to noise_track_filepath.
|
||||
|
||||
Args:
|
||||
noise_track_filepath: output file path for the noise track.
|
||||
input_signal: (clean) input signal samples.
|
||||
impulse_response_filepath: impulse response file path.
|
||||
|
||||
Returns:
|
||||
AudioSegment instance.
|
||||
"""
|
||||
# Load impulse response.
|
||||
data = scipy.io.loadmat(impulse_response_filepath)
|
||||
impulse_response = data['h_air'].flatten()
|
||||
if self._MAX_IMPULSE_RESPONSE_LENGTH is not None:
|
||||
logging.info('truncating impulse response from %d to %d samples',
|
||||
len(impulse_response),
|
||||
self._MAX_IMPULSE_RESPONSE_LENGTH)
|
||||
impulse_response = impulse_response[:self.
|
||||
_MAX_IMPULSE_RESPONSE_LENGTH]
|
||||
|
||||
# Apply impulse response.
|
||||
processed_signal = (
|
||||
signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
|
||||
input_signal, impulse_response))
|
||||
|
||||
# Save.
|
||||
signal_processing.SignalProcessingUtils.SaveWav(
|
||||
noise_track_filepath, processed_signal)
|
||||
|
||||
return processed_signal
|
||||
@ -1,71 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""TestDataGenerator factory class.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from . import exceptions
|
||||
from . import test_data_generation
|
||||
|
||||
|
||||
class TestDataGeneratorFactory(object):
|
||||
"""Factory class used to create test data generators.
|
||||
|
||||
Usage: Create a factory passing parameters to the ctor with which the
|
||||
generators will be produced.
|
||||
"""
|
||||
|
||||
def __init__(self, aechen_ir_database_path, noise_tracks_path,
|
||||
copy_with_identity):
|
||||
"""Ctor.
|
||||
|
||||
Args:
|
||||
aechen_ir_database_path: Path to the Aechen Impulse Response database.
|
||||
noise_tracks_path: Path to the noise tracks to add.
|
||||
copy_with_identity: Flag indicating whether the identity generator has to
|
||||
make copies of the clean speech input files.
|
||||
"""
|
||||
self._output_directory_prefix = None
|
||||
self._aechen_ir_database_path = aechen_ir_database_path
|
||||
self._noise_tracks_path = noise_tracks_path
|
||||
self._copy_with_identity = copy_with_identity
|
||||
|
||||
def SetOutputDirectoryPrefix(self, prefix):
|
||||
self._output_directory_prefix = prefix
|
||||
|
||||
def GetInstance(self, test_data_generators_class):
|
||||
"""Creates an TestDataGenerator instance given a class object.
|
||||
|
||||
Args:
|
||||
test_data_generators_class: TestDataGenerator class object (not an
|
||||
instance).
|
||||
|
||||
Returns:
|
||||
TestDataGenerator instance.
|
||||
"""
|
||||
if self._output_directory_prefix is None:
|
||||
raise exceptions.InitializationException(
|
||||
'The output directory prefix for test data generators is not set'
|
||||
)
|
||||
logging.debug('factory producing %s', test_data_generators_class)
|
||||
|
||||
if test_data_generators_class == (
|
||||
test_data_generation.IdentityTestDataGenerator):
|
||||
return test_data_generation.IdentityTestDataGenerator(
|
||||
self._output_directory_prefix, self._copy_with_identity)
|
||||
elif test_data_generators_class == (
|
||||
test_data_generation.ReverberationTestDataGenerator):
|
||||
return test_data_generation.ReverberationTestDataGenerator(
|
||||
self._output_directory_prefix, self._aechen_ir_database_path)
|
||||
elif test_data_generators_class == (
|
||||
test_data_generation.AdditiveNoiseTestDataGenerator):
|
||||
return test_data_generation.AdditiveNoiseTestDataGenerator(
|
||||
self._output_directory_prefix, self._noise_tracks_path)
|
||||
else:
|
||||
return test_data_generators_class(self._output_directory_prefix)
|
||||
@ -1,207 +0,0 @@
|
||||
# Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
#
|
||||
# Use of this source code is governed by a BSD-style license
|
||||
# that can be found in the LICENSE file in the root of the source
|
||||
# tree. An additional intellectual property rights grant can be found
|
||||
# in the file PATENTS. All contributing project authors may
|
||||
# be found in the AUTHORS file in the root of the source tree.
|
||||
"""Unit tests for the test_data_generation module.
|
||||
"""
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
|
||||
from . import test_data_generation
|
||||
from . import test_data_generation_factory
|
||||
from . import signal_processing
|
||||
|
||||
|
||||
class TestTestDataGenerators(unittest.TestCase):
|
||||
"""Unit tests for the test_data_generation module.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
"""Create temporary folders."""
|
||||
self._base_output_path = tempfile.mkdtemp()
|
||||
self._test_data_cache_path = tempfile.mkdtemp()
|
||||
self._fake_air_db_path = tempfile.mkdtemp()
|
||||
|
||||
# Fake AIR DB impulse responses.
|
||||
# TODO(alessiob): ReverberationTestDataGenerator will change to allow custom
|
||||
# impulse responses. When changed, the coupling below between
|
||||
# impulse_response_mat_file_names and
|
||||
# ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed.
|
||||
impulse_response_mat_file_names = [
|
||||
'air_binaural_lecture_0_0_1.mat',
|
||||
'air_binaural_booth_0_0_1.mat',
|
||||
]
|
||||
for impulse_response_mat_file_name in impulse_response_mat_file_names:
|
||||
data = {'h_air': np.random.rand(1, 1000).astype('<f8')}
|
||||
scipy.io.savemat(
|
||||
os.path.join(self._fake_air_db_path,
|
||||
impulse_response_mat_file_name), data)
|
||||
|
||||
def tearDown(self):
|
||||
"""Recursively delete temporary folders."""
|
||||
shutil.rmtree(self._base_output_path)
|
||||
shutil.rmtree(self._test_data_cache_path)
|
||||
shutil.rmtree(self._fake_air_db_path)
|
||||
|
||||
def testTestDataGenerators(self):
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._base_output_path))
|
||||
self.assertTrue(os.path.exists(self._test_data_cache_path))
|
||||
|
||||
# Check that there is at least one registered test data generator.
|
||||
registered_classes = (
|
||||
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
|
||||
self.assertIsInstance(registered_classes, dict)
|
||||
self.assertGreater(len(registered_classes), 0)
|
||||
|
||||
# Instance generators factory.
|
||||
generators_factory = test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path=self._fake_air_db_path,
|
||||
noise_tracks_path=test_data_generation. \
|
||||
AdditiveNoiseTestDataGenerator. \
|
||||
DEFAULT_NOISE_TRACKS_PATH,
|
||||
copy_with_identity=False)
|
||||
generators_factory.SetOutputDirectoryPrefix('datagen-')
|
||||
|
||||
# Use a simple input file as clean input signal.
|
||||
input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals',
|
||||
'tone-880.wav')
|
||||
self.assertTrue(os.path.exists(input_signal_filepath))
|
||||
|
||||
# Load input signal.
|
||||
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
input_signal_filepath)
|
||||
|
||||
# Try each registered test data generator.
|
||||
for generator_name in registered_classes:
|
||||
# Instance test data generator.
|
||||
generator = generators_factory.GetInstance(
|
||||
registered_classes[generator_name])
|
||||
|
||||
# Generate the noisy input - reference pairs.
|
||||
generator.Generate(input_signal_filepath=input_signal_filepath,
|
||||
test_data_cache_path=self._test_data_cache_path,
|
||||
base_output_path=self._base_output_path)
|
||||
|
||||
# Perform checks.
|
||||
self._CheckGeneratedPairsListSizes(generator)
|
||||
self._CheckGeneratedPairsSignalDurations(generator, input_signal)
|
||||
self._CheckGeneratedPairsOutputPaths(generator)
|
||||
|
||||
def testTestidentityDataGenerator(self):
|
||||
# Preliminary check.
|
||||
self.assertTrue(os.path.exists(self._base_output_path))
|
||||
self.assertTrue(os.path.exists(self._test_data_cache_path))
|
||||
|
||||
# Use a simple input file as clean input signal.
|
||||
input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals',
|
||||
'tone-880.wav')
|
||||
self.assertTrue(os.path.exists(input_signal_filepath))
|
||||
|
||||
def GetNoiseReferenceFilePaths(identity_generator):
|
||||
noisy_signal_filepaths = identity_generator.noisy_signal_filepaths
|
||||
reference_signal_filepaths = identity_generator.reference_signal_filepaths
|
||||
assert noisy_signal_filepaths.keys(
|
||||
) == reference_signal_filepaths.keys()
|
||||
assert len(noisy_signal_filepaths.keys()) == 1
|
||||
key = noisy_signal_filepaths.keys()[0]
|
||||
return noisy_signal_filepaths[key], reference_signal_filepaths[key]
|
||||
|
||||
# Test the `copy_with_identity` flag.
|
||||
for copy_with_identity in [False, True]:
|
||||
# Instance the generator through the factory.
|
||||
factory = test_data_generation_factory.TestDataGeneratorFactory(
|
||||
aechen_ir_database_path='',
|
||||
noise_tracks_path='',
|
||||
copy_with_identity=copy_with_identity)
|
||||
factory.SetOutputDirectoryPrefix('datagen-')
|
||||
generator = factory.GetInstance(
|
||||
test_data_generation.IdentityTestDataGenerator)
|
||||
# Check `copy_with_identity` is set correctly.
|
||||
self.assertEqual(copy_with_identity, generator.copy_with_identity)
|
||||
|
||||
# Generate test data and extract the paths to the noise and the reference
|
||||
# files.
|
||||
generator.Generate(input_signal_filepath=input_signal_filepath,
|
||||
test_data_cache_path=self._test_data_cache_path,
|
||||
base_output_path=self._base_output_path)
|
||||
noisy_signal_filepath, reference_signal_filepath = (
|
||||
GetNoiseReferenceFilePaths(generator))
|
||||
|
||||
# Check that a copy is made if and only if `copy_with_identity` is True.
|
||||
if copy_with_identity:
|
||||
self.assertNotEqual(noisy_signal_filepath,
|
||||
input_signal_filepath)
|
||||
self.assertNotEqual(reference_signal_filepath,
|
||||
input_signal_filepath)
|
||||
else:
|
||||
self.assertEqual(noisy_signal_filepath, input_signal_filepath)
|
||||
self.assertEqual(reference_signal_filepath,
|
||||
input_signal_filepath)
|
||||
|
||||
def _CheckGeneratedPairsListSizes(self, generator):
|
||||
config_names = generator.config_names
|
||||
number_of_pairs = len(config_names)
|
||||
self.assertEqual(number_of_pairs,
|
||||
len(generator.noisy_signal_filepaths))
|
||||
self.assertEqual(number_of_pairs, len(generator.apm_output_paths))
|
||||
self.assertEqual(number_of_pairs,
|
||||
len(generator.reference_signal_filepaths))
|
||||
|
||||
def _CheckGeneratedPairsSignalDurations(self, generator, input_signal):
|
||||
"""Checks duration of the generated signals.
|
||||
|
||||
Checks that the noisy input and the reference tracks are audio files
|
||||
with duration equal to or greater than that of the input signal.
|
||||
|
||||
Args:
|
||||
generator: TestDataGenerator instance.
|
||||
input_signal: AudioSegment instance.
|
||||
"""
|
||||
input_signal_length = (
|
||||
signal_processing.SignalProcessingUtils.CountSamples(input_signal))
|
||||
|
||||
# Iterate over the noisy signal - reference pairs.
|
||||
for config_name in generator.config_names:
|
||||
# Load the noisy input file.
|
||||
noisy_signal_filepath = generator.noisy_signal_filepaths[
|
||||
config_name]
|
||||
noisy_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
noisy_signal_filepath)
|
||||
|
||||
# Check noisy input signal length.
|
||||
noisy_signal_length = (signal_processing.SignalProcessingUtils.
|
||||
CountSamples(noisy_signal))
|
||||
self.assertGreaterEqual(noisy_signal_length, input_signal_length)
|
||||
|
||||
# Load the reference file.
|
||||
reference_signal_filepath = generator.reference_signal_filepaths[
|
||||
config_name]
|
||||
reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
|
||||
reference_signal_filepath)
|
||||
|
||||
# Check noisy input signal length.
|
||||
reference_signal_length = (signal_processing.SignalProcessingUtils.
|
||||
CountSamples(reference_signal))
|
||||
self.assertGreaterEqual(reference_signal_length,
|
||||
input_signal_length)
|
||||
|
||||
def _CheckGeneratedPairsOutputPaths(self, generator):
|
||||
"""Checks that the output path created by the generator exists.
|
||||
|
||||
Args:
|
||||
generator: TestDataGenerator instance.
|
||||
"""
|
||||
# Iterate over the noisy signal - reference pairs.
|
||||
for config_name in generator.config_names:
|
||||
output_path = generator.apm_output_paths[config_name]
|
||||
self.assertTrue(os.path.exists(output_path))
|
||||
@ -1,103 +0,0 @@
|
||||
// Copyright (c) 2017 The WebRTC project authors. All Rights Reserved.
|
||||
//
|
||||
// Use of this source code is governed by a BSD-style license
|
||||
// that can be found in the LICENSE file in the root of the source
|
||||
// tree. An additional intellectual property rights grant can be found
|
||||
// in the file PATENTS. All contributing project authors may
|
||||
// be found in the AUTHORS file in the root of the source tree.
|
||||
|
||||
#include "common_audio/vad/include/vad.h"
|
||||
|
||||
#include <array>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/flags/parse.h"
|
||||
#include "common_audio/wav_file.h"
|
||||
#include "rtc_base/logging.h"
|
||||
|
||||
ABSL_FLAG(std::string, i, "", "Input wav file");
|
||||
ABSL_FLAG(std::string, o, "", "VAD output file");
|
||||
|
||||
namespace webrtc {
|
||||
namespace test {
|
||||
namespace {
|
||||
|
||||
// The allowed values are 10, 20 or 30 ms.
|
||||
constexpr uint8_t kAudioFrameLengthMilliseconds = 30;
|
||||
constexpr int kMaxSampleRate = 48000;
|
||||
constexpr size_t kMaxFrameLen =
|
||||
kAudioFrameLengthMilliseconds * kMaxSampleRate / 1000;
|
||||
|
||||
constexpr uint8_t kBitmaskBuffSize = 8;
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
absl::ParseCommandLine(argc, argv);
|
||||
const std::string input_file = absl::GetFlag(FLAGS_i);
|
||||
const std::string output_file = absl::GetFlag(FLAGS_o);
|
||||
// Open wav input file and check properties.
|
||||
WavReader wav_reader(input_file);
|
||||
if (wav_reader.num_channels() != 1) {
|
||||
RTC_LOG(LS_ERROR) << "Only mono wav files supported";
|
||||
return 1;
|
||||
}
|
||||
if (wav_reader.sample_rate() > kMaxSampleRate) {
|
||||
RTC_LOG(LS_ERROR) << "Beyond maximum sample rate (" << kMaxSampleRate
|
||||
<< ")";
|
||||
return 1;
|
||||
}
|
||||
const size_t audio_frame_length = rtc::CheckedDivExact(
|
||||
kAudioFrameLengthMilliseconds * wav_reader.sample_rate(), 1000);
|
||||
if (audio_frame_length > kMaxFrameLen) {
|
||||
RTC_LOG(LS_ERROR) << "The frame size and/or the sample rate are too large.";
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Create output file and write header.
|
||||
std::ofstream out_file(output_file, std::ofstream::binary);
|
||||
const char audio_frame_length_ms = kAudioFrameLengthMilliseconds;
|
||||
out_file.write(&audio_frame_length_ms, 1); // Header.
|
||||
|
||||
// Run VAD and write decisions.
|
||||
std::unique_ptr<Vad> vad = CreateVad(Vad::Aggressiveness::kVadNormal);
|
||||
std::array<int16_t, kMaxFrameLen> samples;
|
||||
char buff = 0; // Buffer to write one bit per frame.
|
||||
uint8_t next = 0; // Points to the next bit to write in `buff`.
|
||||
while (true) {
|
||||
// Process frame.
|
||||
const auto read_samples =
|
||||
wav_reader.ReadSamples(audio_frame_length, samples.data());
|
||||
if (read_samples < audio_frame_length)
|
||||
break;
|
||||
const auto is_speech = vad->VoiceActivity(
|
||||
samples.data(), audio_frame_length, wav_reader.sample_rate());
|
||||
|
||||
// Write output.
|
||||
buff = is_speech ? buff | (1 << next) : buff & ~(1 << next);
|
||||
if (++next == kBitmaskBuffSize) {
|
||||
out_file.write(&buff, 1); // Flush.
|
||||
buff = 0; // Reset.
|
||||
next = 0;
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize.
|
||||
char extra_bits = 0;
|
||||
if (next > 0) {
|
||||
extra_bits = kBitmaskBuffSize - next;
|
||||
out_file.write(&buff, 1); // Flush.
|
||||
}
|
||||
out_file.write(&extra_bits, 1);
|
||||
out_file.close();
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
return webrtc::test::main(argc, argv);
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user