Reformat python files checked by pylint (part 1/2).

After recently changing .pylintrc (see [1]) we discovered that
the presubmit check always checks all the python files when just
one python file gets updated.

This CL moves all these files one step closer to what the linter
wants.

Autogenerated with:

# Added all the files under pylint control to ~/Desktop/to-reformat
cat ~/Desktop/to-reformat | xargs sed -i '1i\\'
git cl format --python --full

This is part 1 out of 2. The second part will fix function names and
will not be automated.

[1] - https://webrtc-review.googlesource.com/c/src/+/186664

No-Presubmit: True
Bug: webrtc:12114
Change-Id: Idfec4d759f209a2090440d0af2413a1ddc01b841
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/190980
Commit-Queue: Mirko Bonadei <mbonadei@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32530}
This commit is contained in:
Mirko Bonadei 2020-10-30 10:13:45 +01:00 committed by Commit Bot
parent d3a3e9ef36
commit 8cc6695652
93 changed files with 9936 additions and 9285 deletions

File diff suppressed because it is too large Load Diff

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""
This script is the wrapper that runs the low-bandwidth audio test.
@ -23,315 +22,352 @@ import shutil
import subprocess
import sys
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SRC_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, os.pardir, os.pardir))
NO_TOOLS_ERROR_MESSAGE = (
'Could not find PESQ or POLQA at %s.\n'
'\n'
'To fix this run:\n'
' python %s %s\n'
'\n'
'Note that these tools are Google-internal due to licensing, so in order to '
'use them you will have to get your own license and manually put them in the '
'right location.\n'
'See https://cs.chromium.org/chromium/src/third_party/webrtc/tools_webrtc/'
'download_tools.py?rcl=bbceb76f540159e2dba0701ac03c514f01624130&l=13')
'Could not find PESQ or POLQA at %s.\n'
'\n'
'To fix this run:\n'
' python %s %s\n'
'\n'
'Note that these tools are Google-internal due to licensing, so in order to '
'use them you will have to get your own license and manually put them in the '
'right location.\n'
'See https://cs.chromium.org/chromium/src/third_party/webrtc/tools_webrtc/'
'download_tools.py?rcl=bbceb76f540159e2dba0701ac03c514f01624130&l=13')
def _LogCommand(command):
logging.info('Running %r', command)
return command
logging.info('Running %r', command)
return command
def _ParseArgs():
parser = argparse.ArgumentParser(description='Run low-bandwidth audio tests.')
parser.add_argument('build_dir',
help='Path to the build directory (e.g. out/Release).')
parser.add_argument('--remove', action='store_true',
help='Remove output audio files after testing.')
parser.add_argument('--android', action='store_true',
help='Perform the test on a connected Android device instead.')
parser.add_argument('--adb-path', help='Path to adb binary.', default='adb')
parser.add_argument('--num-retries', default='0',
help='Number of times to retry the test on Android.')
parser.add_argument('--isolated-script-test-perf-output', default=None,
help='Path to store perf results in histogram proto format.')
parser.add_argument('--extra-test-args', default=[], action='append',
help='Extra args to path to the test binary.')
parser = argparse.ArgumentParser(
description='Run low-bandwidth audio tests.')
parser.add_argument('build_dir',
help='Path to the build directory (e.g. out/Release).')
parser.add_argument('--remove',
action='store_true',
help='Remove output audio files after testing.')
parser.add_argument(
'--android',
action='store_true',
help='Perform the test on a connected Android device instead.')
parser.add_argument('--adb-path',
help='Path to adb binary.',
default='adb')
parser.add_argument('--num-retries',
default='0',
help='Number of times to retry the test on Android.')
parser.add_argument(
'--isolated-script-test-perf-output',
default=None,
help='Path to store perf results in histogram proto format.')
parser.add_argument('--extra-test-args',
default=[],
action='append',
help='Extra args to path to the test binary.')
# Ignore Chromium-specific flags
parser.add_argument('--test-launcher-summary-output',
type=str, default=None)
args = parser.parse_args()
# Ignore Chromium-specific flags
parser.add_argument('--test-launcher-summary-output',
type=str,
default=None)
args = parser.parse_args()
return args
return args
def _GetPlatform():
if sys.platform == 'win32':
return 'win'
elif sys.platform == 'darwin':
return 'mac'
elif sys.platform.startswith('linux'):
return 'linux'
if sys.platform == 'win32':
return 'win'
elif sys.platform == 'darwin':
return 'mac'
elif sys.platform.startswith('linux'):
return 'linux'
def _GetExtension():
return '.exe' if sys.platform == 'win32' else ''
return '.exe' if sys.platform == 'win32' else ''
def _GetPathToTools():
tools_dir = os.path.join(SRC_DIR, 'tools_webrtc')
toolchain_dir = os.path.join(tools_dir, 'audio_quality')
tools_dir = os.path.join(SRC_DIR, 'tools_webrtc')
toolchain_dir = os.path.join(tools_dir, 'audio_quality')
platform = _GetPlatform()
ext = _GetExtension()
platform = _GetPlatform()
ext = _GetExtension()
pesq_path = os.path.join(toolchain_dir, platform, 'pesq' + ext)
if not os.path.isfile(pesq_path):
pesq_path = None
pesq_path = os.path.join(toolchain_dir, platform, 'pesq' + ext)
if not os.path.isfile(pesq_path):
pesq_path = None
polqa_path = os.path.join(toolchain_dir, platform, 'PolqaOem64' + ext)
if not os.path.isfile(polqa_path):
polqa_path = None
polqa_path = os.path.join(toolchain_dir, platform, 'PolqaOem64' + ext)
if not os.path.isfile(polqa_path):
polqa_path = None
if (platform != 'mac' and not polqa_path) or not pesq_path:
logging.error(NO_TOOLS_ERROR_MESSAGE,
toolchain_dir,
os.path.join(tools_dir, 'download_tools.py'),
toolchain_dir)
if (platform != 'mac' and not polqa_path) or not pesq_path:
logging.error(NO_TOOLS_ERROR_MESSAGE, toolchain_dir,
os.path.join(tools_dir, 'download_tools.py'),
toolchain_dir)
return pesq_path, polqa_path
return pesq_path, polqa_path
def ExtractTestRuns(lines, echo=False):
"""Extracts information about tests from the output of a test runner.
"""Extracts information about tests from the output of a test runner.
Produces tuples
(android_device, test_name, reference_file, degraded_file, cur_perf_results).
"""
for line in lines:
if echo:
sys.stdout.write(line)
for line in lines:
if echo:
sys.stdout.write(line)
# Output from Android has a prefix with the device name.
android_prefix_re = r'(?:I\b.+\brun_tests_on_device\((.+?)\)\s*)?'
test_re = r'^' + android_prefix_re + (r'TEST (\w+) ([^ ]+?) ([^\s]+)'
r' ?([^\s]+)?\s*$')
# Output from Android has a prefix with the device name.
android_prefix_re = r'(?:I\b.+\brun_tests_on_device\((.+?)\)\s*)?'
test_re = r'^' + android_prefix_re + (r'TEST (\w+) ([^ ]+?) ([^\s]+)'
r' ?([^\s]+)?\s*$')
match = re.search(test_re, line)
if match:
yield match.groups()
match = re.search(test_re, line)
if match:
yield match.groups()
def _GetFile(file_path, out_dir, move=False,
android=False, adb_prefix=('adb',)):
out_file_name = os.path.basename(file_path)
out_file_path = os.path.join(out_dir, out_file_name)
def _GetFile(file_path,
out_dir,
move=False,
android=False,
adb_prefix=('adb', )):
out_file_name = os.path.basename(file_path)
out_file_path = os.path.join(out_dir, out_file_name)
if android:
# Pull the file from the connected Android device.
adb_command = adb_prefix + ('pull', file_path, out_dir)
subprocess.check_call(_LogCommand(adb_command))
if move:
# Remove that file.
adb_command = adb_prefix + ('shell', 'rm', file_path)
subprocess.check_call(_LogCommand(adb_command))
elif os.path.abspath(file_path) != os.path.abspath(out_file_path):
if move:
shutil.move(file_path, out_file_path)
else:
shutil.copy(file_path, out_file_path)
if android:
# Pull the file from the connected Android device.
adb_command = adb_prefix + ('pull', file_path, out_dir)
subprocess.check_call(_LogCommand(adb_command))
if move:
# Remove that file.
adb_command = adb_prefix + ('shell', 'rm', file_path)
subprocess.check_call(_LogCommand(adb_command))
elif os.path.abspath(file_path) != os.path.abspath(out_file_path):
if move:
shutil.move(file_path, out_file_path)
else:
shutil.copy(file_path, out_file_path)
return out_file_path
return out_file_path
def _RunPesq(executable_path, reference_file, degraded_file,
def _RunPesq(executable_path,
reference_file,
degraded_file,
sample_rate_hz=16000):
directory = os.path.dirname(reference_file)
assert os.path.dirname(degraded_file) == directory
directory = os.path.dirname(reference_file)
assert os.path.dirname(degraded_file) == directory
# Analyze audio.
command = [executable_path, '+%d' % sample_rate_hz,
os.path.basename(reference_file),
os.path.basename(degraded_file)]
# Need to provide paths in the current directory due to a bug in PESQ:
# On Mac, for some 'path/to/file.wav', if 'file.wav' is longer than
# 'path/to', PESQ crashes.
out = subprocess.check_output(_LogCommand(command),
cwd=directory, stderr=subprocess.STDOUT)
# Analyze audio.
command = [
executable_path,
'+%d' % sample_rate_hz,
os.path.basename(reference_file),
os.path.basename(degraded_file)
]
# Need to provide paths in the current directory due to a bug in PESQ:
# On Mac, for some 'path/to/file.wav', if 'file.wav' is longer than
# 'path/to', PESQ crashes.
out = subprocess.check_output(_LogCommand(command),
cwd=directory,
stderr=subprocess.STDOUT)
# Find the scores in stdout of PESQ.
match = re.search(
r'Prediction \(Raw MOS, MOS-LQO\):\s+=\s+([\d.]+)\s+([\d.]+)', out)
if match:
raw_mos, _ = match.groups()
# Find the scores in stdout of PESQ.
match = re.search(
r'Prediction \(Raw MOS, MOS-LQO\):\s+=\s+([\d.]+)\s+([\d.]+)', out)
if match:
raw_mos, _ = match.groups()
return {'pesq_mos': (raw_mos, 'unitless')}
else:
logging.error('PESQ: %s', out.splitlines()[-1])
return {}
return {'pesq_mos': (raw_mos, 'unitless')}
else:
logging.error('PESQ: %s', out.splitlines()[-1])
return {}
def _RunPolqa(executable_path, reference_file, degraded_file):
# Analyze audio.
command = [executable_path, '-q', '-LC', 'NB',
'-Ref', reference_file, '-Test', degraded_file]
process = subprocess.Popen(_LogCommand(command),
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, err = process.communicate()
# Analyze audio.
command = [
executable_path, '-q', '-LC', 'NB', '-Ref', reference_file, '-Test',
degraded_file
]
process = subprocess.Popen(_LogCommand(command),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
out, err = process.communicate()
# Find the scores in stdout of POLQA.
match = re.search(r'\bMOS-LQO:\s+([\d.]+)', out)
# Find the scores in stdout of POLQA.
match = re.search(r'\bMOS-LQO:\s+([\d.]+)', out)
if process.returncode != 0 or not match:
if process.returncode == 2:
logging.warning('%s (2)', err.strip())
logging.warning('POLQA license error, skipping test.')
else:
logging.error('%s (%d)', err.strip(), process.returncode)
return {}
if process.returncode != 0 or not match:
if process.returncode == 2:
logging.warning('%s (2)', err.strip())
logging.warning('POLQA license error, skipping test.')
else:
logging.error('%s (%d)', err.strip(), process.returncode)
return {}
mos_lqo, = match.groups()
return {'polqa_mos_lqo': (mos_lqo, 'unitless')}
mos_lqo, = match.groups()
return {'polqa_mos_lqo': (mos_lqo, 'unitless')}
def _MergeInPerfResultsFromCcTests(histograms, run_perf_results_file):
from tracing.value import histogram_set
from tracing.value import histogram_set
cc_histograms = histogram_set.HistogramSet()
with open(run_perf_results_file, 'rb') as f:
contents = f.read()
if not contents:
return
cc_histograms = histogram_set.HistogramSet()
with open(run_perf_results_file, 'rb') as f:
contents = f.read()
if not contents:
return
cc_histograms.ImportProto(contents)
cc_histograms.ImportProto(contents)
histograms.Merge(cc_histograms)
histograms.Merge(cc_histograms)
Analyzer = collections.namedtuple('Analyzer', ['name', 'func', 'executable',
'sample_rate_hz'])
Analyzer = collections.namedtuple(
'Analyzer', ['name', 'func', 'executable', 'sample_rate_hz'])
def _ConfigurePythonPath(args):
script_dir = os.path.dirname(os.path.realpath(__file__))
checkout_root = os.path.abspath(
os.path.join(script_dir, os.pardir, os.pardir))
script_dir = os.path.dirname(os.path.realpath(__file__))
checkout_root = os.path.abspath(
os.path.join(script_dir, os.pardir, os.pardir))
# TODO(https://crbug.com/1029452): Use a copy rule and add these from the out
# dir like for the third_party/protobuf code.
sys.path.insert(0, os.path.join(checkout_root, 'third_party', 'catapult',
'tracing'))
# TODO(https://crbug.com/1029452): Use a copy rule and add these from the out
# dir like for the third_party/protobuf code.
sys.path.insert(
0, os.path.join(checkout_root, 'third_party', 'catapult', 'tracing'))
# The low_bandwidth_audio_perf_test gn rule will build the protobuf stub for
# python, so put it in the path for this script before we attempt to import
# it.
histogram_proto_path = os.path.join(
os.path.abspath(args.build_dir), 'pyproto', 'tracing', 'tracing', 'proto')
sys.path.insert(0, histogram_proto_path)
proto_stub_path = os.path.join(os.path.abspath(args.build_dir), 'pyproto')
sys.path.insert(0, proto_stub_path)
# The low_bandwidth_audio_perf_test gn rule will build the protobuf stub for
# python, so put it in the path for this script before we attempt to import
# it.
histogram_proto_path = os.path.join(os.path.abspath(args.build_dir),
'pyproto', 'tracing', 'tracing',
'proto')
sys.path.insert(0, histogram_proto_path)
proto_stub_path = os.path.join(os.path.abspath(args.build_dir), 'pyproto')
sys.path.insert(0, proto_stub_path)
# Fail early in case the proto hasn't been built.
try:
import histogram_pb2
except ImportError as e:
logging.exception(e)
raise ImportError('Could not import histogram_pb2. You need to build the '
'low_bandwidth_audio_perf_test target before invoking '
'this script. Expected to find '
'histogram_pb2.py in %s.' % histogram_proto_path)
# Fail early in case the proto hasn't been built.
try:
import histogram_pb2
except ImportError as e:
logging.exception(e)
raise ImportError(
'Could not import histogram_pb2. You need to build the '
'low_bandwidth_audio_perf_test target before invoking '
'this script. Expected to find '
'histogram_pb2.py in %s.' % histogram_proto_path)
def main():
# pylint: disable=W0101
logging.basicConfig(level=logging.INFO)
logging.info('Invoked with %s', str(sys.argv))
# pylint: disable=W0101
logging.basicConfig(level=logging.INFO)
logging.info('Invoked with %s', str(sys.argv))
args = _ParseArgs()
args = _ParseArgs()
_ConfigurePythonPath(args)
_ConfigurePythonPath(args)
# Import catapult modules here after configuring the pythonpath.
from tracing.value import histogram_set
from tracing.value.diagnostics import reserved_infos
from tracing.value.diagnostics import generic_set
# Import catapult modules here after configuring the pythonpath.
from tracing.value import histogram_set
from tracing.value.diagnostics import reserved_infos
from tracing.value.diagnostics import generic_set
pesq_path, polqa_path = _GetPathToTools()
if pesq_path is None:
return 1
pesq_path, polqa_path = _GetPathToTools()
if pesq_path is None:
return 1
out_dir = os.path.join(args.build_dir, '..')
if args.android:
test_command = [os.path.join(args.build_dir, 'bin',
'run_low_bandwidth_audio_test'),
'-v', '--num-retries', args.num_retries]
else:
test_command = [os.path.join(args.build_dir, 'low_bandwidth_audio_test')]
out_dir = os.path.join(args.build_dir, '..')
if args.android:
test_command = [
os.path.join(args.build_dir, 'bin',
'run_low_bandwidth_audio_test'), '-v',
'--num-retries', args.num_retries
]
else:
test_command = [
os.path.join(args.build_dir, 'low_bandwidth_audio_test')
]
analyzers = [Analyzer('pesq', _RunPesq, pesq_path, 16000)]
# Check if POLQA can run at all, or skip the 48 kHz tests entirely.
example_path = os.path.join(SRC_DIR, 'resources',
'voice_engine', 'audio_tiny48.wav')
if polqa_path and _RunPolqa(polqa_path, example_path, example_path):
analyzers.append(Analyzer('polqa', _RunPolqa, polqa_path, 48000))
analyzers = [Analyzer('pesq', _RunPesq, pesq_path, 16000)]
# Check if POLQA can run at all, or skip the 48 kHz tests entirely.
example_path = os.path.join(SRC_DIR, 'resources', 'voice_engine',
'audio_tiny48.wav')
if polqa_path and _RunPolqa(polqa_path, example_path, example_path):
analyzers.append(Analyzer('polqa', _RunPolqa, polqa_path, 48000))
histograms = histogram_set.HistogramSet()
for analyzer in analyzers:
# Start the test executable that produces audio files.
test_process = subprocess.Popen(
_LogCommand(test_command + [
histograms = histogram_set.HistogramSet()
for analyzer in analyzers:
# Start the test executable that produces audio files.
test_process = subprocess.Popen(_LogCommand(test_command + [
'--sample_rate_hz=%d' % analyzer.sample_rate_hz,
'--test_case_prefix=%s' % analyzer.name,
] + args.extra_test_args),
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
perf_results_file = None
try:
lines = iter(test_process.stdout.readline, '')
for result in ExtractTestRuns(lines, echo=True):
(android_device, test_name, reference_file, degraded_file,
perf_results_file) = result
] + args.extra_test_args),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
perf_results_file = None
try:
lines = iter(test_process.stdout.readline, '')
for result in ExtractTestRuns(lines, echo=True):
(android_device, test_name, reference_file, degraded_file,
perf_results_file) = result
adb_prefix = (args.adb_path,)
if android_device:
adb_prefix += ('-s', android_device)
adb_prefix = (args.adb_path, )
if android_device:
adb_prefix += ('-s', android_device)
reference_file = _GetFile(reference_file, out_dir,
android=args.android, adb_prefix=adb_prefix)
degraded_file = _GetFile(degraded_file, out_dir, move=True,
android=args.android, adb_prefix=adb_prefix)
reference_file = _GetFile(reference_file,
out_dir,
android=args.android,
adb_prefix=adb_prefix)
degraded_file = _GetFile(degraded_file,
out_dir,
move=True,
android=args.android,
adb_prefix=adb_prefix)
analyzer_results = analyzer.func(analyzer.executable,
reference_file, degraded_file)
for metric, (value, units) in analyzer_results.items():
hist = histograms.CreateHistogram(metric, units, [value])
user_story = generic_set.GenericSet([test_name])
hist.diagnostics[reserved_infos.STORIES.name] = user_story
analyzer_results = analyzer.func(analyzer.executable,
reference_file, degraded_file)
for metric, (value, units) in analyzer_results.items():
hist = histograms.CreateHistogram(metric, units, [value])
user_story = generic_set.GenericSet([test_name])
hist.diagnostics[reserved_infos.STORIES.name] = user_story
# Output human readable results.
print 'RESULT %s: %s= %s %s' % (metric, test_name, value, units)
# Output human readable results.
print 'RESULT %s: %s= %s %s' % (metric, test_name, value,
units)
if args.remove:
os.remove(reference_file)
os.remove(degraded_file)
finally:
test_process.terminate()
if perf_results_file:
perf_results_file = _GetFile(perf_results_file, out_dir, move=True,
android=args.android, adb_prefix=adb_prefix)
_MergeInPerfResultsFromCcTests(histograms, perf_results_file)
if args.remove:
os.remove(perf_results_file)
if args.remove:
os.remove(reference_file)
os.remove(degraded_file)
finally:
test_process.terminate()
if perf_results_file:
perf_results_file = _GetFile(perf_results_file,
out_dir,
move=True,
android=args.android,
adb_prefix=adb_prefix)
_MergeInPerfResultsFromCcTests(histograms, perf_results_file)
if args.remove:
os.remove(perf_results_file)
if args.isolated_script_test_perf_output:
with open(args.isolated_script_test_perf_output, 'wb') as f:
f.write(histograms.AsProto().SerializeToString())
if args.isolated_script_test_perf_output:
with open(args.isolated_script_test_perf_output, 'wb') as f:
f.write(histograms.AsProto().SerializeToString())
return test_process.wait()
return test_process.wait()
if __name__ == '__main__':
sys.exit(main())
sys.exit(main())

View File

@ -11,7 +11,6 @@ import os
import unittest
import sys
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PARENT_DIR = os.path.join(SCRIPT_DIR, os.pardir)
sys.path.append(PARENT_DIR)
@ -19,46 +18,51 @@ import low_bandwidth_audio_test
class TestExtractTestRuns(unittest.TestCase):
def _TestLog(self, log, *expected):
self.assertEqual(
tuple(low_bandwidth_audio_test.ExtractTestRuns(log.splitlines(True))),
expected)
def _TestLog(self, log, *expected):
self.assertEqual(
tuple(
low_bandwidth_audio_test.ExtractTestRuns(
log.splitlines(True))), expected)
def testLinux(self):
self._TestLog(LINUX_LOG,
(None, 'GoodNetworkHighBitrate',
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
'/webrtc/src/out/LowBandwidth_GoodNetworkHighBitrate.wav', None),
(None, 'Mobile2GNetwork',
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
'/webrtc/src/out/LowBandwidth_Mobile2GNetwork.wav', None),
(None, 'PCGoodNetworkHighBitrate',
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
'/webrtc/src/out/PCLowBandwidth_PCGoodNetworkHighBitrate.wav',
'/webrtc/src/out/PCLowBandwidth_perf_48.json'),
(None, 'PCMobile2GNetwork',
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
'/webrtc/src/out/PCLowBandwidth_PCMobile2GNetwork.wav',
'/webrtc/src/out/PCLowBandwidth_perf_48.json'))
def testLinux(self):
self._TestLog(
LINUX_LOG,
(None, 'GoodNetworkHighBitrate',
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
'/webrtc/src/out/LowBandwidth_GoodNetworkHighBitrate.wav', None),
(None, 'Mobile2GNetwork',
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
'/webrtc/src/out/LowBandwidth_Mobile2GNetwork.wav', None),
(None, 'PCGoodNetworkHighBitrate',
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
'/webrtc/src/out/PCLowBandwidth_PCGoodNetworkHighBitrate.wav',
'/webrtc/src/out/PCLowBandwidth_perf_48.json'),
(None, 'PCMobile2GNetwork',
'/webrtc/src/resources/voice_engine/audio_tiny16.wav',
'/webrtc/src/out/PCLowBandwidth_PCMobile2GNetwork.wav',
'/webrtc/src/out/PCLowBandwidth_perf_48.json'))
def testAndroid(self):
self._TestLog(ANDROID_LOG,
('ddfa6149', 'Mobile2GNetwork',
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
'/sdcard/chromium_tests_root/LowBandwidth_Mobile2GNetwork.wav', None),
('TA99205CNO', 'GoodNetworkHighBitrate',
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
'/sdcard/chromium_tests_root/LowBandwidth_GoodNetworkHighBitrate.wav',
None),
('ddfa6149', 'PCMobile2GNetwork',
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
'/sdcard/chromium_tests_root/PCLowBandwidth_PCMobile2GNetwork.wav',
'/sdcard/chromium_tests_root/PCLowBandwidth_perf_48.json'),
('TA99205CNO', 'PCGoodNetworkHighBitrate',
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
('/sdcard/chromium_tests_root/'
'PCLowBandwidth_PCGoodNetworkHighBitrate.wav'),
'/sdcard/chromium_tests_root/PCLowBandwidth_perf_48.json'))
def testAndroid(self):
self._TestLog(ANDROID_LOG, (
'ddfa6149', 'Mobile2GNetwork',
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
'/sdcard/chromium_tests_root/LowBandwidth_Mobile2GNetwork.wav',
None
), (
'TA99205CNO', 'GoodNetworkHighBitrate',
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
'/sdcard/chromium_tests_root/LowBandwidth_GoodNetworkHighBitrate.wav',
None
), (
'ddfa6149', 'PCMobile2GNetwork',
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
'/sdcard/chromium_tests_root/PCLowBandwidth_PCMobile2GNetwork.wav',
'/sdcard/chromium_tests_root/PCLowBandwidth_perf_48.json'
), ('TA99205CNO', 'PCGoodNetworkHighBitrate',
'/sdcard/chromium_tests_root/resources/voice_engine/audio_tiny16.wav',
('/sdcard/chromium_tests_root/'
'PCLowBandwidth_PCGoodNetworkHighBitrate.wav'),
'/sdcard/chromium_tests_root/PCLowBandwidth_perf_48.json'))
LINUX_LOG = r'''\
@ -233,6 +237,5 @@ I 16.608s tear_down_device(ddfa6149) Wrote device cache: /webrtc/src/out/debu
I 16.608s tear_down_device(TA99205CNO) Wrote device cache: /webrtc/src/out/debug-android/device_cache_TA99305CMO.json
'''
if __name__ == "__main__":
unittest.main()
unittest.main()

View File

@ -15,110 +15,113 @@ import time
from com.android.monkeyrunner import MonkeyRunner, MonkeyDevice
def main():
parser = OptionParser()
parser = OptionParser()
parser.add_option('--devname', dest='devname', help='The device id')
parser.add_option('--devname', dest='devname', help='The device id')
parser.add_option(
'--videooutsave',
dest='videooutsave',
help='The path where to save the video out file on local computer')
parser.add_option(
'--videooutsave',
dest='videooutsave',
help='The path where to save the video out file on local computer')
parser.add_option(
'--videoout',
dest='videoout',
help='The path where to put the video out file')
parser.add_option('--videoout',
dest='videoout',
help='The path where to put the video out file')
parser.add_option(
'--videoout_width',
dest='videoout_width',
type='int',
help='The width for the video out file')
parser.add_option('--videoout_width',
dest='videoout_width',
type='int',
help='The width for the video out file')
parser.add_option(
'--videoout_height',
dest='videoout_height',
type='int',
help='The height for the video out file')
parser.add_option('--videoout_height',
dest='videoout_height',
type='int',
help='The height for the video out file')
parser.add_option(
'--videoin',
dest='videoin',
help='The path where to read input file instead of camera')
parser.add_option(
'--videoin',
dest='videoin',
help='The path where to read input file instead of camera')
parser.add_option(
'--call_length',
dest='call_length',
type='int',
help='The length of the call')
parser.add_option('--call_length',
dest='call_length',
type='int',
help='The length of the call')
(options, args) = parser.parse_args()
(options, args) = parser.parse_args()
print (options, args)
print(options, args)
devname = options.devname
devname = options.devname
videoin = options.videoin
videoin = options.videoin
videoout = options.videoout
videoout_width = options.videoout_width
videoout_height = options.videoout_height
videoout = options.videoout
videoout_width = options.videoout_width
videoout_height = options.videoout_height
videooutsave = options.videooutsave
videooutsave = options.videooutsave
call_length = options.call_length or 10
call_length = options.call_length or 10
room = ''.join(random.choice(string.ascii_letters + string.digits)
for _ in range(8))
room = ''.join(
random.choice(string.ascii_letters + string.digits) for _ in range(8))
# Delete output video file.
if videoout:
subprocess.check_call(['adb', '-s', devname, 'shell', 'rm', '-f',
videoout])
# Delete output video file.
if videoout:
subprocess.check_call(
['adb', '-s', devname, 'shell', 'rm', '-f', videoout])
device = MonkeyRunner.waitForConnection(2, devname)
device = MonkeyRunner.waitForConnection(2, devname)
extras = {
'org.appspot.apprtc.USE_VALUES_FROM_INTENT': True,
'org.appspot.apprtc.AUDIOCODEC': 'OPUS',
'org.appspot.apprtc.LOOPBACK': True,
'org.appspot.apprtc.VIDEOCODEC': 'VP8',
'org.appspot.apprtc.CAPTURETOTEXTURE': False,
'org.appspot.apprtc.CAMERA2': False,
'org.appspot.apprtc.ROOMID': room}
extras = {
'org.appspot.apprtc.USE_VALUES_FROM_INTENT': True,
'org.appspot.apprtc.AUDIOCODEC': 'OPUS',
'org.appspot.apprtc.LOOPBACK': True,
'org.appspot.apprtc.VIDEOCODEC': 'VP8',
'org.appspot.apprtc.CAPTURETOTEXTURE': False,
'org.appspot.apprtc.CAMERA2': False,
'org.appspot.apprtc.ROOMID': room
}
if videoin:
extras.update({'org.appspot.apprtc.VIDEO_FILE_AS_CAMERA': videoin})
if videoin:
extras.update({'org.appspot.apprtc.VIDEO_FILE_AS_CAMERA': videoin})
if videoout:
extras.update({
'org.appspot.apprtc.SAVE_REMOTE_VIDEO_TO_FILE': videoout,
'org.appspot.apprtc.SAVE_REMOTE_VIDEO_TO_FILE_WIDTH': videoout_width,
'org.appspot.apprtc.SAVE_REMOTE_VIDEO_TO_FILE_HEIGHT': videoout_height})
if videoout:
extras.update({
'org.appspot.apprtc.SAVE_REMOTE_VIDEO_TO_FILE':
videoout,
'org.appspot.apprtc.SAVE_REMOTE_VIDEO_TO_FILE_WIDTH':
videoout_width,
'org.appspot.apprtc.SAVE_REMOTE_VIDEO_TO_FILE_HEIGHT':
videoout_height
})
print extras
print extras
device.startActivity(data='https://appr.tc',
action='android.intent.action.VIEW',
component='org.appspot.apprtc/.ConnectActivity', extras=extras)
device.startActivity(data='https://appr.tc',
action='android.intent.action.VIEW',
component='org.appspot.apprtc/.ConnectActivity',
extras=extras)
print 'Running a call for %d seconds' % call_length
for _ in xrange(call_length):
sys.stdout.write('.')
sys.stdout.flush()
time.sleep(1)
print '\nEnding call.'
print 'Running a call for %d seconds' % call_length
for _ in xrange(call_length):
sys.stdout.write('.')
sys.stdout.flush()
time.sleep(1)
print '\nEnding call.'
# Press back to end the call. Will end on both sides.
device.press('KEYCODE_BACK', MonkeyDevice.DOWN_AND_UP)
# Press back to end the call. Will end on both sides.
device.press('KEYCODE_BACK', MonkeyDevice.DOWN_AND_UP)
if videooutsave:
time.sleep(2)
if videooutsave:
time.sleep(2)
subprocess.check_call(
['adb', '-s', devname, 'pull', videoout, videooutsave])
subprocess.check_call(['adb', '-s', devname, 'pull',
videoout, videooutsave])
if __name__ == '__main__':
main()
main()

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""
This scripts tests creating an Android Studio project using the
generate_gradle.py script and making a debug build using it.
@ -23,58 +22,59 @@ import subprocess
import sys
import tempfile
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SRC_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, os.pardir, os.pardir))
GENERATE_GRADLE_SCRIPT = os.path.join(SRC_DIR,
'build/android/gradle/generate_gradle.py')
GENERATE_GRADLE_SCRIPT = os.path.join(
SRC_DIR, 'build/android/gradle/generate_gradle.py')
GRADLEW_BIN = os.path.join(SCRIPT_DIR, 'third_party/gradle/gradlew')
def _RunCommand(argv, cwd=SRC_DIR, **kwargs):
logging.info('Running %r', argv)
subprocess.check_call(argv, cwd=cwd, **kwargs)
logging.info('Running %r', argv)
subprocess.check_call(argv, cwd=cwd, **kwargs)
def _ParseArgs():
parser = argparse.ArgumentParser(
description='Test generating Android gradle project.')
parser.add_argument('build_dir_android',
help='The path to the build directory for Android.')
parser.add_argument('--project_dir',
help='A temporary directory to put the output.')
parser = argparse.ArgumentParser(
description='Test generating Android gradle project.')
parser.add_argument('build_dir_android',
help='The path to the build directory for Android.')
parser.add_argument('--project_dir',
help='A temporary directory to put the output.')
args = parser.parse_args()
return args
args = parser.parse_args()
return args
def main():
logging.basicConfig(level=logging.INFO)
args = _ParseArgs()
logging.basicConfig(level=logging.INFO)
args = _ParseArgs()
project_dir = args.project_dir
if not project_dir:
project_dir = tempfile.mkdtemp()
project_dir = args.project_dir
if not project_dir:
project_dir = tempfile.mkdtemp()
output_dir = os.path.abspath(args.build_dir_android)
project_dir = os.path.abspath(project_dir)
output_dir = os.path.abspath(args.build_dir_android)
project_dir = os.path.abspath(project_dir)
try:
env = os.environ.copy()
env['PATH'] = os.pathsep.join([
os.path.join(SRC_DIR, 'third_party', 'depot_tools'), env.get('PATH', '')
])
_RunCommand([GENERATE_GRADLE_SCRIPT, '--output-directory', output_dir,
'--target', '//examples:AppRTCMobile',
'--project-dir', project_dir,
'--use-gradle-process-resources', '--split-projects'],
env=env)
_RunCommand([GRADLEW_BIN, 'assembleDebug'], project_dir)
finally:
# Do not delete temporary directory if user specified it manually.
if not args.project_dir:
shutil.rmtree(project_dir, True)
try:
env = os.environ.copy()
env['PATH'] = os.pathsep.join([
os.path.join(SRC_DIR, 'third_party', 'depot_tools'),
env.get('PATH', '')
])
_RunCommand([
GENERATE_GRADLE_SCRIPT, '--output-directory', output_dir,
'--target', '//examples:AppRTCMobile', '--project-dir',
project_dir, '--use-gradle-process-resources', '--split-projects'
],
env=env)
_RunCommand([GRADLEW_BIN, 'assembleDebug'], project_dir)
finally:
# Do not delete temporary directory if user specified it manually.
if not args.project_dir:
shutil.rmtree(project_dir, True)
if __name__ == '__main__':
sys.exit(main())
sys.exit(main())

View File

@ -24,124 +24,126 @@ import debug_dump_pb2
def GetNextMessageSize(file_to_parse):
data = file_to_parse.read(4)
if data == '':
return 0
return struct.unpack('<I', data)[0]
data = file_to_parse.read(4)
if data == '':
return 0
return struct.unpack('<I', data)[0]
def GetNextMessageFromFile(file_to_parse):
message_size = GetNextMessageSize(file_to_parse)
if message_size == 0:
return None
try:
event = debug_dump_pb2.Event()
event.ParseFromString(file_to_parse.read(message_size))
except IOError:
print 'Invalid message in file'
return None
return event
message_size = GetNextMessageSize(file_to_parse)
if message_size == 0:
return None
try:
event = debug_dump_pb2.Event()
event.ParseFromString(file_to_parse.read(message_size))
except IOError:
print 'Invalid message in file'
return None
return event
def InitMetrics():
metrics = {}
event = debug_dump_pb2.Event()
for metric in event.network_metrics.DESCRIPTOR.fields:
metrics[metric.name] = {'time': [], 'value': []}
return metrics
metrics = {}
event = debug_dump_pb2.Event()
for metric in event.network_metrics.DESCRIPTOR.fields:
metrics[metric.name] = {'time': [], 'value': []}
return metrics
def InitDecisions():
decisions = {}
event = debug_dump_pb2.Event()
for decision in event.encoder_runtime_config.DESCRIPTOR.fields:
decisions[decision.name] = {'time': [], 'value': []}
return decisions
decisions = {}
event = debug_dump_pb2.Event()
for decision in event.encoder_runtime_config.DESCRIPTOR.fields:
decisions[decision.name] = {'time': [], 'value': []}
return decisions
def ParseAnaDump(dump_file_to_parse):
with open(dump_file_to_parse, 'rb') as file_to_parse:
metrics = InitMetrics()
decisions = InitDecisions()
first_time_stamp = None
while True:
event = GetNextMessageFromFile(file_to_parse)
if event is None:
break
if first_time_stamp is None:
first_time_stamp = event.timestamp
if event.type == debug_dump_pb2.Event.ENCODER_RUNTIME_CONFIG:
for decision in event.encoder_runtime_config.DESCRIPTOR.fields:
if event.encoder_runtime_config.HasField(decision.name):
decisions[decision.name]['time'].append(event.timestamp -
first_time_stamp)
decisions[decision.name]['value'].append(
getattr(event.encoder_runtime_config, decision.name))
if event.type == debug_dump_pb2.Event.NETWORK_METRICS:
for metric in event.network_metrics.DESCRIPTOR.fields:
if event.network_metrics.HasField(metric.name):
metrics[metric.name]['time'].append(event.timestamp -
first_time_stamp)
metrics[metric.name]['value'].append(
getattr(event.network_metrics, metric.name))
return (metrics, decisions)
with open(dump_file_to_parse, 'rb') as file_to_parse:
metrics = InitMetrics()
decisions = InitDecisions()
first_time_stamp = None
while True:
event = GetNextMessageFromFile(file_to_parse)
if event is None:
break
if first_time_stamp is None:
first_time_stamp = event.timestamp
if event.type == debug_dump_pb2.Event.ENCODER_RUNTIME_CONFIG:
for decision in event.encoder_runtime_config.DESCRIPTOR.fields:
if event.encoder_runtime_config.HasField(decision.name):
decisions[decision.name]['time'].append(
event.timestamp - first_time_stamp)
decisions[decision.name]['value'].append(
getattr(event.encoder_runtime_config,
decision.name))
if event.type == debug_dump_pb2.Event.NETWORK_METRICS:
for metric in event.network_metrics.DESCRIPTOR.fields:
if event.network_metrics.HasField(metric.name):
metrics[metric.name]['time'].append(event.timestamp -
first_time_stamp)
metrics[metric.name]['value'].append(
getattr(event.network_metrics, metric.name))
return (metrics, decisions)
def main():
parser = OptionParser()
parser.add_option(
"-f", "--dump_file", dest="dump_file_to_parse", help="dump file to parse")
parser.add_option(
'-m',
'--metric_plot',
default=[],
type=str,
help='metric key (name of the metric) to plot',
dest='metric_keys',
action='append')
parser = OptionParser()
parser.add_option("-f",
"--dump_file",
dest="dump_file_to_parse",
help="dump file to parse")
parser.add_option('-m',
'--metric_plot',
default=[],
type=str,
help='metric key (name of the metric) to plot',
dest='metric_keys',
action='append')
parser.add_option(
'-d',
'--decision_plot',
default=[],
type=str,
help='decision key (name of the decision) to plot',
dest='decision_keys',
action='append')
parser.add_option('-d',
'--decision_plot',
default=[],
type=str,
help='decision key (name of the decision) to plot',
dest='decision_keys',
action='append')
options = parser.parse_args()[0]
if options.dump_file_to_parse is None:
print "No dump file to parse is set.\n"
parser.print_help()
exit()
(metrics, decisions) = ParseAnaDump(options.dump_file_to_parse)
metric_keys = options.metric_keys
decision_keys = options.decision_keys
plot_count = len(metric_keys) + len(decision_keys)
if plot_count == 0:
print "You have to set at least one metric or decision to plot.\n"
parser.print_help()
exit()
plots = []
if plot_count == 1:
f, mp_plot = plt.subplots()
plots.append(mp_plot)
else:
f, mp_plots = plt.subplots(plot_count, sharex=True)
plots.extend(mp_plots.tolist())
options = parser.parse_args()[0]
if options.dump_file_to_parse is None:
print "No dump file to parse is set.\n"
parser.print_help()
exit()
(metrics, decisions) = ParseAnaDump(options.dump_file_to_parse)
metric_keys = options.metric_keys
decision_keys = options.decision_keys
plot_count = len(metric_keys) + len(decision_keys)
if plot_count == 0:
print "You have to set at least one metric or decision to plot.\n"
parser.print_help()
exit()
plots = []
if plot_count == 1:
f, mp_plot = plt.subplots()
plots.append(mp_plot)
else:
f, mp_plots = plt.subplots(plot_count, sharex=True)
plots.extend(mp_plots.tolist())
for key in metric_keys:
plot = plots.pop()
plot.grid(True)
plot.set_title(key + " (metric)")
plot.plot(metrics[key]['time'], metrics[key]['value'])
for key in decision_keys:
plot = plots.pop()
plot.grid(True)
plot.set_title(key + " (decision)")
plot.plot(decisions[key]['time'], decisions[key]['value'])
f.subplots_adjust(hspace=0.3)
plt.show()
for key in metric_keys:
plot = plots.pop()
plot.grid(True)
plot.set_title(key + " (metric)")
plot.plot(metrics[key]['time'], metrics[key]['value'])
for key in decision_keys:
plot = plots.pop()
plot.grid(True)
plot.set_title(key + " (decision)")
plot.plot(decisions[key]['time'], decisions[key]['value'])
f.subplots_adjust(hspace=0.3)
plt.show()
if __name__ == "__main__":
main()
main()

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Perform APM module quality assessment on one or more input files using one or
more APM simulator configuration files and one or more test data generators.
@ -47,139 +46,172 @@ _POLQA_BIN_NAME = 'PolqaOem64'
def _InstanceArgumentsParser():
"""Arguments parser factory.
"""Arguments parser factory.
"""
parser = argparse.ArgumentParser(description=(
'Perform APM module quality assessment on one or more input files using '
'one or more APM simulator configuration files and one or more '
'test data generators.'))
parser = argparse.ArgumentParser(description=(
'Perform APM module quality assessment on one or more input files using '
'one or more APM simulator configuration files and one or more '
'test data generators.'))
parser.add_argument('-c', '--config_files', nargs='+', required=False,
help=('path to the configuration files defining the '
'arguments with which the APM simulator tool is '
'called'),
default=[_DEFAULT_CONFIG_FILE])
parser.add_argument('-c',
'--config_files',
nargs='+',
required=False,
help=('path to the configuration files defining the '
'arguments with which the APM simulator tool is '
'called'),
default=[_DEFAULT_CONFIG_FILE])
parser.add_argument('-i', '--capture_input_files', nargs='+', required=True,
help='path to the capture input wav files (one or more)')
parser.add_argument(
'-i',
'--capture_input_files',
nargs='+',
required=True,
help='path to the capture input wav files (one or more)')
parser.add_argument('-r', '--render_input_files', nargs='+', required=False,
help=('path to the render input wav files; either '
'omitted or one file for each file in '
'--capture_input_files (files will be paired by '
'index)'), default=None)
parser.add_argument('-r',
'--render_input_files',
nargs='+',
required=False,
help=('path to the render input wav files; either '
'omitted or one file for each file in '
'--capture_input_files (files will be paired by '
'index)'),
default=None)
parser.add_argument('-p', '--echo_path_simulator', required=False,
help=('custom echo path simulator name; required if '
'--render_input_files is specified'),
choices=_ECHO_PATH_SIMULATOR_NAMES,
default=echo_path_simulation.NoEchoPathSimulator.NAME)
parser.add_argument('-p',
'--echo_path_simulator',
required=False,
help=('custom echo path simulator name; required if '
'--render_input_files is specified'),
choices=_ECHO_PATH_SIMULATOR_NAMES,
default=echo_path_simulation.NoEchoPathSimulator.NAME)
parser.add_argument('-t', '--test_data_generators', nargs='+', required=False,
help='custom list of test data generators to use',
choices=_TEST_DATA_GENERATORS_NAMES,
default=_TEST_DATA_GENERATORS_NAMES)
parser.add_argument('-t',
'--test_data_generators',
nargs='+',
required=False,
help='custom list of test data generators to use',
choices=_TEST_DATA_GENERATORS_NAMES,
default=_TEST_DATA_GENERATORS_NAMES)
parser.add_argument('--additive_noise_tracks_path', required=False,
help='path to the wav files for the additive',
default=test_data_generation. \
AdditiveNoiseTestDataGenerator. \
DEFAULT_NOISE_TRACKS_PATH)
parser.add_argument('--additive_noise_tracks_path', required=False,
help='path to the wav files for the additive',
default=test_data_generation. \
AdditiveNoiseTestDataGenerator. \
DEFAULT_NOISE_TRACKS_PATH)
parser.add_argument('-e', '--eval_scores', nargs='+', required=False,
help='custom list of evaluation scores to use',
choices=_EVAL_SCORE_WORKER_NAMES,
default=_EVAL_SCORE_WORKER_NAMES)
parser.add_argument('-e',
'--eval_scores',
nargs='+',
required=False,
help='custom list of evaluation scores to use',
choices=_EVAL_SCORE_WORKER_NAMES,
default=_EVAL_SCORE_WORKER_NAMES)
parser.add_argument('-o', '--output_dir', required=False,
help=('base path to the output directory in which the '
'output wav files and the evaluation outcomes '
'are saved'),
default='output')
parser.add_argument('-o',
'--output_dir',
required=False,
help=('base path to the output directory in which the '
'output wav files and the evaluation outcomes '
'are saved'),
default='output')
parser.add_argument('--polqa_path', required=True,
help='path to the POLQA tool')
parser.add_argument('--polqa_path',
required=True,
help='path to the POLQA tool')
parser.add_argument('--air_db_path', required=True,
help='path to the Aechen IR database')
parser.add_argument('--air_db_path',
required=True,
help='path to the Aechen IR database')
parser.add_argument('--apm_sim_path', required=False,
help='path to the APM simulator tool',
default=audioproc_wrapper. \
AudioProcWrapper. \
DEFAULT_APM_SIMULATOR_BIN_PATH)
parser.add_argument('--apm_sim_path', required=False,
help='path to the APM simulator tool',
default=audioproc_wrapper. \
AudioProcWrapper. \
DEFAULT_APM_SIMULATOR_BIN_PATH)
parser.add_argument('--echo_metric_tool_bin_path', required=False,
help=('path to the echo metric binary '
'(required for the echo eval score)'),
default=None)
parser.add_argument('--echo_metric_tool_bin_path',
required=False,
help=('path to the echo metric binary '
'(required for the echo eval score)'),
default=None)
parser.add_argument('--copy_with_identity_generator', required=False,
help=('If true, the identity test data generator makes a '
'copy of the clean speech input file.'),
default=False)
parser.add_argument(
'--copy_with_identity_generator',
required=False,
help=('If true, the identity test data generator makes a '
'copy of the clean speech input file.'),
default=False)
parser.add_argument('--external_vad_paths', nargs='+', required=False,
help=('Paths to external VAD programs. Each must take'
'\'-i <wav file> -o <output>\' inputs'), default=[])
parser.add_argument('--external_vad_paths',
nargs='+',
required=False,
help=('Paths to external VAD programs. Each must take'
'\'-i <wav file> -o <output>\' inputs'),
default=[])
parser.add_argument('--external_vad_names', nargs='+', required=False,
help=('Keys to the vad paths. Must be different and '
'as many as the paths.'), default=[])
parser.add_argument('--external_vad_names',
nargs='+',
required=False,
help=('Keys to the vad paths. Must be different and '
'as many as the paths.'),
default=[])
return parser
return parser
def _ValidateArguments(args, parser):
if args.capture_input_files and args.render_input_files and (
len(args.capture_input_files) != len(args.render_input_files)):
parser.error('--render_input_files and --capture_input_files must be lists '
'having the same length')
sys.exit(1)
if args.capture_input_files and args.render_input_files and (len(
args.capture_input_files) != len(args.render_input_files)):
parser.error(
'--render_input_files and --capture_input_files must be lists '
'having the same length')
sys.exit(1)
if args.render_input_files and not args.echo_path_simulator:
parser.error('when --render_input_files is set, --echo_path_simulator is '
'also required')
sys.exit(1)
if args.render_input_files and not args.echo_path_simulator:
parser.error(
'when --render_input_files is set, --echo_path_simulator is '
'also required')
sys.exit(1)
if len(args.external_vad_names) != len(args.external_vad_paths):
parser.error('If provided, --external_vad_paths and '
'--external_vad_names must '
'have the same number of arguments.')
sys.exit(1)
if len(args.external_vad_names) != len(args.external_vad_paths):
parser.error('If provided, --external_vad_paths and '
'--external_vad_names must '
'have the same number of arguments.')
sys.exit(1)
def main():
# TODO(alessiob): level = logging.INFO once debugged.
logging.basicConfig(level=logging.DEBUG)
parser = _InstanceArgumentsParser()
args = parser.parse_args()
_ValidateArguments(args, parser)
# TODO(alessiob): level = logging.INFO once debugged.
logging.basicConfig(level=logging.DEBUG)
parser = _InstanceArgumentsParser()
args = parser.parse_args()
_ValidateArguments(args, parser)
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path=args.air_db_path,
noise_tracks_path=args.additive_noise_tracks_path,
copy_with_identity=args.copy_with_identity_generator)),
evaluation_score_factory=eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME),
echo_metric_tool_bin_path=args.echo_metric_tool_bin_path
),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path),
evaluator=evaluation.ApmModuleEvaluator(),
external_vads=external_vad.ExternalVad.ConstructVadDict(
args.external_vad_paths, args.external_vad_names))
simulator.Run(
config_filepaths=args.config_files,
capture_input_filepaths=args.capture_input_files,
render_input_filepaths=args.render_input_files,
echo_path_simulator_name=args.echo_path_simulator,
test_data_generator_names=args.test_data_generators,
eval_score_names=args.eval_scores,
output_dir=args.output_dir)
sys.exit(0)
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path=args.air_db_path,
noise_tracks_path=args.additive_noise_tracks_path,
copy_with_identity=args.copy_with_identity_generator)),
evaluation_score_factory=eval_scores_factory.
EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(args.polqa_path, _POLQA_BIN_NAME),
echo_metric_tool_bin_path=args.echo_metric_tool_bin_path),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(args.apm_sim_path),
evaluator=evaluation.ApmModuleEvaluator(),
external_vads=external_vad.ExternalVad.ConstructVadDict(
args.external_vad_paths, args.external_vad_names))
simulator.Run(config_filepaths=args.config_files,
capture_input_filepaths=args.capture_input_files,
render_input_filepaths=args.render_input_files,
echo_path_simulator_name=args.echo_path_simulator,
test_data_generator_names=args.test_data_generators,
eval_score_names=args.eval_scores,
output_dir=args.output_dir)
sys.exit(0)
if __name__ == '__main__':
main()
main()

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Shows boxplots of given score for different values of selected
parameters. Can be used to compare scores by audioproc_f flag.
@ -30,29 +29,37 @@ import quality_assessment.collect_data as collect_data
def InstanceArgumentsParser():
"""Arguments parser factory.
"""Arguments parser factory.
"""
parser = collect_data.InstanceArgumentsParser()
parser.description = (
'Shows boxplot of given score for different values of selected'
'parameters. Can be used to compare scores by audioproc_f flag')
parser = collect_data.InstanceArgumentsParser()
parser.description = (
'Shows boxplot of given score for different values of selected'
'parameters. Can be used to compare scores by audioproc_f flag')
parser.add_argument('-v', '--eval_score', required=True,
help=('Score name for constructing boxplots'))
parser.add_argument('-v',
'--eval_score',
required=True,
help=('Score name for constructing boxplots'))
parser.add_argument('-n', '--config_dir', required=False,
help=('path to the folder with the configuration files'),
default='apm_configs')
parser.add_argument(
'-n',
'--config_dir',
required=False,
help=('path to the folder with the configuration files'),
default='apm_configs')
parser.add_argument('-z', '--params_to_plot', required=True,
nargs='+', help=('audioproc_f parameter values'
'by which to group scores (no leading dash)'))
parser.add_argument('-z',
'--params_to_plot',
required=True,
nargs='+',
help=('audioproc_f parameter values'
'by which to group scores (no leading dash)'))
return parser
return parser
def FilterScoresByParams(data_frame, filter_params, score_name, config_dir):
"""Filters data on the values of one or more parameters.
"""Filters data on the values of one or more parameters.
Args:
data_frame: pandas.DataFrame of all used input data.
@ -71,34 +78,36 @@ def FilterScoresByParams(data_frame, filter_params, score_name, config_dir):
Returns: dictionary, key is a param value, result is all scores for
that param value (see `filter_params` for explanation).
"""
results = collections.defaultdict(dict)
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
results = collections.defaultdict(dict)
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
for config_name in config_names:
config_json = data_access.AudioProcConfigFile.Load(
os.path.join(config_dir, config_name + '.json'))
data_with_config = data_frame[data_frame.apm_config == config_name]
data_cell_scores = data_with_config[data_with_config.eval_score_name ==
score_name]
for config_name in config_names:
config_json = data_access.AudioProcConfigFile.Load(
os.path.join(config_dir, config_name + '.json'))
data_with_config = data_frame[data_frame.apm_config == config_name]
data_cell_scores = data_with_config[data_with_config.eval_score_name ==
score_name]
# Exactly one of |params_to_plot| must match:
(matching_param, ) = [x for x in filter_params if '-' + x in config_json]
# Exactly one of |params_to_plot| must match:
(matching_param, ) = [
x for x in filter_params if '-' + x in config_json
]
# Add scores for every track to the result.
for capture_name in data_cell_scores.capture:
result_score = float(data_cell_scores[data_cell_scores.capture ==
capture_name].score)
config_dict = results[config_json['-' + matching_param]]
if capture_name not in config_dict:
config_dict[capture_name] = {}
# Add scores for every track to the result.
for capture_name in data_cell_scores.capture:
result_score = float(data_cell_scores[data_cell_scores.capture ==
capture_name].score)
config_dict = results[config_json['-' + matching_param]]
if capture_name not in config_dict:
config_dict[capture_name] = {}
config_dict[capture_name][matching_param] = result_score
config_dict[capture_name][matching_param] = result_score
return results
return results
def _FlattenToScoresList(config_param_score_dict):
"""Extracts a list of scores from input data structure.
"""Extracts a list of scores from input data structure.
Args:
config_param_score_dict: of the form {'capture_name':
@ -107,40 +116,39 @@ def _FlattenToScoresList(config_param_score_dict):
Returns: Plain list of all score value present in input data
structure
"""
result = []
for capture_name in config_param_score_dict:
result += list(config_param_score_dict[capture_name].values())
return result
result = []
for capture_name in config_param_score_dict:
result += list(config_param_score_dict[capture_name].values())
return result
def main():
# Init.
# TODO(alessiob): INFO once debugged.
logging.basicConfig(level=logging.DEBUG)
parser = InstanceArgumentsParser()
args = parser.parse_args()
# Init.
# TODO(alessiob): INFO once debugged.
logging.basicConfig(level=logging.DEBUG)
parser = InstanceArgumentsParser()
args = parser.parse_args()
# Get the scores.
src_path = collect_data.ConstructSrcPath(args)
logging.debug(src_path)
scores_data_frame = collect_data.FindScores(src_path, args)
# Get the scores.
src_path = collect_data.ConstructSrcPath(args)
logging.debug(src_path)
scores_data_frame = collect_data.FindScores(src_path, args)
# Filter the data by `args.params_to_plot`
scores_filtered = FilterScoresByParams(scores_data_frame,
args.params_to_plot,
args.eval_score,
args.config_dir)
# Filter the data by `args.params_to_plot`
scores_filtered = FilterScoresByParams(scores_data_frame,
args.params_to_plot,
args.eval_score, args.config_dir)
data_list = sorted(scores_filtered.items())
data_values = [_FlattenToScoresList(x) for (_, x) in data_list]
data_labels = [x for (x, _) in data_list]
data_list = sorted(scores_filtered.items())
data_values = [_FlattenToScoresList(x) for (_, x) in data_list]
data_labels = [x for (x, _) in data_list]
_, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
axes.boxplot(data_values, labels=data_labels)
axes.set_ylabel(args.eval_score)
axes.set_xlabel('/'.join(args.params_to_plot))
plt.show()
_, axes = plt.subplots(nrows=1, ncols=1, figsize=(6, 6))
axes.boxplot(data_values, labels=data_labels)
axes.set_ylabel(args.eval_score)
axes.set_xlabel('/'.join(args.params_to_plot))
plt.show()
if __name__ == "__main__":
main()
main()

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Export the scores computed by the apm_quality_assessment.py script into an
HTML file.
"""
@ -20,7 +19,7 @@ import quality_assessment.export as export
def _BuildOutputFilename(filename_suffix):
"""Builds the filename for the exported file.
"""Builds the filename for the exported file.
Args:
filename_suffix: suffix for the output file name.
@ -28,34 +27,37 @@ def _BuildOutputFilename(filename_suffix):
Returns:
A string.
"""
if filename_suffix is None:
return 'results.html'
return 'results-{}.html'.format(filename_suffix)
if filename_suffix is None:
return 'results.html'
return 'results-{}.html'.format(filename_suffix)
def main():
# Init.
logging.basicConfig(level=logging.DEBUG) # TODO(alessio): INFO once debugged.
parser = collect_data.InstanceArgumentsParser()
parser.add_argument('-f', '--filename_suffix',
help=('suffix of the exported file'))
parser.description = ('Exports pre-computed APM module quality assessment '
'results into HTML tables')
args = parser.parse_args()
# Init.
logging.basicConfig(
level=logging.DEBUG) # TODO(alessio): INFO once debugged.
parser = collect_data.InstanceArgumentsParser()
parser.add_argument('-f',
'--filename_suffix',
help=('suffix of the exported file'))
parser.description = ('Exports pre-computed APM module quality assessment '
'results into HTML tables')
args = parser.parse_args()
# Get the scores.
src_path = collect_data.ConstructSrcPath(args)
logging.debug(src_path)
scores_data_frame = collect_data.FindScores(src_path, args)
# Get the scores.
src_path = collect_data.ConstructSrcPath(args)
logging.debug(src_path)
scores_data_frame = collect_data.FindScores(src_path, args)
# Export.
output_filepath = os.path.join(args.output_dir, _BuildOutputFilename(
args.filename_suffix))
exporter = export.HtmlExport(output_filepath)
exporter.Export(scores_data_frame)
# Export.
output_filepath = os.path.join(args.output_dir,
_BuildOutputFilename(args.filename_suffix))
exporter = export.HtmlExport(output_filepath)
exporter.Export(scores_data_frame)
logging.info('output file successfully written in %s', output_filepath)
sys.exit(0)
logging.info('output file successfully written in %s', output_filepath)
sys.exit(0)
if __name__ == '__main__':
main()
main()

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Generate .json files with which the APM module can be tested using the
apm_quality_assessment.py script and audioproc_f as APM simulator.
"""
@ -20,7 +19,7 @@ OUTPUT_PATH = os.path.abspath('apm_configs')
def _GenerateDefaultOverridden(config_override):
"""Generates one or more APM overriden configurations.
"""Generates one or more APM overriden configurations.
For each item in config_override, it overrides the default configuration and
writes a new APM configuration file.
@ -45,54 +44,85 @@ def _GenerateDefaultOverridden(config_override):
config_override: dict of APM configuration file names as keys; the values
are dict instances encoding the audioproc_f flags.
"""
for config_filename in config_override:
config = config_override[config_filename]
config['-all_default'] = None
for config_filename in config_override:
config = config_override[config_filename]
config['-all_default'] = None
config_filepath = os.path.join(OUTPUT_PATH, 'default-{}.json'.format(
config_filename))
logging.debug('config file <%s> | %s', config_filepath, config)
config_filepath = os.path.join(
OUTPUT_PATH, 'default-{}.json'.format(config_filename))
logging.debug('config file <%s> | %s', config_filepath, config)
data_access.AudioProcConfigFile.Save(config_filepath, config)
logging.info('config file created: <%s>', config_filepath)
data_access.AudioProcConfigFile.Save(config_filepath, config)
logging.info('config file created: <%s>', config_filepath)
def _GenerateAllDefaultButOne():
"""Disables the flags enabled by default one-by-one.
"""Disables the flags enabled by default one-by-one.
"""
config_sets = {
'no_AEC': {'-aec': 0,},
'no_AGC': {'-agc': 0,},
'no_HP_filter': {'-hpf': 0,},
'no_level_estimator': {'-le': 0,},
'no_noise_suppressor': {'-ns': 0,},
'no_transient_suppressor': {'-ts': 0,},
'no_vad': {'-vad': 0,},
}
_GenerateDefaultOverridden(config_sets)
config_sets = {
'no_AEC': {
'-aec': 0,
},
'no_AGC': {
'-agc': 0,
},
'no_HP_filter': {
'-hpf': 0,
},
'no_level_estimator': {
'-le': 0,
},
'no_noise_suppressor': {
'-ns': 0,
},
'no_transient_suppressor': {
'-ts': 0,
},
'no_vad': {
'-vad': 0,
},
}
_GenerateDefaultOverridden(config_sets)
def _GenerateAllDefaultPlusOne():
"""Enables the flags disabled by default one-by-one.
"""Enables the flags disabled by default one-by-one.
"""
config_sets = {
'with_AECM': {'-aec': 0, '-aecm': 1,}, # AEC and AECM are exclusive.
'with_AGC_limiter': {'-agc_limiter': 1,},
'with_AEC_delay_agnostic': {'-delay_agnostic': 1,},
'with_drift_compensation': {'-drift_compensation': 1,},
'with_residual_echo_detector': {'-ed': 1,},
'with_AEC_extended_filter': {'-extended_filter': 1,},
'with_LC': {'-lc': 1,},
'with_refined_adaptive_filter': {'-refined_adaptive_filter': 1,},
}
_GenerateDefaultOverridden(config_sets)
config_sets = {
'with_AECM': {
'-aec': 0,
'-aecm': 1,
}, # AEC and AECM are exclusive.
'with_AGC_limiter': {
'-agc_limiter': 1,
},
'with_AEC_delay_agnostic': {
'-delay_agnostic': 1,
},
'with_drift_compensation': {
'-drift_compensation': 1,
},
'with_residual_echo_detector': {
'-ed': 1,
},
'with_AEC_extended_filter': {
'-extended_filter': 1,
},
'with_LC': {
'-lc': 1,
},
'with_refined_adaptive_filter': {
'-refined_adaptive_filter': 1,
},
}
_GenerateDefaultOverridden(config_sets)
def main():
logging.basicConfig(level=logging.INFO)
_GenerateAllDefaultPlusOne()
_GenerateAllDefaultButOne()
logging.basicConfig(level=logging.INFO)
_GenerateAllDefaultPlusOne()
_GenerateAllDefaultButOne()
if __name__ == '__main__':
main()
main()

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Finds the APM configuration that maximizes a provided metric by
parsing the output generated apm_quality_assessment.py.
"""
@ -20,33 +19,44 @@ import os
import quality_assessment.data_access as data_access
import quality_assessment.collect_data as collect_data
def _InstanceArgumentsParser():
"""Arguments parser factory. Extends the arguments from 'collect_data'
"""Arguments parser factory. Extends the arguments from 'collect_data'
with a few extra for selecting what parameters to optimize for.
"""
parser = collect_data.InstanceArgumentsParser()
parser.description = (
'Rudimentary optimization of a function over different parameter'
'combinations.')
parser = collect_data.InstanceArgumentsParser()
parser.description = (
'Rudimentary optimization of a function over different parameter'
'combinations.')
parser.add_argument('-n', '--config_dir', required=False,
help=('path to the folder with the configuration files'),
default='apm_configs')
parser.add_argument(
'-n',
'--config_dir',
required=False,
help=('path to the folder with the configuration files'),
default='apm_configs')
parser.add_argument('-p', '--params', required=True, nargs='+',
help=('parameters to parse from the config files in'
'config_dir'))
parser.add_argument('-p',
'--params',
required=True,
nargs='+',
help=('parameters to parse from the config files in'
'config_dir'))
parser.add_argument('-z', '--params_not_to_optimize', required=False,
nargs='+', default=[],
help=('parameters from `params` not to be optimized for'))
parser.add_argument(
'-z',
'--params_not_to_optimize',
required=False,
nargs='+',
default=[],
help=('parameters from `params` not to be optimized for'))
return parser
return parser
def _ConfigurationAndScores(data_frame, params,
params_not_to_optimize, config_dir):
"""Returns a list of all configurations and scores.
def _ConfigurationAndScores(data_frame, params, params_not_to_optimize,
config_dir):
"""Returns a list of all configurations and scores.
Args:
data_frame: A pandas data frame with the scores and config name
@ -72,47 +82,47 @@ def _ConfigurationAndScores(data_frame, params,
param combinations for params in `params_not_to_optimize` and
their scores.
"""
results = collections.defaultdict(list)
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
score_names = data_frame['eval_score_name'].drop_duplicates().values.tolist()
results = collections.defaultdict(list)
config_names = data_frame['apm_config'].drop_duplicates().values.tolist()
score_names = data_frame['eval_score_name'].drop_duplicates(
).values.tolist()
# Normalize the scores
normalization_constants = {}
for score_name in score_names:
scores = data_frame[data_frame.eval_score_name == score_name].score
normalization_constants[score_name] = max(scores)
params_to_optimize = [p for p in params if p not in params_not_to_optimize]
param_combination = collections.namedtuple("ParamCombination",
params_to_optimize)
for config_name in config_names:
config_json = data_access.AudioProcConfigFile.Load(
os.path.join(config_dir, config_name + ".json"))
scores = {}
data_cell = data_frame[data_frame.apm_config == config_name]
# Normalize the scores
normalization_constants = {}
for score_name in score_names:
data_cell_scores = data_cell[data_cell.eval_score_name ==
score_name].score
scores[score_name] = sum(data_cell_scores) / len(data_cell_scores)
scores[score_name] /= normalization_constants[score_name]
scores = data_frame[data_frame.eval_score_name == score_name].score
normalization_constants[score_name] = max(scores)
result = {'scores': scores, 'params': {}}
config_optimize_params = {}
for param in params:
if param in params_to_optimize:
config_optimize_params[param] = config_json['-' + param]
else:
result['params'][param] = config_json['-' + param]
params_to_optimize = [p for p in params if p not in params_not_to_optimize]
param_combination = collections.namedtuple("ParamCombination",
params_to_optimize)
current_param_combination = param_combination(
**config_optimize_params)
results[current_param_combination].append(result)
return results
for config_name in config_names:
config_json = data_access.AudioProcConfigFile.Load(
os.path.join(config_dir, config_name + ".json"))
scores = {}
data_cell = data_frame[data_frame.apm_config == config_name]
for score_name in score_names:
data_cell_scores = data_cell[data_cell.eval_score_name ==
score_name].score
scores[score_name] = sum(data_cell_scores) / len(data_cell_scores)
scores[score_name] /= normalization_constants[score_name]
result = {'scores': scores, 'params': {}}
config_optimize_params = {}
for param in params:
if param in params_to_optimize:
config_optimize_params[param] = config_json['-' + param]
else:
result['params'][param] = config_json['-' + param]
current_param_combination = param_combination(**config_optimize_params)
results[current_param_combination].append(result)
return results
def _FindOptimalParameter(configs_and_scores, score_weighting):
"""Finds the config producing the maximal score.
"""Finds the config producing the maximal score.
Args:
configs_and_scores: structure of the form returned by
@ -127,53 +137,53 @@ def _FindOptimalParameter(configs_and_scores, score_weighting):
to its scores.
"""
min_score = float('+inf')
best_params = None
for config in configs_and_scores:
scores_and_params = configs_and_scores[config]
current_score = score_weighting(scores_and_params)
if current_score < min_score:
min_score = current_score
best_params = config
logging.debug("Score: %f", current_score)
logging.debug("Config: %s", str(config))
return best_params
min_score = float('+inf')
best_params = None
for config in configs_and_scores:
scores_and_params = configs_and_scores[config]
current_score = score_weighting(scores_and_params)
if current_score < min_score:
min_score = current_score
best_params = config
logging.debug("Score: %f", current_score)
logging.debug("Config: %s", str(config))
return best_params
def _ExampleWeighting(scores_and_configs):
"""Example argument to `_FindOptimalParameter`
"""Example argument to `_FindOptimalParameter`
Args:
scores_and_configs: a list of configs and scores, in the form
described in _FindOptimalParameter
Returns:
numeric value, the sum of all scores
"""
res = 0
for score_config in scores_and_configs:
res += sum(score_config['scores'].values())
return res
res = 0
for score_config in scores_and_configs:
res += sum(score_config['scores'].values())
return res
def main():
# Init.
# TODO(alessiob): INFO once debugged.
logging.basicConfig(level=logging.DEBUG)
parser = _InstanceArgumentsParser()
args = parser.parse_args()
# Init.
# TODO(alessiob): INFO once debugged.
logging.basicConfig(level=logging.DEBUG)
parser = _InstanceArgumentsParser()
args = parser.parse_args()
# Get the scores.
src_path = collect_data.ConstructSrcPath(args)
logging.debug('Src path <%s>', src_path)
scores_data_frame = collect_data.FindScores(src_path, args)
all_scores = _ConfigurationAndScores(scores_data_frame,
args.params,
args.params_not_to_optimize,
args.config_dir)
# Get the scores.
src_path = collect_data.ConstructSrcPath(args)
logging.debug('Src path <%s>', src_path)
scores_data_frame = collect_data.FindScores(src_path, args)
all_scores = _ConfigurationAndScores(scores_data_frame, args.params,
args.params_not_to_optimize,
args.config_dir)
opt_param = _FindOptimalParameter(all_scores, _ExampleWeighting)
opt_param = _FindOptimalParameter(all_scores, _ExampleWeighting)
logging.info('Optimal parameter combination: <%s>', opt_param)
logging.info('It\'s score values: <%s>', all_scores[opt_param])
logging.info('Optimal parameter combination: <%s>', opt_param)
logging.info('It\'s score values: <%s>', all_scores[opt_param])
if __name__ == "__main__":
main()
main()

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the apm_quality_assessment module.
"""
@ -16,13 +15,14 @@ import mock
import apm_quality_assessment
class TestSimulationScript(unittest.TestCase):
"""Unit tests for the apm_quality_assessment module.
"""Unit tests for the apm_quality_assessment module.
"""
def testMain(self):
# Exit with error code if no arguments are passed.
with self.assertRaises(SystemExit) as cm, mock.patch.object(
sys, 'argv', ['apm_quality_assessment.py']):
apm_quality_assessment.main()
self.assertGreater(cm.exception.code, 0)
def testMain(self):
# Exit with error code if no arguments are passed.
with self.assertRaises(SystemExit) as cm, mock.patch.object(
sys, 'argv', ['apm_quality_assessment.py']):
apm_quality_assessment.main()
self.assertGreater(cm.exception.code, 0)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Extraction of annotations from audio files.
"""
@ -19,10 +18,10 @@ import sys
import tempfile
try:
import numpy as np
import numpy as np
except ImportError:
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
from . import external_vad
from . import exceptions
@ -30,262 +29,268 @@ from . import signal_processing
class AudioAnnotationsExtractor(object):
"""Extracts annotations from audio files.
"""Extracts annotations from audio files.
"""
class VadType(object):
ENERGY_THRESHOLD = 1 # TODO(alessiob): Consider switching to P56 standard.
WEBRTC_COMMON_AUDIO = 2 # common_audio/vad/include/vad.h
WEBRTC_APM = 4 # modules/audio_processing/vad/vad.h
class VadType(object):
ENERGY_THRESHOLD = 1 # TODO(alessiob): Consider switching to P56 standard.
WEBRTC_COMMON_AUDIO = 2 # common_audio/vad/include/vad.h
WEBRTC_APM = 4 # modules/audio_processing/vad/vad.h
def __init__(self, value):
if (not isinstance(value, int)) or not 0 <= value <= 7:
raise exceptions.InitializationException(
'Invalid vad type: ' + value)
self._value = value
def __init__(self, value):
if (not isinstance(value, int)) or not 0 <= value <= 7:
raise exceptions.InitializationException('Invalid vad type: ' +
value)
self._value = value
def Contains(self, vad_type):
return self._value | vad_type == self._value
def Contains(self, vad_type):
return self._value | vad_type == self._value
def __str__(self):
vads = []
if self.Contains(self.ENERGY_THRESHOLD):
vads.append("energy")
if self.Contains(self.WEBRTC_COMMON_AUDIO):
vads.append("common_audio")
if self.Contains(self.WEBRTC_APM):
vads.append("apm")
return "VadType({})".format(", ".join(vads))
def __str__(self):
vads = []
if self.Contains(self.ENERGY_THRESHOLD):
vads.append("energy")
if self.Contains(self.WEBRTC_COMMON_AUDIO):
vads.append("common_audio")
if self.Contains(self.WEBRTC_APM):
vads.append("apm")
return "VadType({})".format(", ".join(vads))
_OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz'
# Level estimation params.
_ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0)
_LEVEL_FRAME_SIZE_MS = 1.0
# The time constants in ms indicate the time it takes for the level estimate
# to go down/up by 1 db if the signal is zero.
_LEVEL_ATTACK_MS = 5.0
_LEVEL_DECAY_MS = 20.0
# VAD params.
_VAD_THRESHOLD = 1
_VAD_WEBRTC_PATH = os.path.join(os.path.dirname(
os.path.abspath(__file__)), os.pardir, os.pardir)
_VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
_VAD_WEBRTC_APM_PATH = os.path.join(
_VAD_WEBRTC_PATH, 'apm_vad')
def __init__(self, vad_type, external_vads=None):
self._signal = None
self._level = None
self._level_frame_size = None
self._common_audio_vad = None
self._energy_vad = None
self._apm_vad_probs = None
self._apm_vad_rms = None
self._vad_frame_size = None
self._vad_frame_size_ms = None
self._c_attack = None
self._c_decay = None
self._vad_type = self.VadType(vad_type)
logging.info('VADs used for annotations: ' + str(self._vad_type))
if external_vads is None:
external_vads = {}
self._external_vads = external_vads
assert len(self._external_vads) == len(external_vads), (
'The external VAD names must be unique.')
for vad in external_vads.values():
if not isinstance(vad, external_vad.ExternalVad):
raise exceptions.InitializationException(
'Invalid vad type: ' + str(type(vad)))
logging.info('External VAD used for annotation: ' +
str(vad.name))
assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \
self._VAD_WEBRTC_COMMON_AUDIO_PATH
assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \
self._VAD_WEBRTC_APM_PATH
@classmethod
def GetOutputFileNameTemplate(cls):
return cls._OUTPUT_FILENAME_TEMPLATE
def GetLevel(self):
return self._level
def GetLevelFrameSize(self):
return self._level_frame_size
@classmethod
def GetLevelFrameSizeMs(cls):
return cls._LEVEL_FRAME_SIZE_MS
def GetVadOutput(self, vad_type):
if vad_type == self.VadType.ENERGY_THRESHOLD:
return self._energy_vad
elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO:
return self._common_audio_vad
elif vad_type == self.VadType.WEBRTC_APM:
return (self._apm_vad_probs, self._apm_vad_rms)
else:
raise exceptions.InitializationException(
'Invalid vad type: ' + vad_type)
def GetVadFrameSize(self):
return self._vad_frame_size
def GetVadFrameSizeMs(self):
return self._vad_frame_size_ms
def Extract(self, filepath):
# Load signal.
self._signal = signal_processing.SignalProcessingUtils.LoadWav(filepath)
if self._signal.channels != 1:
raise NotImplementedError('Multiple-channel annotations not implemented')
_OUTPUT_FILENAME_TEMPLATE = '{}annotations.npz'
# Level estimation params.
self._level_frame_size = int(self._signal.frame_rate / 1000 * (
self._LEVEL_FRAME_SIZE_MS))
self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else (
self._ONE_DB_REDUCTION ** (
self._LEVEL_FRAME_SIZE_MS / self._LEVEL_ATTACK_MS))
self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else (
self._ONE_DB_REDUCTION ** (
self._LEVEL_FRAME_SIZE_MS / self._LEVEL_DECAY_MS))
_ONE_DB_REDUCTION = np.power(10.0, -1.0 / 20.0)
_LEVEL_FRAME_SIZE_MS = 1.0
# The time constants in ms indicate the time it takes for the level estimate
# to go down/up by 1 db if the signal is zero.
_LEVEL_ATTACK_MS = 5.0
_LEVEL_DECAY_MS = 20.0
# Compute level.
self._LevelEstimation()
# VAD params.
_VAD_THRESHOLD = 1
_VAD_WEBRTC_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)),
os.pardir, os.pardir)
_VAD_WEBRTC_COMMON_AUDIO_PATH = os.path.join(_VAD_WEBRTC_PATH, 'vad')
# Ideal VAD output, it requires clean speech with high SNR as input.
if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD):
# Naive VAD based on level thresholding.
vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
self._energy_vad = np.uint8(self._level > vad_threshold)
self._vad_frame_size = self._level_frame_size
self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS
if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO):
# WebRTC common_audio/ VAD.
self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate)
if self._vad_type.Contains(self.VadType.WEBRTC_APM):
# WebRTC modules/audio_processing/ VAD.
self._RunWebRtcApmVad(filepath)
for extvad_name in self._external_vads:
self._external_vads[extvad_name].Run(filepath)
_VAD_WEBRTC_APM_PATH = os.path.join(_VAD_WEBRTC_PATH, 'apm_vad')
def Save(self, output_path, annotation_name=""):
ext_kwargs = {'extvad_conf-' + ext_vad:
self._external_vads[ext_vad].GetVadOutput()
for ext_vad in self._external_vads}
np.savez_compressed(
file=os.path.join(
def __init__(self, vad_type, external_vads=None):
self._signal = None
self._level = None
self._level_frame_size = None
self._common_audio_vad = None
self._energy_vad = None
self._apm_vad_probs = None
self._apm_vad_rms = None
self._vad_frame_size = None
self._vad_frame_size_ms = None
self._c_attack = None
self._c_decay = None
self._vad_type = self.VadType(vad_type)
logging.info('VADs used for annotations: ' + str(self._vad_type))
if external_vads is None:
external_vads = {}
self._external_vads = external_vads
assert len(self._external_vads) == len(external_vads), (
'The external VAD names must be unique.')
for vad in external_vads.values():
if not isinstance(vad, external_vad.ExternalVad):
raise exceptions.InitializationException('Invalid vad type: ' +
str(type(vad)))
logging.info('External VAD used for annotation: ' + str(vad.name))
assert os.path.exists(self._VAD_WEBRTC_COMMON_AUDIO_PATH), \
self._VAD_WEBRTC_COMMON_AUDIO_PATH
assert os.path.exists(self._VAD_WEBRTC_APM_PATH), \
self._VAD_WEBRTC_APM_PATH
@classmethod
def GetOutputFileNameTemplate(cls):
return cls._OUTPUT_FILENAME_TEMPLATE
def GetLevel(self):
return self._level
def GetLevelFrameSize(self):
return self._level_frame_size
@classmethod
def GetLevelFrameSizeMs(cls):
return cls._LEVEL_FRAME_SIZE_MS
def GetVadOutput(self, vad_type):
if vad_type == self.VadType.ENERGY_THRESHOLD:
return self._energy_vad
elif vad_type == self.VadType.WEBRTC_COMMON_AUDIO:
return self._common_audio_vad
elif vad_type == self.VadType.WEBRTC_APM:
return (self._apm_vad_probs, self._apm_vad_rms)
else:
raise exceptions.InitializationException('Invalid vad type: ' +
vad_type)
def GetVadFrameSize(self):
return self._vad_frame_size
def GetVadFrameSizeMs(self):
return self._vad_frame_size_ms
def Extract(self, filepath):
# Load signal.
self._signal = signal_processing.SignalProcessingUtils.LoadWav(
filepath)
if self._signal.channels != 1:
raise NotImplementedError(
'Multiple-channel annotations not implemented')
# Level estimation params.
self._level_frame_size = int(self._signal.frame_rate / 1000 *
(self._LEVEL_FRAME_SIZE_MS))
self._c_attack = 0.0 if self._LEVEL_ATTACK_MS == 0 else (
self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS /
self._LEVEL_ATTACK_MS))
self._c_decay = 0.0 if self._LEVEL_DECAY_MS == 0 else (
self._ONE_DB_REDUCTION**(self._LEVEL_FRAME_SIZE_MS /
self._LEVEL_DECAY_MS))
# Compute level.
self._LevelEstimation()
# Ideal VAD output, it requires clean speech with high SNR as input.
if self._vad_type.Contains(self.VadType.ENERGY_THRESHOLD):
# Naive VAD based on level thresholding.
vad_threshold = np.percentile(self._level, self._VAD_THRESHOLD)
self._energy_vad = np.uint8(self._level > vad_threshold)
self._vad_frame_size = self._level_frame_size
self._vad_frame_size_ms = self._LEVEL_FRAME_SIZE_MS
if self._vad_type.Contains(self.VadType.WEBRTC_COMMON_AUDIO):
# WebRTC common_audio/ VAD.
self._RunWebRtcCommonAudioVad(filepath, self._signal.frame_rate)
if self._vad_type.Contains(self.VadType.WEBRTC_APM):
# WebRTC modules/audio_processing/ VAD.
self._RunWebRtcApmVad(filepath)
for extvad_name in self._external_vads:
self._external_vads[extvad_name].Run(filepath)
def Save(self, output_path, annotation_name=""):
ext_kwargs = {
'extvad_conf-' + ext_vad:
self._external_vads[ext_vad].GetVadOutput()
for ext_vad in self._external_vads
}
np.savez_compressed(file=os.path.join(
output_path,
self.GetOutputFileNameTemplate().format(annotation_name)),
level=self._level,
level_frame_size=self._level_frame_size,
level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS,
vad_output=self._common_audio_vad,
vad_energy_output=self._energy_vad,
vad_frame_size=self._vad_frame_size,
vad_frame_size_ms=self._vad_frame_size_ms,
vad_probs=self._apm_vad_probs,
vad_rms=self._apm_vad_rms,
**ext_kwargs
)
level=self._level,
level_frame_size=self._level_frame_size,
level_frame_size_ms=self._LEVEL_FRAME_SIZE_MS,
vad_output=self._common_audio_vad,
vad_energy_output=self._energy_vad,
vad_frame_size=self._vad_frame_size,
vad_frame_size_ms=self._vad_frame_size_ms,
vad_probs=self._apm_vad_probs,
vad_rms=self._apm_vad_rms,
**ext_kwargs)
def _LevelEstimation(self):
# Read samples.
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
self._signal).astype(np.float32) / 32768.0
num_frames = len(samples) // self._level_frame_size
num_samples = num_frames * self._level_frame_size
def _LevelEstimation(self):
# Read samples.
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
self._signal).astype(np.float32) / 32768.0
num_frames = len(samples) // self._level_frame_size
num_samples = num_frames * self._level_frame_size
# Envelope.
self._level = np.max(np.reshape(np.abs(samples[:num_samples]), (
num_frames, self._level_frame_size)), axis=1)
assert len(self._level) == num_frames
# Envelope.
self._level = np.max(np.reshape(np.abs(samples[:num_samples]),
(num_frames, self._level_frame_size)),
axis=1)
assert len(self._level) == num_frames
# Envelope smoothing.
smooth = lambda curr, prev, k: (1 - k) * curr + k * prev
self._level[0] = smooth(self._level[0], 0.0, self._c_attack)
for i in range(1, num_frames):
self._level[i] = smooth(
self._level[i], self._level[i - 1], self._c_attack if (
self._level[i] > self._level[i - 1]) else self._c_decay)
# Envelope smoothing.
smooth = lambda curr, prev, k: (1 - k) * curr + k * prev
self._level[0] = smooth(self._level[0], 0.0, self._c_attack)
for i in range(1, num_frames):
self._level[i] = smooth(
self._level[i], self._level[i - 1], self._c_attack if
(self._level[i] > self._level[i - 1]) else self._c_decay)
def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate):
self._common_audio_vad = None
self._vad_frame_size = None
def _RunWebRtcCommonAudioVad(self, wav_file_path, sample_rate):
self._common_audio_vad = None
self._vad_frame_size = None
# Create temporary output path.
tmp_path = tempfile.mkdtemp()
output_file_path = os.path.join(
tmp_path, os.path.split(wav_file_path)[1] + '_vad.tmp')
# Create temporary output path.
tmp_path = tempfile.mkdtemp()
output_file_path = os.path.join(
tmp_path,
os.path.split(wav_file_path)[1] + '_vad.tmp')
# Call WebRTC VAD.
try:
subprocess.call([
self._VAD_WEBRTC_COMMON_AUDIO_PATH,
'-i', wav_file_path,
'-o', output_file_path
], cwd=self._VAD_WEBRTC_PATH)
# Call WebRTC VAD.
try:
subprocess.call([
self._VAD_WEBRTC_COMMON_AUDIO_PATH, '-i', wav_file_path, '-o',
output_file_path
],
cwd=self._VAD_WEBRTC_PATH)
# Read bytes.
with open(output_file_path, 'rb') as f:
raw_data = f.read()
# Read bytes.
with open(output_file_path, 'rb') as f:
raw_data = f.read()
# Parse side information.
self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0]
self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000
assert self._vad_frame_size_ms in [10, 20, 30]
extra_bits = struct.unpack('B', raw_data[-1])[0]
assert 0 <= extra_bits <= 8
# Parse side information.
self._vad_frame_size_ms = struct.unpack('B', raw_data[0])[0]
self._vad_frame_size = self._vad_frame_size_ms * sample_rate / 1000
assert self._vad_frame_size_ms in [10, 20, 30]
extra_bits = struct.unpack('B', raw_data[-1])[0]
assert 0 <= extra_bits <= 8
# Init VAD vector.
num_bytes = len(raw_data)
num_frames = 8 * (num_bytes - 2) - extra_bits # 8 frames for each byte.
self._common_audio_vad = np.zeros(num_frames, np.uint8)
# Init VAD vector.
num_bytes = len(raw_data)
num_frames = 8 * (num_bytes -
2) - extra_bits # 8 frames for each byte.
self._common_audio_vad = np.zeros(num_frames, np.uint8)
# Read VAD decisions.
for i, byte in enumerate(raw_data[1:-1]):
byte = struct.unpack('B', byte)[0]
for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)):
self._common_audio_vad[i * 8 + j] = int(byte & 1)
byte = byte >> 1
except Exception as e:
logging.error('Error while running the WebRTC VAD (' + e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
# Read VAD decisions.
for i, byte in enumerate(raw_data[1:-1]):
byte = struct.unpack('B', byte)[0]
for j in range(8 if i < num_bytes - 3 else (8 - extra_bits)):
self._common_audio_vad[i * 8 + j] = int(byte & 1)
byte = byte >> 1
except Exception as e:
logging.error('Error while running the WebRTC VAD (' + e.message +
')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
def _RunWebRtcApmVad(self, wav_file_path):
# Create temporary output path.
tmp_path = tempfile.mkdtemp()
output_file_path_probs = os.path.join(
tmp_path, os.path.split(wav_file_path)[1] + '_vad_probs.tmp')
output_file_path_rms = os.path.join(
tmp_path, os.path.split(wav_file_path)[1] + '_vad_rms.tmp')
def _RunWebRtcApmVad(self, wav_file_path):
# Create temporary output path.
tmp_path = tempfile.mkdtemp()
output_file_path_probs = os.path.join(
tmp_path,
os.path.split(wav_file_path)[1] + '_vad_probs.tmp')
output_file_path_rms = os.path.join(
tmp_path,
os.path.split(wav_file_path)[1] + '_vad_rms.tmp')
# Call WebRTC VAD.
try:
subprocess.call([
self._VAD_WEBRTC_APM_PATH,
'-i', wav_file_path,
'-o_probs', output_file_path_probs,
'-o_rms', output_file_path_rms
], cwd=self._VAD_WEBRTC_PATH)
# Call WebRTC VAD.
try:
subprocess.call([
self._VAD_WEBRTC_APM_PATH, '-i', wav_file_path, '-o_probs',
output_file_path_probs, '-o_rms', output_file_path_rms
],
cwd=self._VAD_WEBRTC_PATH)
# Parse annotations.
self._apm_vad_probs = np.fromfile(output_file_path_probs, np.double)
self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double)
assert len(self._apm_vad_rms) == len(self._apm_vad_probs)
# Parse annotations.
self._apm_vad_probs = np.fromfile(output_file_path_probs,
np.double)
self._apm_vad_rms = np.fromfile(output_file_path_rms, np.double)
assert len(self._apm_vad_rms) == len(self._apm_vad_probs)
except Exception as e:
logging.error('Error while running the WebRTC APM VAD (' +
e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
except Exception as e:
logging.error('Error while running the WebRTC APM VAD (' +
e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the annotations module.
"""
@ -25,133 +24,137 @@ from . import signal_processing
class TestAnnotationsExtraction(unittest.TestCase):
"""Unit tests for the annotations module.
"""Unit tests for the annotations module.
"""
_CLEAN_TMP_OUTPUT = True
_DEBUG_PLOT_VAD = False
_VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType
_ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD |
_VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO |
_VAD_TYPE_CLASS.WEBRTC_APM)
_CLEAN_TMP_OUTPUT = True
_DEBUG_PLOT_VAD = False
_VAD_TYPE_CLASS = annotations.AudioAnnotationsExtractor.VadType
_ALL_VAD_TYPES = (_VAD_TYPE_CLASS.ENERGY_THRESHOLD
| _VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO
| _VAD_TYPE_CLASS.WEBRTC_APM)
def setUp(self):
"""Create temporary folder."""
self._tmp_path = tempfile.mkdtemp()
self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav')
pure_tone, _ = input_signal_creator.InputSignalCreator.Create(
'pure_tone', [440, 1000])
signal_processing.SignalProcessingUtils.SaveWav(
self._wav_file_path, pure_tone)
self._sample_rate = pure_tone.frame_rate
def setUp(self):
"""Create temporary folder."""
self._tmp_path = tempfile.mkdtemp()
self._wav_file_path = os.path.join(self._tmp_path, 'tone.wav')
pure_tone, _ = input_signal_creator.InputSignalCreator.Create(
'pure_tone', [440, 1000])
signal_processing.SignalProcessingUtils.SaveWav(
self._wav_file_path, pure_tone)
self._sample_rate = pure_tone.frame_rate
def tearDown(self):
"""Recursively delete temporary folder."""
if self._CLEAN_TMP_OUTPUT:
shutil.rmtree(self._tmp_path)
else:
logging.warning(self.id() + ' did not clean the temporary path ' +
(self._tmp_path))
def tearDown(self):
"""Recursively delete temporary folder."""
if self._CLEAN_TMP_OUTPUT:
shutil.rmtree(self._tmp_path)
else:
logging.warning(self.id() + ' did not clean the temporary path ' + (
self._tmp_path))
def testFrameSizes(self):
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
e.Extract(self._wav_file_path)
samples_to_ms = lambda n, sr: 1000 * n // sr
self.assertEqual(
samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
e.GetLevelFrameSizeMs())
self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
e.GetVadFrameSizeMs())
def testFrameSizes(self):
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
e.Extract(self._wav_file_path)
samples_to_ms = lambda n, sr: 1000 * n // sr
self.assertEqual(samples_to_ms(e.GetLevelFrameSize(), self._sample_rate),
e.GetLevelFrameSizeMs())
self.assertEqual(samples_to_ms(e.GetVadFrameSize(), self._sample_rate),
e.GetVadFrameSizeMs())
def testVoiceActivityDetectors(self):
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
vad_type = self._VAD_TYPE_CLASS(vad_type_value)
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value)
e.Extract(self._wav_file_path)
if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD):
# pylint: disable=unpacking-non-sequence
vad_output = e.GetVadOutput(
self._VAD_TYPE_CLASS.ENERGY_THRESHOLD)
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(
float(np.sum(vad_output)) / len(vad_output), 0.95)
def testVoiceActivityDetectors(self):
for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
vad_type = self._VAD_TYPE_CLASS(vad_type_value)
e = annotations.AudioAnnotationsExtractor(vad_type=vad_type_value)
e.Extract(self._wav_file_path)
if vad_type.Contains(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD):
# pylint: disable=unpacking-non-sequence
vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD)
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
0.95)
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO):
# pylint: disable=unpacking-non-sequence
vad_output = e.GetVadOutput(
self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO)
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(
float(np.sum(vad_output)) / len(vad_output), 0.95)
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO):
# pylint: disable=unpacking-non-sequence
vad_output = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO)
self.assertGreater(len(vad_output), 0)
self.assertGreaterEqual(float(np.sum(vad_output)) / len(vad_output),
0.95)
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM):
# pylint: disable=unpacking-non-sequence
(vad_probs,
vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)
self.assertGreater(len(vad_probs), 0)
self.assertGreater(len(vad_rms), 0)
self.assertGreaterEqual(
float(np.sum(vad_probs)) / len(vad_probs), 0.5)
self.assertGreaterEqual(
float(np.sum(vad_rms)) / len(vad_rms), 20000)
if vad_type.Contains(self._VAD_TYPE_CLASS.WEBRTC_APM):
# pylint: disable=unpacking-non-sequence
(vad_probs, vad_rms) = e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)
self.assertGreater(len(vad_probs), 0)
self.assertGreater(len(vad_rms), 0)
self.assertGreaterEqual(float(np.sum(vad_probs)) / len(vad_probs),
0.5)
self.assertGreaterEqual(float(np.sum(vad_rms)) / len(vad_rms), 20000)
if self._DEBUG_PLOT_VAD:
frame_times_s = lambda num_frames, frame_size_ms: np.arange(
num_frames).astype(np.float32) * frame_size_ms / 1000.0
level = e.GetLevel()
t_level = frame_times_s(num_frames=len(level),
frame_size_ms=e.GetLevelFrameSizeMs())
t_vad = frame_times_s(num_frames=len(vad_output),
frame_size_ms=e.GetVadFrameSizeMs())
import matplotlib.pyplot as plt
plt.figure()
plt.hold(True)
plt.plot(t_level, level)
plt.plot(t_vad, vad_output * np.max(level), '.')
plt.show()
if self._DEBUG_PLOT_VAD:
frame_times_s = lambda num_frames, frame_size_ms: np.arange(
num_frames).astype(np.float32) * frame_size_ms / 1000.0
level = e.GetLevel()
t_level = frame_times_s(
num_frames=len(level),
frame_size_ms=e.GetLevelFrameSizeMs())
t_vad = frame_times_s(
num_frames=len(vad_output),
frame_size_ms=e.GetVadFrameSizeMs())
import matplotlib.pyplot as plt
plt.figure()
plt.hold(True)
plt.plot(t_level, level)
plt.plot(t_vad, vad_output * np.max(level), '.')
plt.show()
def testSaveLoad(self):
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
e.Extract(self._wav_file_path)
e.Save(self._tmp_path, "fake-annotation")
def testSaveLoad(self):
e = annotations.AudioAnnotationsExtractor(self._ALL_VAD_TYPES)
e.Extract(self._wav_file_path)
e.Save(self._tmp_path, "fake-annotation")
data = np.load(
os.path.join(
self._tmp_path,
e.GetOutputFileNameTemplate().format("fake-annotation")))
np.testing.assert_array_equal(e.GetLevel(), data['level'])
self.assertEqual(np.float32, data['level'].dtype)
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD),
data['vad_energy_output'])
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO),
data['vad_output'])
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0],
data['vad_probs'])
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1],
data['vad_rms'])
self.assertEqual(np.uint8, data['vad_energy_output'].dtype)
self.assertEqual(np.float64, data['vad_probs'].dtype)
self.assertEqual(np.float64, data['vad_rms'].dtype)
data = np.load(os.path.join(
self._tmp_path,
e.GetOutputFileNameTemplate().format("fake-annotation")))
np.testing.assert_array_equal(e.GetLevel(), data['level'])
self.assertEqual(np.float32, data['level'].dtype)
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.ENERGY_THRESHOLD),
data['vad_energy_output'])
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_COMMON_AUDIO),
data['vad_output'])
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[0], data['vad_probs'])
np.testing.assert_array_equal(
e.GetVadOutput(self._VAD_TYPE_CLASS.WEBRTC_APM)[1], data['vad_rms'])
self.assertEqual(np.uint8, data['vad_energy_output'].dtype)
self.assertEqual(np.float64, data['vad_probs'].dtype)
self.assertEqual(np.float64, data['vad_rms'].dtype)
def testEmptyExternalShouldNotCrash(self):
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
annotations.AudioAnnotationsExtractor(vad_type_value, {})
def testEmptyExternalShouldNotCrash(self):
for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
annotations.AudioAnnotationsExtractor(vad_type_value, {})
def testFakeExternalSaveLoad(self):
def FakeExternalFactory():
return external_vad.ExternalVad(
os.path.join(os.path.dirname(os.path.abspath(__file__)),
'fake_external_vad.py'), 'fake')
def testFakeExternalSaveLoad(self):
def FakeExternalFactory():
return external_vad.ExternalVad(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'fake_external_vad.py'),
'fake'
)
for vad_type_value in range(0, self._ALL_VAD_TYPES+1):
e = annotations.AudioAnnotationsExtractor(
vad_type_value,
{'fake': FakeExternalFactory()})
e.Extract(self._wav_file_path)
e.Save(self._tmp_path, annotation_name="fake-annotation")
data = np.load(os.path.join(
self._tmp_path,
e.GetOutputFileNameTemplate().format("fake-annotation")))
self.assertEqual(np.float32, data['extvad_conf-fake'].dtype)
np.testing.assert_almost_equal(np.arange(100, dtype=np.float32),
data['extvad_conf-fake'])
for vad_type_value in range(0, self._ALL_VAD_TYPES + 1):
e = annotations.AudioAnnotationsExtractor(
vad_type_value, {'fake': FakeExternalFactory()})
e.Extract(self._wav_file_path)
e.Save(self._tmp_path, annotation_name="fake-annotation")
data = np.load(
os.path.join(
self._tmp_path,
e.GetOutputFileNameTemplate().format("fake-annotation")))
self.assertEqual(np.float32, data['extvad_conf-fake'].dtype)
np.testing.assert_almost_equal(np.arange(100, dtype=np.float32),
data['extvad_conf-fake'])

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Class implementing a wrapper for APM simulators.
"""
@ -19,33 +18,36 @@ from . import exceptions
class AudioProcWrapper(object):
"""Wrapper for APM simulators.
"""Wrapper for APM simulators.
"""
DEFAULT_APM_SIMULATOR_BIN_PATH = os.path.abspath(os.path.join(
os.pardir, 'audioproc_f'))
OUTPUT_FILENAME = 'output.wav'
DEFAULT_APM_SIMULATOR_BIN_PATH = os.path.abspath(
os.path.join(os.pardir, 'audioproc_f'))
OUTPUT_FILENAME = 'output.wav'
def __init__(self, simulator_bin_path):
"""Ctor.
def __init__(self, simulator_bin_path):
"""Ctor.
Args:
simulator_bin_path: path to the APM simulator binary.
"""
self._simulator_bin_path = simulator_bin_path
self._config = None
self._output_signal_filepath = None
self._simulator_bin_path = simulator_bin_path
self._config = None
self._output_signal_filepath = None
# Profiler instance to measure running time.
self._profiler = cProfile.Profile()
# Profiler instance to measure running time.
self._profiler = cProfile.Profile()
@property
def output_filepath(self):
return self._output_signal_filepath
@property
def output_filepath(self):
return self._output_signal_filepath
def Run(self, config_filepath, capture_input_filepath, output_path,
render_input_filepath=None):
"""Runs APM simulator.
def Run(self,
config_filepath,
capture_input_filepath,
output_path,
render_input_filepath=None):
"""Runs APM simulator.
Args:
config_filepath: path to the configuration file specifying the arguments
@ -56,41 +58,43 @@ class AudioProcWrapper(object):
render_input_filepath: path to the render audio track input file (aka
reverse or far-end).
"""
# Init.
self._output_signal_filepath = os.path.join(
output_path, self.OUTPUT_FILENAME)
profiling_stats_filepath = os.path.join(output_path, 'profiling.stats')
# Init.
self._output_signal_filepath = os.path.join(output_path,
self.OUTPUT_FILENAME)
profiling_stats_filepath = os.path.join(output_path, 'profiling.stats')
# Skip if the output has already been generated.
if os.path.exists(self._output_signal_filepath) and os.path.exists(
profiling_stats_filepath):
return
# Skip if the output has already been generated.
if os.path.exists(self._output_signal_filepath) and os.path.exists(
profiling_stats_filepath):
return
# Load configuration.
self._config = data_access.AudioProcConfigFile.Load(config_filepath)
# Load configuration.
self._config = data_access.AudioProcConfigFile.Load(config_filepath)
# Set remaining parameters.
if not os.path.exists(capture_input_filepath):
raise exceptions.FileNotFoundError('cannot find capture input file')
self._config['-i'] = capture_input_filepath
self._config['-o'] = self._output_signal_filepath
if render_input_filepath is not None:
if not os.path.exists(render_input_filepath):
raise exceptions.FileNotFoundError('cannot find render input file')
self._config['-ri'] = render_input_filepath
# Set remaining parameters.
if not os.path.exists(capture_input_filepath):
raise exceptions.FileNotFoundError(
'cannot find capture input file')
self._config['-i'] = capture_input_filepath
self._config['-o'] = self._output_signal_filepath
if render_input_filepath is not None:
if not os.path.exists(render_input_filepath):
raise exceptions.FileNotFoundError(
'cannot find render input file')
self._config['-ri'] = render_input_filepath
# Build arguments list.
args = [self._simulator_bin_path]
for param_name in self._config:
args.append(param_name)
if self._config[param_name] is not None:
args.append(str(self._config[param_name]))
logging.debug(' '.join(args))
# Build arguments list.
args = [self._simulator_bin_path]
for param_name in self._config:
args.append(param_name)
if self._config[param_name] is not None:
args.append(str(self._config[param_name]))
logging.debug(' '.join(args))
# Run.
self._profiler.enable()
subprocess.call(args)
self._profiler.disable()
# Run.
self._profiler.enable()
subprocess.call(args)
self._profiler.disable()
# Save profiling stats.
self._profiler.dump_stats(profiling_stats_filepath)
# Save profiling stats.
self._profiler.dump_stats(profiling_stats_filepath)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Imports a filtered subset of the scores and configurations computed
by apm_quality_assessment.py into a pandas data frame.
"""
@ -18,71 +17,88 @@ import re
import sys
try:
import pandas as pd
import pandas as pd
except ImportError:
logging.critical('Cannot import the third-party Python package pandas')
sys.exit(1)
logging.critical('Cannot import the third-party Python package pandas')
sys.exit(1)
from . import data_access as data_access
from . import simulation as sim
# Compiled regular expressions used to extract score descriptors.
RE_CONFIG_NAME = re.compile(
sim.ApmModuleSimulator.GetPrefixApmConfig() + r'(.+)')
RE_CAPTURE_NAME = re.compile(
sim.ApmModuleSimulator.GetPrefixCapture() + r'(.+)')
RE_RENDER_NAME = re.compile(
sim.ApmModuleSimulator.GetPrefixRender() + r'(.+)')
RE_ECHO_SIM_NAME = re.compile(
sim.ApmModuleSimulator.GetPrefixEchoSimulator() + r'(.+)')
RE_CONFIG_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixApmConfig() +
r'(.+)')
RE_CAPTURE_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixCapture() +
r'(.+)')
RE_RENDER_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixRender() + r'(.+)')
RE_ECHO_SIM_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixEchoSimulator() +
r'(.+)')
RE_TEST_DATA_GEN_NAME = re.compile(
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + r'(.+)')
RE_TEST_DATA_GEN_PARAMS = re.compile(
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + r'(.+)')
RE_SCORE_NAME = re.compile(
sim.ApmModuleSimulator.GetPrefixScore() + r'(.+)(\..+)')
RE_SCORE_NAME = re.compile(sim.ApmModuleSimulator.GetPrefixScore() +
r'(.+)(\..+)')
def InstanceArgumentsParser():
"""Arguments parser factory.
"""Arguments parser factory.
"""
parser = argparse.ArgumentParser(description=(
'Override this description in a user script by changing'
' `parser.description` of the returned parser.'))
parser = argparse.ArgumentParser(
description=('Override this description in a user script by changing'
' `parser.description` of the returned parser.'))
parser.add_argument('-o', '--output_dir', required=True,
help=('the same base path used with the '
'apm_quality_assessment tool'))
parser.add_argument('-o',
'--output_dir',
required=True,
help=('the same base path used with the '
'apm_quality_assessment tool'))
parser.add_argument('-c', '--config_names', type=re.compile,
help=('regular expression to filter the APM configuration'
' names'))
parser.add_argument(
'-c',
'--config_names',
type=re.compile,
help=('regular expression to filter the APM configuration'
' names'))
parser.add_argument('-i', '--capture_names', type=re.compile,
help=('regular expression to filter the capture signal '
'names'))
parser.add_argument(
'-i',
'--capture_names',
type=re.compile,
help=('regular expression to filter the capture signal '
'names'))
parser.add_argument('-r', '--render_names', type=re.compile,
help=('regular expression to filter the render signal '
'names'))
parser.add_argument('-r',
'--render_names',
type=re.compile,
help=('regular expression to filter the render signal '
'names'))
parser.add_argument('-e', '--echo_simulator_names', type=re.compile,
help=('regular expression to filter the echo simulator '
'names'))
parser.add_argument(
'-e',
'--echo_simulator_names',
type=re.compile,
help=('regular expression to filter the echo simulator '
'names'))
parser.add_argument('-t', '--test_data_generators', type=re.compile,
help=('regular expression to filter the test data '
'generator names'))
parser.add_argument('-t',
'--test_data_generators',
type=re.compile,
help=('regular expression to filter the test data '
'generator names'))
parser.add_argument('-s', '--eval_scores', type=re.compile,
help=('regular expression to filter the evaluation score '
'names'))
parser.add_argument(
'-s',
'--eval_scores',
type=re.compile,
help=('regular expression to filter the evaluation score '
'names'))
return parser
return parser
def _GetScoreDescriptors(score_filepath):
"""Extracts a score descriptor from the given score file path.
"""Extracts a score descriptor from the given score file path.
Args:
score_filepath: path to the score file.
@ -92,23 +108,23 @@ def _GetScoreDescriptors(score_filepath):
render audio track name, echo simulator name, test data generator name,
test data generator parameters as string, evaluation score name).
"""
fields = score_filepath.split(os.sep)[-7:]
extract_name = lambda index, reg_expr: (
reg_expr.match(fields[index]).groups(0)[0])
return (
extract_name(0, RE_CONFIG_NAME),
extract_name(1, RE_CAPTURE_NAME),
extract_name(2, RE_RENDER_NAME),
extract_name(3, RE_ECHO_SIM_NAME),
extract_name(4, RE_TEST_DATA_GEN_NAME),
extract_name(5, RE_TEST_DATA_GEN_PARAMS),
extract_name(6, RE_SCORE_NAME),
)
fields = score_filepath.split(os.sep)[-7:]
extract_name = lambda index, reg_expr: (reg_expr.match(fields[index]).
groups(0)[0])
return (
extract_name(0, RE_CONFIG_NAME),
extract_name(1, RE_CAPTURE_NAME),
extract_name(2, RE_RENDER_NAME),
extract_name(3, RE_ECHO_SIM_NAME),
extract_name(4, RE_TEST_DATA_GEN_NAME),
extract_name(5, RE_TEST_DATA_GEN_PARAMS),
extract_name(6, RE_SCORE_NAME),
)
def _ExcludeScore(config_name, capture_name, render_name, echo_simulator_name,
test_data_gen_name, score_name, args):
"""Decides whether excluding a score.
"""Decides whether excluding a score.
A set of optional regular expressions in args is used to determine if the
score should be excluded (depending on its |*_name| descriptors).
@ -125,27 +141,27 @@ def _ExcludeScore(config_name, capture_name, render_name, echo_simulator_name,
Returns:
A boolean.
"""
value_regexpr_pairs = [
(config_name, args.config_names),
(capture_name, args.capture_names),
(render_name, args.render_names),
(echo_simulator_name, args.echo_simulator_names),
(test_data_gen_name, args.test_data_generators),
(score_name, args.eval_scores),
]
value_regexpr_pairs = [
(config_name, args.config_names),
(capture_name, args.capture_names),
(render_name, args.render_names),
(echo_simulator_name, args.echo_simulator_names),
(test_data_gen_name, args.test_data_generators),
(score_name, args.eval_scores),
]
# Score accepted if each value matches the corresponding regular expression.
for value, regexpr in value_regexpr_pairs:
if regexpr is None:
continue
if not regexpr.match(value):
return True
# Score accepted if each value matches the corresponding regular expression.
for value, regexpr in value_regexpr_pairs:
if regexpr is None:
continue
if not regexpr.match(value):
return True
return False
return False
def FindScores(src_path, args):
"""Given a search path, find scores and return a DataFrame object.
"""Given a search path, find scores and return a DataFrame object.
Args:
src_path: Search path pattern.
@ -154,89 +170,74 @@ def FindScores(src_path, args):
Returns:
A DataFrame object.
"""
# Get scores.
scores = []
for score_filepath in glob.iglob(src_path):
# Extract score descriptor fields from the path.
(config_name,
capture_name,
render_name,
echo_simulator_name,
test_data_gen_name,
test_data_gen_params,
score_name) = _GetScoreDescriptors(score_filepath)
# Get scores.
scores = []
for score_filepath in glob.iglob(src_path):
# Extract score descriptor fields from the path.
(config_name, capture_name, render_name, echo_simulator_name,
test_data_gen_name, test_data_gen_params,
score_name) = _GetScoreDescriptors(score_filepath)
# Ignore the score if required.
if _ExcludeScore(
config_name,
capture_name,
render_name,
echo_simulator_name,
test_data_gen_name,
score_name,
args):
logging.info(
'ignored score: %s %s %s %s %s %s',
config_name,
capture_name,
render_name,
echo_simulator_name,
test_data_gen_name,
score_name)
continue
# Ignore the score if required.
if _ExcludeScore(config_name, capture_name, render_name,
echo_simulator_name, test_data_gen_name, score_name,
args):
logging.info('ignored score: %s %s %s %s %s %s', config_name,
capture_name, render_name, echo_simulator_name,
test_data_gen_name, score_name)
continue
# Read metadata and score.
metadata = data_access.Metadata.LoadAudioTestDataPaths(
os.path.split(score_filepath)[0])
score = data_access.ScoreFile.Load(score_filepath)
# Read metadata and score.
metadata = data_access.Metadata.LoadAudioTestDataPaths(
os.path.split(score_filepath)[0])
score = data_access.ScoreFile.Load(score_filepath)
# Add a score with its descriptor fields.
scores.append((
metadata['clean_capture_input_filepath'],
metadata['echo_free_capture_filepath'],
metadata['echo_filepath'],
metadata['render_filepath'],
metadata['capture_filepath'],
metadata['apm_output_filepath'],
metadata['apm_reference_filepath'],
config_name,
capture_name,
render_name,
echo_simulator_name,
test_data_gen_name,
test_data_gen_params,
score_name,
score,
))
# Add a score with its descriptor fields.
scores.append((
metadata['clean_capture_input_filepath'],
metadata['echo_free_capture_filepath'],
metadata['echo_filepath'],
metadata['render_filepath'],
metadata['capture_filepath'],
metadata['apm_output_filepath'],
metadata['apm_reference_filepath'],
config_name,
capture_name,
render_name,
echo_simulator_name,
test_data_gen_name,
test_data_gen_params,
score_name,
score,
))
return pd.DataFrame(
data=scores,
columns=(
'clean_capture_input_filepath',
'echo_free_capture_filepath',
'echo_filepath',
'render_filepath',
'capture_filepath',
'apm_output_filepath',
'apm_reference_filepath',
'apm_config',
'capture',
'render',
'echo_simulator',
'test_data_gen',
'test_data_gen_params',
'eval_score_name',
'score',
))
return pd.DataFrame(data=scores,
columns=(
'clean_capture_input_filepath',
'echo_free_capture_filepath',
'echo_filepath',
'render_filepath',
'capture_filepath',
'apm_output_filepath',
'apm_reference_filepath',
'apm_config',
'capture',
'render',
'echo_simulator',
'test_data_gen',
'test_data_gen_params',
'eval_score_name',
'score',
))
def ConstructSrcPath(args):
return os.path.join(
args.output_dir,
sim.ApmModuleSimulator.GetPrefixApmConfig() + '*',
sim.ApmModuleSimulator.GetPrefixCapture() + '*',
sim.ApmModuleSimulator.GetPrefixRender() + '*',
sim.ApmModuleSimulator.GetPrefixEchoSimulator() + '*',
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + '*',
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + '*',
sim.ApmModuleSimulator.GetPrefixScore() + '*')
return os.path.join(
args.output_dir,
sim.ApmModuleSimulator.GetPrefixApmConfig() + '*',
sim.ApmModuleSimulator.GetPrefixCapture() + '*',
sim.ApmModuleSimulator.GetPrefixRender() + '*',
sim.ApmModuleSimulator.GetPrefixEchoSimulator() + '*',
sim.ApmModuleSimulator.GetPrefixTestDataGenerator() + '*',
sim.ApmModuleSimulator.GetPrefixTestDataGeneratorParameters() + '*',
sim.ApmModuleSimulator.GetPrefixScore() + '*')

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Data access utility functions and classes.
"""
@ -14,29 +13,29 @@ import os
def MakeDirectory(path):
"""Makes a directory recursively without rising exceptions if existing.
"""Makes a directory recursively without rising exceptions if existing.
Args:
path: path to the directory to be created.
"""
if os.path.exists(path):
return
os.makedirs(path)
if os.path.exists(path):
return
os.makedirs(path)
class Metadata(object):
"""Data access class to save and load metadata.
"""Data access class to save and load metadata.
"""
def __init__(self):
pass
def __init__(self):
pass
_GENERIC_METADATA_SUFFIX = '.mdata'
_AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json'
_GENERIC_METADATA_SUFFIX = '.mdata'
_AUDIO_TEST_DATA_FILENAME = 'audio_test_data.json'
@classmethod
def LoadFileMetadata(cls, filepath):
"""Loads generic metadata linked to a file.
@classmethod
def LoadFileMetadata(cls, filepath):
"""Loads generic metadata linked to a file.
Args:
filepath: path to the metadata file to read.
@ -44,23 +43,23 @@ class Metadata(object):
Returns:
A dict.
"""
with open(filepath + cls._GENERIC_METADATA_SUFFIX) as f:
return json.load(f)
with open(filepath + cls._GENERIC_METADATA_SUFFIX) as f:
return json.load(f)
@classmethod
def SaveFileMetadata(cls, filepath, metadata):
"""Saves generic metadata linked to a file.
@classmethod
def SaveFileMetadata(cls, filepath, metadata):
"""Saves generic metadata linked to a file.
Args:
filepath: path to the metadata file to write.
metadata: a dict.
"""
with open(filepath + cls._GENERIC_METADATA_SUFFIX, 'w') as f:
json.dump(metadata, f)
with open(filepath + cls._GENERIC_METADATA_SUFFIX, 'w') as f:
json.dump(metadata, f)
@classmethod
def LoadAudioTestDataPaths(cls, metadata_path):
"""Loads the input and the reference audio track paths.
@classmethod
def LoadAudioTestDataPaths(cls, metadata_path):
"""Loads the input and the reference audio track paths.
Args:
metadata_path: path to the directory containing the metadata file.
@ -68,14 +67,14 @@ class Metadata(object):
Returns:
Tuple with the paths to the input and output audio tracks.
"""
metadata_filepath = os.path.join(
metadata_path, cls._AUDIO_TEST_DATA_FILENAME)
with open(metadata_filepath) as f:
return json.load(f)
metadata_filepath = os.path.join(metadata_path,
cls._AUDIO_TEST_DATA_FILENAME)
with open(metadata_filepath) as f:
return json.load(f)
@classmethod
def SaveAudioTestDataPaths(cls, output_path, **filepaths):
"""Saves the input and the reference audio track paths.
@classmethod
def SaveAudioTestDataPaths(cls, output_path, **filepaths):
"""Saves the input and the reference audio track paths.
Args:
output_path: path to the directory containing the metadata file.
@ -83,23 +82,24 @@ class Metadata(object):
Keyword Args:
filepaths: collection of audio track file paths to save.
"""
output_filepath = os.path.join(output_path, cls._AUDIO_TEST_DATA_FILENAME)
with open(output_filepath, 'w') as f:
json.dump(filepaths, f)
output_filepath = os.path.join(output_path,
cls._AUDIO_TEST_DATA_FILENAME)
with open(output_filepath, 'w') as f:
json.dump(filepaths, f)
class AudioProcConfigFile(object):
"""Data access to load/save APM simulator argument lists.
"""Data access to load/save APM simulator argument lists.
The arguments stored in the config files are used to control the APM flags.
"""
def __init__(self):
pass
def __init__(self):
pass
@classmethod
def Load(cls, filepath):
"""Loads a configuration file for an APM simulator.
@classmethod
def Load(cls, filepath):
"""Loads a configuration file for an APM simulator.
Args:
filepath: path to the configuration file.
@ -107,31 +107,31 @@ class AudioProcConfigFile(object):
Returns:
A dict containing the configuration.
"""
with open(filepath) as f:
return json.load(f)
with open(filepath) as f:
return json.load(f)
@classmethod
def Save(cls, filepath, config):
"""Saves a configuration file for an APM simulator.
@classmethod
def Save(cls, filepath, config):
"""Saves a configuration file for an APM simulator.
Args:
filepath: path to the configuration file.
config: a dict containing the configuration.
"""
with open(filepath, 'w') as f:
json.dump(config, f)
with open(filepath, 'w') as f:
json.dump(config, f)
class ScoreFile(object):
"""Data access class to save and load float scalar scores.
"""Data access class to save and load float scalar scores.
"""
def __init__(self):
pass
def __init__(self):
pass
@classmethod
def Load(cls, filepath):
"""Loads a score from file.
@classmethod
def Load(cls, filepath):
"""Loads a score from file.
Args:
filepath: path to the score file.
@ -139,16 +139,16 @@ class ScoreFile(object):
Returns:
A float encoding the score.
"""
with open(filepath) as f:
return float(f.readline().strip())
with open(filepath) as f:
return float(f.readline().strip())
@classmethod
def Save(cls, filepath, score):
"""Saves a score into a file.
@classmethod
def Save(cls, filepath, score):
"""Saves a score into a file.
Args:
filepath: path to the score file.
score: float encoding the score.
"""
with open(filepath, 'w') as f:
f.write('{0:f}\n'.format(score))
with open(filepath, 'w') as f:
f.write('{0:f}\n'.format(score))

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Echo path simulation module.
"""
@ -16,21 +15,21 @@ from . import signal_processing
class EchoPathSimulator(object):
"""Abstract class for the echo path simulators.
"""Abstract class for the echo path simulators.
In general, an echo path simulator is a function of the render signal and
simulates the propagation of the latter into the microphone (e.g., due to
mechanical or electrical paths).
"""
NAME = None
REGISTERED_CLASSES = {}
NAME = None
REGISTERED_CLASSES = {}
def __init__(self):
pass
def __init__(self):
pass
def Simulate(self, output_path):
"""Creates the echo signal and stores it in an audio file (abstract method).
def Simulate(self, output_path):
"""Creates the echo signal and stores it in an audio file (abstract method).
Args:
output_path: Path in which any output can be saved.
@ -38,11 +37,11 @@ class EchoPathSimulator(object):
Returns:
Path to the generated audio track file or None if no echo is present.
"""
raise NotImplementedError()
raise NotImplementedError()
@classmethod
def RegisterClass(cls, class_to_register):
"""Registers an EchoPathSimulator implementation.
@classmethod
def RegisterClass(cls, class_to_register):
"""Registers an EchoPathSimulator implementation.
Decorator to automatically register the classes that extend
EchoPathSimulator.
@ -52,85 +51,86 @@ class EchoPathSimulator(object):
class NoEchoPathSimulator(EchoPathSimulator):
pass
"""
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
return class_to_register
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
return class_to_register
@EchoPathSimulator.RegisterClass
class NoEchoPathSimulator(EchoPathSimulator):
"""Simulates absence of echo."""
"""Simulates absence of echo."""
NAME = 'noecho'
NAME = 'noecho'
def __init__(self):
EchoPathSimulator.__init__(self)
def __init__(self):
EchoPathSimulator.__init__(self)
def Simulate(self, output_path):
return None
def Simulate(self, output_path):
return None
@EchoPathSimulator.RegisterClass
class LinearEchoPathSimulator(EchoPathSimulator):
"""Simulates linear echo path.
"""Simulates linear echo path.
This class applies a given impulse response to the render input and then it
sums the signal to the capture input signal.
"""
NAME = 'linear'
NAME = 'linear'
def __init__(self, render_input_filepath, impulse_response):
"""
def __init__(self, render_input_filepath, impulse_response):
"""
Args:
render_input_filepath: Render audio track file.
impulse_response: list or numpy vector of float values.
"""
EchoPathSimulator.__init__(self)
self._render_input_filepath = render_input_filepath
self._impulse_response = impulse_response
EchoPathSimulator.__init__(self)
self._render_input_filepath = render_input_filepath
self._impulse_response = impulse_response
def Simulate(self, output_path):
"""Simulates linear echo path."""
# Form the file name with a hash of the impulse response.
impulse_response_hash = hashlib.sha256(
str(self._impulse_response).encode('utf-8', 'ignore')).hexdigest()
echo_filepath = os.path.join(output_path, 'linear_echo_{}.wav'.format(
impulse_response_hash))
def Simulate(self, output_path):
"""Simulates linear echo path."""
# Form the file name with a hash of the impulse response.
impulse_response_hash = hashlib.sha256(
str(self._impulse_response).encode('utf-8', 'ignore')).hexdigest()
echo_filepath = os.path.join(
output_path, 'linear_echo_{}.wav'.format(impulse_response_hash))
# If the simulated echo audio track file does not exists, create it.
if not os.path.exists(echo_filepath):
render = signal_processing.SignalProcessingUtils.LoadWav(
self._render_input_filepath)
echo = signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
render, self._impulse_response)
signal_processing.SignalProcessingUtils.SaveWav(echo_filepath, echo)
# If the simulated echo audio track file does not exists, create it.
if not os.path.exists(echo_filepath):
render = signal_processing.SignalProcessingUtils.LoadWav(
self._render_input_filepath)
echo = signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
render, self._impulse_response)
signal_processing.SignalProcessingUtils.SaveWav(
echo_filepath, echo)
return echo_filepath
return echo_filepath
@EchoPathSimulator.RegisterClass
class RecordedEchoPathSimulator(EchoPathSimulator):
"""Uses recorded echo.
"""Uses recorded echo.
This class uses the clean capture input file name to build the file name of
the corresponding recording containing echo (a predefined suffix is used).
Such a file is expected to be already existing.
"""
NAME = 'recorded'
NAME = 'recorded'
_FILE_NAME_SUFFIX = '_echo'
_FILE_NAME_SUFFIX = '_echo'
def __init__(self, render_input_filepath):
EchoPathSimulator.__init__(self)
self._render_input_filepath = render_input_filepath
def __init__(self, render_input_filepath):
EchoPathSimulator.__init__(self)
self._render_input_filepath = render_input_filepath
def Simulate(self, output_path):
"""Uses recorded echo path."""
path, file_name_ext = os.path.split(self._render_input_filepath)
file_name, file_ext = os.path.splitext(file_name_ext)
echo_filepath = os.path.join(path, '{}{}{}'.format(
file_name, self._FILE_NAME_SUFFIX, file_ext))
assert os.path.exists(echo_filepath), (
'cannot find the echo audio track file {}'.format(echo_filepath))
return echo_filepath
def Simulate(self, output_path):
"""Uses recorded echo path."""
path, file_name_ext = os.path.split(self._render_input_filepath)
file_name, file_ext = os.path.splitext(file_name_ext)
echo_filepath = os.path.join(
path, '{}{}{}'.format(file_name, self._FILE_NAME_SUFFIX, file_ext))
assert os.path.exists(echo_filepath), (
'cannot find the echo audio track file {}'.format(echo_filepath))
return echo_filepath

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Echo path simulation factory module.
"""
@ -16,16 +15,16 @@ from . import echo_path_simulation
class EchoPathSimulatorFactory(object):
# TODO(alessiob): Replace 20 ms delay (at 48 kHz sample rate) with a more
# realistic impulse response.
_LINEAR_ECHO_IMPULSE_RESPONSE = np.array([0.0]*(20 * 48) + [0.15])
# TODO(alessiob): Replace 20 ms delay (at 48 kHz sample rate) with a more
# realistic impulse response.
_LINEAR_ECHO_IMPULSE_RESPONSE = np.array([0.0] * (20 * 48) + [0.15])
def __init__(self):
pass
def __init__(self):
pass
@classmethod
def GetInstance(cls, echo_path_simulator_class, render_input_filepath):
"""Creates an EchoPathSimulator instance given a class object.
@classmethod
def GetInstance(cls, echo_path_simulator_class, render_input_filepath):
"""Creates an EchoPathSimulator instance given a class object.
Args:
echo_path_simulator_class: EchoPathSimulator class object (not an
@ -35,14 +34,15 @@ class EchoPathSimulatorFactory(object):
Returns:
An EchoPathSimulator instance.
"""
assert render_input_filepath is not None or (
echo_path_simulator_class == echo_path_simulation.NoEchoPathSimulator)
assert render_input_filepath is not None or (
echo_path_simulator_class ==
echo_path_simulation.NoEchoPathSimulator)
if echo_path_simulator_class == echo_path_simulation.NoEchoPathSimulator:
return echo_path_simulation.NoEchoPathSimulator()
elif echo_path_simulator_class == (
echo_path_simulation.LinearEchoPathSimulator):
return echo_path_simulation.LinearEchoPathSimulator(
render_input_filepath, cls._LINEAR_ECHO_IMPULSE_RESPONSE)
else:
return echo_path_simulator_class(render_input_filepath)
if echo_path_simulator_class == echo_path_simulation.NoEchoPathSimulator:
return echo_path_simulation.NoEchoPathSimulator()
elif echo_path_simulator_class == (
echo_path_simulation.LinearEchoPathSimulator):
return echo_path_simulation.LinearEchoPathSimulator(
render_input_filepath, cls._LINEAR_ECHO_IMPULSE_RESPONSE)
else:
return echo_path_simulator_class(render_input_filepath)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the echo path simulation module.
"""
@ -22,60 +21,62 @@ from . import signal_processing
class TestEchoPathSimulators(unittest.TestCase):
"""Unit tests for the eval_scores module.
"""Unit tests for the eval_scores module.
"""
def setUp(self):
"""Creates temporary data."""
self._tmp_path = tempfile.mkdtemp()
def setUp(self):
"""Creates temporary data."""
self._tmp_path = tempfile.mkdtemp()
# Create and save white noise.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
self._audio_track_num_samples = (
signal_processing.SignalProcessingUtils.CountSamples(white_noise))
self._audio_track_filepath = os.path.join(self._tmp_path, 'white_noise.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._audio_track_filepath, white_noise)
# Create and save white noise.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
self._audio_track_num_samples = (
signal_processing.SignalProcessingUtils.CountSamples(white_noise))
self._audio_track_filepath = os.path.join(self._tmp_path,
'white_noise.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._audio_track_filepath, white_noise)
# Make a copy the white noise audio track file; it will be used by
# echo_path_simulation.RecordedEchoPathSimulator.
shutil.copy(self._audio_track_filepath, os.path.join(
self._tmp_path, 'white_noise_echo.wav'))
# Make a copy the white noise audio track file; it will be used by
# echo_path_simulation.RecordedEchoPathSimulator.
shutil.copy(self._audio_track_filepath,
os.path.join(self._tmp_path, 'white_noise_echo.wav'))
def tearDown(self):
"""Recursively deletes temporary folders."""
shutil.rmtree(self._tmp_path)
def tearDown(self):
"""Recursively deletes temporary folders."""
shutil.rmtree(self._tmp_path)
def testRegisteredClasses(self):
# Check that there is at least one registered echo path simulator.
registered_classes = (
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES)
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
def testRegisteredClasses(self):
# Check that there is at least one registered echo path simulator.
registered_classes = (
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES)
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Instance factory.
factory = echo_path_simulation_factory.EchoPathSimulatorFactory()
# Instance factory.
factory = echo_path_simulation_factory.EchoPathSimulatorFactory()
# Try each registered echo path simulator.
for echo_path_simulator_name in registered_classes:
simulator = factory.GetInstance(
echo_path_simulator_class=registered_classes[
echo_path_simulator_name],
render_input_filepath=self._audio_track_filepath)
# Try each registered echo path simulator.
for echo_path_simulator_name in registered_classes:
simulator = factory.GetInstance(
echo_path_simulator_class=registered_classes[
echo_path_simulator_name],
render_input_filepath=self._audio_track_filepath)
echo_filepath = simulator.Simulate(self._tmp_path)
if echo_filepath is None:
self.assertEqual(echo_path_simulation.NoEchoPathSimulator.NAME,
echo_path_simulator_name)
# No other tests in this case.
continue
echo_filepath = simulator.Simulate(self._tmp_path)
if echo_filepath is None:
self.assertEqual(echo_path_simulation.NoEchoPathSimulator.NAME,
echo_path_simulator_name)
# No other tests in this case.
continue
# Check that the echo audio track file exists and its length is greater or
# equal to that of the render audio track.
self.assertTrue(os.path.exists(echo_filepath))
echo = signal_processing.SignalProcessingUtils.LoadWav(echo_filepath)
self.assertGreaterEqual(
signal_processing.SignalProcessingUtils.CountSamples(echo),
self._audio_track_num_samples)
# Check that the echo audio track file exists and its length is greater or
# equal to that of the render audio track.
self.assertTrue(os.path.exists(echo_filepath))
echo = signal_processing.SignalProcessingUtils.LoadWav(
echo_filepath)
self.assertGreaterEqual(
signal_processing.SignalProcessingUtils.CountSamples(echo),
self._audio_track_num_samples)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Evaluation score abstract class and implementations.
"""
@ -17,10 +16,10 @@ import subprocess
import sys
try:
import numpy as np
import numpy as np
except ImportError:
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
from . import data_access
from . import exceptions
@ -29,23 +28,23 @@ from . import signal_processing
class EvaluationScore(object):
NAME = None
REGISTERED_CLASSES = {}
NAME = None
REGISTERED_CLASSES = {}
def __init__(self, score_filename_prefix):
self._score_filename_prefix = score_filename_prefix
self._input_signal_metadata = None
self._reference_signal = None
self._reference_signal_filepath = None
self._tested_signal = None
self._tested_signal_filepath = None
self._output_filepath = None
self._score = None
self._render_signal_filepath = None
def __init__(self, score_filename_prefix):
self._score_filename_prefix = score_filename_prefix
self._input_signal_metadata = None
self._reference_signal = None
self._reference_signal_filepath = None
self._tested_signal = None
self._tested_signal_filepath = None
self._output_filepath = None
self._score = None
self._render_signal_filepath = None
@classmethod
def RegisterClass(cls, class_to_register):
"""Registers an EvaluationScore implementation.
@classmethod
def RegisterClass(cls, class_to_register):
"""Registers an EvaluationScore implementation.
Decorator to automatically register the classes that extend EvaluationScore.
Example usage:
@ -54,91 +53,90 @@ class EvaluationScore(object):
class AudioLevelScore(EvaluationScore):
pass
"""
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
return class_to_register
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
return class_to_register
@property
def output_filepath(self):
return self._output_filepath
@property
def output_filepath(self):
return self._output_filepath
@property
def score(self):
return self._score
@property
def score(self):
return self._score
def SetInputSignalMetadata(self, metadata):
"""Sets input signal metadata.
def SetInputSignalMetadata(self, metadata):
"""Sets input signal metadata.
Args:
metadata: dict instance.
"""
self._input_signal_metadata = metadata
self._input_signal_metadata = metadata
def SetReferenceSignalFilepath(self, filepath):
"""Sets the path to the audio track used as reference signal.
def SetReferenceSignalFilepath(self, filepath):
"""Sets the path to the audio track used as reference signal.
Args:
filepath: path to the reference audio track.
"""
self._reference_signal_filepath = filepath
self._reference_signal_filepath = filepath
def SetTestedSignalFilepath(self, filepath):
"""Sets the path to the audio track used as test signal.
def SetTestedSignalFilepath(self, filepath):
"""Sets the path to the audio track used as test signal.
Args:
filepath: path to the test audio track.
"""
self._tested_signal_filepath = filepath
self._tested_signal_filepath = filepath
def SetRenderSignalFilepath(self, filepath):
"""Sets the path to the audio track used as render signal.
def SetRenderSignalFilepath(self, filepath):
"""Sets the path to the audio track used as render signal.
Args:
filepath: path to the test audio track.
"""
self._render_signal_filepath = filepath
self._render_signal_filepath = filepath
def Run(self, output_path):
"""Extracts the score for the set test data pair.
def Run(self, output_path):
"""Extracts the score for the set test data pair.
Args:
output_path: path to the directory where the output is written.
"""
self._output_filepath = os.path.join(
output_path, self._score_filename_prefix + self.NAME + '.txt')
try:
# If the score has already been computed, load.
self._LoadScore()
logging.debug('score found and loaded')
except IOError:
# Compute the score.
logging.debug('score not found, compute')
self._Run(output_path)
self._output_filepath = os.path.join(
output_path, self._score_filename_prefix + self.NAME + '.txt')
try:
# If the score has already been computed, load.
self._LoadScore()
logging.debug('score found and loaded')
except IOError:
# Compute the score.
logging.debug('score not found, compute')
self._Run(output_path)
def _Run(self, output_path):
# Abstract method.
raise NotImplementedError()
def _Run(self, output_path):
# Abstract method.
raise NotImplementedError()
def _LoadReferenceSignal(self):
assert self._reference_signal_filepath is not None
self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
self._reference_signal_filepath)
def _LoadReferenceSignal(self):
assert self._reference_signal_filepath is not None
self._reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
self._reference_signal_filepath)
def _LoadTestedSignal(self):
assert self._tested_signal_filepath is not None
self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav(
self._tested_signal_filepath)
def _LoadTestedSignal(self):
assert self._tested_signal_filepath is not None
self._tested_signal = signal_processing.SignalProcessingUtils.LoadWav(
self._tested_signal_filepath)
def _LoadScore(self):
return data_access.ScoreFile.Load(self._output_filepath)
def _LoadScore(self):
return data_access.ScoreFile.Load(self._output_filepath)
def _SaveScore(self):
return data_access.ScoreFile.Save(self._output_filepath, self._score)
def _SaveScore(self):
return data_access.ScoreFile.Save(self._output_filepath, self._score)
@EvaluationScore.RegisterClass
class AudioLevelPeakScore(EvaluationScore):
"""Peak audio level score.
"""Peak audio level score.
Defined as the difference between the peak audio level of the tested and
the reference signals.
@ -148,21 +146,21 @@ class AudioLevelPeakScore(EvaluationScore):
Worst case: +/-inf dB
"""
NAME = 'audio_level_peak'
NAME = 'audio_level_peak'
def __init__(self, score_filename_prefix):
EvaluationScore.__init__(self, score_filename_prefix)
def __init__(self, score_filename_prefix):
EvaluationScore.__init__(self, score_filename_prefix)
def _Run(self, output_path):
self._LoadReferenceSignal()
self._LoadTestedSignal()
self._score = self._tested_signal.dBFS - self._reference_signal.dBFS
self._SaveScore()
def _Run(self, output_path):
self._LoadReferenceSignal()
self._LoadTestedSignal()
self._score = self._tested_signal.dBFS - self._reference_signal.dBFS
self._SaveScore()
@EvaluationScore.RegisterClass
class MeanAudioLevelScore(EvaluationScore):
"""Mean audio level score.
"""Mean audio level score.
Defined as the difference between the mean audio level of the tested and
the reference signals.
@ -172,29 +170,30 @@ class MeanAudioLevelScore(EvaluationScore):
Worst case: +/-inf dB
"""
NAME = 'audio_level_mean'
NAME = 'audio_level_mean'
def __init__(self, score_filename_prefix):
EvaluationScore.__init__(self, score_filename_prefix)
def __init__(self, score_filename_prefix):
EvaluationScore.__init__(self, score_filename_prefix)
def _Run(self, output_path):
self._LoadReferenceSignal()
self._LoadTestedSignal()
def _Run(self, output_path):
self._LoadReferenceSignal()
self._LoadTestedSignal()
dbfs_diffs_sum = 0.0
seconds = min(len(self._tested_signal), len(self._reference_signal)) // 1000
for t in range(seconds):
t0 = t * seconds
t1 = t0 + seconds
dbfs_diffs_sum += (
self._tested_signal[t0:t1].dBFS - self._reference_signal[t0:t1].dBFS)
self._score = dbfs_diffs_sum / float(seconds)
self._SaveScore()
dbfs_diffs_sum = 0.0
seconds = min(len(self._tested_signal), len(
self._reference_signal)) // 1000
for t in range(seconds):
t0 = t * seconds
t1 = t0 + seconds
dbfs_diffs_sum += (self._tested_signal[t0:t1].dBFS -
self._reference_signal[t0:t1].dBFS)
self._score = dbfs_diffs_sum / float(seconds)
self._SaveScore()
@EvaluationScore.RegisterClass
class EchoMetric(EvaluationScore):
"""Echo score.
"""Echo score.
Proportion of detected echo.
@ -203,46 +202,47 @@ class EchoMetric(EvaluationScore):
Worst case: 1
"""
NAME = 'echo_metric'
NAME = 'echo_metric'
def __init__(self, score_filename_prefix, echo_detector_bin_filepath):
EvaluationScore.__init__(self, score_filename_prefix)
def __init__(self, score_filename_prefix, echo_detector_bin_filepath):
EvaluationScore.__init__(self, score_filename_prefix)
# POLQA binary file path.
self._echo_detector_bin_filepath = echo_detector_bin_filepath
if not os.path.exists(self._echo_detector_bin_filepath):
logging.error('cannot find EchoMetric tool binary file')
raise exceptions.FileNotFoundError()
# POLQA binary file path.
self._echo_detector_bin_filepath = echo_detector_bin_filepath
if not os.path.exists(self._echo_detector_bin_filepath):
logging.error('cannot find EchoMetric tool binary file')
raise exceptions.FileNotFoundError()
self._echo_detector_bin_path, _ = os.path.split(
self._echo_detector_bin_filepath)
self._echo_detector_bin_path, _ = os.path.split(
self._echo_detector_bin_filepath)
def _Run(self, output_path):
echo_detector_out_filepath = os.path.join(output_path, 'echo_detector.out')
if os.path.exists(echo_detector_out_filepath):
os.unlink(echo_detector_out_filepath)
def _Run(self, output_path):
echo_detector_out_filepath = os.path.join(output_path,
'echo_detector.out')
if os.path.exists(echo_detector_out_filepath):
os.unlink(echo_detector_out_filepath)
logging.debug("Render signal filepath: %s", self._render_signal_filepath)
if not os.path.exists(self._render_signal_filepath):
logging.error("Render input required for evaluating the echo metric.")
logging.debug("Render signal filepath: %s",
self._render_signal_filepath)
if not os.path.exists(self._render_signal_filepath):
logging.error(
"Render input required for evaluating the echo metric.")
args = [
self._echo_detector_bin_filepath,
'--output_file', echo_detector_out_filepath,
'--',
'-i', self._tested_signal_filepath,
'-ri', self._render_signal_filepath
]
logging.debug(' '.join(args))
subprocess.call(args, cwd=self._echo_detector_bin_path)
args = [
self._echo_detector_bin_filepath, '--output_file',
echo_detector_out_filepath, '--', '-i',
self._tested_signal_filepath, '-ri', self._render_signal_filepath
]
logging.debug(' '.join(args))
subprocess.call(args, cwd=self._echo_detector_bin_path)
# Parse Echo detector tool output and extract the score.
self._score = self._ParseOutputFile(echo_detector_out_filepath)
self._SaveScore()
# Parse Echo detector tool output and extract the score.
self._score = self._ParseOutputFile(echo_detector_out_filepath)
self._SaveScore()
@classmethod
def _ParseOutputFile(cls, echo_metric_file_path):
"""
@classmethod
def _ParseOutputFile(cls, echo_metric_file_path):
"""
Parses the POLQA tool output formatted as a table ('-t' option).
Args:
@ -251,12 +251,13 @@ class EchoMetric(EvaluationScore):
Returns:
The score as a number in [0, 1].
"""
with open(echo_metric_file_path) as f:
return float(f.read())
with open(echo_metric_file_path) as f:
return float(f.read())
@EvaluationScore.RegisterClass
class PolqaScore(EvaluationScore):
"""POLQA score.
"""POLQA score.
See http://www.polqa.info/.
@ -265,44 +266,51 @@ class PolqaScore(EvaluationScore):
Worst case: 1.0
"""
NAME = 'polqa'
NAME = 'polqa'
def __init__(self, score_filename_prefix, polqa_bin_filepath):
EvaluationScore.__init__(self, score_filename_prefix)
def __init__(self, score_filename_prefix, polqa_bin_filepath):
EvaluationScore.__init__(self, score_filename_prefix)
# POLQA binary file path.
self._polqa_bin_filepath = polqa_bin_filepath
if not os.path.exists(self._polqa_bin_filepath):
logging.error('cannot find POLQA tool binary file')
raise exceptions.FileNotFoundError()
# POLQA binary file path.
self._polqa_bin_filepath = polqa_bin_filepath
if not os.path.exists(self._polqa_bin_filepath):
logging.error('cannot find POLQA tool binary file')
raise exceptions.FileNotFoundError()
# Path to the POLQA directory with binary and license files.
self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath)
# Path to the POLQA directory with binary and license files.
self._polqa_tool_path, _ = os.path.split(self._polqa_bin_filepath)
def _Run(self, output_path):
polqa_out_filepath = os.path.join(output_path, 'polqa.out')
if os.path.exists(polqa_out_filepath):
os.unlink(polqa_out_filepath)
def _Run(self, output_path):
polqa_out_filepath = os.path.join(output_path, 'polqa.out')
if os.path.exists(polqa_out_filepath):
os.unlink(polqa_out_filepath)
args = [
self._polqa_bin_filepath, '-t', '-q', '-Overwrite',
'-Ref', self._reference_signal_filepath,
'-Test', self._tested_signal_filepath,
'-LC', 'NB',
'-Out', polqa_out_filepath,
]
logging.debug(' '.join(args))
subprocess.call(args, cwd=self._polqa_tool_path)
args = [
self._polqa_bin_filepath,
'-t',
'-q',
'-Overwrite',
'-Ref',
self._reference_signal_filepath,
'-Test',
self._tested_signal_filepath,
'-LC',
'NB',
'-Out',
polqa_out_filepath,
]
logging.debug(' '.join(args))
subprocess.call(args, cwd=self._polqa_tool_path)
# Parse POLQA tool output and extract the score.
polqa_output = self._ParseOutputFile(polqa_out_filepath)
self._score = float(polqa_output['PolqaScore'])
# Parse POLQA tool output and extract the score.
polqa_output = self._ParseOutputFile(polqa_out_filepath)
self._score = float(polqa_output['PolqaScore'])
self._SaveScore()
self._SaveScore()
@classmethod
def _ParseOutputFile(cls, polqa_out_filepath):
"""
@classmethod
def _ParseOutputFile(cls, polqa_out_filepath):
"""
Parses the POLQA tool output formatted as a table ('-t' option).
Args:
@ -311,29 +319,32 @@ class PolqaScore(EvaluationScore):
Returns:
A dict.
"""
data = []
with open(polqa_out_filepath) as f:
for line in f:
line = line.strip()
if len(line) == 0 or line.startswith('*'):
# Ignore comments.
continue
# Read fields.
data.append(re.split(r'\t+', line))
data = []
with open(polqa_out_filepath) as f:
for line in f:
line = line.strip()
if len(line) == 0 or line.startswith('*'):
# Ignore comments.
continue
# Read fields.
data.append(re.split(r'\t+', line))
# Two rows expected (header and values).
assert len(data) == 2, 'Cannot parse POLQA output'
number_of_fields = len(data[0])
assert number_of_fields == len(data[1])
# Two rows expected (header and values).
assert len(data) == 2, 'Cannot parse POLQA output'
number_of_fields = len(data[0])
assert number_of_fields == len(data[1])
# Build and return a dictionary with field names (header) as keys and the
# corresponding field values as values.
return {data[0][index]: data[1][index] for index in range(number_of_fields)}
# Build and return a dictionary with field names (header) as keys and the
# corresponding field values as values.
return {
data[0][index]: data[1][index]
for index in range(number_of_fields)
}
@EvaluationScore.RegisterClass
class TotalHarmonicDistorsionScore(EvaluationScore):
"""Total harmonic distorsion plus noise score.
"""Total harmonic distorsion plus noise score.
Total harmonic distorsion plus noise score.
See "https://en.wikipedia.org/wiki/Total_harmonic_distortion#THD.2BN".
@ -343,69 +354,74 @@ class TotalHarmonicDistorsionScore(EvaluationScore):
Worst case: +inf
"""
NAME = 'thd'
NAME = 'thd'
def __init__(self, score_filename_prefix):
EvaluationScore.__init__(self, score_filename_prefix)
self._input_frequency = None
def __init__(self, score_filename_prefix):
EvaluationScore.__init__(self, score_filename_prefix)
self._input_frequency = None
def _Run(self, output_path):
self._CheckInputSignal()
def _Run(self, output_path):
self._CheckInputSignal()
self._LoadTestedSignal()
if self._tested_signal.channels != 1:
raise exceptions.EvaluationScoreException(
'unsupported number of channels')
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
self._tested_signal)
self._LoadTestedSignal()
if self._tested_signal.channels != 1:
raise exceptions.EvaluationScoreException(
'unsupported number of channels')
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
self._tested_signal)
# Init.
num_samples = len(samples)
duration = len(self._tested_signal) / 1000.0
scaling = 2.0 / num_samples
max_freq = self._tested_signal.frame_rate / 2
f0_freq = float(self._input_frequency)
t = np.linspace(0, duration, num_samples)
# Init.
num_samples = len(samples)
duration = len(self._tested_signal) / 1000.0
scaling = 2.0 / num_samples
max_freq = self._tested_signal.frame_rate / 2
f0_freq = float(self._input_frequency)
t = np.linspace(0, duration, num_samples)
# Analyze harmonics.
b_terms = []
n = 1
while f0_freq * n < max_freq:
x_n = np.sum(samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
y_n = np.sum(samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
b_terms.append(np.sqrt(x_n**2 + y_n**2))
n += 1
# Analyze harmonics.
b_terms = []
n = 1
while f0_freq * n < max_freq:
x_n = np.sum(
samples * np.sin(2.0 * np.pi * n * f0_freq * t)) * scaling
y_n = np.sum(
samples * np.cos(2.0 * np.pi * n * f0_freq * t)) * scaling
b_terms.append(np.sqrt(x_n**2 + y_n**2))
n += 1
output_without_fundamental = samples - b_terms[0] * np.sin(
2.0 * np.pi * f0_freq * t)
distortion_and_noise = np.sqrt(np.sum(
output_without_fundamental**2) * np.pi * scaling)
output_without_fundamental = samples - b_terms[0] * np.sin(
2.0 * np.pi * f0_freq * t)
distortion_and_noise = np.sqrt(
np.sum(output_without_fundamental**2) * np.pi * scaling)
# TODO(alessiob): Fix or remove if not needed.
# thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
# TODO(alessiob): Fix or remove if not needed.
# thd = np.sqrt(np.sum(b_terms[1:]**2)) / b_terms[0]
# TODO(alessiob): Check the range of |thd_plus_noise| and update the class
# docstring above if accordingly.
thd_plus_noise = distortion_and_noise / b_terms[0]
# TODO(alessiob): Check the range of |thd_plus_noise| and update the class
# docstring above if accordingly.
thd_plus_noise = distortion_and_noise / b_terms[0]
self._score = thd_plus_noise
self._SaveScore()
self._score = thd_plus_noise
self._SaveScore()
def _CheckInputSignal(self):
# Check input signal and get properties.
try:
if self._input_signal_metadata['signal'] != 'pure_tone':
raise exceptions.EvaluationScoreException(
'The THD score requires a pure tone as input signal')
self._input_frequency = self._input_signal_metadata['frequency']
if self._input_signal_metadata['test_data_gen_name'] != 'identity' or (
self._input_signal_metadata['test_data_gen_config'] != 'default'):
raise exceptions.EvaluationScoreException(
'The THD score cannot be used with any test data generator other '
'than "identity"')
except TypeError:
raise exceptions.EvaluationScoreException(
'The THD score requires an input signal with associated metadata')
except KeyError:
raise exceptions.EvaluationScoreException(
'Invalid input signal metadata to compute the THD score')
def _CheckInputSignal(self):
# Check input signal and get properties.
try:
if self._input_signal_metadata['signal'] != 'pure_tone':
raise exceptions.EvaluationScoreException(
'The THD score requires a pure tone as input signal')
self._input_frequency = self._input_signal_metadata['frequency']
if self._input_signal_metadata[
'test_data_gen_name'] != 'identity' or (
self._input_signal_metadata['test_data_gen_config'] !=
'default'):
raise exceptions.EvaluationScoreException(
'The THD score cannot be used with any test data generator other '
'than "identity"')
except TypeError:
raise exceptions.EvaluationScoreException(
'The THD score requires an input signal with associated metadata'
)
except KeyError:
raise exceptions.EvaluationScoreException(
'Invalid input signal metadata to compute the THD score')

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""EvaluationScore factory class.
"""
@ -16,22 +15,22 @@ from . import eval_scores
class EvaluationScoreWorkerFactory(object):
"""Factory class used to instantiate evaluation score workers.
"""Factory class used to instantiate evaluation score workers.
The ctor gets the parametrs that are used to instatiate the evaluation score
workers.
"""
def __init__(self, polqa_tool_bin_path, echo_metric_tool_bin_path):
self._score_filename_prefix = None
self._polqa_tool_bin_path = polqa_tool_bin_path
self._echo_metric_tool_bin_path = echo_metric_tool_bin_path
def __init__(self, polqa_tool_bin_path, echo_metric_tool_bin_path):
self._score_filename_prefix = None
self._polqa_tool_bin_path = polqa_tool_bin_path
self._echo_metric_tool_bin_path = echo_metric_tool_bin_path
def SetScoreFilenamePrefix(self, prefix):
self._score_filename_prefix = prefix
def SetScoreFilenamePrefix(self, prefix):
self._score_filename_prefix = prefix
def GetInstance(self, evaluation_score_class):
"""Creates an EvaluationScore instance given a class object.
def GetInstance(self, evaluation_score_class):
"""Creates an EvaluationScore instance given a class object.
Args:
evaluation_score_class: EvaluationScore class object (not an instance).
@ -39,17 +38,18 @@ class EvaluationScoreWorkerFactory(object):
Returns:
An EvaluationScore instance.
"""
if self._score_filename_prefix is None:
raise exceptions.InitializationException(
'The score file name prefix for evaluation score workers is not set')
logging.debug(
'factory producing a %s evaluation score', evaluation_score_class)
if self._score_filename_prefix is None:
raise exceptions.InitializationException(
'The score file name prefix for evaluation score workers is not set'
)
logging.debug('factory producing a %s evaluation score',
evaluation_score_class)
if evaluation_score_class == eval_scores.PolqaScore:
return eval_scores.PolqaScore(
self._score_filename_prefix, self._polqa_tool_bin_path)
elif evaluation_score_class == eval_scores.EchoMetric:
return eval_scores.EchoMetric(
self._score_filename_prefix, self._echo_metric_tool_bin_path)
else:
return evaluation_score_class(self._score_filename_prefix)
if evaluation_score_class == eval_scores.PolqaScore:
return eval_scores.PolqaScore(self._score_filename_prefix,
self._polqa_tool_bin_path)
elif evaluation_score_class == eval_scores.EchoMetric:
return eval_scores.EchoMetric(self._score_filename_prefix,
self._echo_metric_tool_bin_path)
else:
return evaluation_score_class(self._score_filename_prefix)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the eval_scores module.
"""
@ -23,111 +22,116 @@ from . import signal_processing
class TestEvalScores(unittest.TestCase):
"""Unit tests for the eval_scores module.
"""Unit tests for the eval_scores module.
"""
def setUp(self):
"""Create temporary output folder and two audio track files."""
self._output_path = tempfile.mkdtemp()
def setUp(self):
"""Create temporary output folder and two audio track files."""
self._output_path = tempfile.mkdtemp()
# Create fake reference and tested (i.e., APM output) audio track files.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
fake_reference_signal = (
signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))
fake_tested_signal = (
signal_processing.SignalProcessingUtils.GenerateWhiteNoise(silence))
# Create fake reference and tested (i.e., APM output) audio track files.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
fake_reference_signal = (signal_processing.SignalProcessingUtils.
GenerateWhiteNoise(silence))
fake_tested_signal = (signal_processing.SignalProcessingUtils.
GenerateWhiteNoise(silence))
# Save fake audio tracks.
self._fake_reference_signal_filepath = os.path.join(
self._output_path, 'fake_ref.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_reference_signal_filepath, fake_reference_signal)
self._fake_tested_signal_filepath = os.path.join(
self._output_path, 'fake_test.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_tested_signal_filepath, fake_tested_signal)
# Save fake audio tracks.
self._fake_reference_signal_filepath = os.path.join(
self._output_path, 'fake_ref.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_reference_signal_filepath, fake_reference_signal)
self._fake_tested_signal_filepath = os.path.join(
self._output_path, 'fake_test.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_tested_signal_filepath, fake_tested_signal)
def tearDown(self):
"""Recursively delete temporary folder."""
shutil.rmtree(self._output_path)
def tearDown(self):
"""Recursively delete temporary folder."""
shutil.rmtree(self._output_path)
def testRegisteredClasses(self):
# Evaluation score names to exclude (tested separately).
exceptions = ['thd', 'echo_metric']
def testRegisteredClasses(self):
# Evaluation score names to exclude (tested separately).
exceptions = ['thd', 'echo_metric']
# Preliminary check.
self.assertTrue(os.path.exists(self._output_path))
# Preliminary check.
self.assertTrue(os.path.exists(self._output_path))
# Check that there is at least one registered evaluation score worker.
registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Check that there is at least one registered evaluation score worker.
registered_classes = eval_scores.EvaluationScore.REGISTERED_CLASSES
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Instance evaluation score workers factory with fake dependencies.
eval_score_workers_factory = (
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
echo_metric_tool_bin_path=None
))
eval_score_workers_factory.SetScoreFilenamePrefix('scores-')
# Instance evaluation score workers factory with fake dependencies.
eval_score_workers_factory = (
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
echo_metric_tool_bin_path=None))
eval_score_workers_factory.SetScoreFilenamePrefix('scores-')
# Try each registered evaluation score worker.
for eval_score_name in registered_classes:
if eval_score_name in exceptions:
continue
# Try each registered evaluation score worker.
for eval_score_name in registered_classes:
if eval_score_name in exceptions:
continue
# Instance evaluation score worker.
eval_score_worker = eval_score_workers_factory.GetInstance(
registered_classes[eval_score_name])
# Instance evaluation score worker.
eval_score_worker = eval_score_workers_factory.GetInstance(
registered_classes[eval_score_name])
# Set fake input metadata and reference and test file paths, then run.
eval_score_worker.SetReferenceSignalFilepath(
self._fake_reference_signal_filepath)
eval_score_worker.SetTestedSignalFilepath(
self._fake_tested_signal_filepath)
eval_score_worker.Run(self._output_path)
# Set fake input metadata and reference and test file paths, then run.
eval_score_worker.SetReferenceSignalFilepath(
self._fake_reference_signal_filepath)
eval_score_worker.SetTestedSignalFilepath(
self._fake_tested_signal_filepath)
eval_score_worker.Run(self._output_path)
# Check output.
score = data_access.ScoreFile.Load(eval_score_worker.output_filepath)
self.assertTrue(isinstance(score, float))
# Check output.
score = data_access.ScoreFile.Load(
eval_score_worker.output_filepath)
self.assertTrue(isinstance(score, float))
def testTotalHarmonicDistorsionScore(self):
# Init.
pure_tone_freq = 5000.0
eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
eval_score_worker.SetInputSignalMetadata({
'signal': 'pure_tone',
'frequency': pure_tone_freq,
'test_data_gen_name': 'identity',
'test_data_gen_config': 'default',
})
template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
def testTotalHarmonicDistorsionScore(self):
# Init.
pure_tone_freq = 5000.0
eval_score_worker = eval_scores.TotalHarmonicDistorsionScore('scores-')
eval_score_worker.SetInputSignalMetadata({
'signal':
'pure_tone',
'frequency':
pure_tone_freq,
'test_data_gen_name':
'identity',
'test_data_gen_config':
'default',
})
template = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
# Create 3 test signals: pure tone, pure tone + white noise, white noise
# only.
pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
template, pure_tone_freq)
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
template)
noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
pure_tone, white_noise)
# Create 3 test signals: pure tone, pure tone + white noise, white noise
# only.
pure_tone = signal_processing.SignalProcessingUtils.GeneratePureTone(
template, pure_tone_freq)
white_noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
template)
noisy_tone = signal_processing.SignalProcessingUtils.MixSignals(
pure_tone, white_noise)
# Compute scores for increasingly distorted pure tone signals.
scores = [None, None, None]
for index, tested_signal in enumerate([pure_tone, noisy_tone, white_noise]):
# Save signal.
tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
signal_processing.SignalProcessingUtils.SaveWav(
tmp_filepath, tested_signal)
# Compute scores for increasingly distorted pure tone signals.
scores = [None, None, None]
for index, tested_signal in enumerate(
[pure_tone, noisy_tone, white_noise]):
# Save signal.
tmp_filepath = os.path.join(self._output_path, 'tmp_thd.wav')
signal_processing.SignalProcessingUtils.SaveWav(
tmp_filepath, tested_signal)
# Compute score.
eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
eval_score_worker.Run(self._output_path)
scores[index] = eval_score_worker.score
# Compute score.
eval_score_worker.SetTestedSignalFilepath(tmp_filepath)
eval_score_worker.Run(self._output_path)
scores[index] = eval_score_worker.score
# Remove output file to avoid caching.
os.remove(eval_score_worker.output_filepath)
# Remove output file to avoid caching.
os.remove(eval_score_worker.output_filepath)
# Validate scores (lowest score with a pure tone).
self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))
# Validate scores (lowest score with a pure tone).
self.assertTrue(all([scores[i + 1] > scores[i] for i in range(2)]))

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Evaluator of the APM module.
"""
@ -13,17 +12,17 @@ import logging
class ApmModuleEvaluator(object):
"""APM evaluator class.
"""APM evaluator class.
"""
def __init__(self):
pass
def __init__(self):
pass
@classmethod
def Run(cls, evaluation_score_workers, apm_input_metadata,
apm_output_filepath, reference_input_filepath,
render_input_filepath, output_path):
"""Runs the evaluation.
@classmethod
def Run(cls, evaluation_score_workers, apm_input_metadata,
apm_output_filepath, reference_input_filepath,
render_input_filepath, output_path):
"""Runs the evaluation.
Iterates over the given evaluation score workers.
@ -37,20 +36,22 @@ class ApmModuleEvaluator(object):
Returns:
A dict of evaluation score name and score pairs.
"""
# Init.
scores = {}
# Init.
scores = {}
for evaluation_score_worker in evaluation_score_workers:
logging.info(' computing <%s> score', evaluation_score_worker.NAME)
evaluation_score_worker.SetInputSignalMetadata(apm_input_metadata)
evaluation_score_worker.SetReferenceSignalFilepath(
reference_input_filepath)
evaluation_score_worker.SetTestedSignalFilepath(
apm_output_filepath)
evaluation_score_worker.SetRenderSignalFilepath(
render_input_filepath)
for evaluation_score_worker in evaluation_score_workers:
logging.info(' computing <%s> score',
evaluation_score_worker.NAME)
evaluation_score_worker.SetInputSignalMetadata(apm_input_metadata)
evaluation_score_worker.SetReferenceSignalFilepath(
reference_input_filepath)
evaluation_score_worker.SetTestedSignalFilepath(
apm_output_filepath)
evaluation_score_worker.SetRenderSignalFilepath(
render_input_filepath)
evaluation_score_worker.Run(output_path)
scores[evaluation_score_worker.NAME] = evaluation_score_worker.score
evaluation_score_worker.Run(output_path)
scores[
evaluation_score_worker.NAME] = evaluation_score_worker.score
return scores
return scores

View File

@ -5,42 +5,41 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Exception classes.
"""
class FileNotFoundError(Exception):
"""File not found exception.
"""File not found exception.
"""
pass
pass
class SignalProcessingException(Exception):
"""Signal processing exception.
"""Signal processing exception.
"""
pass
pass
class InputMixerException(Exception):
"""Input mixer exception.
"""Input mixer exception.
"""
pass
pass
class InputSignalCreatorException(Exception):
"""Input signal creator exception.
"""Input signal creator exception.
"""
pass
pass
class EvaluationScoreException(Exception):
"""Evaluation score exception.
"""Evaluation score exception.
"""
pass
pass
class InitializationException(Exception):
"""Initialization exception.
"""Initialization exception.
"""
pass
pass

View File

@ -14,58 +14,58 @@ import re
import sys
try:
import csscompressor
import csscompressor
except ImportError:
logging.critical('Cannot import the third-party Python package csscompressor')
sys.exit(1)
logging.critical(
'Cannot import the third-party Python package csscompressor')
sys.exit(1)
try:
import jsmin
import jsmin
except ImportError:
logging.critical('Cannot import the third-party Python package jsmin')
sys.exit(1)
logging.critical('Cannot import the third-party Python package jsmin')
sys.exit(1)
class HtmlExport(object):
"""HTML exporter class for APM quality scores."""
"""HTML exporter class for APM quality scores."""
_NEW_LINE = '\n'
_NEW_LINE = '\n'
# CSS and JS file paths.
_PATH = os.path.dirname(os.path.realpath(__file__))
_CSS_FILEPATH = os.path.join(_PATH, 'results.css')
_CSS_MINIFIED = True
_JS_FILEPATH = os.path.join(_PATH, 'results.js')
_JS_MINIFIED = True
# CSS and JS file paths.
_PATH = os.path.dirname(os.path.realpath(__file__))
_CSS_FILEPATH = os.path.join(_PATH, 'results.css')
_CSS_MINIFIED = True
_JS_FILEPATH = os.path.join(_PATH, 'results.js')
_JS_MINIFIED = True
def __init__(self, output_filepath):
self._scores_data_frame = None
self._output_filepath = output_filepath
def __init__(self, output_filepath):
self._scores_data_frame = None
self._output_filepath = output_filepath
def Export(self, scores_data_frame):
"""Exports scores into an HTML file.
def Export(self, scores_data_frame):
"""Exports scores into an HTML file.
Args:
scores_data_frame: DataFrame instance.
"""
self._scores_data_frame = scores_data_frame
html = ['<html>',
self._scores_data_frame = scores_data_frame
html = [
'<html>',
self._BuildHeader(),
('<script type="text/javascript">'
'(function () {'
'window.addEventListener(\'load\', function () {'
'var inspector = new AudioInspector();'
'});'
'(function () {'
'window.addEventListener(\'load\', function () {'
'var inspector = new AudioInspector();'
'});'
'})();'
'</script>'),
'<body>',
self._BuildBody(),
'</body>',
'</html>']
self._Save(self._output_filepath, self._NEW_LINE.join(html))
'</script>'), '<body>',
self._BuildBody(), '</body>', '</html>'
]
self._Save(self._output_filepath, self._NEW_LINE.join(html))
def _BuildHeader(self):
"""Builds the <head> section of the HTML file.
def _BuildHeader(self):
"""Builds the <head> section of the HTML file.
The header contains the page title and either embedded or linked CSS and JS
files.
@ -73,325 +73,349 @@ class HtmlExport(object):
Returns:
A string with <head>...</head> HTML.
"""
html = ['<head>', '<title>Results</title>']
html = ['<head>', '<title>Results</title>']
# Add Material Design hosted libs.
html.append('<link rel="stylesheet" href="http://fonts.googleapis.com/'
'css?family=Roboto:300,400,500,700" type="text/css">')
html.append('<link rel="stylesheet" href="https://fonts.googleapis.com/'
'icon?family=Material+Icons">')
html.append('<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/'
'material.indigo-pink.min.css">')
html.append('<script defer src="https://code.getmdl.io/1.3.0/'
'material.min.js"></script>')
# Embed custom JavaScript and CSS files.
html.append('<script>')
with open(self._JS_FILEPATH) as f:
html.append(jsmin.jsmin(f.read()) if self._JS_MINIFIED else (
f.read().rstrip()))
html.append('</script>')
html.append('<style>')
with open(self._CSS_FILEPATH) as f:
html.append(csscompressor.compress(f.read()) if self._CSS_MINIFIED else (
f.read().rstrip()))
html.append('</style>')
html.append('</head>')
return self._NEW_LINE.join(html)
def _BuildBody(self):
"""Builds the content of the <body> section."""
score_names = self._scores_data_frame['eval_score_name'].drop_duplicates(
).values.tolist()
html = [
('<div class="mdl-layout mdl-js-layout mdl-layout--fixed-header '
'mdl-layout--fixed-tabs">'),
'<header class="mdl-layout__header">',
'<div class="mdl-layout__header-row">',
'<span class="mdl-layout-title">APM QA results ({})</span>'.format(
self._output_filepath),
'</div>',
]
# Tab selectors.
html.append('<div class="mdl-layout__tab-bar mdl-js-ripple-effect">')
for tab_index, score_name in enumerate(score_names):
is_active = tab_index == 0
html.append('<a href="#score-tab-{}" class="mdl-layout__tab{}">'
'{}</a>'.format(tab_index,
' is-active' if is_active else '',
self._FormatName(score_name)))
html.append('</div>')
html.append('</header>')
html.append('<main class="mdl-layout__content" style="overflow-x: auto;">')
# Tabs content.
for tab_index, score_name in enumerate(score_names):
html.append('<section class="mdl-layout__tab-panel{}" '
'id="score-tab-{}">'.format(
' is-active' if is_active else '', tab_index))
html.append('<div class="page-content">')
html.append(self._BuildScoreTab(score_name, ('s{}'.format(tab_index),)))
html.append('</div>')
html.append('</section>')
html.append('</main>')
html.append('</div>')
# Add snackbar for notifications.
html.append(
'<div id="snackbar" aria-live="assertive" aria-atomic="true"'
' aria-relevant="text" class="mdl-snackbar mdl-js-snackbar">'
'<div class="mdl-snackbar__text"></div>'
'<button type="button" class="mdl-snackbar__action"></button>'
'</div>')
return self._NEW_LINE.join(html)
def _BuildScoreTab(self, score_name, anchor_data):
"""Builds the content of a tab."""
# Find unique values.
scores = self._scores_data_frame[
self._scores_data_frame.eval_score_name == score_name]
apm_configs = sorted(self._FindUniqueTuples(scores, ['apm_config']))
test_data_gen_configs = sorted(self._FindUniqueTuples(
scores, ['test_data_gen', 'test_data_gen_params']))
html = [
'<div class="mdl-grid">',
'<div class="mdl-layout-spacer"></div>',
'<div class="mdl-cell mdl-cell--10-col">',
('<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp" '
'style="width: 100%;">'),
]
# Header.
html.append('<thead><tr><th>APM config / Test data generator</th>')
for test_data_gen_info in test_data_gen_configs:
html.append('<th>{} {}</th>'.format(
self._FormatName(test_data_gen_info[0]), test_data_gen_info[1]))
html.append('</tr></thead>')
# Body.
html.append('<tbody>')
for apm_config in apm_configs:
html.append('<tr><td>' + self._FormatName(apm_config[0]) + '</td>')
for test_data_gen_info in test_data_gen_configs:
dialog_id = self._ScoreStatsInspectorDialogId(
score_name, apm_config[0], test_data_gen_info[0],
test_data_gen_info[1])
# Add Material Design hosted libs.
html.append('<link rel="stylesheet" href="http://fonts.googleapis.com/'
'css?family=Roboto:300,400,500,700" type="text/css">')
html.append(
'<td onclick="openScoreStatsInspector(\'{}\')">{}</td>'.format(
dialog_id, self._BuildScoreTableCell(
score_name, test_data_gen_info[0], test_data_gen_info[1],
apm_config[0])))
html.append('</tr>')
html.append('</tbody>')
'<link rel="stylesheet" href="https://fonts.googleapis.com/'
'icon?family=Material+Icons">')
html.append(
'<link rel="stylesheet" href="https://code.getmdl.io/1.3.0/'
'material.indigo-pink.min.css">')
html.append('<script defer src="https://code.getmdl.io/1.3.0/'
'material.min.js"></script>')
html.append('</table></div><div class="mdl-layout-spacer"></div></div>')
# Embed custom JavaScript and CSS files.
html.append('<script>')
with open(self._JS_FILEPATH) as f:
html.append(
jsmin.jsmin(f.read()) if self._JS_MINIFIED else (
f.read().rstrip()))
html.append('</script>')
html.append('<style>')
with open(self._CSS_FILEPATH) as f:
html.append(
csscompressor.compress(f.read()) if self._CSS_MINIFIED else (
f.read().rstrip()))
html.append('</style>')
html.append(self._BuildScoreStatsInspectorDialogs(
score_name, apm_configs, test_data_gen_configs,
anchor_data))
html.append('</head>')
return self._NEW_LINE.join(html)
return self._NEW_LINE.join(html)
def _BuildScoreTableCell(self, score_name, test_data_gen,
test_data_gen_params, apm_config):
"""Builds the content of a table cell for a score table."""
scores = self._SliceDataForScoreTableCell(
score_name, apm_config, test_data_gen, test_data_gen_params)
stats = self._ComputeScoreStats(scores)
def _BuildBody(self):
"""Builds the content of the <body> section."""
score_names = self._scores_data_frame[
'eval_score_name'].drop_duplicates().values.tolist()
html = []
items_id_prefix = (
score_name + test_data_gen + test_data_gen_params + apm_config)
if stats['count'] == 1:
# Show the only available score.
item_id = hashlib.md5(items_id_prefix.encode('utf-8')).hexdigest()
html.append('<div id="single-value-{0}">{1:f}</div>'.format(
item_id, scores['score'].mean()))
html.append('<div class="mdl-tooltip" data-mdl-for="single-value-{}">{}'
'</div>'.format(item_id, 'single value'))
else:
# Show stats.
for stat_name in ['min', 'max', 'mean', 'std dev']:
item_id = hashlib.md5(
(items_id_prefix + stat_name).encode('utf-8')).hexdigest()
html.append('<div id="stats-{0}">{1:f}</div>'.format(
item_id, stats[stat_name]))
html.append('<div class="mdl-tooltip" data-mdl-for="stats-{}">{}'
'</div>'.format(item_id, stat_name))
html = [
('<div class="mdl-layout mdl-js-layout mdl-layout--fixed-header '
'mdl-layout--fixed-tabs">'),
'<header class="mdl-layout__header">',
'<div class="mdl-layout__header-row">',
'<span class="mdl-layout-title">APM QA results ({})</span>'.format(
self._output_filepath),
'</div>',
]
return self._NEW_LINE.join(html)
def _BuildScoreStatsInspectorDialogs(
self, score_name, apm_configs, test_data_gen_configs, anchor_data):
"""Builds a set of score stats inspector dialogs."""
html = []
for apm_config in apm_configs:
for test_data_gen_info in test_data_gen_configs:
dialog_id = self._ScoreStatsInspectorDialogId(
score_name, apm_config[0],
test_data_gen_info[0], test_data_gen_info[1])
html.append('<dialog class="mdl-dialog" id="{}" '
'style="width: 40%;">'.format(dialog_id))
# Content.
html.append('<div class="mdl-dialog__content">')
html.append('<h6><strong>APM config preset</strong>: {}<br/>'
'<strong>Test data generator</strong>: {} ({})</h6>'.format(
self._FormatName(apm_config[0]),
self._FormatName(test_data_gen_info[0]),
test_data_gen_info[1]))
html.append(self._BuildScoreStatsInspectorDialog(
score_name, apm_config[0], test_data_gen_info[0],
test_data_gen_info[1], anchor_data + (dialog_id,)))
# Tab selectors.
html.append('<div class="mdl-layout__tab-bar mdl-js-ripple-effect">')
for tab_index, score_name in enumerate(score_names):
is_active = tab_index == 0
html.append('<a href="#score-tab-{}" class="mdl-layout__tab{}">'
'{}</a>'.format(tab_index,
' is-active' if is_active else '',
self._FormatName(score_name)))
html.append('</div>')
# Actions.
html.append('<div class="mdl-dialog__actions">')
html.append('<button type="button" class="mdl-button" '
'onclick="closeScoreStatsInspector()">'
'Close</button>')
html.append('</header>')
html.append(
'<main class="mdl-layout__content" style="overflow-x: auto;">')
# Tabs content.
for tab_index, score_name in enumerate(score_names):
html.append('<section class="mdl-layout__tab-panel{}" '
'id="score-tab-{}">'.format(
' is-active' if is_active else '', tab_index))
html.append('<div class="page-content">')
html.append(
self._BuildScoreTab(score_name, ('s{}'.format(tab_index), )))
html.append('</div>')
html.append('</section>')
html.append('</main>')
html.append('</div>')
html.append('</dialog>')
# Add snackbar for notifications.
html.append(
'<div id="snackbar" aria-live="assertive" aria-atomic="true"'
' aria-relevant="text" class="mdl-snackbar mdl-js-snackbar">'
'<div class="mdl-snackbar__text"></div>'
'<button type="button" class="mdl-snackbar__action"></button>'
'</div>')
return self._NEW_LINE.join(html)
return self._NEW_LINE.join(html)
def _BuildScoreStatsInspectorDialog(
self, score_name, apm_config, test_data_gen, test_data_gen_params,
anchor_data):
"""Builds one score stats inspector dialog."""
scores = self._SliceDataForScoreTableCell(
score_name, apm_config, test_data_gen, test_data_gen_params)
def _BuildScoreTab(self, score_name, anchor_data):
"""Builds the content of a tab."""
# Find unique values.
scores = self._scores_data_frame[
self._scores_data_frame.eval_score_name == score_name]
apm_configs = sorted(self._FindUniqueTuples(scores, ['apm_config']))
test_data_gen_configs = sorted(
self._FindUniqueTuples(scores,
['test_data_gen', 'test_data_gen_params']))
capture_render_pairs = sorted(self._FindUniqueTuples(
scores, ['capture', 'render']))
echo_simulators = sorted(self._FindUniqueTuples(scores, ['echo_simulator']))
html = [
'<div class="mdl-grid">',
'<div class="mdl-layout-spacer"></div>',
'<div class="mdl-cell mdl-cell--10-col">',
('<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp" '
'style="width: 100%;">'),
]
html = ['<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp">']
# Header.
html.append('<thead><tr><th>APM config / Test data generator</th>')
for test_data_gen_info in test_data_gen_configs:
html.append('<th>{} {}</th>'.format(
self._FormatName(test_data_gen_info[0]),
test_data_gen_info[1]))
html.append('</tr></thead>')
# Header.
html.append('<thead><tr><th>Capture-Render / Echo simulator</th>')
for echo_simulator in echo_simulators:
html.append('<th>' + self._FormatName(echo_simulator[0]) +'</th>')
html.append('</tr></thead>')
# Body.
html.append('<tbody>')
for apm_config in apm_configs:
html.append('<tr><td>' + self._FormatName(apm_config[0]) + '</td>')
for test_data_gen_info in test_data_gen_configs:
dialog_id = self._ScoreStatsInspectorDialogId(
score_name, apm_config[0], test_data_gen_info[0],
test_data_gen_info[1])
html.append(
'<td onclick="openScoreStatsInspector(\'{}\')">{}</td>'.
format(
dialog_id,
self._BuildScoreTableCell(score_name,
test_data_gen_info[0],
test_data_gen_info[1],
apm_config[0])))
html.append('</tr>')
html.append('</tbody>')
# Body.
html.append('<tbody>')
for row, (capture, render) in enumerate(capture_render_pairs):
html.append('<tr><td><div>{}</div><div>{}</div></td>'.format(
capture, render))
for col, echo_simulator in enumerate(echo_simulators):
score_tuple = self._SliceDataForScoreStatsTableCell(
scores, capture, render, echo_simulator[0])
cell_class = 'r{}c{}'.format(row, col)
html.append('<td class="single-score-cell {}">{}</td>'.format(
cell_class, self._BuildScoreStatsInspectorTableCell(
score_tuple, anchor_data + (cell_class,))))
html.append('</tr>')
html.append('</tbody>')
html.append(
'</table></div><div class="mdl-layout-spacer"></div></div>')
html.append('</table>')
html.append(
self._BuildScoreStatsInspectorDialogs(score_name, apm_configs,
test_data_gen_configs,
anchor_data))
# Placeholder for the audio inspector.
html.append('<div class="audio-inspector-placeholder"></div>')
return self._NEW_LINE.join(html)
return self._NEW_LINE.join(html)
def _BuildScoreTableCell(self, score_name, test_data_gen,
test_data_gen_params, apm_config):
"""Builds the content of a table cell for a score table."""
scores = self._SliceDataForScoreTableCell(score_name, apm_config,
test_data_gen,
test_data_gen_params)
stats = self._ComputeScoreStats(scores)
def _BuildScoreStatsInspectorTableCell(self, score_tuple, anchor_data):
"""Builds the content of a cell of a score stats inspector."""
anchor = '&'.join(anchor_data)
html = [('<div class="v">{}</div>'
'<button class="mdl-button mdl-js-button mdl-button--icon"'
' data-anchor="{}">'
'<i class="material-icons mdl-color-text--blue-grey">link</i>'
'</button>').format(score_tuple.score, anchor)]
html = []
items_id_prefix = (score_name + test_data_gen + test_data_gen_params +
apm_config)
if stats['count'] == 1:
# Show the only available score.
item_id = hashlib.md5(items_id_prefix.encode('utf-8')).hexdigest()
html.append('<div id="single-value-{0}">{1:f}</div>'.format(
item_id, scores['score'].mean()))
html.append(
'<div class="mdl-tooltip" data-mdl-for="single-value-{}">{}'
'</div>'.format(item_id, 'single value'))
else:
# Show stats.
for stat_name in ['min', 'max', 'mean', 'std dev']:
item_id = hashlib.md5(
(items_id_prefix + stat_name).encode('utf-8')).hexdigest()
html.append('<div id="stats-{0}">{1:f}</div>'.format(
item_id, stats[stat_name]))
html.append(
'<div class="mdl-tooltip" data-mdl-for="stats-{}">{}'
'</div>'.format(item_id, stat_name))
# Add all the available file paths as hidden data.
for field_name in score_tuple.keys():
if field_name.endswith('_filepath'):
html.append('<input type="hidden" name="{}" value="{}">'.format(
field_name, score_tuple[field_name]))
return self._NEW_LINE.join(html)
return self._NEW_LINE.join(html)
def _BuildScoreStatsInspectorDialogs(self, score_name, apm_configs,
test_data_gen_configs, anchor_data):
"""Builds a set of score stats inspector dialogs."""
html = []
for apm_config in apm_configs:
for test_data_gen_info in test_data_gen_configs:
dialog_id = self._ScoreStatsInspectorDialogId(
score_name, apm_config[0], test_data_gen_info[0],
test_data_gen_info[1])
def _SliceDataForScoreTableCell(
self, score_name, apm_config, test_data_gen, test_data_gen_params):
"""Slices |self._scores_data_frame| to extract the data for a tab."""
masks = []
masks.append(self._scores_data_frame.eval_score_name == score_name)
masks.append(self._scores_data_frame.apm_config == apm_config)
masks.append(self._scores_data_frame.test_data_gen == test_data_gen)
masks.append(
self._scores_data_frame.test_data_gen_params == test_data_gen_params)
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
del masks
return self._scores_data_frame[mask]
html.append('<dialog class="mdl-dialog" id="{}" '
'style="width: 40%;">'.format(dialog_id))
@classmethod
def _SliceDataForScoreStatsTableCell(
cls, scores, capture, render, echo_simulator):
"""Slices |scores| to extract the data for a tab."""
masks = []
# Content.
html.append('<div class="mdl-dialog__content">')
html.append(
'<h6><strong>APM config preset</strong>: {}<br/>'
'<strong>Test data generator</strong>: {} ({})</h6>'.
format(self._FormatName(apm_config[0]),
self._FormatName(test_data_gen_info[0]),
test_data_gen_info[1]))
html.append(
self._BuildScoreStatsInspectorDialog(
score_name, apm_config[0], test_data_gen_info[0],
test_data_gen_info[1], anchor_data + (dialog_id, )))
html.append('</div>')
masks.append(scores.capture == capture)
masks.append(scores.render == render)
masks.append(scores.echo_simulator == echo_simulator)
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
del masks
# Actions.
html.append('<div class="mdl-dialog__actions">')
html.append('<button type="button" class="mdl-button" '
'onclick="closeScoreStatsInspector()">'
'Close</button>')
html.append('</div>')
sliced_data = scores[mask]
assert len(sliced_data) == 1, 'single score is expected'
return sliced_data.iloc[0]
html.append('</dialog>')
@classmethod
def _FindUniqueTuples(cls, data_frame, fields):
"""Slices |data_frame| to a list of fields and finds unique tuples."""
return data_frame[fields].drop_duplicates().values.tolist()
return self._NEW_LINE.join(html)
@classmethod
def _ComputeScoreStats(cls, data_frame):
"""Computes score stats."""
scores = data_frame['score']
return {
'count': scores.count(),
'min': scores.min(),
'max': scores.max(),
'mean': scores.mean(),
'std dev': scores.std(),
}
def _BuildScoreStatsInspectorDialog(self, score_name, apm_config,
test_data_gen, test_data_gen_params,
anchor_data):
"""Builds one score stats inspector dialog."""
scores = self._SliceDataForScoreTableCell(score_name, apm_config,
test_data_gen,
test_data_gen_params)
@classmethod
def _ScoreStatsInspectorDialogId(cls, score_name, apm_config, test_data_gen,
test_data_gen_params):
"""Assigns a unique name to a dialog."""
return 'score-stats-dialog-' + hashlib.md5(
'score-stats-inspector-{}-{}-{}-{}'.format(
score_name, apm_config, test_data_gen,
test_data_gen_params).encode('utf-8')).hexdigest()
capture_render_pairs = sorted(
self._FindUniqueTuples(scores, ['capture', 'render']))
echo_simulators = sorted(
self._FindUniqueTuples(scores, ['echo_simulator']))
@classmethod
def _Save(cls, output_filepath, html):
"""Writes the HTML file.
html = [
'<table class="mdl-data-table mdl-js-data-table mdl-shadow--2dp">'
]
# Header.
html.append('<thead><tr><th>Capture-Render / Echo simulator</th>')
for echo_simulator in echo_simulators:
html.append('<th>' + self._FormatName(echo_simulator[0]) + '</th>')
html.append('</tr></thead>')
# Body.
html.append('<tbody>')
for row, (capture, render) in enumerate(capture_render_pairs):
html.append('<tr><td><div>{}</div><div>{}</div></td>'.format(
capture, render))
for col, echo_simulator in enumerate(echo_simulators):
score_tuple = self._SliceDataForScoreStatsTableCell(
scores, capture, render, echo_simulator[0])
cell_class = 'r{}c{}'.format(row, col)
html.append('<td class="single-score-cell {}">{}</td>'.format(
cell_class,
self._BuildScoreStatsInspectorTableCell(
score_tuple, anchor_data + (cell_class, ))))
html.append('</tr>')
html.append('</tbody>')
html.append('</table>')
# Placeholder for the audio inspector.
html.append('<div class="audio-inspector-placeholder"></div>')
return self._NEW_LINE.join(html)
def _BuildScoreStatsInspectorTableCell(self, score_tuple, anchor_data):
"""Builds the content of a cell of a score stats inspector."""
anchor = '&'.join(anchor_data)
html = [('<div class="v">{}</div>'
'<button class="mdl-button mdl-js-button mdl-button--icon"'
' data-anchor="{}">'
'<i class="material-icons mdl-color-text--blue-grey">link</i>'
'</button>').format(score_tuple.score, anchor)]
# Add all the available file paths as hidden data.
for field_name in score_tuple.keys():
if field_name.endswith('_filepath'):
html.append(
'<input type="hidden" name="{}" value="{}">'.format(
field_name, score_tuple[field_name]))
return self._NEW_LINE.join(html)
def _SliceDataForScoreTableCell(self, score_name, apm_config,
test_data_gen, test_data_gen_params):
"""Slices |self._scores_data_frame| to extract the data for a tab."""
masks = []
masks.append(self._scores_data_frame.eval_score_name == score_name)
masks.append(self._scores_data_frame.apm_config == apm_config)
masks.append(self._scores_data_frame.test_data_gen == test_data_gen)
masks.append(self._scores_data_frame.test_data_gen_params ==
test_data_gen_params)
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
del masks
return self._scores_data_frame[mask]
@classmethod
def _SliceDataForScoreStatsTableCell(cls, scores, capture, render,
echo_simulator):
"""Slices |scores| to extract the data for a tab."""
masks = []
masks.append(scores.capture == capture)
masks.append(scores.render == render)
masks.append(scores.echo_simulator == echo_simulator)
mask = functools.reduce((lambda i1, i2: i1 & i2), masks)
del masks
sliced_data = scores[mask]
assert len(sliced_data) == 1, 'single score is expected'
return sliced_data.iloc[0]
@classmethod
def _FindUniqueTuples(cls, data_frame, fields):
"""Slices |data_frame| to a list of fields and finds unique tuples."""
return data_frame[fields].drop_duplicates().values.tolist()
@classmethod
def _ComputeScoreStats(cls, data_frame):
"""Computes score stats."""
scores = data_frame['score']
return {
'count': scores.count(),
'min': scores.min(),
'max': scores.max(),
'mean': scores.mean(),
'std dev': scores.std(),
}
@classmethod
def _ScoreStatsInspectorDialogId(cls, score_name, apm_config,
test_data_gen, test_data_gen_params):
"""Assigns a unique name to a dialog."""
return 'score-stats-dialog-' + hashlib.md5(
'score-stats-inspector-{}-{}-{}-{}'.format(
score_name, apm_config, test_data_gen,
test_data_gen_params).encode('utf-8')).hexdigest()
@classmethod
def _Save(cls, output_filepath, html):
"""Writes the HTML file.
Args:
output_filepath: output file path.
html: string with the HTML content.
"""
with open(output_filepath, 'w') as f:
f.write(html)
with open(output_filepath, 'w') as f:
f.write(html)
@classmethod
def _FormatName(cls, name):
"""Formats a name.
@classmethod
def _FormatName(cls, name):
"""Formats a name.
Args:
name: a string.
@ -399,4 +423,4 @@ class HtmlExport(object):
Returns:
A copy of name in which underscores and dashes are replaced with a space.
"""
return re.sub(r'[_\-]', ' ', name)
return re.sub(r'[_\-]', ' ', name)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the export module.
"""
@ -27,60 +26,61 @@ from . import test_data_generation_factory
class TestExport(unittest.TestCase):
"""Unit tests for the export module.
"""Unit tests for the export module.
"""
_CLEAN_TMP_OUTPUT = True
_CLEAN_TMP_OUTPUT = True
def setUp(self):
"""Creates temporary data to export."""
self._tmp_path = tempfile.mkdtemp()
def setUp(self):
"""Creates temporary data to export."""
self._tmp_path = tempfile.mkdtemp()
# Run a fake experiment to produce data to export.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)), 'fake_polqa'),
echo_metric_tool_bin_path=None
)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
simulator.Run(
config_filepaths=['apm_configs/default.json'],
capture_input_filepaths=[
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
os.path.join(self._tmp_path, 'pure_tone-880_1000.wav'),
],
test_data_generator_names=['identity', 'white_noise'],
eval_score_names=['audio_level_peak', 'audio_level_mean'],
output_dir=self._tmp_path)
# Run a fake experiment to produce data to export.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(os.path.abspath(__file__)),
'fake_polqa'),
echo_metric_tool_bin_path=None)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.
DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
simulator.Run(
config_filepaths=['apm_configs/default.json'],
capture_input_filepaths=[
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
os.path.join(self._tmp_path, 'pure_tone-880_1000.wav'),
],
test_data_generator_names=['identity', 'white_noise'],
eval_score_names=['audio_level_peak', 'audio_level_mean'],
output_dir=self._tmp_path)
# Export results.
p = collect_data.InstanceArgumentsParser()
args = p.parse_args(['--output_dir', self._tmp_path])
src_path = collect_data.ConstructSrcPath(args)
self._data_to_export = collect_data.FindScores(src_path, args)
# Export results.
p = collect_data.InstanceArgumentsParser()
args = p.parse_args(['--output_dir', self._tmp_path])
src_path = collect_data.ConstructSrcPath(args)
self._data_to_export = collect_data.FindScores(src_path, args)
def tearDown(self):
"""Recursively deletes temporary folders."""
if self._CLEAN_TMP_OUTPUT:
shutil.rmtree(self._tmp_path)
else:
logging.warning(self.id() + ' did not clean the temporary path ' + (
self._tmp_path))
def tearDown(self):
"""Recursively deletes temporary folders."""
if self._CLEAN_TMP_OUTPUT:
shutil.rmtree(self._tmp_path)
else:
logging.warning(self.id() + ' did not clean the temporary path ' +
(self._tmp_path))
def testCreateHtmlReport(self):
fn_out = os.path.join(self._tmp_path, 'results.html')
exporter = export.HtmlExport(fn_out)
exporter.Export(self._data_to_export)
def testCreateHtmlReport(self):
fn_out = os.path.join(self._tmp_path, 'results.html')
exporter = export.HtmlExport(fn_out)
exporter.Export(self._data_to_export)
document = pq.PyQuery(filename=fn_out)
self.assertIsInstance(document, pq.PyQuery)
# TODO(alessiob): Use PyQuery API to check the HTML file.
document = pq.PyQuery(filename=fn_out)
self.assertIsInstance(document, pq.PyQuery)
# TODO(alessiob): Use PyQuery API to check the HTML file.

View File

@ -16,62 +16,60 @@ import sys
import tempfile
try:
import numpy as np
import numpy as np
except ImportError:
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
from . import signal_processing
class ExternalVad(object):
def __init__(self, path_to_binary, name):
"""Args:
class ExternalVad(object):
def __init__(self, path_to_binary, name):
"""Args:
path_to_binary: path to binary that accepts '-i <wav>', '-o
<float probabilities>'. There must be one float value per
10ms audio
name: a name to identify the external VAD. Used for saving
the output as extvad_output-<name>.
"""
self._path_to_binary = path_to_binary
self.name = name
assert os.path.exists(self._path_to_binary), (
self._path_to_binary)
self._vad_output = None
self._path_to_binary = path_to_binary
self.name = name
assert os.path.exists(self._path_to_binary), (self._path_to_binary)
self._vad_output = None
def Run(self, wav_file_path):
_signal = signal_processing.SignalProcessingUtils.LoadWav(wav_file_path)
if _signal.channels != 1:
raise NotImplementedError('Multiple-channel'
' annotations not implemented')
if _signal.frame_rate != 48000:
raise NotImplementedError('Frame rates '
'other than 48000 not implemented')
def Run(self, wav_file_path):
_signal = signal_processing.SignalProcessingUtils.LoadWav(
wav_file_path)
if _signal.channels != 1:
raise NotImplementedError('Multiple-channel'
' annotations not implemented')
if _signal.frame_rate != 48000:
raise NotImplementedError('Frame rates '
'other than 48000 not implemented')
tmp_path = tempfile.mkdtemp()
try:
output_file_path = os.path.join(
tmp_path, self.name + '_vad.tmp')
subprocess.call([
self._path_to_binary,
'-i', wav_file_path,
'-o', output_file_path
])
self._vad_output = np.fromfile(output_file_path, np.float32)
except Exception as e:
logging.error('Error while running the ' + self.name +
' VAD (' + e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
tmp_path = tempfile.mkdtemp()
try:
output_file_path = os.path.join(tmp_path, self.name + '_vad.tmp')
subprocess.call([
self._path_to_binary, '-i', wav_file_path, '-o',
output_file_path
])
self._vad_output = np.fromfile(output_file_path, np.float32)
except Exception as e:
logging.error('Error while running the ' + self.name + ' VAD (' +
e.message + ')')
finally:
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
def GetVadOutput(self):
assert self._vad_output is not None
return self._vad_output
def GetVadOutput(self):
assert self._vad_output is not None
return self._vad_output
@classmethod
def ConstructVadDict(cls, vad_paths, vad_names):
external_vads = {}
for path, name in zip(vad_paths, vad_names):
external_vads[name] = ExternalVad(path, name)
return external_vads
@classmethod
def ConstructVadDict(cls, vad_paths, vad_names):
external_vads = {}
for path, name in zip(vad_paths, vad_names):
external_vads[name] = ExternalVad(path, name)
return external_vads

View File

@ -9,16 +9,17 @@
import argparse
import numpy as np
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-i', required=True)
parser.add_argument('-o', required=True)
parser = argparse.ArgumentParser()
parser.add_argument('-i', required=True)
parser.add_argument('-o', required=True)
args = parser.parse_args()
args = parser.parse_args()
array = np.arange(100, dtype=np.float32)
array.tofile(open(args.o, 'w'))
array = np.arange(100, dtype=np.float32)
array.tofile(open(args.o, 'w'))
if __name__ == '__main__':
main()
main()

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Input mixer module.
"""
@ -17,24 +16,24 @@ from . import signal_processing
class ApmInputMixer(object):
"""Class to mix a set of audio segments down to the APM input."""
"""Class to mix a set of audio segments down to the APM input."""
_HARD_CLIPPING_LOG_MSG = 'hard clipping detected in the mixed signal'
_HARD_CLIPPING_LOG_MSG = 'hard clipping detected in the mixed signal'
def __init__(self):
pass
def __init__(self):
pass
@classmethod
def HardClippingLogMessage(cls):
"""Returns the log message used when hard clipping is detected in the mix.
@classmethod
def HardClippingLogMessage(cls):
"""Returns the log message used when hard clipping is detected in the mix.
This method is mainly intended to be used by the unit tests.
"""
return cls._HARD_CLIPPING_LOG_MSG
return cls._HARD_CLIPPING_LOG_MSG
@classmethod
def Mix(cls, output_path, capture_input_filepath, echo_filepath):
"""Mixes capture and echo.
@classmethod
def Mix(cls, output_path, capture_input_filepath, echo_filepath):
"""Mixes capture and echo.
Creates the overall capture input for APM by mixing the "echo-free" capture
signal with the echo signal (e.g., echo simulated via the
@ -58,38 +57,41 @@ class ApmInputMixer(object):
Returns:
Path to the mix audio track file.
"""
if echo_filepath is None:
return capture_input_filepath
if echo_filepath is None:
return capture_input_filepath
# Build the mix output file name as a function of the echo file name.
# This ensures that if the internal parameters of the echo path simulator
# change, no erroneous cache hit occurs.
echo_file_name, _ = os.path.splitext(os.path.split(echo_filepath)[1])
capture_input_file_name, _ = os.path.splitext(
os.path.split(capture_input_filepath)[1])
mix_filepath = os.path.join(output_path, 'mix_capture_{}_{}.wav'.format(
capture_input_file_name, echo_file_name))
# Build the mix output file name as a function of the echo file name.
# This ensures that if the internal parameters of the echo path simulator
# change, no erroneous cache hit occurs.
echo_file_name, _ = os.path.splitext(os.path.split(echo_filepath)[1])
capture_input_file_name, _ = os.path.splitext(
os.path.split(capture_input_filepath)[1])
mix_filepath = os.path.join(
output_path,
'mix_capture_{}_{}.wav'.format(capture_input_file_name,
echo_file_name))
# Create the mix if not done yet.
mix = None
if not os.path.exists(mix_filepath):
echo_free_capture = signal_processing.SignalProcessingUtils.LoadWav(
capture_input_filepath)
echo = signal_processing.SignalProcessingUtils.LoadWav(echo_filepath)
# Create the mix if not done yet.
mix = None
if not os.path.exists(mix_filepath):
echo_free_capture = signal_processing.SignalProcessingUtils.LoadWav(
capture_input_filepath)
echo = signal_processing.SignalProcessingUtils.LoadWav(
echo_filepath)
if signal_processing.SignalProcessingUtils.CountSamples(echo) < (
signal_processing.SignalProcessingUtils.CountSamples(
echo_free_capture)):
raise exceptions.InputMixerException(
'echo cannot be shorter than capture')
if signal_processing.SignalProcessingUtils.CountSamples(echo) < (
signal_processing.SignalProcessingUtils.CountSamples(
echo_free_capture)):
raise exceptions.InputMixerException(
'echo cannot be shorter than capture')
mix = echo_free_capture.overlay(echo)
signal_processing.SignalProcessingUtils.SaveWav(mix_filepath, mix)
mix = echo_free_capture.overlay(echo)
signal_processing.SignalProcessingUtils.SaveWav(mix_filepath, mix)
# Check if hard clipping occurs.
if mix is None:
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
if signal_processing.SignalProcessingUtils.DetectHardClipping(mix):
logging.warning(cls._HARD_CLIPPING_LOG_MSG)
# Check if hard clipping occurs.
if mix is None:
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
if signal_processing.SignalProcessingUtils.DetectHardClipping(mix):
logging.warning(cls._HARD_CLIPPING_LOG_MSG)
return mix_filepath
return mix_filepath

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the input mixer module.
"""
@ -23,122 +22,119 @@ from . import signal_processing
class TestApmInputMixer(unittest.TestCase):
"""Unit tests for the ApmInputMixer class.
"""Unit tests for the ApmInputMixer class.
"""
# Audio track file names created in setUp().
_FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer']
# Audio track file names created in setUp().
_FILENAMES = ['capture', 'echo_1', 'echo_2', 'shorter', 'longer']
# Target peak power level (dBFS) of each audio track file created in setUp().
# These values are hand-crafted in order to make saturation happen when
# capture and echo_2 are mixed and the contrary for capture and echo_1.
# None means that the power is not changed.
_MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None]
# Target peak power level (dBFS) of each audio track file created in setUp().
# These values are hand-crafted in order to make saturation happen when
# capture and echo_2 are mixed and the contrary for capture and echo_1.
# None means that the power is not changed.
_MAX_PEAK_POWER_LEVELS = [-10.0, -5.0, 0.0, None, None]
# Audio track file durations in milliseconds.
_DURATIONS = [1000, 1000, 1000, 800, 1200]
# Audio track file durations in milliseconds.
_DURATIONS = [1000, 1000, 1000, 800, 1200]
_SAMPLE_RATE = 48000
_SAMPLE_RATE = 48000
def setUp(self):
"""Creates temporary data."""
self._tmp_path = tempfile.mkdtemp()
def setUp(self):
"""Creates temporary data."""
self._tmp_path = tempfile.mkdtemp()
# Create audio track files.
self._audio_tracks = {}
for filename, peak_power, duration in zip(
self._FILENAMES, self._MAX_PEAK_POWER_LEVELS, self._DURATIONS):
audio_track_filepath = os.path.join(self._tmp_path, '{}.wav'.format(
filename))
# Create audio track files.
self._audio_tracks = {}
for filename, peak_power, duration in zip(self._FILENAMES,
self._MAX_PEAK_POWER_LEVELS,
self._DURATIONS):
audio_track_filepath = os.path.join(self._tmp_path,
'{}.wav'.format(filename))
# Create a pure tone with the target peak power level.
template = signal_processing.SignalProcessingUtils.GenerateSilence(
duration=duration, sample_rate=self._SAMPLE_RATE)
signal = signal_processing.SignalProcessingUtils.GeneratePureTone(
template)
if peak_power is not None:
signal = signal.apply_gain(-signal.max_dBFS + peak_power)
# Create a pure tone with the target peak power level.
template = signal_processing.SignalProcessingUtils.GenerateSilence(
duration=duration, sample_rate=self._SAMPLE_RATE)
signal = signal_processing.SignalProcessingUtils.GeneratePureTone(
template)
if peak_power is not None:
signal = signal.apply_gain(-signal.max_dBFS + peak_power)
signal_processing.SignalProcessingUtils.SaveWav(
audio_track_filepath, signal)
self._audio_tracks[filename] = {
'filepath': audio_track_filepath,
'num_samples': signal_processing.SignalProcessingUtils.CountSamples(
signal)
}
signal_processing.SignalProcessingUtils.SaveWav(
audio_track_filepath, signal)
self._audio_tracks[filename] = {
'filepath':
audio_track_filepath,
'num_samples':
signal_processing.SignalProcessingUtils.CountSamples(signal)
}
def tearDown(self):
"""Recursively deletes temporary folders."""
shutil.rmtree(self._tmp_path)
def tearDown(self):
"""Recursively deletes temporary folders."""
shutil.rmtree(self._tmp_path)
def testCheckMixSameDuration(self):
"""Checks the duration when mixing capture and echo with same duration."""
mix_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path,
self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_1']['filepath'])
self.assertTrue(os.path.exists(mix_filepath))
def testCheckMixSameDuration(self):
"""Checks the duration when mixing capture and echo with same duration."""
mix_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_1']['filepath'])
self.assertTrue(os.path.exists(mix_filepath))
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
self.assertEqual(self._audio_tracks['capture']['num_samples'],
signal_processing.SignalProcessingUtils.CountSamples(mix))
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
self.assertEqual(
self._audio_tracks['capture']['num_samples'],
signal_processing.SignalProcessingUtils.CountSamples(mix))
def testRejectShorterEcho(self):
"""Rejects echo signals that are shorter than the capture signal."""
try:
_ = input_mixer.ApmInputMixer.Mix(
self._tmp_path,
self._audio_tracks['capture']['filepath'],
self._audio_tracks['shorter']['filepath'])
self.fail('no exception raised')
except exceptions.InputMixerException:
pass
def testRejectShorterEcho(self):
"""Rejects echo signals that are shorter than the capture signal."""
try:
_ = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['shorter']['filepath'])
self.fail('no exception raised')
except exceptions.InputMixerException:
pass
def testCheckMixDurationWithLongerEcho(self):
"""Checks the duration when mixing an echo longer than the capture."""
mix_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path,
self._audio_tracks['capture']['filepath'],
self._audio_tracks['longer']['filepath'])
self.assertTrue(os.path.exists(mix_filepath))
def testCheckMixDurationWithLongerEcho(self):
"""Checks the duration when mixing an echo longer than the capture."""
mix_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['longer']['filepath'])
self.assertTrue(os.path.exists(mix_filepath))
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
self.assertEqual(self._audio_tracks['capture']['num_samples'],
signal_processing.SignalProcessingUtils.CountSamples(mix))
mix = signal_processing.SignalProcessingUtils.LoadWav(mix_filepath)
self.assertEqual(
self._audio_tracks['capture']['num_samples'],
signal_processing.SignalProcessingUtils.CountSamples(mix))
def testCheckOutputFileNamesConflict(self):
"""Checks that different echo files lead to different output file names."""
mix1_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path,
self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_1']['filepath'])
self.assertTrue(os.path.exists(mix1_filepath))
def testCheckOutputFileNamesConflict(self):
"""Checks that different echo files lead to different output file names."""
mix1_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_1']['filepath'])
self.assertTrue(os.path.exists(mix1_filepath))
mix2_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path,
self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_2']['filepath'])
self.assertTrue(os.path.exists(mix2_filepath))
mix2_filepath = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_2']['filepath'])
self.assertTrue(os.path.exists(mix2_filepath))
self.assertNotEqual(mix1_filepath, mix2_filepath)
self.assertNotEqual(mix1_filepath, mix2_filepath)
def testHardClippingLogExpected(self):
"""Checks that hard clipping warning is raised when occurring."""
logging.warning = mock.MagicMock(name='warning')
_ = input_mixer.ApmInputMixer.Mix(
self._tmp_path,
self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_2']['filepath'])
logging.warning.assert_called_once_with(
input_mixer.ApmInputMixer.HardClippingLogMessage())
def testHardClippingLogExpected(self):
"""Checks that hard clipping warning is raised when occurring."""
logging.warning = mock.MagicMock(name='warning')
_ = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_2']['filepath'])
logging.warning.assert_called_once_with(
input_mixer.ApmInputMixer.HardClippingLogMessage())
def testHardClippingLogNotExpected(self):
"""Checks that hard clipping warning is not raised when not occurring."""
logging.warning = mock.MagicMock(name='warning')
_ = input_mixer.ApmInputMixer.Mix(
self._tmp_path,
self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_1']['filepath'])
self.assertNotIn(
mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()),
logging.warning.call_args_list)
def testHardClippingLogNotExpected(self):
"""Checks that hard clipping warning is not raised when not occurring."""
logging.warning = mock.MagicMock(name='warning')
_ = input_mixer.ApmInputMixer.Mix(
self._tmp_path, self._audio_tracks['capture']['filepath'],
self._audio_tracks['echo_1']['filepath'])
self.assertNotIn(
mock.call(input_mixer.ApmInputMixer.HardClippingLogMessage()),
logging.warning.call_args_list)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Input signal creator module.
"""
@ -14,12 +13,12 @@ from . import signal_processing
class InputSignalCreator(object):
"""Input signal creator class.
"""Input signal creator class.
"""
@classmethod
def Create(cls, name, raw_params):
"""Creates a input signal and its metadata.
@classmethod
def Create(cls, name, raw_params):
"""Creates a input signal and its metadata.
Args:
name: Input signal creator name.
@ -28,29 +27,30 @@ class InputSignalCreator(object):
Returns:
(AudioSegment, dict) tuple.
"""
try:
signal = {}
params = {}
try:
signal = {}
params = {}
if name == 'pure_tone':
params['frequency'] = float(raw_params[0])
params['duration'] = int(raw_params[1])
signal = cls._CreatePureTone(params['frequency'], params['duration'])
else:
raise exceptions.InputSignalCreatorException(
'Invalid input signal creator name')
if name == 'pure_tone':
params['frequency'] = float(raw_params[0])
params['duration'] = int(raw_params[1])
signal = cls._CreatePureTone(params['frequency'],
params['duration'])
else:
raise exceptions.InputSignalCreatorException(
'Invalid input signal creator name')
# Complete metadata.
params['signal'] = name
# Complete metadata.
params['signal'] = name
return signal, params
except (TypeError, AssertionError) as e:
raise exceptions.InputSignalCreatorException(
'Invalid signal creator parameters: {}'.format(e))
return signal, params
except (TypeError, AssertionError) as e:
raise exceptions.InputSignalCreatorException(
'Invalid signal creator parameters: {}'.format(e))
@classmethod
def _CreatePureTone(cls, frequency, duration):
"""
@classmethod
def _CreatePureTone(cls, frequency, duration):
"""
Generates a pure tone at 48000 Hz.
Args:
@ -60,8 +60,9 @@ class InputSignalCreator(object):
Returns:
AudioSegment instance.
"""
assert 0 < frequency <= 24000
assert duration > 0
template = signal_processing.SignalProcessingUtils.GenerateSilence(duration)
return signal_processing.SignalProcessingUtils.GeneratePureTone(
template, frequency)
assert 0 < frequency <= 24000
assert duration > 0
template = signal_processing.SignalProcessingUtils.GenerateSilence(
duration)
return signal_processing.SignalProcessingUtils.GeneratePureTone(
template, frequency)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Signal processing utility module.
"""
@ -16,44 +15,44 @@ import sys
import enum
try:
import numpy as np
import numpy as np
except ImportError:
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
logging.critical('Cannot import the third-party Python package numpy')
sys.exit(1)
try:
import pydub
import pydub.generators
import pydub
import pydub.generators
except ImportError:
logging.critical('Cannot import the third-party Python package pydub')
sys.exit(1)
logging.critical('Cannot import the third-party Python package pydub')
sys.exit(1)
try:
import scipy.signal
import scipy.fftpack
import scipy.signal
import scipy.fftpack
except ImportError:
logging.critical('Cannot import the third-party Python package scipy')
sys.exit(1)
logging.critical('Cannot import the third-party Python package scipy')
sys.exit(1)
from . import exceptions
class SignalProcessingUtils(object):
"""Collection of signal processing utilities.
"""Collection of signal processing utilities.
"""
@enum.unique
class MixPadding(enum.Enum):
NO_PADDING = 0
ZERO_PADDING = 1
LOOP = 2
@enum.unique
class MixPadding(enum.Enum):
NO_PADDING = 0
ZERO_PADDING = 1
LOOP = 2
def __init__(self):
pass
def __init__(self):
pass
@classmethod
def LoadWav(cls, filepath, channels=1):
"""Loads wav file.
@classmethod
def LoadWav(cls, filepath, channels=1):
"""Loads wav file.
Args:
filepath: path to the wav audio track file to load.
@ -62,25 +61,26 @@ class SignalProcessingUtils(object):
Returns:
AudioSegment instance.
"""
if not os.path.exists(filepath):
logging.error('cannot find the <%s> audio track file', filepath)
raise exceptions.FileNotFoundError()
return pydub.AudioSegment.from_file(
filepath, format='wav', channels=channels)
if not os.path.exists(filepath):
logging.error('cannot find the <%s> audio track file', filepath)
raise exceptions.FileNotFoundError()
return pydub.AudioSegment.from_file(filepath,
format='wav',
channels=channels)
@classmethod
def SaveWav(cls, output_filepath, signal):
"""Saves wav file.
@classmethod
def SaveWav(cls, output_filepath, signal):
"""Saves wav file.
Args:
output_filepath: path to the wav audio track file to save.
signal: AudioSegment instance.
"""
return signal.export(output_filepath, format='wav')
return signal.export(output_filepath, format='wav')
@classmethod
def CountSamples(cls, signal):
"""Number of samples per channel.
@classmethod
def CountSamples(cls, signal):
"""Number of samples per channel.
Args:
signal: AudioSegment instance.
@ -88,14 +88,14 @@ class SignalProcessingUtils(object):
Returns:
An integer.
"""
number_of_samples = len(signal.get_array_of_samples())
assert signal.channels > 0
assert number_of_samples % signal.channels == 0
return number_of_samples / signal.channels
number_of_samples = len(signal.get_array_of_samples())
assert signal.channels > 0
assert number_of_samples % signal.channels == 0
return number_of_samples / signal.channels
@classmethod
def GenerateSilence(cls, duration=1000, sample_rate=48000):
"""Generates silence.
@classmethod
def GenerateSilence(cls, duration=1000, sample_rate=48000):
"""Generates silence.
This method can also be used to create a template AudioSegment instance.
A template can then be used with other Generate*() methods accepting an
@ -108,11 +108,11 @@ class SignalProcessingUtils(object):
Returns:
AudioSegment instance.
"""
return pydub.AudioSegment.silent(duration, sample_rate)
return pydub.AudioSegment.silent(duration, sample_rate)
@classmethod
def GeneratePureTone(cls, template, frequency=440.0):
"""Generates a pure tone.
@classmethod
def GeneratePureTone(cls, template, frequency=440.0):
"""Generates a pure tone.
The pure tone is generated with the same duration and in the same format of
the given template signal.
@ -124,21 +124,18 @@ class SignalProcessingUtils(object):
Return:
AudioSegment instance.
"""
if frequency > template.frame_rate >> 1:
raise exceptions.SignalProcessingException('Invalid frequency')
if frequency > template.frame_rate >> 1:
raise exceptions.SignalProcessingException('Invalid frequency')
generator = pydub.generators.Sine(
sample_rate=template.frame_rate,
bit_depth=template.sample_width * 8,
freq=frequency)
generator = pydub.generators.Sine(sample_rate=template.frame_rate,
bit_depth=template.sample_width * 8,
freq=frequency)
return generator.to_audio_segment(
duration=len(template),
volume=0.0)
return generator.to_audio_segment(duration=len(template), volume=0.0)
@classmethod
def GenerateWhiteNoise(cls, template):
"""Generates white noise.
@classmethod
def GenerateWhiteNoise(cls, template):
"""Generates white noise.
The white noise is generated with the same duration and in the same format
of the given template signal.
@ -149,33 +146,32 @@ class SignalProcessingUtils(object):
Return:
AudioSegment instance.
"""
generator = pydub.generators.WhiteNoise(
sample_rate=template.frame_rate,
bit_depth=template.sample_width * 8)
return generator.to_audio_segment(
duration=len(template),
volume=0.0)
generator = pydub.generators.WhiteNoise(
sample_rate=template.frame_rate,
bit_depth=template.sample_width * 8)
return generator.to_audio_segment(duration=len(template), volume=0.0)
@classmethod
def AudioSegmentToRawData(cls, signal):
samples = signal.get_array_of_samples()
if samples.typecode != 'h':
raise exceptions.SignalProcessingException('Unsupported samples type')
return np.array(signal.get_array_of_samples(), np.int16)
@classmethod
def AudioSegmentToRawData(cls, signal):
samples = signal.get_array_of_samples()
if samples.typecode != 'h':
raise exceptions.SignalProcessingException(
'Unsupported samples type')
return np.array(signal.get_array_of_samples(), np.int16)
@classmethod
def Fft(cls, signal, normalize=True):
if signal.channels != 1:
raise NotImplementedError('multiple-channel FFT not implemented')
x = cls.AudioSegmentToRawData(signal).astype(np.float32)
if normalize:
x /= max(abs(np.max(x)), 1.0)
y = scipy.fftpack.fft(x)
return y[:len(y) / 2]
@classmethod
def Fft(cls, signal, normalize=True):
if signal.channels != 1:
raise NotImplementedError('multiple-channel FFT not implemented')
x = cls.AudioSegmentToRawData(signal).astype(np.float32)
if normalize:
x /= max(abs(np.max(x)), 1.0)
y = scipy.fftpack.fft(x)
return y[:len(y) / 2]
@classmethod
def DetectHardClipping(cls, signal, threshold=2):
"""Detects hard clipping.
@classmethod
def DetectHardClipping(cls, signal, threshold=2):
"""Detects hard clipping.
Hard clipping is simply detected by counting samples that touch either the
lower or upper bound too many times in a row (according to |threshold|).
@ -189,32 +185,33 @@ class SignalProcessingUtils(object):
Returns:
True if hard clipping is detect, False otherwise.
"""
if signal.channels != 1:
raise NotImplementedError('multiple-channel clipping not implemented')
if signal.sample_width != 2: # Note that signal.sample_width is in bytes.
raise exceptions.SignalProcessingException(
'hard-clipping detection only supported for 16 bit samples')
samples = cls.AudioSegmentToRawData(signal)
if signal.channels != 1:
raise NotImplementedError(
'multiple-channel clipping not implemented')
if signal.sample_width != 2: # Note that signal.sample_width is in bytes.
raise exceptions.SignalProcessingException(
'hard-clipping detection only supported for 16 bit samples')
samples = cls.AudioSegmentToRawData(signal)
# Detect adjacent clipped samples.
samples_type_info = np.iinfo(samples.dtype)
mask_min = samples == samples_type_info.min
mask_max = samples == samples_type_info.max
# Detect adjacent clipped samples.
samples_type_info = np.iinfo(samples.dtype)
mask_min = samples == samples_type_info.min
mask_max = samples == samples_type_info.max
def HasLongSequence(vector, min_legth=threshold):
"""Returns True if there are one or more long sequences of True flags."""
seq_length = 0
for b in vector:
seq_length = seq_length + 1 if b else 0
if seq_length >= min_legth:
return True
return False
def HasLongSequence(vector, min_legth=threshold):
"""Returns True if there are one or more long sequences of True flags."""
seq_length = 0
for b in vector:
seq_length = seq_length + 1 if b else 0
if seq_length >= min_legth:
return True
return False
return HasLongSequence(mask_min) or HasLongSequence(mask_max)
return HasLongSequence(mask_min) or HasLongSequence(mask_max)
@classmethod
def ApplyImpulseResponse(cls, signal, impulse_response):
"""Applies an impulse response to a signal.
@classmethod
def ApplyImpulseResponse(cls, signal, impulse_response):
"""Applies an impulse response to a signal.
Args:
signal: AudioSegment instance.
@ -223,44 +220,48 @@ class SignalProcessingUtils(object):
Returns:
AudioSegment instance.
"""
# Get samples.
assert signal.channels == 1, (
'multiple-channel recordings not supported')
samples = signal.get_array_of_samples()
# Get samples.
assert signal.channels == 1, (
'multiple-channel recordings not supported')
samples = signal.get_array_of_samples()
# Convolve.
logging.info('applying %d order impulse response to a signal lasting %d ms',
len(impulse_response), len(signal))
convolved_samples = scipy.signal.fftconvolve(
in1=samples,
in2=impulse_response,
mode='full').astype(np.int16)
logging.info('convolution computed')
# Convolve.
logging.info(
'applying %d order impulse response to a signal lasting %d ms',
len(impulse_response), len(signal))
convolved_samples = scipy.signal.fftconvolve(in1=samples,
in2=impulse_response,
mode='full').astype(
np.int16)
logging.info('convolution computed')
# Cast.
convolved_samples = array.array(signal.array_type, convolved_samples)
# Cast.
convolved_samples = array.array(signal.array_type, convolved_samples)
# Verify.
logging.debug('signal length: %d samples', len(samples))
logging.debug('convolved signal length: %d samples', len(convolved_samples))
assert len(convolved_samples) > len(samples)
# Verify.
logging.debug('signal length: %d samples', len(samples))
logging.debug('convolved signal length: %d samples',
len(convolved_samples))
assert len(convolved_samples) > len(samples)
# Generate convolved signal AudioSegment instance.
convolved_signal = pydub.AudioSegment(
data=convolved_samples,
metadata={
'sample_width': signal.sample_width,
'frame_rate': signal.frame_rate,
'frame_width': signal.frame_width,
'channels': signal.channels,
})
assert len(convolved_signal) > len(signal)
# Generate convolved signal AudioSegment instance.
convolved_signal = pydub.AudioSegment(data=convolved_samples,
metadata={
'sample_width':
signal.sample_width,
'frame_rate':
signal.frame_rate,
'frame_width':
signal.frame_width,
'channels': signal.channels,
})
assert len(convolved_signal) > len(signal)
return convolved_signal
return convolved_signal
@classmethod
def Normalize(cls, signal):
"""Normalizes a signal.
@classmethod
def Normalize(cls, signal):
"""Normalizes a signal.
Args:
signal: AudioSegment instance.
@ -268,11 +269,11 @@ class SignalProcessingUtils(object):
Returns:
An AudioSegment instance.
"""
return signal.apply_gain(-signal.max_dBFS)
return signal.apply_gain(-signal.max_dBFS)
@classmethod
def Copy(cls, signal):
"""Makes a copy os a signal.
@classmethod
def Copy(cls, signal):
"""Makes a copy os a signal.
Args:
signal: AudioSegment instance.
@ -280,19 +281,21 @@ class SignalProcessingUtils(object):
Returns:
An AudioSegment instance.
"""
return pydub.AudioSegment(
data=signal.get_array_of_samples(),
metadata={
'sample_width': signal.sample_width,
'frame_rate': signal.frame_rate,
'frame_width': signal.frame_width,
'channels': signal.channels,
})
return pydub.AudioSegment(data=signal.get_array_of_samples(),
metadata={
'sample_width': signal.sample_width,
'frame_rate': signal.frame_rate,
'frame_width': signal.frame_width,
'channels': signal.channels,
})
@classmethod
def MixSignals(cls, signal, noise, target_snr=0.0,
pad_noise=MixPadding.NO_PADDING):
"""Mixes |signal| and |noise| with a target SNR.
@classmethod
def MixSignals(cls,
signal,
noise,
target_snr=0.0,
pad_noise=MixPadding.NO_PADDING):
"""Mixes |signal| and |noise| with a target SNR.
Mix |signal| and |noise| with a desired SNR by scaling |noise|.
If the target SNR is +/- infinite, a copy of signal/noise is returned.
@ -312,45 +315,45 @@ class SignalProcessingUtils(object):
Returns:
An AudioSegment instance.
"""
# Handle infinite target SNR.
if target_snr == -np.Inf:
# Return a copy of noise.
logging.warning('SNR = -Inf, returning noise')
return cls.Copy(noise)
elif target_snr == np.Inf:
# Return a copy of signal.
logging.warning('SNR = +Inf, returning signal')
return cls.Copy(signal)
# Handle infinite target SNR.
if target_snr == -np.Inf:
# Return a copy of noise.
logging.warning('SNR = -Inf, returning noise')
return cls.Copy(noise)
elif target_snr == np.Inf:
# Return a copy of signal.
logging.warning('SNR = +Inf, returning signal')
return cls.Copy(signal)
# Check signal and noise power.
signal_power = float(signal.dBFS)
noise_power = float(noise.dBFS)
if signal_power == -np.Inf:
logging.error('signal has -Inf power, cannot mix')
raise exceptions.SignalProcessingException(
'cannot mix a signal with -Inf power')
if noise_power == -np.Inf:
logging.error('noise has -Inf power, cannot mix')
raise exceptions.SignalProcessingException(
'cannot mix a signal with -Inf power')
# Check signal and noise power.
signal_power = float(signal.dBFS)
noise_power = float(noise.dBFS)
if signal_power == -np.Inf:
logging.error('signal has -Inf power, cannot mix')
raise exceptions.SignalProcessingException(
'cannot mix a signal with -Inf power')
if noise_power == -np.Inf:
logging.error('noise has -Inf power, cannot mix')
raise exceptions.SignalProcessingException(
'cannot mix a signal with -Inf power')
# Mix.
gain_db = signal_power - noise_power - target_snr
signal_duration = len(signal)
noise_duration = len(noise)
if signal_duration <= noise_duration:
# Ignore |pad_noise|, |noise| is truncated if longer that |signal|, the
# mix will have the same length of |signal|.
return signal.overlay(noise.apply_gain(gain_db))
elif pad_noise == cls.MixPadding.NO_PADDING:
# |signal| is longer than |noise|, but no padding is applied to |noise|.
# Truncate |signal|.
return noise.overlay(signal, gain_during_overlay=gain_db)
elif pad_noise == cls.MixPadding.ZERO_PADDING:
# TODO(alessiob): Check that this works as expected.
return signal.overlay(noise.apply_gain(gain_db))
elif pad_noise == cls.MixPadding.LOOP:
# |signal| is longer than |noise|, extend |noise| by looping.
return signal.overlay(noise.apply_gain(gain_db), loop=True)
else:
raise exceptions.SignalProcessingException('invalid padding type')
# Mix.
gain_db = signal_power - noise_power - target_snr
signal_duration = len(signal)
noise_duration = len(noise)
if signal_duration <= noise_duration:
# Ignore |pad_noise|, |noise| is truncated if longer that |signal|, the
# mix will have the same length of |signal|.
return signal.overlay(noise.apply_gain(gain_db))
elif pad_noise == cls.MixPadding.NO_PADDING:
# |signal| is longer than |noise|, but no padding is applied to |noise|.
# Truncate |signal|.
return noise.overlay(signal, gain_during_overlay=gain_db)
elif pad_noise == cls.MixPadding.ZERO_PADDING:
# TODO(alessiob): Check that this works as expected.
return signal.overlay(noise.apply_gain(gain_db))
elif pad_noise == cls.MixPadding.LOOP:
# |signal| is longer than |noise|, extend |noise| by looping.
return signal.overlay(noise.apply_gain(gain_db), loop=True)
else:
raise exceptions.SignalProcessingException('invalid padding type')

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the signal_processing module.
"""
@ -19,168 +18,166 @@ from . import signal_processing
class TestSignalProcessing(unittest.TestCase):
"""Unit tests for the signal_processing module.
"""Unit tests for the signal_processing module.
"""
def testMixSignals(self):
# Generate a template signal with which white noise can be generated.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
def testMixSignals(self):
# Generate a template signal with which white noise can be generated.
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
# Generate two distinct AudioSegment instances with 1 second of white noise.
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
# Generate two distinct AudioSegment instances with 1 second of white noise.
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
noise = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
# Extract samples.
signal_samples = signal.get_array_of_samples()
noise_samples = noise.get_array_of_samples()
# Extract samples.
signal_samples = signal.get_array_of_samples()
noise_samples = noise.get_array_of_samples()
# Test target SNR -Inf (noise expected).
mix_neg_inf = signal_processing.SignalProcessingUtils.MixSignals(
signal, noise, -np.Inf)
self.assertTrue(len(noise), len(mix_neg_inf)) # Check duration.
mix_neg_inf_samples = mix_neg_inf.get_array_of_samples()
self.assertTrue( # Check samples.
all([x == y for x, y in zip(noise_samples, mix_neg_inf_samples)]))
# Test target SNR -Inf (noise expected).
mix_neg_inf = signal_processing.SignalProcessingUtils.MixSignals(
signal, noise, -np.Inf)
self.assertTrue(len(noise), len(mix_neg_inf)) # Check duration.
mix_neg_inf_samples = mix_neg_inf.get_array_of_samples()
self.assertTrue( # Check samples.
all([x == y for x, y in zip(noise_samples, mix_neg_inf_samples)]))
# Test target SNR 0.0 (different data expected).
mix_0 = signal_processing.SignalProcessingUtils.MixSignals(
signal, noise, 0.0)
self.assertTrue(len(signal), len(mix_0)) # Check duration.
self.assertTrue(len(noise), len(mix_0))
mix_0_samples = mix_0.get_array_of_samples()
self.assertTrue(
any([x != y for x, y in zip(signal_samples, mix_0_samples)]))
self.assertTrue(
any([x != y for x, y in zip(noise_samples, mix_0_samples)]))
# Test target SNR 0.0 (different data expected).
mix_0 = signal_processing.SignalProcessingUtils.MixSignals(
signal, noise, 0.0)
self.assertTrue(len(signal), len(mix_0)) # Check duration.
self.assertTrue(len(noise), len(mix_0))
mix_0_samples = mix_0.get_array_of_samples()
self.assertTrue(
any([x != y for x, y in zip(signal_samples, mix_0_samples)]))
self.assertTrue(
any([x != y for x, y in zip(noise_samples, mix_0_samples)]))
# Test target SNR +Inf (signal expected).
mix_pos_inf = signal_processing.SignalProcessingUtils.MixSignals(
signal, noise, np.Inf)
self.assertTrue(len(signal), len(mix_pos_inf)) # Check duration.
mix_pos_inf_samples = mix_pos_inf.get_array_of_samples()
self.assertTrue( # Check samples.
all([x == y for x, y in zip(signal_samples, mix_pos_inf_samples)]))
# Test target SNR +Inf (signal expected).
mix_pos_inf = signal_processing.SignalProcessingUtils.MixSignals(
signal, noise, np.Inf)
self.assertTrue(len(signal), len(mix_pos_inf)) # Check duration.
mix_pos_inf_samples = mix_pos_inf.get_array_of_samples()
self.assertTrue( # Check samples.
all([x == y for x, y in zip(signal_samples, mix_pos_inf_samples)]))
def testMixSignalsMinInfPower(self):
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
def testMixSignalsMinInfPower(self):
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
with self.assertRaises(exceptions.SignalProcessingException):
_ = signal_processing.SignalProcessingUtils.MixSignals(
signal, silence, 0.0)
with self.assertRaises(exceptions.SignalProcessingException):
_ = signal_processing.SignalProcessingUtils.MixSignals(
signal, silence, 0.0)
with self.assertRaises(exceptions.SignalProcessingException):
_ = signal_processing.SignalProcessingUtils.MixSignals(
silence, signal, 0.0)
with self.assertRaises(exceptions.SignalProcessingException):
_ = signal_processing.SignalProcessingUtils.MixSignals(
silence, signal, 0.0)
def testMixSignalNoiseDifferentLengths(self):
# Test signals.
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
longer = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
pydub.AudioSegment.silent(duration=2000, frame_rate=8000))
def testMixSignalNoiseDifferentLengths(self):
# Test signals.
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
longer = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
pydub.AudioSegment.silent(duration=2000, frame_rate=8000))
# When the signal is shorter than the noise, the mix length always equals
# that of the signal regardless of whether padding is applied.
# No noise padding, length of signal less than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=shorter,
noise=longer,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.NO_PADDING)
self.assertEqual(len(shorter), len(mix))
# With noise padding, length of signal less than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=shorter,
noise=longer,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.ZERO_PADDING)
self.assertEqual(len(shorter), len(mix))
# When the signal is shorter than the noise, the mix length always equals
# that of the signal regardless of whether padding is applied.
# No noise padding, length of signal less than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=shorter,
noise=longer,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
NO_PADDING)
self.assertEqual(len(shorter), len(mix))
# With noise padding, length of signal less than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=shorter,
noise=longer,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
ZERO_PADDING)
self.assertEqual(len(shorter), len(mix))
# When the signal is longer than the noise, the mix length depends on
# whether padding is applied.
# No noise padding, length of signal greater than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.NO_PADDING)
self.assertEqual(len(shorter), len(mix))
# With noise padding, length of signal greater than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.ZERO_PADDING)
self.assertEqual(len(longer), len(mix))
# When the signal is longer than the noise, the mix length depends on
# whether padding is applied.
# No noise padding, length of signal greater than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
NO_PADDING)
self.assertEqual(len(shorter), len(mix))
# With noise padding, length of signal greater than that of noise.
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
ZERO_PADDING)
self.assertEqual(len(longer), len(mix))
def testMixSignalNoisePaddingTypes(self):
# Test signals.
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
longer = signal_processing.SignalProcessingUtils.GeneratePureTone(
pydub.AudioSegment.silent(duration=2000, frame_rate=8000), 440.0)
def testMixSignalNoisePaddingTypes(self):
# Test signals.
shorter = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
pydub.AudioSegment.silent(duration=1000, frame_rate=8000))
longer = signal_processing.SignalProcessingUtils.GeneratePureTone(
pydub.AudioSegment.silent(duration=2000, frame_rate=8000), 440.0)
# Zero padding: expect pure tone only in 1-2s.
mix_zero_pad = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
target_snr=-6,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.ZERO_PADDING)
# Zero padding: expect pure tone only in 1-2s.
mix_zero_pad = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
target_snr=-6,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.
ZERO_PADDING)
# Loop: expect pure tone plus noise in 1-2s.
mix_loop = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
target_snr=-6,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.LOOP)
# Loop: expect pure tone plus noise in 1-2s.
mix_loop = signal_processing.SignalProcessingUtils.MixSignals(
signal=longer,
noise=shorter,
target_snr=-6,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.LOOP)
def Energy(signal):
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
signal).astype(np.float32)
return np.sum(samples * samples)
def Energy(signal):
samples = signal_processing.SignalProcessingUtils.AudioSegmentToRawData(
signal).astype(np.float32)
return np.sum(samples * samples)
e_mix_zero_pad = Energy(mix_zero_pad[-1000:])
e_mix_loop = Energy(mix_loop[-1000:])
self.assertLess(0, e_mix_zero_pad)
self.assertLess(e_mix_zero_pad, e_mix_loop)
e_mix_zero_pad = Energy(mix_zero_pad[-1000:])
e_mix_loop = Energy(mix_loop[-1000:])
self.assertLess(0, e_mix_zero_pad)
self.assertLess(e_mix_zero_pad, e_mix_loop)
def testMixSignalSnr(self):
# Test signals.
tone_low = signal_processing.SignalProcessingUtils.GeneratePureTone(
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 250.0)
tone_high = signal_processing.SignalProcessingUtils.GeneratePureTone(
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 3000.0)
def testMixSignalSnr(self):
# Test signals.
tone_low = signal_processing.SignalProcessingUtils.GeneratePureTone(
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 250.0)
tone_high = signal_processing.SignalProcessingUtils.GeneratePureTone(
pydub.AudioSegment.silent(duration=64, frame_rate=8000), 3000.0)
def ToneAmplitudes(mix):
"""Returns the amplitude of the coefficients #16 and #192, which
def ToneAmplitudes(mix):
"""Returns the amplitude of the coefficients #16 and #192, which
correspond to the tones at 250 and 3k Hz respectively."""
mix_fft = np.absolute(signal_processing.SignalProcessingUtils.Fft(mix))
return mix_fft[16], mix_fft[192]
mix_fft = np.absolute(
signal_processing.SignalProcessingUtils.Fft(mix))
return mix_fft[16], mix_fft[192]
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_low,
noise=tone_high,
target_snr=-6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_low, ampl_high)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_low, noise=tone_high, target_snr=-6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_low, ampl_high)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_high,
noise=tone_low,
target_snr=-6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_high, ampl_low)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_high, noise=tone_low, target_snr=-6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_high, ampl_low)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_low,
noise=tone_high,
target_snr=6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_high, ampl_low)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_low, noise=tone_high, target_snr=6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_high, ampl_low)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_high,
noise=tone_low,
target_snr=6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_low, ampl_high)
mix = signal_processing.SignalProcessingUtils.MixSignals(
signal=tone_high, noise=tone_low, target_snr=6)
ampl_low, ampl_high = ToneAmplitudes(mix)
self.assertLess(ampl_low, ampl_high)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""APM module simulator.
"""
@ -25,85 +24,93 @@ from . import test_data_generation
class ApmModuleSimulator(object):
"""Audio processing module (APM) simulator class.
"""Audio processing module (APM) simulator class.
"""
_TEST_DATA_GENERATOR_CLASSES = (
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES
_TEST_DATA_GENERATOR_CLASSES = (
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
_EVAL_SCORE_WORKER_CLASSES = eval_scores.EvaluationScore.REGISTERED_CLASSES
_PREFIX_APM_CONFIG = 'apmcfg-'
_PREFIX_CAPTURE = 'capture-'
_PREFIX_RENDER = 'render-'
_PREFIX_ECHO_SIMULATOR = 'echosim-'
_PREFIX_TEST_DATA_GEN = 'datagen-'
_PREFIX_TEST_DATA_GEN_PARAMS = 'datagen_params-'
_PREFIX_SCORE = 'score-'
_PREFIX_APM_CONFIG = 'apmcfg-'
_PREFIX_CAPTURE = 'capture-'
_PREFIX_RENDER = 'render-'
_PREFIX_ECHO_SIMULATOR = 'echosim-'
_PREFIX_TEST_DATA_GEN = 'datagen-'
_PREFIX_TEST_DATA_GEN_PARAMS = 'datagen_params-'
_PREFIX_SCORE = 'score-'
def __init__(self, test_data_generator_factory, evaluation_score_factory,
ap_wrapper, evaluator, external_vads=None):
if external_vads is None:
external_vads = {}
self._test_data_generator_factory = test_data_generator_factory
self._evaluation_score_factory = evaluation_score_factory
self._audioproc_wrapper = ap_wrapper
self._evaluator = evaluator
self._annotator = annotations.AudioAnnotationsExtractor(
annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD |
annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO |
annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM,
external_vads
)
def __init__(self,
test_data_generator_factory,
evaluation_score_factory,
ap_wrapper,
evaluator,
external_vads=None):
if external_vads is None:
external_vads = {}
self._test_data_generator_factory = test_data_generator_factory
self._evaluation_score_factory = evaluation_score_factory
self._audioproc_wrapper = ap_wrapper
self._evaluator = evaluator
self._annotator = annotations.AudioAnnotationsExtractor(
annotations.AudioAnnotationsExtractor.VadType.ENERGY_THRESHOLD
| annotations.AudioAnnotationsExtractor.VadType.WEBRTC_COMMON_AUDIO
| annotations.AudioAnnotationsExtractor.VadType.WEBRTC_APM,
external_vads)
# Init.
self._test_data_generator_factory.SetOutputDirectoryPrefix(
self._PREFIX_TEST_DATA_GEN_PARAMS)
self._evaluation_score_factory.SetScoreFilenamePrefix(
self._PREFIX_SCORE)
# Init.
self._test_data_generator_factory.SetOutputDirectoryPrefix(
self._PREFIX_TEST_DATA_GEN_PARAMS)
self._evaluation_score_factory.SetScoreFilenamePrefix(
self._PREFIX_SCORE)
# Properties for each run.
self._base_output_path = None
self._output_cache_path = None
self._test_data_generators = None
self._evaluation_score_workers = None
self._config_filepaths = None
self._capture_input_filepaths = None
self._render_input_filepaths = None
self._echo_path_simulator_class = None
# Properties for each run.
self._base_output_path = None
self._output_cache_path = None
self._test_data_generators = None
self._evaluation_score_workers = None
self._config_filepaths = None
self._capture_input_filepaths = None
self._render_input_filepaths = None
self._echo_path_simulator_class = None
@classmethod
def GetPrefixApmConfig(cls):
return cls._PREFIX_APM_CONFIG
@classmethod
def GetPrefixApmConfig(cls):
return cls._PREFIX_APM_CONFIG
@classmethod
def GetPrefixCapture(cls):
return cls._PREFIX_CAPTURE
@classmethod
def GetPrefixCapture(cls):
return cls._PREFIX_CAPTURE
@classmethod
def GetPrefixRender(cls):
return cls._PREFIX_RENDER
@classmethod
def GetPrefixRender(cls):
return cls._PREFIX_RENDER
@classmethod
def GetPrefixEchoSimulator(cls):
return cls._PREFIX_ECHO_SIMULATOR
@classmethod
def GetPrefixEchoSimulator(cls):
return cls._PREFIX_ECHO_SIMULATOR
@classmethod
def GetPrefixTestDataGenerator(cls):
return cls._PREFIX_TEST_DATA_GEN
@classmethod
def GetPrefixTestDataGenerator(cls):
return cls._PREFIX_TEST_DATA_GEN
@classmethod
def GetPrefixTestDataGeneratorParameters(cls):
return cls._PREFIX_TEST_DATA_GEN_PARAMS
@classmethod
def GetPrefixTestDataGeneratorParameters(cls):
return cls._PREFIX_TEST_DATA_GEN_PARAMS
@classmethod
def GetPrefixScore(cls):
return cls._PREFIX_SCORE
@classmethod
def GetPrefixScore(cls):
return cls._PREFIX_SCORE
def Run(self, config_filepaths, capture_input_filepaths,
test_data_generator_names, eval_score_names, output_dir,
render_input_filepaths=None, echo_path_simulator_name=(
echo_path_simulation.NoEchoPathSimulator.NAME)):
"""Runs the APM simulation.
def Run(self,
config_filepaths,
capture_input_filepaths,
test_data_generator_names,
eval_score_names,
output_dir,
render_input_filepaths=None,
echo_path_simulator_name=(
echo_path_simulation.NoEchoPathSimulator.NAME)):
"""Runs the APM simulation.
Initializes paths and required instances, then runs all the simulations.
The render input can be optionally added. If added, the number of capture
@ -120,132 +127,140 @@ class ApmModuleSimulator(object):
echo_path_simulator_name: name of the echo path simulator to use when
render input is provided.
"""
assert render_input_filepaths is None or (
len(capture_input_filepaths) == len(render_input_filepaths)), (
'render input set size not matching input set size')
assert render_input_filepaths is None or echo_path_simulator_name in (
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES), (
'invalid echo path simulator')
self._base_output_path = os.path.abspath(output_dir)
assert render_input_filepaths is None or (
len(capture_input_filepaths) == len(render_input_filepaths)), (
'render input set size not matching input set size')
assert render_input_filepaths is None or echo_path_simulator_name in (
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES), (
'invalid echo path simulator')
self._base_output_path = os.path.abspath(output_dir)
# Output path used to cache the data shared across simulations.
self._output_cache_path = os.path.join(self._base_output_path, '_cache')
# Output path used to cache the data shared across simulations.
self._output_cache_path = os.path.join(self._base_output_path,
'_cache')
# Instance test data generators.
self._test_data_generators = [self._test_data_generator_factory.GetInstance(
test_data_generators_class=(
self._TEST_DATA_GENERATOR_CLASSES[name])) for name in (
test_data_generator_names)]
# Instance test data generators.
self._test_data_generators = [
self._test_data_generator_factory.GetInstance(
test_data_generators_class=(
self._TEST_DATA_GENERATOR_CLASSES[name]))
for name in (test_data_generator_names)
]
# Instance evaluation score workers.
self._evaluation_score_workers = [
self._evaluation_score_factory.GetInstance(
evaluation_score_class=self._EVAL_SCORE_WORKER_CLASSES[name]) for (
name) in eval_score_names]
# Instance evaluation score workers.
self._evaluation_score_workers = [
self._evaluation_score_factory.GetInstance(
evaluation_score_class=self._EVAL_SCORE_WORKER_CLASSES[name])
for (name) in eval_score_names
]
# Set APM configuration file paths.
self._config_filepaths = self._CreatePathsCollection(config_filepaths)
# Set APM configuration file paths.
self._config_filepaths = self._CreatePathsCollection(config_filepaths)
# Set probing signal file paths.
if render_input_filepaths is None:
# Capture input only.
self._capture_input_filepaths = self._CreatePathsCollection(
capture_input_filepaths)
self._render_input_filepaths = None
else:
# Set both capture and render input signals.
self._SetTestInputSignalFilePaths(
capture_input_filepaths, render_input_filepaths)
# Set probing signal file paths.
if render_input_filepaths is None:
# Capture input only.
self._capture_input_filepaths = self._CreatePathsCollection(
capture_input_filepaths)
self._render_input_filepaths = None
else:
# Set both capture and render input signals.
self._SetTestInputSignalFilePaths(capture_input_filepaths,
render_input_filepaths)
# Set the echo path simulator class.
self._echo_path_simulator_class = (
echo_path_simulation.EchoPathSimulator.REGISTERED_CLASSES[
echo_path_simulator_name])
# Set the echo path simulator class.
self._echo_path_simulator_class = (
echo_path_simulation.EchoPathSimulator.
REGISTERED_CLASSES[echo_path_simulator_name])
self._SimulateAll()
self._SimulateAll()
def _SimulateAll(self):
"""Runs all the simulations.
def _SimulateAll(self):
"""Runs all the simulations.
Iterates over the combinations of APM configurations, probing signals, and
test data generators. This method is mainly responsible for the creation of
the cache and output directories required in order to call _Simulate().
"""
without_render_input = self._render_input_filepaths is None
without_render_input = self._render_input_filepaths is None
# Try different APM config files.
for config_name in self._config_filepaths:
config_filepath = self._config_filepaths[config_name]
# Try different APM config files.
for config_name in self._config_filepaths:
config_filepath = self._config_filepaths[config_name]
# Try different capture-render pairs.
for capture_input_name in self._capture_input_filepaths:
# Output path for the capture signal annotations.
capture_annotations_cache_path = os.path.join(
self._output_cache_path,
self._PREFIX_CAPTURE + capture_input_name)
data_access.MakeDirectory(capture_annotations_cache_path)
# Try different capture-render pairs.
for capture_input_name in self._capture_input_filepaths:
# Output path for the capture signal annotations.
capture_annotations_cache_path = os.path.join(
self._output_cache_path,
self._PREFIX_CAPTURE + capture_input_name)
data_access.MakeDirectory(capture_annotations_cache_path)
# Capture.
capture_input_filepath = self._capture_input_filepaths[
capture_input_name]
if not os.path.exists(capture_input_filepath):
# If the input signal file does not exist, try to create using the
# available input signal creators.
self._CreateInputSignal(capture_input_filepath)
assert os.path.exists(capture_input_filepath)
self._ExtractCaptureAnnotations(
capture_input_filepath, capture_annotations_cache_path)
# Capture.
capture_input_filepath = self._capture_input_filepaths[
capture_input_name]
if not os.path.exists(capture_input_filepath):
# If the input signal file does not exist, try to create using the
# available input signal creators.
self._CreateInputSignal(capture_input_filepath)
assert os.path.exists(capture_input_filepath)
self._ExtractCaptureAnnotations(
capture_input_filepath, capture_annotations_cache_path)
# Render and simulated echo path (optional).
render_input_filepath = None if without_render_input else (
self._render_input_filepaths[capture_input_name])
render_input_name = '(none)' if without_render_input else (
self._ExtractFileName(render_input_filepath))
echo_path_simulator = (
echo_path_simulation_factory.EchoPathSimulatorFactory.GetInstance(
self._echo_path_simulator_class, render_input_filepath))
# Render and simulated echo path (optional).
render_input_filepath = None if without_render_input else (
self._render_input_filepaths[capture_input_name])
render_input_name = '(none)' if without_render_input else (
self._ExtractFileName(render_input_filepath))
echo_path_simulator = (echo_path_simulation_factory.
EchoPathSimulatorFactory.GetInstance(
self._echo_path_simulator_class,
render_input_filepath))
# Try different test data generators.
for test_data_generators in self._test_data_generators:
logging.info('APM config preset: <%s>, capture: <%s>, render: <%s>,'
'test data generator: <%s>, echo simulator: <%s>',
config_name, capture_input_name, render_input_name,
test_data_generators.NAME, echo_path_simulator.NAME)
# Try different test data generators.
for test_data_generators in self._test_data_generators:
logging.info(
'APM config preset: <%s>, capture: <%s>, render: <%s>,'
'test data generator: <%s>, echo simulator: <%s>',
config_name, capture_input_name, render_input_name,
test_data_generators.NAME, echo_path_simulator.NAME)
# Output path for the generated test data.
test_data_cache_path = os.path.join(
capture_annotations_cache_path,
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
data_access.MakeDirectory(test_data_cache_path)
logging.debug('test data cache path: <%s>', test_data_cache_path)
# Output path for the generated test data.
test_data_cache_path = os.path.join(
capture_annotations_cache_path,
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
data_access.MakeDirectory(test_data_cache_path)
logging.debug('test data cache path: <%s>',
test_data_cache_path)
# Output path for the echo simulator and APM input mixer output.
echo_test_data_cache_path = os.path.join(
test_data_cache_path, 'echosim-{}'.format(
echo_path_simulator.NAME))
data_access.MakeDirectory(echo_test_data_cache_path)
logging.debug('echo test data cache path: <%s>',
echo_test_data_cache_path)
# Output path for the echo simulator and APM input mixer output.
echo_test_data_cache_path = os.path.join(
test_data_cache_path,
'echosim-{}'.format(echo_path_simulator.NAME))
data_access.MakeDirectory(echo_test_data_cache_path)
logging.debug('echo test data cache path: <%s>',
echo_test_data_cache_path)
# Full output path.
output_path = os.path.join(
self._base_output_path,
self._PREFIX_APM_CONFIG + config_name,
self._PREFIX_CAPTURE + capture_input_name,
self._PREFIX_RENDER + render_input_name,
self._PREFIX_ECHO_SIMULATOR + echo_path_simulator.NAME,
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
data_access.MakeDirectory(output_path)
logging.debug('output path: <%s>', output_path)
# Full output path.
output_path = os.path.join(
self._base_output_path,
self._PREFIX_APM_CONFIG + config_name,
self._PREFIX_CAPTURE + capture_input_name,
self._PREFIX_RENDER + render_input_name,
self._PREFIX_ECHO_SIMULATOR + echo_path_simulator.NAME,
self._PREFIX_TEST_DATA_GEN + test_data_generators.NAME)
data_access.MakeDirectory(output_path)
logging.debug('output path: <%s>', output_path)
self._Simulate(test_data_generators, capture_input_filepath,
render_input_filepath, test_data_cache_path,
echo_test_data_cache_path, output_path,
config_filepath, echo_path_simulator)
self._Simulate(test_data_generators,
capture_input_filepath,
render_input_filepath, test_data_cache_path,
echo_test_data_cache_path, output_path,
config_filepath, echo_path_simulator)
@staticmethod
def _CreateInputSignal(input_signal_filepath):
"""Creates a missing input signal file.
@staticmethod
def _CreateInputSignal(input_signal_filepath):
"""Creates a missing input signal file.
The file name is parsed to extract input signal creator and params. If a
creator is matched and the parameters are valid, a new signal is generated
@ -257,30 +272,33 @@ class ApmModuleSimulator(object):
Raises:
InputSignalCreatorException
"""
filename = os.path.splitext(os.path.split(input_signal_filepath)[-1])[0]
filename_parts = filename.split('-')
filename = os.path.splitext(
os.path.split(input_signal_filepath)[-1])[0]
filename_parts = filename.split('-')
if len(filename_parts) < 2:
raise exceptions.InputSignalCreatorException(
'Cannot parse input signal file name')
if len(filename_parts) < 2:
raise exceptions.InputSignalCreatorException(
'Cannot parse input signal file name')
signal, metadata = input_signal_creator.InputSignalCreator.Create(
filename_parts[0], filename_parts[1].split('_'))
signal, metadata = input_signal_creator.InputSignalCreator.Create(
filename_parts[0], filename_parts[1].split('_'))
signal_processing.SignalProcessingUtils.SaveWav(
input_signal_filepath, signal)
data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata)
signal_processing.SignalProcessingUtils.SaveWav(
input_signal_filepath, signal)
data_access.Metadata.SaveFileMetadata(input_signal_filepath, metadata)
def _ExtractCaptureAnnotations(self, input_filepath, output_path,
annotation_name=""):
self._annotator.Extract(input_filepath)
self._annotator.Save(output_path, annotation_name)
def _ExtractCaptureAnnotations(self,
input_filepath,
output_path,
annotation_name=""):
self._annotator.Extract(input_filepath)
self._annotator.Save(output_path, annotation_name)
def _Simulate(self, test_data_generators, clean_capture_input_filepath,
render_input_filepath, test_data_cache_path,
echo_test_data_cache_path, output_path, config_filepath,
echo_path_simulator):
"""Runs a single set of simulation.
def _Simulate(self, test_data_generators, clean_capture_input_filepath,
render_input_filepath, test_data_cache_path,
echo_test_data_cache_path, output_path, config_filepath,
echo_path_simulator):
"""Runs a single set of simulation.
Simulates a given combination of APM configuration, probing signal, and
test data generator. It iterates over the test data generator
@ -298,90 +316,92 @@ class ApmModuleSimulator(object):
config_filepath: APM configuration file to test.
echo_path_simulator: EchoPathSimulator instance.
"""
# Generate pairs of noisy input and reference signal files.
test_data_generators.Generate(
input_signal_filepath=clean_capture_input_filepath,
test_data_cache_path=test_data_cache_path,
base_output_path=output_path)
# Generate pairs of noisy input and reference signal files.
test_data_generators.Generate(
input_signal_filepath=clean_capture_input_filepath,
test_data_cache_path=test_data_cache_path,
base_output_path=output_path)
# Extract metadata linked to the clean input file (if any).
apm_input_metadata = None
try:
apm_input_metadata = data_access.Metadata.LoadFileMetadata(
clean_capture_input_filepath)
except IOError as e:
apm_input_metadata = {}
apm_input_metadata['test_data_gen_name'] = test_data_generators.NAME
apm_input_metadata['test_data_gen_config'] = None
# Extract metadata linked to the clean input file (if any).
apm_input_metadata = None
try:
apm_input_metadata = data_access.Metadata.LoadFileMetadata(
clean_capture_input_filepath)
except IOError as e:
apm_input_metadata = {}
apm_input_metadata['test_data_gen_name'] = test_data_generators.NAME
apm_input_metadata['test_data_gen_config'] = None
# For each test data pair, simulate a call and evaluate.
for config_name in test_data_generators.config_names:
logging.info(' - test data generator config: <%s>', config_name)
apm_input_metadata['test_data_gen_config'] = config_name
# For each test data pair, simulate a call and evaluate.
for config_name in test_data_generators.config_names:
logging.info(' - test data generator config: <%s>', config_name)
apm_input_metadata['test_data_gen_config'] = config_name
# Paths to the test data generator output.
# Note that the reference signal does not depend on the render input
# which is optional.
noisy_capture_input_filepath = (
test_data_generators.noisy_signal_filepaths[config_name])
reference_signal_filepath = (
test_data_generators.reference_signal_filepaths[config_name])
# Paths to the test data generator output.
# Note that the reference signal does not depend on the render input
# which is optional.
noisy_capture_input_filepath = (
test_data_generators.noisy_signal_filepaths[config_name])
reference_signal_filepath = (
test_data_generators.reference_signal_filepaths[config_name])
# Output path for the evaluation (e.g., APM output file).
evaluation_output_path = test_data_generators.apm_output_paths[
config_name]
# Output path for the evaluation (e.g., APM output file).
evaluation_output_path = test_data_generators.apm_output_paths[
config_name]
# Paths to the APM input signals.
echo_path_filepath = echo_path_simulator.Simulate(
echo_test_data_cache_path)
apm_input_filepath = input_mixer.ApmInputMixer.Mix(
echo_test_data_cache_path, noisy_capture_input_filepath,
echo_path_filepath)
# Paths to the APM input signals.
echo_path_filepath = echo_path_simulator.Simulate(
echo_test_data_cache_path)
apm_input_filepath = input_mixer.ApmInputMixer.Mix(
echo_test_data_cache_path, noisy_capture_input_filepath,
echo_path_filepath)
# Extract annotations for the APM input mix.
apm_input_basepath, apm_input_filename = os.path.split(
apm_input_filepath)
self._ExtractCaptureAnnotations(
apm_input_filepath, apm_input_basepath,
os.path.splitext(apm_input_filename)[0] + '-')
# Extract annotations for the APM input mix.
apm_input_basepath, apm_input_filename = os.path.split(
apm_input_filepath)
self._ExtractCaptureAnnotations(
apm_input_filepath, apm_input_basepath,
os.path.splitext(apm_input_filename)[0] + '-')
# Simulate a call using APM.
self._audioproc_wrapper.Run(
config_filepath=config_filepath,
capture_input_filepath=apm_input_filepath,
render_input_filepath=render_input_filepath,
output_path=evaluation_output_path)
# Simulate a call using APM.
self._audioproc_wrapper.Run(
config_filepath=config_filepath,
capture_input_filepath=apm_input_filepath,
render_input_filepath=render_input_filepath,
output_path=evaluation_output_path)
try:
# Evaluate.
self._evaluator.Run(
evaluation_score_workers=self._evaluation_score_workers,
apm_input_metadata=apm_input_metadata,
apm_output_filepath=self._audioproc_wrapper.output_filepath,
reference_input_filepath=reference_signal_filepath,
render_input_filepath=render_input_filepath,
output_path=evaluation_output_path,
)
try:
# Evaluate.
self._evaluator.Run(
evaluation_score_workers=self._evaluation_score_workers,
apm_input_metadata=apm_input_metadata,
apm_output_filepath=self._audioproc_wrapper.
output_filepath,
reference_input_filepath=reference_signal_filepath,
render_input_filepath=render_input_filepath,
output_path=evaluation_output_path,
)
# Save simulation metadata.
data_access.Metadata.SaveAudioTestDataPaths(
output_path=evaluation_output_path,
clean_capture_input_filepath=clean_capture_input_filepath,
echo_free_capture_filepath=noisy_capture_input_filepath,
echo_filepath=echo_path_filepath,
render_filepath=render_input_filepath,
capture_filepath=apm_input_filepath,
apm_output_filepath=self._audioproc_wrapper.output_filepath,
apm_reference_filepath=reference_signal_filepath,
apm_config_filepath=config_filepath,
)
except exceptions.EvaluationScoreException as e:
logging.warning('the evaluation failed: %s', e.message)
continue
# Save simulation metadata.
data_access.Metadata.SaveAudioTestDataPaths(
output_path=evaluation_output_path,
clean_capture_input_filepath=clean_capture_input_filepath,
echo_free_capture_filepath=noisy_capture_input_filepath,
echo_filepath=echo_path_filepath,
render_filepath=render_input_filepath,
capture_filepath=apm_input_filepath,
apm_output_filepath=self._audioproc_wrapper.
output_filepath,
apm_reference_filepath=reference_signal_filepath,
apm_config_filepath=config_filepath,
)
except exceptions.EvaluationScoreException as e:
logging.warning('the evaluation failed: %s', e.message)
continue
def _SetTestInputSignalFilePaths(self, capture_input_filepaths,
render_input_filepaths):
"""Sets input and render input file paths collections.
def _SetTestInputSignalFilePaths(self, capture_input_filepaths,
render_input_filepaths):
"""Sets input and render input file paths collections.
Pairs the input and render input files by storing the file paths into two
collections. The key is the file name of the input file.
@ -390,20 +410,20 @@ class ApmModuleSimulator(object):
capture_input_filepaths: list of file paths.
render_input_filepaths: list of file paths.
"""
self._capture_input_filepaths = {}
self._render_input_filepaths = {}
assert len(capture_input_filepaths) == len(render_input_filepaths)
for capture_input_filepath, render_input_filepath in zip(
capture_input_filepaths, render_input_filepaths):
name = self._ExtractFileName(capture_input_filepath)
self._capture_input_filepaths[name] = os.path.abspath(
capture_input_filepath)
self._render_input_filepaths[name] = os.path.abspath(
render_input_filepath)
self._capture_input_filepaths = {}
self._render_input_filepaths = {}
assert len(capture_input_filepaths) == len(render_input_filepaths)
for capture_input_filepath, render_input_filepath in zip(
capture_input_filepaths, render_input_filepaths):
name = self._ExtractFileName(capture_input_filepath)
self._capture_input_filepaths[name] = os.path.abspath(
capture_input_filepath)
self._render_input_filepaths[name] = os.path.abspath(
render_input_filepath)
@classmethod
def _CreatePathsCollection(cls, filepaths):
"""Creates a collection of file paths.
@classmethod
def _CreatePathsCollection(cls, filepaths):
"""Creates a collection of file paths.
Given a list of file paths, makes a collection with one item for each file
path. The value is absolute path, the key is the file name without
@ -415,12 +435,12 @@ class ApmModuleSimulator(object):
Returns:
A dict.
"""
filepaths_collection = {}
for filepath in filepaths:
name = cls._ExtractFileName(filepath)
filepaths_collection[name] = os.path.abspath(filepath)
return filepaths_collection
filepaths_collection = {}
for filepath in filepaths:
name = cls._ExtractFileName(filepath)
filepaths_collection[name] = os.path.abspath(filepath)
return filepaths_collection
@classmethod
def _ExtractFileName(cls, filepath):
return os.path.splitext(os.path.split(filepath)[-1])[0]
@classmethod
def _ExtractFileName(cls, filepath):
return os.path.splitext(os.path.split(filepath)[-1])[0]

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the simulation module.
"""
@ -28,177 +27,177 @@ from . import test_data_generation_factory
class TestApmModuleSimulator(unittest.TestCase):
"""Unit tests for the ApmModuleSimulator class.
"""Unit tests for the ApmModuleSimulator class.
"""
def setUp(self):
"""Create temporary folders and fake audio track."""
self._output_path = tempfile.mkdtemp()
self._tmp_path = tempfile.mkdtemp()
def setUp(self):
"""Create temporary folders and fake audio track."""
self._output_path = tempfile.mkdtemp()
self._tmp_path = tempfile.mkdtemp()
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
self._fake_audio_track_path = os.path.join(self._output_path, 'fake.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_audio_track_path, fake_signal)
silence = pydub.AudioSegment.silent(duration=1000, frame_rate=48000)
fake_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
silence)
self._fake_audio_track_path = os.path.join(self._output_path,
'fake.wav')
signal_processing.SignalProcessingUtils.SaveWav(
self._fake_audio_track_path, fake_signal)
def tearDown(self):
"""Recursively delete temporary folders."""
shutil.rmtree(self._output_path)
shutil.rmtree(self._tmp_path)
def tearDown(self):
"""Recursively delete temporary folders."""
shutil.rmtree(self._output_path)
shutil.rmtree(self._tmp_path)
def testSimulation(self):
# Instance dependencies to mock and inject.
ap_wrapper = audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
evaluator = evaluation.ApmModuleEvaluator()
ap_wrapper.Run = mock.MagicMock(name='Run')
evaluator.Run = mock.MagicMock(name='Run')
def testSimulation(self):
# Instance dependencies to mock and inject.
ap_wrapper = audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH)
evaluator = evaluation.ApmModuleEvaluator()
ap_wrapper.Run = mock.MagicMock(name='Run')
evaluator.Run = mock.MagicMock(name='Run')
# Instance non-mocked dependencies.
test_data_generator_factory = (
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False))
evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(__file__), 'fake_polqa'),
echo_metric_tool_bin_path=None
)
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=test_data_generator_factory,
evaluation_score_factory=evaluation_score_factory,
ap_wrapper=ap_wrapper,
evaluator=evaluator,
external_vads={'fake': external_vad.ExternalVad(os.path.join(
os.path.dirname(__file__), 'fake_external_vad.py'), 'fake')}
)
# What to simulate.
config_files = ['apm_configs/default.json']
input_files = [self._fake_audio_track_path]
test_data_generators = ['identity', 'white_noise']
eval_scores = ['audio_level_mean', 'polqa']
# Run all simulations.
simulator.Run(
config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=test_data_generators,
eval_score_names=eval_scores,
output_dir=self._output_path)
# Check.
# TODO(alessiob): Once the TestDataGenerator classes can be configured by
# the client code (e.g., number of SNR pairs for the white noise test data
# generator), the exact number of calls to ap_wrapper.Run and evaluator.Run
# is known; use that with assertEqual.
min_number_of_simulations = len(config_files) * len(input_files) * len(
test_data_generators)
self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
min_number_of_simulations)
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
min_number_of_simulations)
def testInputSignalCreation(self):
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
# Instance non-mocked dependencies.
test_data_generator_factory = (
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(__file__), 'fake_polqa'),
echo_metric_tool_bin_path=None
)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
copy_with_identity=False))
evaluation_score_factory = eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
'fake_polqa'),
echo_metric_tool_bin_path=None)
# Inexistent input files to be silently created.
input_files = [
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
]
self.assertFalse(any([os.path.exists(input_file) for input_file in (
input_files)]))
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=test_data_generator_factory,
evaluation_score_factory=evaluation_score_factory,
ap_wrapper=ap_wrapper,
evaluator=evaluator,
external_vads={
'fake':
external_vad.ExternalVad(
os.path.join(os.path.dirname(__file__),
'fake_external_vad.py'), 'fake')
})
# The input files are created during the simulation.
simulator.Run(
config_filepaths=['apm_configs/default.json'],
capture_input_filepaths=input_files,
test_data_generator_names=['identity'],
eval_score_names=['audio_level_peak'],
output_dir=self._output_path)
self.assertTrue(all([os.path.exists(input_file) for input_file in (
input_files)]))
# What to simulate.
config_files = ['apm_configs/default.json']
input_files = [self._fake_audio_track_path]
test_data_generators = ['identity', 'white_noise']
eval_scores = ['audio_level_mean', 'polqa']
def testPureToneGenerationWithTotalHarmonicDistorsion(self):
logging.warning = mock.MagicMock(name='warning')
# Run all simulations.
simulator.Run(config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=test_data_generators,
eval_score_names=eval_scores,
output_dir=self._output_path)
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(
os.path.dirname(__file__), 'fake_polqa'),
echo_metric_tool_bin_path=None
)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
# Check.
# TODO(alessiob): Once the TestDataGenerator classes can be configured by
# the client code (e.g., number of SNR pairs for the white noise test data
# generator), the exact number of calls to ap_wrapper.Run and evaluator.Run
# is known; use that with assertEqual.
min_number_of_simulations = len(config_files) * len(input_files) * len(
test_data_generators)
self.assertGreaterEqual(len(ap_wrapper.Run.call_args_list),
min_number_of_simulations)
self.assertGreaterEqual(len(evaluator.Run.call_args_list),
min_number_of_simulations)
# What to simulate.
config_files = ['apm_configs/default.json']
input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
eval_scores = ['thd']
def testInputSignalCreation(self):
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
'fake_polqa'),
echo_metric_tool_bin_path=None)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.
DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
# Should work.
simulator.Run(
config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=['identity'],
eval_score_names=eval_scores,
output_dir=self._output_path)
self.assertFalse(logging.warning.called)
# Inexistent input files to be silently created.
input_files = [
os.path.join(self._tmp_path, 'pure_tone-440_1000.wav'),
os.path.join(self._tmp_path, 'pure_tone-1000_500.wav'),
]
self.assertFalse(
any([os.path.exists(input_file) for input_file in (input_files)]))
# Warning expected.
simulator.Run(
config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=['white_noise'], # Not allowed with THD.
eval_score_names=eval_scores,
output_dir=self._output_path)
logging.warning.assert_called_with('the evaluation failed: %s', (
'The THD score cannot be used with any test data generator other than '
'"identity"'))
# The input files are created during the simulation.
simulator.Run(config_filepaths=['apm_configs/default.json'],
capture_input_filepaths=input_files,
test_data_generator_names=['identity'],
eval_score_names=['audio_level_peak'],
output_dir=self._output_path)
self.assertTrue(
all([os.path.exists(input_file) for input_file in (input_files)]))
# # Init.
# generator = test_data_generation.IdentityTestDataGenerator('tmp')
# input_signal_filepath = os.path.join(
# self._test_data_cache_path, 'pure_tone-440_1000.wav')
def testPureToneGenerationWithTotalHarmonicDistorsion(self):
logging.warning = mock.MagicMock(name='warning')
# # Check that the input signal is generated.
# self.assertFalse(os.path.exists(input_signal_filepath))
# generator.Generate(
# input_signal_filepath=input_signal_filepath,
# test_data_cache_path=self._test_data_cache_path,
# base_output_path=self._base_output_path)
# self.assertTrue(os.path.exists(input_signal_filepath))
# Instance simulator.
simulator = simulation.ApmModuleSimulator(
test_data_generator_factory=(
test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=False)),
evaluation_score_factory=(
eval_scores_factory.EvaluationScoreWorkerFactory(
polqa_tool_bin_path=os.path.join(os.path.dirname(__file__),
'fake_polqa'),
echo_metric_tool_bin_path=None)),
ap_wrapper=audioproc_wrapper.AudioProcWrapper(
audioproc_wrapper.AudioProcWrapper.
DEFAULT_APM_SIMULATOR_BIN_PATH),
evaluator=evaluation.ApmModuleEvaluator())
# # Check input signal properties.
# input_signal = signal_processing.SignalProcessingUtils.LoadWav(
# input_signal_filepath)
# self.assertEqual(1000, len(input_signal))
# What to simulate.
config_files = ['apm_configs/default.json']
input_files = [os.path.join(self._tmp_path, 'pure_tone-440_1000.wav')]
eval_scores = ['thd']
# Should work.
simulator.Run(config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=['identity'],
eval_score_names=eval_scores,
output_dir=self._output_path)
self.assertFalse(logging.warning.called)
# Warning expected.
simulator.Run(
config_filepaths=config_files,
capture_input_filepaths=input_files,
test_data_generator_names=['white_noise'], # Not allowed with THD.
eval_score_names=eval_scores,
output_dir=self._output_path)
logging.warning.assert_called_with('the evaluation failed: %s', (
'The THD score cannot be used with any test data generator other than '
'"identity"'))
# # Init.
# generator = test_data_generation.IdentityTestDataGenerator('tmp')
# input_signal_filepath = os.path.join(
# self._test_data_cache_path, 'pure_tone-440_1000.wav')
# # Check that the input signal is generated.
# self.assertFalse(os.path.exists(input_signal_filepath))
# generator.Generate(
# input_signal_filepath=input_signal_filepath,
# test_data_cache_path=self._test_data_cache_path,
# base_output_path=self._base_output_path)
# self.assertTrue(os.path.exists(input_signal_filepath))
# # Check input signal properties.
# input_signal = signal_processing.SignalProcessingUtils.LoadWav(
# input_signal_filepath)
# self.assertEqual(1000, len(input_signal))

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Test data generators producing signals pairs intended to be used to
test the APM module. Each pair consists of a noisy input and a reference signal.
The former is used as APM input and it is generated by adding noise to a
@ -27,10 +26,10 @@ import shutil
import sys
try:
import scipy.io
import scipy.io
except ImportError:
logging.critical('Cannot import the third-party Python package scipy')
sys.exit(1)
logging.critical('Cannot import the third-party Python package scipy')
sys.exit(1)
from . import data_access
from . import exceptions
@ -38,7 +37,7 @@ from . import signal_processing
class TestDataGenerator(object):
"""Abstract class responsible for the generation of noisy signals.
"""Abstract class responsible for the generation of noisy signals.
Given a clean signal, it generates two streams named noisy signal and
reference. The former is the clean signal deteriorated by the noise source,
@ -50,24 +49,24 @@ class TestDataGenerator(object):
An test data generator generates one or more pairs.
"""
NAME = None
REGISTERED_CLASSES = {}
NAME = None
REGISTERED_CLASSES = {}
def __init__(self, output_directory_prefix):
self._output_directory_prefix = output_directory_prefix
# Init dictionaries with one entry for each test data generator
# configuration (e.g., different SNRs).
# Noisy audio track files (stored separately in a cache folder).
self._noisy_signal_filepaths = None
# Path to be used for the APM simulation output files.
self._apm_output_paths = None
# Reference audio track files (stored separately in a cache folder).
self._reference_signal_filepaths = None
self.Clear()
def __init__(self, output_directory_prefix):
self._output_directory_prefix = output_directory_prefix
# Init dictionaries with one entry for each test data generator
# configuration (e.g., different SNRs).
# Noisy audio track files (stored separately in a cache folder).
self._noisy_signal_filepaths = None
# Path to be used for the APM simulation output files.
self._apm_output_paths = None
# Reference audio track files (stored separately in a cache folder).
self._reference_signal_filepaths = None
self.Clear()
@classmethod
def RegisterClass(cls, class_to_register):
"""Registers a TestDataGenerator implementation.
@classmethod
def RegisterClass(cls, class_to_register):
"""Registers a TestDataGenerator implementation.
Decorator to automatically register the classes that extend
TestDataGenerator.
@ -77,28 +76,28 @@ class TestDataGenerator(object):
class IdentityGenerator(TestDataGenerator):
pass
"""
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
return class_to_register
cls.REGISTERED_CLASSES[class_to_register.NAME] = class_to_register
return class_to_register
@property
def config_names(self):
return self._noisy_signal_filepaths.keys()
@property
def config_names(self):
return self._noisy_signal_filepaths.keys()
@property
def noisy_signal_filepaths(self):
return self._noisy_signal_filepaths
@property
def noisy_signal_filepaths(self):
return self._noisy_signal_filepaths
@property
def apm_output_paths(self):
return self._apm_output_paths
@property
def apm_output_paths(self):
return self._apm_output_paths
@property
def reference_signal_filepaths(self):
return self._reference_signal_filepaths
@property
def reference_signal_filepaths(self):
return self._reference_signal_filepaths
def Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):
"""Generates a set of noisy input and reference audiotrack file pairs.
def Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
"""Generates a set of noisy input and reference audiotrack file pairs.
This method initializes an empty set of pairs and calls the _Generate()
method implemented in a concrete class.
@ -109,26 +108,26 @@ class TestDataGenerator(object):
files.
base_output_path: base path where output is written.
"""
self.Clear()
self._Generate(
input_signal_filepath, test_data_cache_path, base_output_path)
self.Clear()
self._Generate(input_signal_filepath, test_data_cache_path,
base_output_path)
def Clear(self):
"""Clears the generated output path dictionaries.
def Clear(self):
"""Clears the generated output path dictionaries.
"""
self._noisy_signal_filepaths = {}
self._apm_output_paths = {}
self._reference_signal_filepaths = {}
self._noisy_signal_filepaths = {}
self._apm_output_paths = {}
self._reference_signal_filepaths = {}
def _Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):
"""Abstract method to be implemented in each concrete class.
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
"""Abstract method to be implemented in each concrete class.
"""
raise NotImplementedError()
raise NotImplementedError()
def _AddNoiseSnrPairs(self, base_output_path, noisy_mix_filepaths,
snr_value_pairs):
"""Adds noisy-reference signal pairs.
def _AddNoiseSnrPairs(self, base_output_path, noisy_mix_filepaths,
snr_value_pairs):
"""Adds noisy-reference signal pairs.
Args:
base_output_path: noisy tracks base output path.
@ -136,22 +135,22 @@ class TestDataGenerator(object):
by noisy track name and SNR level.
snr_value_pairs: list of SNR pairs.
"""
for noise_track_name in noisy_mix_filepaths:
for snr_noisy, snr_refence in snr_value_pairs:
config_name = '{0}_{1:d}_{2:d}_SNR'.format(
noise_track_name, snr_noisy, snr_refence)
output_path = self._MakeDir(base_output_path, config_name)
self._AddNoiseReferenceFilesPair(
config_name=config_name,
noisy_signal_filepath=noisy_mix_filepaths[
noise_track_name][snr_noisy],
reference_signal_filepath=noisy_mix_filepaths[
noise_track_name][snr_refence],
output_path=output_path)
for noise_track_name in noisy_mix_filepaths:
for snr_noisy, snr_refence in snr_value_pairs:
config_name = '{0}_{1:d}_{2:d}_SNR'.format(
noise_track_name, snr_noisy, snr_refence)
output_path = self._MakeDir(base_output_path, config_name)
self._AddNoiseReferenceFilesPair(
config_name=config_name,
noisy_signal_filepath=noisy_mix_filepaths[noise_track_name]
[snr_noisy],
reference_signal_filepath=noisy_mix_filepaths[
noise_track_name][snr_refence],
output_path=output_path)
def _AddNoiseReferenceFilesPair(self, config_name, noisy_signal_filepath,
reference_signal_filepath, output_path):
"""Adds one noisy-reference signal pair.
def _AddNoiseReferenceFilesPair(self, config_name, noisy_signal_filepath,
reference_signal_filepath, output_path):
"""Adds one noisy-reference signal pair.
Args:
config_name: name of the APM configuration.
@ -159,264 +158,275 @@ class TestDataGenerator(object):
reference_signal_filepath: path to reference audio track file.
output_path: APM output path.
"""
assert config_name not in self._noisy_signal_filepaths
self._noisy_signal_filepaths[config_name] = os.path.abspath(
noisy_signal_filepath)
self._apm_output_paths[config_name] = os.path.abspath(output_path)
self._reference_signal_filepaths[config_name] = os.path.abspath(
reference_signal_filepath)
assert config_name not in self._noisy_signal_filepaths
self._noisy_signal_filepaths[config_name] = os.path.abspath(
noisy_signal_filepath)
self._apm_output_paths[config_name] = os.path.abspath(output_path)
self._reference_signal_filepaths[config_name] = os.path.abspath(
reference_signal_filepath)
def _MakeDir(self, base_output_path, test_data_generator_config_name):
output_path = os.path.join(
base_output_path,
self._output_directory_prefix + test_data_generator_config_name)
data_access.MakeDirectory(output_path)
return output_path
def _MakeDir(self, base_output_path, test_data_generator_config_name):
output_path = os.path.join(
base_output_path,
self._output_directory_prefix + test_data_generator_config_name)
data_access.MakeDirectory(output_path)
return output_path
@TestDataGenerator.RegisterClass
class IdentityTestDataGenerator(TestDataGenerator):
"""Generator that adds no noise.
"""Generator that adds no noise.
Both the noisy and the reference signals are the input signal.
"""
NAME = 'identity'
NAME = 'identity'
def __init__(self, output_directory_prefix, copy_with_identity):
TestDataGenerator.__init__(self, output_directory_prefix)
self._copy_with_identity = copy_with_identity
def __init__(self, output_directory_prefix, copy_with_identity):
TestDataGenerator.__init__(self, output_directory_prefix)
self._copy_with_identity = copy_with_identity
@property
def copy_with_identity(self):
return self._copy_with_identity
@property
def copy_with_identity(self):
return self._copy_with_identity
def _Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):
config_name = 'default'
output_path = self._MakeDir(base_output_path, config_name)
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
config_name = 'default'
output_path = self._MakeDir(base_output_path, config_name)
if self._copy_with_identity:
input_signal_filepath_new = os.path.join(
test_data_cache_path, os.path.split(input_signal_filepath)[1])
logging.info('copying ' + input_signal_filepath + ' to ' + (
input_signal_filepath_new))
shutil.copy(input_signal_filepath, input_signal_filepath_new)
input_signal_filepath = input_signal_filepath_new
if self._copy_with_identity:
input_signal_filepath_new = os.path.join(
test_data_cache_path,
os.path.split(input_signal_filepath)[1])
logging.info('copying ' + input_signal_filepath + ' to ' +
(input_signal_filepath_new))
shutil.copy(input_signal_filepath, input_signal_filepath_new)
input_signal_filepath = input_signal_filepath_new
self._AddNoiseReferenceFilesPair(
config_name=config_name,
noisy_signal_filepath=input_signal_filepath,
reference_signal_filepath=input_signal_filepath,
output_path=output_path)
self._AddNoiseReferenceFilesPair(
config_name=config_name,
noisy_signal_filepath=input_signal_filepath,
reference_signal_filepath=input_signal_filepath,
output_path=output_path)
@TestDataGenerator.RegisterClass
class WhiteNoiseTestDataGenerator(TestDataGenerator):
"""Generator that adds white noise.
"""Generator that adds white noise.
"""
NAME = 'white_noise'
NAME = 'white_noise'
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
# The reference (second value of each pair) always has a lower amount of noise
# - i.e., the SNR is 10 dB higher.
_SNR_VALUE_PAIRS = [
[20, 30], # Smallest noise.
[10, 20],
[5, 15],
[0, 10], # Largest noise.
]
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
# The reference (second value of each pair) always has a lower amount of noise
# - i.e., the SNR is 10 dB higher.
_SNR_VALUE_PAIRS = [
[20, 30], # Smallest noise.
[10, 20],
[5, 15],
[0, 10], # Largest noise.
]
_NOISY_SIGNAL_FILENAME_TEMPLATE = 'noise_{0:d}_SNR.wav'
_NOISY_SIGNAL_FILENAME_TEMPLATE = 'noise_{0:d}_SNR.wav'
def __init__(self, output_directory_prefix):
TestDataGenerator.__init__(self, output_directory_prefix)
def __init__(self, output_directory_prefix):
TestDataGenerator.__init__(self, output_directory_prefix)
def _Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):
# Load the input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
# Load the input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
# Create the noise track.
noise_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
input_signal)
# Create the noise track.
noise_signal = signal_processing.SignalProcessingUtils.GenerateWhiteNoise(
input_signal)
# Create the noisy mixes (once for each unique SNR value).
noisy_mix_filepaths = {}
snr_values = set([snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
for snr in snr_values:
noisy_signal_filepath = os.path.join(
test_data_cache_path,
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(snr))
# Create the noisy mixes (once for each unique SNR value).
noisy_mix_filepaths = {}
snr_values = set(
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
for snr in snr_values:
noisy_signal_filepath = os.path.join(
test_data_cache_path,
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(snr))
# Create and save if not done.
if not os.path.exists(noisy_signal_filepath):
# Create noisy signal.
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
input_signal, noise_signal, snr)
# Create and save if not done.
if not os.path.exists(noisy_signal_filepath):
# Create noisy signal.
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
input_signal, noise_signal, snr)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noisy_signal_filepath, noisy_signal)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noisy_signal_filepath, noisy_signal)
# Add file to the collection of mixes.
noisy_mix_filepaths[snr] = noisy_signal_filepath
# Add file to the collection of mixes.
noisy_mix_filepaths[snr] = noisy_signal_filepath
# Add all the noisy-reference signal pairs.
for snr_noisy, snr_refence in self._SNR_VALUE_PAIRS:
config_name = '{0:d}_{1:d}_SNR'.format(snr_noisy, snr_refence)
output_path = self._MakeDir(base_output_path, config_name)
self._AddNoiseReferenceFilesPair(
config_name=config_name,
noisy_signal_filepath=noisy_mix_filepaths[snr_noisy],
reference_signal_filepath=noisy_mix_filepaths[snr_refence],
output_path=output_path)
# Add all the noisy-reference signal pairs.
for snr_noisy, snr_refence in self._SNR_VALUE_PAIRS:
config_name = '{0:d}_{1:d}_SNR'.format(snr_noisy, snr_refence)
output_path = self._MakeDir(base_output_path, config_name)
self._AddNoiseReferenceFilesPair(
config_name=config_name,
noisy_signal_filepath=noisy_mix_filepaths[snr_noisy],
reference_signal_filepath=noisy_mix_filepaths[snr_refence],
output_path=output_path)
# TODO(alessiob): remove comment when class implemented.
# @TestDataGenerator.RegisterClass
class NarrowBandNoiseTestDataGenerator(TestDataGenerator):
"""Generator that adds narrow-band noise.
"""Generator that adds narrow-band noise.
"""
NAME = 'narrow_band_noise'
NAME = 'narrow_band_noise'
def __init__(self, output_directory_prefix):
TestDataGenerator.__init__(self, output_directory_prefix)
def __init__(self, output_directory_prefix):
TestDataGenerator.__init__(self, output_directory_prefix)
def _Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):
# TODO(alessiob): implement.
pass
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
# TODO(alessiob): implement.
pass
@TestDataGenerator.RegisterClass
class AdditiveNoiseTestDataGenerator(TestDataGenerator):
"""Generator that adds noise loops.
"""Generator that adds noise loops.
This generator uses all the wav files in a given path (default: noise_tracks/)
and mixes them to the clean speech with different target SNRs (hard-coded).
"""
NAME = 'additive_noise'
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
NAME = 'additive_noise'
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
DEFAULT_NOISE_TRACKS_PATH = os.path.join(
os.path.dirname(__file__), os.pardir, 'noise_tracks')
DEFAULT_NOISE_TRACKS_PATH = os.path.join(os.path.dirname(__file__),
os.pardir, 'noise_tracks')
# TODO(alessiob): Make the list of SNR pairs customizable.
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
# The reference (second value of each pair) always has a lower amount of noise
# - i.e., the SNR is 10 dB higher.
_SNR_VALUE_PAIRS = [
[20, 30], # Smallest noise.
[10, 20],
[5, 15],
[0, 10], # Largest noise.
]
# TODO(alessiob): Make the list of SNR pairs customizable.
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
# The reference (second value of each pair) always has a lower amount of noise
# - i.e., the SNR is 10 dB higher.
_SNR_VALUE_PAIRS = [
[20, 30], # Smallest noise.
[10, 20],
[5, 15],
[0, 10], # Largest noise.
]
def __init__(self, output_directory_prefix, noise_tracks_path):
TestDataGenerator.__init__(self, output_directory_prefix)
self._noise_tracks_path = noise_tracks_path
self._noise_tracks_file_names = [n for n in os.listdir(
self._noise_tracks_path) if n.lower().endswith('.wav')]
if len(self._noise_tracks_file_names) == 0:
raise exceptions.InitializationException(
'No wav files found in the noise tracks path %s' % (
self._noise_tracks_path))
def __init__(self, output_directory_prefix, noise_tracks_path):
TestDataGenerator.__init__(self, output_directory_prefix)
self._noise_tracks_path = noise_tracks_path
self._noise_tracks_file_names = [
n for n in os.listdir(self._noise_tracks_path)
if n.lower().endswith('.wav')
]
if len(self._noise_tracks_file_names) == 0:
raise exceptions.InitializationException(
'No wav files found in the noise tracks path %s' %
(self._noise_tracks_path))
def _Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):
"""Generates test data pairs using environmental noise.
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
"""Generates test data pairs using environmental noise.
For each noise track and pair of SNR values, the following two audio tracks
are created: the noisy signal and the reference signal. The former is
obtained by mixing the (clean) input signal to the corresponding noise
track enforcing the target SNR.
"""
# Init.
snr_values = set([snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
# Init.
snr_values = set(
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
# Load the input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
# Load the input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
noisy_mix_filepaths = {}
for noise_track_filename in self._noise_tracks_file_names:
# Load the noise track.
noise_track_name, _ = os.path.splitext(noise_track_filename)
noise_track_filepath = os.path.join(
self._noise_tracks_path, noise_track_filename)
if not os.path.exists(noise_track_filepath):
logging.error('cannot find the <%s> noise track', noise_track_filename)
raise exceptions.FileNotFoundError()
noisy_mix_filepaths = {}
for noise_track_filename in self._noise_tracks_file_names:
# Load the noise track.
noise_track_name, _ = os.path.splitext(noise_track_filename)
noise_track_filepath = os.path.join(self._noise_tracks_path,
noise_track_filename)
if not os.path.exists(noise_track_filepath):
logging.error('cannot find the <%s> noise track',
noise_track_filename)
raise exceptions.FileNotFoundError()
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
noise_track_filepath)
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
noise_track_filepath)
# Create the noisy mixes (once for each unique SNR value).
noisy_mix_filepaths[noise_track_name] = {}
for snr in snr_values:
noisy_signal_filepath = os.path.join(
test_data_cache_path,
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(noise_track_name, snr))
# Create the noisy mixes (once for each unique SNR value).
noisy_mix_filepaths[noise_track_name] = {}
for snr in snr_values:
noisy_signal_filepath = os.path.join(
test_data_cache_path,
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
noise_track_name, snr))
# Create and save if not done.
if not os.path.exists(noisy_signal_filepath):
# Create noisy signal.
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
input_signal, noise_signal, snr,
pad_noise=signal_processing.SignalProcessingUtils.MixPadding.LOOP)
# Create and save if not done.
if not os.path.exists(noisy_signal_filepath):
# Create noisy signal.
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
input_signal,
noise_signal,
snr,
pad_noise=signal_processing.SignalProcessingUtils.
MixPadding.LOOP)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noisy_signal_filepath, noisy_signal)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noisy_signal_filepath, noisy_signal)
# Add file to the collection of mixes.
noisy_mix_filepaths[noise_track_name][snr] = noisy_signal_filepath
# Add file to the collection of mixes.
noisy_mix_filepaths[noise_track_name][
snr] = noisy_signal_filepath
# Add all the noise-SNR pairs.
self._AddNoiseSnrPairs(
base_output_path, noisy_mix_filepaths, self._SNR_VALUE_PAIRS)
# Add all the noise-SNR pairs.
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
self._SNR_VALUE_PAIRS)
@TestDataGenerator.RegisterClass
class ReverberationTestDataGenerator(TestDataGenerator):
"""Generator that adds reverberation noise.
"""Generator that adds reverberation noise.
TODO(alessiob): Make this class more generic since the impulse response can be
anything (not just reverberation); call it e.g.,
ConvolutionalNoiseTestDataGenerator.
"""
NAME = 'reverberation'
NAME = 'reverberation'
_IMPULSE_RESPONSES = {
'lecture': 'air_binaural_lecture_0_0_1.mat', # Long echo.
'booth': 'air_binaural_booth_0_0_1.mat', # Short echo.
}
_MAX_IMPULSE_RESPONSE_LENGTH = None
_IMPULSE_RESPONSES = {
'lecture': 'air_binaural_lecture_0_0_1.mat', # Long echo.
'booth': 'air_binaural_booth_0_0_1.mat', # Short echo.
}
_MAX_IMPULSE_RESPONSE_LENGTH = None
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
# The reference (second value of each pair) always has a lower amount of noise
# - i.e., the SNR is 5 dB higher.
_SNR_VALUE_PAIRS = [
[3, 8], # Smallest noise.
[-3, 2], # Largest noise.
]
# Each pair indicates the clean vs. noisy and reference vs. noisy SNRs.
# The reference (second value of each pair) always has a lower amount of noise
# - i.e., the SNR is 5 dB higher.
_SNR_VALUE_PAIRS = [
[3, 8], # Smallest noise.
[-3, 2], # Largest noise.
]
_NOISE_TRACK_FILENAME_TEMPLATE = '{0}.wav'
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
_NOISE_TRACK_FILENAME_TEMPLATE = '{0}.wav'
_NOISY_SIGNAL_FILENAME_TEMPLATE = '{0}_{1:d}_SNR.wav'
def __init__(self, output_directory_prefix, aechen_ir_database_path):
TestDataGenerator.__init__(self, output_directory_prefix)
self._aechen_ir_database_path = aechen_ir_database_path
def __init__(self, output_directory_prefix, aechen_ir_database_path):
TestDataGenerator.__init__(self, output_directory_prefix)
self._aechen_ir_database_path = aechen_ir_database_path
def _Generate(
self, input_signal_filepath, test_data_cache_path, base_output_path):
"""Generates test data pairs using reverberation noise.
def _Generate(self, input_signal_filepath, test_data_cache_path,
base_output_path):
"""Generates test data pairs using reverberation noise.
For each impulse response, one noise track is created. For each impulse
response and pair of SNR values, the following 2 audio tracks are
@ -424,61 +434,64 @@ class ReverberationTestDataGenerator(TestDataGenerator):
obtained by mixing the (clean) input signal to the corresponding noise
track enforcing the target SNR.
"""
# Init.
snr_values = set([snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
# Init.
snr_values = set(
[snr for pair in self._SNR_VALUE_PAIRS for snr in pair])
# Load the input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
# Load the input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
noisy_mix_filepaths = {}
for impulse_response_name in self._IMPULSE_RESPONSES:
noise_track_filename = self._NOISE_TRACK_FILENAME_TEMPLATE.format(
impulse_response_name)
noise_track_filepath = os.path.join(
test_data_cache_path, noise_track_filename)
noise_signal = None
try:
# Load noise track.
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
noise_track_filepath)
except exceptions.FileNotFoundError:
# Generate noise track by applying the impulse response.
impulse_response_filepath = os.path.join(
self._aechen_ir_database_path,
self._IMPULSE_RESPONSES[impulse_response_name])
noise_signal = self._GenerateNoiseTrack(
noise_track_filepath, input_signal, impulse_response_filepath)
assert noise_signal is not None
noisy_mix_filepaths = {}
for impulse_response_name in self._IMPULSE_RESPONSES:
noise_track_filename = self._NOISE_TRACK_FILENAME_TEMPLATE.format(
impulse_response_name)
noise_track_filepath = os.path.join(test_data_cache_path,
noise_track_filename)
noise_signal = None
try:
# Load noise track.
noise_signal = signal_processing.SignalProcessingUtils.LoadWav(
noise_track_filepath)
except exceptions.FileNotFoundError:
# Generate noise track by applying the impulse response.
impulse_response_filepath = os.path.join(
self._aechen_ir_database_path,
self._IMPULSE_RESPONSES[impulse_response_name])
noise_signal = self._GenerateNoiseTrack(
noise_track_filepath, input_signal,
impulse_response_filepath)
assert noise_signal is not None
# Create the noisy mixes (once for each unique SNR value).
noisy_mix_filepaths[impulse_response_name] = {}
for snr in snr_values:
noisy_signal_filepath = os.path.join(
test_data_cache_path,
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
impulse_response_name, snr))
# Create the noisy mixes (once for each unique SNR value).
noisy_mix_filepaths[impulse_response_name] = {}
for snr in snr_values:
noisy_signal_filepath = os.path.join(
test_data_cache_path,
self._NOISY_SIGNAL_FILENAME_TEMPLATE.format(
impulse_response_name, snr))
# Create and save if not done.
if not os.path.exists(noisy_signal_filepath):
# Create noisy signal.
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
input_signal, noise_signal, snr)
# Create and save if not done.
if not os.path.exists(noisy_signal_filepath):
# Create noisy signal.
noisy_signal = signal_processing.SignalProcessingUtils.MixSignals(
input_signal, noise_signal, snr)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noisy_signal_filepath, noisy_signal)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noisy_signal_filepath, noisy_signal)
# Add file to the collection of mixes.
noisy_mix_filepaths[impulse_response_name][snr] = noisy_signal_filepath
# Add file to the collection of mixes.
noisy_mix_filepaths[impulse_response_name][
snr] = noisy_signal_filepath
# Add all the noise-SNR pairs.
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
self._SNR_VALUE_PAIRS)
# Add all the noise-SNR pairs.
self._AddNoiseSnrPairs(base_output_path, noisy_mix_filepaths,
self._SNR_VALUE_PAIRS)
def _GenerateNoiseTrack(self, noise_track_filepath, input_signal,
def _GenerateNoiseTrack(self, noise_track_filepath, input_signal,
impulse_response_filepath):
"""Generates noise track.
"""Generates noise track.
Generate a signal by convolving input_signal with the impulse response in
impulse_response_filepath; then save to noise_track_filepath.
@ -491,21 +504,23 @@ class ReverberationTestDataGenerator(TestDataGenerator):
Returns:
AudioSegment instance.
"""
# Load impulse response.
data = scipy.io.loadmat(impulse_response_filepath)
impulse_response = data['h_air'].flatten()
if self._MAX_IMPULSE_RESPONSE_LENGTH is not None:
logging.info('truncating impulse response from %d to %d samples',
len(impulse_response), self._MAX_IMPULSE_RESPONSE_LENGTH)
impulse_response = impulse_response[:self._MAX_IMPULSE_RESPONSE_LENGTH]
# Load impulse response.
data = scipy.io.loadmat(impulse_response_filepath)
impulse_response = data['h_air'].flatten()
if self._MAX_IMPULSE_RESPONSE_LENGTH is not None:
logging.info('truncating impulse response from %d to %d samples',
len(impulse_response),
self._MAX_IMPULSE_RESPONSE_LENGTH)
impulse_response = impulse_response[:self.
_MAX_IMPULSE_RESPONSE_LENGTH]
# Apply impulse response.
processed_signal = (
signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
input_signal, impulse_response))
# Apply impulse response.
processed_signal = (
signal_processing.SignalProcessingUtils.ApplyImpulseResponse(
input_signal, impulse_response))
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noise_track_filepath, processed_signal)
# Save.
signal_processing.SignalProcessingUtils.SaveWav(
noise_track_filepath, processed_signal)
return processed_signal
return processed_signal

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""TestDataGenerator factory class.
"""
@ -16,15 +15,15 @@ from . import test_data_generation
class TestDataGeneratorFactory(object):
"""Factory class used to create test data generators.
"""Factory class used to create test data generators.
Usage: Create a factory passing parameters to the ctor with which the
generators will be produced.
"""
def __init__(self, aechen_ir_database_path, noise_tracks_path,
copy_with_identity):
"""Ctor.
def __init__(self, aechen_ir_database_path, noise_tracks_path,
copy_with_identity):
"""Ctor.
Args:
aechen_ir_database_path: Path to the Aechen Impulse Response database.
@ -32,16 +31,16 @@ class TestDataGeneratorFactory(object):
copy_with_identity: Flag indicating whether the identity generator has to
make copies of the clean speech input files.
"""
self._output_directory_prefix = None
self._aechen_ir_database_path = aechen_ir_database_path
self._noise_tracks_path = noise_tracks_path
self._copy_with_identity = copy_with_identity
self._output_directory_prefix = None
self._aechen_ir_database_path = aechen_ir_database_path
self._noise_tracks_path = noise_tracks_path
self._copy_with_identity = copy_with_identity
def SetOutputDirectoryPrefix(self, prefix):
self._output_directory_prefix = prefix
def SetOutputDirectoryPrefix(self, prefix):
self._output_directory_prefix = prefix
def GetInstance(self, test_data_generators_class):
"""Creates an TestDataGenerator instance given a class object.
def GetInstance(self, test_data_generators_class):
"""Creates an TestDataGenerator instance given a class object.
Args:
test_data_generators_class: TestDataGenerator class object (not an
@ -50,22 +49,23 @@ class TestDataGeneratorFactory(object):
Returns:
TestDataGenerator instance.
"""
if self._output_directory_prefix is None:
raise exceptions.InitializationException(
'The output directory prefix for test data generators is not set')
logging.debug('factory producing %s', test_data_generators_class)
if self._output_directory_prefix is None:
raise exceptions.InitializationException(
'The output directory prefix for test data generators is not set'
)
logging.debug('factory producing %s', test_data_generators_class)
if test_data_generators_class == (
test_data_generation.IdentityTestDataGenerator):
return test_data_generation.IdentityTestDataGenerator(
self._output_directory_prefix, self._copy_with_identity)
elif test_data_generators_class == (
test_data_generation.ReverberationTestDataGenerator):
return test_data_generation.ReverberationTestDataGenerator(
self._output_directory_prefix, self._aechen_ir_database_path)
elif test_data_generators_class == (
test_data_generation.AdditiveNoiseTestDataGenerator):
return test_data_generation.AdditiveNoiseTestDataGenerator(
self._output_directory_prefix, self._noise_tracks_path)
else:
return test_data_generators_class(self._output_directory_prefix)
if test_data_generators_class == (
test_data_generation.IdentityTestDataGenerator):
return test_data_generation.IdentityTestDataGenerator(
self._output_directory_prefix, self._copy_with_identity)
elif test_data_generators_class == (
test_data_generation.ReverberationTestDataGenerator):
return test_data_generation.ReverberationTestDataGenerator(
self._output_directory_prefix, self._aechen_ir_database_path)
elif test_data_generators_class == (
test_data_generation.AdditiveNoiseTestDataGenerator):
return test_data_generation.AdditiveNoiseTestDataGenerator(
self._output_directory_prefix, self._noise_tracks_path)
else:
return test_data_generators_class(self._output_directory_prefix)

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Unit tests for the test_data_generation module.
"""
@ -23,141 +22,143 @@ from . import signal_processing
class TestTestDataGenerators(unittest.TestCase):
"""Unit tests for the test_data_generation module.
"""Unit tests for the test_data_generation module.
"""
def setUp(self):
"""Create temporary folders."""
self._base_output_path = tempfile.mkdtemp()
self._test_data_cache_path = tempfile.mkdtemp()
self._fake_air_db_path = tempfile.mkdtemp()
def setUp(self):
"""Create temporary folders."""
self._base_output_path = tempfile.mkdtemp()
self._test_data_cache_path = tempfile.mkdtemp()
self._fake_air_db_path = tempfile.mkdtemp()
# Fake AIR DB impulse responses.
# TODO(alessiob): ReverberationTestDataGenerator will change to allow custom
# impulse responses. When changed, the coupling below between
# impulse_response_mat_file_names and
# ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed.
impulse_response_mat_file_names = [
'air_binaural_lecture_0_0_1.mat',
'air_binaural_booth_0_0_1.mat',
]
for impulse_response_mat_file_name in impulse_response_mat_file_names:
data = {'h_air': np.random.rand(1, 1000).astype('<f8')}
scipy.io.savemat(os.path.join(
self._fake_air_db_path, impulse_response_mat_file_name), data)
# Fake AIR DB impulse responses.
# TODO(alessiob): ReverberationTestDataGenerator will change to allow custom
# impulse responses. When changed, the coupling below between
# impulse_response_mat_file_names and
# ReverberationTestDataGenerator._IMPULSE_RESPONSES can be removed.
impulse_response_mat_file_names = [
'air_binaural_lecture_0_0_1.mat',
'air_binaural_booth_0_0_1.mat',
]
for impulse_response_mat_file_name in impulse_response_mat_file_names:
data = {'h_air': np.random.rand(1, 1000).astype('<f8')}
scipy.io.savemat(
os.path.join(self._fake_air_db_path,
impulse_response_mat_file_name), data)
def tearDown(self):
"""Recursively delete temporary folders."""
shutil.rmtree(self._base_output_path)
shutil.rmtree(self._test_data_cache_path)
shutil.rmtree(self._fake_air_db_path)
def tearDown(self):
"""Recursively delete temporary folders."""
shutil.rmtree(self._base_output_path)
shutil.rmtree(self._test_data_cache_path)
shutil.rmtree(self._fake_air_db_path)
def testTestDataGenerators(self):
# Preliminary check.
self.assertTrue(os.path.exists(self._base_output_path))
self.assertTrue(os.path.exists(self._test_data_cache_path))
def testTestDataGenerators(self):
# Preliminary check.
self.assertTrue(os.path.exists(self._base_output_path))
self.assertTrue(os.path.exists(self._test_data_cache_path))
# Check that there is at least one registered test data generator.
registered_classes = (
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Check that there is at least one registered test data generator.
registered_classes = (
test_data_generation.TestDataGenerator.REGISTERED_CLASSES)
self.assertIsInstance(registered_classes, dict)
self.assertGreater(len(registered_classes), 0)
# Instance generators factory.
generators_factory = test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path=self._fake_air_db_path,
noise_tracks_path=test_data_generation. \
AdditiveNoiseTestDataGenerator. \
DEFAULT_NOISE_TRACKS_PATH,
copy_with_identity=False)
generators_factory.SetOutputDirectoryPrefix('datagen-')
# Instance generators factory.
generators_factory = test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path=self._fake_air_db_path,
noise_tracks_path=test_data_generation. \
AdditiveNoiseTestDataGenerator. \
DEFAULT_NOISE_TRACKS_PATH,
copy_with_identity=False)
generators_factory.SetOutputDirectoryPrefix('datagen-')
# Use a simple input file as clean input signal.
input_signal_filepath = os.path.join(
os.getcwd(), 'probing_signals', 'tone-880.wav')
self.assertTrue(os.path.exists(input_signal_filepath))
# Use a simple input file as clean input signal.
input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals',
'tone-880.wav')
self.assertTrue(os.path.exists(input_signal_filepath))
# Load input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
# Load input signal.
input_signal = signal_processing.SignalProcessingUtils.LoadWav(
input_signal_filepath)
# Try each registered test data generator.
for generator_name in registered_classes:
# Instance test data generator.
generator = generators_factory.GetInstance(
registered_classes[generator_name])
# Try each registered test data generator.
for generator_name in registered_classes:
# Instance test data generator.
generator = generators_factory.GetInstance(
registered_classes[generator_name])
# Generate the noisy input - reference pairs.
generator.Generate(
input_signal_filepath=input_signal_filepath,
test_data_cache_path=self._test_data_cache_path,
base_output_path=self._base_output_path)
# Generate the noisy input - reference pairs.
generator.Generate(input_signal_filepath=input_signal_filepath,
test_data_cache_path=self._test_data_cache_path,
base_output_path=self._base_output_path)
# Perform checks.
self._CheckGeneratedPairsListSizes(generator)
self._CheckGeneratedPairsSignalDurations(generator, input_signal)
self._CheckGeneratedPairsOutputPaths(generator)
# Perform checks.
self._CheckGeneratedPairsListSizes(generator)
self._CheckGeneratedPairsSignalDurations(generator, input_signal)
self._CheckGeneratedPairsOutputPaths(generator)
def testTestidentityDataGenerator(self):
# Preliminary check.
self.assertTrue(os.path.exists(self._base_output_path))
self.assertTrue(os.path.exists(self._test_data_cache_path))
def testTestidentityDataGenerator(self):
# Preliminary check.
self.assertTrue(os.path.exists(self._base_output_path))
self.assertTrue(os.path.exists(self._test_data_cache_path))
# Use a simple input file as clean input signal.
input_signal_filepath = os.path.join(
os.getcwd(), 'probing_signals', 'tone-880.wav')
self.assertTrue(os.path.exists(input_signal_filepath))
# Use a simple input file as clean input signal.
input_signal_filepath = os.path.join(os.getcwd(), 'probing_signals',
'tone-880.wav')
self.assertTrue(os.path.exists(input_signal_filepath))
def GetNoiseReferenceFilePaths(identity_generator):
noisy_signal_filepaths = identity_generator.noisy_signal_filepaths
reference_signal_filepaths = identity_generator.reference_signal_filepaths
assert noisy_signal_filepaths.keys() == reference_signal_filepaths.keys()
assert len(noisy_signal_filepaths.keys()) == 1
key = noisy_signal_filepaths.keys()[0]
return noisy_signal_filepaths[key], reference_signal_filepaths[key]
def GetNoiseReferenceFilePaths(identity_generator):
noisy_signal_filepaths = identity_generator.noisy_signal_filepaths
reference_signal_filepaths = identity_generator.reference_signal_filepaths
assert noisy_signal_filepaths.keys(
) == reference_signal_filepaths.keys()
assert len(noisy_signal_filepaths.keys()) == 1
key = noisy_signal_filepaths.keys()[0]
return noisy_signal_filepaths[key], reference_signal_filepaths[key]
# Test the |copy_with_identity| flag.
for copy_with_identity in [False, True]:
# Instance the generator through the factory.
factory = test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='', noise_tracks_path='',
copy_with_identity=copy_with_identity)
factory.SetOutputDirectoryPrefix('datagen-')
generator = factory.GetInstance(
test_data_generation.IdentityTestDataGenerator)
# Check |copy_with_identity| is set correctly.
self.assertEqual(copy_with_identity, generator.copy_with_identity)
# Test the |copy_with_identity| flag.
for copy_with_identity in [False, True]:
# Instance the generator through the factory.
factory = test_data_generation_factory.TestDataGeneratorFactory(
aechen_ir_database_path='',
noise_tracks_path='',
copy_with_identity=copy_with_identity)
factory.SetOutputDirectoryPrefix('datagen-')
generator = factory.GetInstance(
test_data_generation.IdentityTestDataGenerator)
# Check |copy_with_identity| is set correctly.
self.assertEqual(copy_with_identity, generator.copy_with_identity)
# Generate test data and extract the paths to the noise and the reference
# files.
generator.Generate(
input_signal_filepath=input_signal_filepath,
test_data_cache_path=self._test_data_cache_path,
base_output_path=self._base_output_path)
noisy_signal_filepath, reference_signal_filepath = (
GetNoiseReferenceFilePaths(generator))
# Generate test data and extract the paths to the noise and the reference
# files.
generator.Generate(input_signal_filepath=input_signal_filepath,
test_data_cache_path=self._test_data_cache_path,
base_output_path=self._base_output_path)
noisy_signal_filepath, reference_signal_filepath = (
GetNoiseReferenceFilePaths(generator))
# Check that a copy is made if and only if |copy_with_identity| is True.
if copy_with_identity:
self.assertNotEqual(noisy_signal_filepath, input_signal_filepath)
self.assertNotEqual(reference_signal_filepath, input_signal_filepath)
else:
self.assertEqual(noisy_signal_filepath, input_signal_filepath)
self.assertEqual(reference_signal_filepath, input_signal_filepath)
# Check that a copy is made if and only if |copy_with_identity| is True.
if copy_with_identity:
self.assertNotEqual(noisy_signal_filepath,
input_signal_filepath)
self.assertNotEqual(reference_signal_filepath,
input_signal_filepath)
else:
self.assertEqual(noisy_signal_filepath, input_signal_filepath)
self.assertEqual(reference_signal_filepath,
input_signal_filepath)
def _CheckGeneratedPairsListSizes(self, generator):
config_names = generator.config_names
number_of_pairs = len(config_names)
self.assertEqual(number_of_pairs,
len(generator.noisy_signal_filepaths))
self.assertEqual(number_of_pairs,
len(generator.apm_output_paths))
self.assertEqual(number_of_pairs,
len(generator.reference_signal_filepaths))
def _CheckGeneratedPairsListSizes(self, generator):
config_names = generator.config_names
number_of_pairs = len(config_names)
self.assertEqual(number_of_pairs,
len(generator.noisy_signal_filepaths))
self.assertEqual(number_of_pairs, len(generator.apm_output_paths))
self.assertEqual(number_of_pairs,
len(generator.reference_signal_filepaths))
def _CheckGeneratedPairsSignalDurations(
self, generator, input_signal):
"""Checks duration of the generated signals.
def _CheckGeneratedPairsSignalDurations(self, generator, input_signal):
"""Checks duration of the generated signals.
Checks that the noisy input and the reference tracks are audio files
with duration equal to or greater than that of the input signal.
@ -166,41 +167,41 @@ class TestTestDataGenerators(unittest.TestCase):
generator: TestDataGenerator instance.
input_signal: AudioSegment instance.
"""
input_signal_length = (
signal_processing.SignalProcessingUtils.CountSamples(input_signal))
input_signal_length = (
signal_processing.SignalProcessingUtils.CountSamples(input_signal))
# Iterate over the noisy signal - reference pairs.
for config_name in generator.config_names:
# Load the noisy input file.
noisy_signal_filepath = generator.noisy_signal_filepaths[
config_name]
noisy_signal = signal_processing.SignalProcessingUtils.LoadWav(
noisy_signal_filepath)
# Iterate over the noisy signal - reference pairs.
for config_name in generator.config_names:
# Load the noisy input file.
noisy_signal_filepath = generator.noisy_signal_filepaths[
config_name]
noisy_signal = signal_processing.SignalProcessingUtils.LoadWav(
noisy_signal_filepath)
# Check noisy input signal length.
noisy_signal_length = (
signal_processing.SignalProcessingUtils.CountSamples(noisy_signal))
self.assertGreaterEqual(noisy_signal_length, input_signal_length)
# Check noisy input signal length.
noisy_signal_length = (signal_processing.SignalProcessingUtils.
CountSamples(noisy_signal))
self.assertGreaterEqual(noisy_signal_length, input_signal_length)
# Load the reference file.
reference_signal_filepath = generator.reference_signal_filepaths[
config_name]
reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
reference_signal_filepath)
# Load the reference file.
reference_signal_filepath = generator.reference_signal_filepaths[
config_name]
reference_signal = signal_processing.SignalProcessingUtils.LoadWav(
reference_signal_filepath)
# Check noisy input signal length.
reference_signal_length = (
signal_processing.SignalProcessingUtils.CountSamples(
reference_signal))
self.assertGreaterEqual(reference_signal_length, input_signal_length)
# Check noisy input signal length.
reference_signal_length = (signal_processing.SignalProcessingUtils.
CountSamples(reference_signal))
self.assertGreaterEqual(reference_signal_length,
input_signal_length)
def _CheckGeneratedPairsOutputPaths(self, generator):
"""Checks that the output path created by the generator exists.
def _CheckGeneratedPairsOutputPaths(self, generator):
"""Checks that the output path created by the generator exists.
Args:
generator: TestDataGenerator instance.
"""
# Iterate over the noisy signal - reference pairs.
for config_name in generator.config_names:
output_path = generator.apm_output_paths[config_name]
self.assertTrue(os.path.exists(output_path))
# Iterate over the noisy signal - reference pairs.
for config_name in generator.config_names:
output_path = generator.apm_output_paths[config_name]
self.assertTrue(os.path.exists(output_path))

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Plots statistics from WebRTC integration test logs.
Usage: $ python plot_webrtc_test_logs.py filename.txt
@ -52,43 +51,43 @@ AVG_DELTA_FRAME_SIZE = ('avg_delta_frame_size_bytes',
# Settings.
SETTINGS = [
WIDTH,
HEIGHT,
FILENAME,
NUM_FRAMES,
WIDTH,
HEIGHT,
FILENAME,
NUM_FRAMES,
]
# Settings, options for x-axis.
X_SETTINGS = [
CORES,
FRAMERATE,
DENOISING,
RESILIENCE,
ERROR_CONCEALMENT,
BITRATE, # TODO(asapersson): Needs to be last.
CORES,
FRAMERATE,
DENOISING,
RESILIENCE,
ERROR_CONCEALMENT,
BITRATE, # TODO(asapersson): Needs to be last.
]
# Settings, options for subplots.
SUBPLOT_SETTINGS = [
CODEC_TYPE,
ENCODER_IMPLEMENTATION_NAME,
DECODER_IMPLEMENTATION_NAME,
CODEC_IMPLEMENTATION_NAME,
CODEC_TYPE,
ENCODER_IMPLEMENTATION_NAME,
DECODER_IMPLEMENTATION_NAME,
CODEC_IMPLEMENTATION_NAME,
] + X_SETTINGS
# Results.
RESULTS = [
PSNR,
SSIM,
ENC_BITRATE,
NUM_DROPPED_FRAMES,
TIME_TO_TARGET,
ENCODE_SPEED_FPS,
DECODE_SPEED_FPS,
QP,
CPU_USAGE,
AVG_KEY_FRAME_SIZE,
AVG_DELTA_FRAME_SIZE,
PSNR,
SSIM,
ENC_BITRATE,
NUM_DROPPED_FRAMES,
TIME_TO_TARGET,
ENCODE_SPEED_FPS,
DECODE_SPEED_FPS,
QP,
CPU_USAGE,
AVG_KEY_FRAME_SIZE,
AVG_DELTA_FRAME_SIZE,
]
METRICS_TO_PARSE = SETTINGS + SUBPLOT_SETTINGS + RESULTS
@ -102,7 +101,7 @@ GRID_COLOR = [0.45, 0.45, 0.45]
def ParseSetting(filename, setting):
"""Parses setting from file.
"""Parses setting from file.
Args:
filename: The name of the file.
@ -111,36 +110,36 @@ def ParseSetting(filename, setting):
Returns:
A list holding parsed settings, e.g. ['width: 128.0', 'width: 160.0'] """
settings = []
settings = []
settings_file = open(filename)
while True:
line = settings_file.readline()
if not line:
break
if re.search(r'%s' % EVENT_START, line):
# Parse event.
parsed = {}
while True:
settings_file = open(filename)
while True:
line = settings_file.readline()
if not line:
break
if re.search(r'%s' % EVENT_END, line):
# Add parsed setting to list.
if setting in parsed:
s = setting + ': ' + str(parsed[setting])
if s not in settings:
settings.append(s)
break
break
if re.search(r'%s' % EVENT_START, line):
# Parse event.
parsed = {}
while True:
line = settings_file.readline()
if not line:
break
if re.search(r'%s' % EVENT_END, line):
# Add parsed setting to list.
if setting in parsed:
s = setting + ': ' + str(parsed[setting])
if s not in settings:
settings.append(s)
break
TryFindMetric(parsed, line)
TryFindMetric(parsed, line)
settings_file.close()
return settings
settings_file.close()
return settings
def ParseMetrics(filename, setting1, setting2):
"""Parses metrics from file.
"""Parses metrics from file.
Args:
filename: The name of the file.
@ -175,82 +174,82 @@ def ParseMetrics(filename, setting1, setting2):
}
} """
metrics = {}
metrics = {}
# Parse events.
settings_file = open(filename)
while True:
line = settings_file.readline()
if not line:
break
if re.search(r'%s' % EVENT_START, line):
# Parse event.
parsed = {}
while True:
# Parse events.
settings_file = open(filename)
while True:
line = settings_file.readline()
if not line:
break
if re.search(r'%s' % EVENT_END, line):
# Add parsed values to metrics.
key1 = setting1 + ': ' + str(parsed[setting1])
key2 = setting2 + ': ' + str(parsed[setting2])
if key1 not in metrics:
metrics[key1] = {}
if key2 not in metrics[key1]:
metrics[key1][key2] = {}
break
if re.search(r'%s' % EVENT_START, line):
# Parse event.
parsed = {}
while True:
line = settings_file.readline()
if not line:
break
if re.search(r'%s' % EVENT_END, line):
# Add parsed values to metrics.
key1 = setting1 + ': ' + str(parsed[setting1])
key2 = setting2 + ': ' + str(parsed[setting2])
if key1 not in metrics:
metrics[key1] = {}
if key2 not in metrics[key1]:
metrics[key1][key2] = {}
for label in parsed:
if label not in metrics[key1][key2]:
metrics[key1][key2][label] = []
metrics[key1][key2][label].append(parsed[label])
for label in parsed:
if label not in metrics[key1][key2]:
metrics[key1][key2][label] = []
metrics[key1][key2][label].append(parsed[label])
break
break
TryFindMetric(parsed, line)
TryFindMetric(parsed, line)
settings_file.close()
return metrics
settings_file.close()
return metrics
def TryFindMetric(parsed, line):
for metric in METRICS_TO_PARSE:
name = metric[0]
label = metric[1]
if re.search(r'%s' % name, line):
found, value = GetMetric(name, line)
if found:
parsed[label] = value
return
for metric in METRICS_TO_PARSE:
name = metric[0]
label = metric[1]
if re.search(r'%s' % name, line):
found, value = GetMetric(name, line)
if found:
parsed[label] = value
return
def GetMetric(name, string):
# Float (e.g. bitrate = 98.8253).
pattern = r'%s\s*[:=]\s*([+-]?\d+\.*\d*)' % name
m = re.search(r'%s' % pattern, string)
if m is not None:
return StringToFloat(m.group(1))
# Float (e.g. bitrate = 98.8253).
pattern = r'%s\s*[:=]\s*([+-]?\d+\.*\d*)' % name
m = re.search(r'%s' % pattern, string)
if m is not None:
return StringToFloat(m.group(1))
# Alphanumeric characters (e.g. codec type : VP8).
pattern = r'%s\s*[:=]\s*(\w+)' % name
m = re.search(r'%s' % pattern, string)
if m is not None:
return True, m.group(1)
# Alphanumeric characters (e.g. codec type : VP8).
pattern = r'%s\s*[:=]\s*(\w+)' % name
m = re.search(r'%s' % pattern, string)
if m is not None:
return True, m.group(1)
return False, -1
return False, -1
def StringToFloat(value):
try:
value = float(value)
except ValueError:
print "Not a float, skipped %s" % value
return False, -1
try:
value = float(value)
except ValueError:
print "Not a float, skipped %s" % value
return False, -1
return True, value
return True, value
def Plot(y_metric, x_metric, metrics):
"""Plots y_metric vs x_metric per key in metrics.
"""Plots y_metric vs x_metric per key in metrics.
For example:
y_metric = 'PSNR (dB)'
@ -266,26 +265,31 @@ def Plot(y_metric, x_metric, metrics):
},
}
"""
for key in sorted(metrics):
data = metrics[key]
if y_metric not in data:
print "Failed to find metric: %s" % y_metric
continue
for key in sorted(metrics):
data = metrics[key]
if y_metric not in data:
print "Failed to find metric: %s" % y_metric
continue
y = numpy.array(data[y_metric])
x = numpy.array(data[x_metric])
if len(y) != len(x):
print "Length mismatch for %s, %s" % (y, x)
continue
y = numpy.array(data[y_metric])
x = numpy.array(data[x_metric])
if len(y) != len(x):
print "Length mismatch for %s, %s" % (y, x)
continue
label = y_metric + ' - ' + str(key)
label = y_metric + ' - ' + str(key)
plt.plot(x, y, label=label, linewidth=1.5, marker='o', markersize=5,
markeredgewidth=0.0)
plt.plot(x,
y,
label=label,
linewidth=1.5,
marker='o',
markersize=5,
markeredgewidth=0.0)
def PlotFigure(settings, y_metrics, x_metric, metrics, title):
"""Plots metrics in y_metrics list. One figure is plotted and each entry
"""Plots metrics in y_metrics list. One figure is plotted and each entry
in the list is plotted in a subplot (and sorted per settings).
For example:
@ -295,136 +299,140 @@ def PlotFigure(settings, y_metrics, x_metric, metrics, title):
"""
plt.figure()
plt.suptitle(title, fontsize='large', fontweight='bold')
settings.sort()
rows = len(settings)
cols = 1
pos = 1
while pos <= rows:
plt.rc('grid', color=GRID_COLOR)
ax = plt.subplot(rows, cols, pos)
plt.grid()
plt.setp(ax.get_xticklabels(), visible=(pos == rows), fontsize='large')
plt.setp(ax.get_yticklabels(), fontsize='large')
setting = settings[pos - 1]
Plot(y_metrics[pos - 1], x_metric, metrics[setting])
if setting.startswith(WIDTH[1]):
plt.title(setting, fontsize='medium')
plt.legend(fontsize='large', loc='best')
pos += 1
plt.figure()
plt.suptitle(title, fontsize='large', fontweight='bold')
settings.sort()
rows = len(settings)
cols = 1
pos = 1
while pos <= rows:
plt.rc('grid', color=GRID_COLOR)
ax = plt.subplot(rows, cols, pos)
plt.grid()
plt.setp(ax.get_xticklabels(), visible=(pos == rows), fontsize='large')
plt.setp(ax.get_yticklabels(), fontsize='large')
setting = settings[pos - 1]
Plot(y_metrics[pos - 1], x_metric, metrics[setting])
if setting.startswith(WIDTH[1]):
plt.title(setting, fontsize='medium')
plt.legend(fontsize='large', loc='best')
pos += 1
plt.xlabel(x_metric, fontsize='large')
plt.subplots_adjust(left=0.06, right=0.98, bottom=0.05, top=0.94, hspace=0.08)
plt.xlabel(x_metric, fontsize='large')
plt.subplots_adjust(left=0.06,
right=0.98,
bottom=0.05,
top=0.94,
hspace=0.08)
def GetTitle(filename, setting):
title = ''
if setting != CODEC_IMPLEMENTATION_NAME[1] and setting != CODEC_TYPE[1]:
codec_types = ParseSetting(filename, CODEC_TYPE[1])
for i in range(0, len(codec_types)):
title += codec_types[i] + ', '
title = ''
if setting != CODEC_IMPLEMENTATION_NAME[1] and setting != CODEC_TYPE[1]:
codec_types = ParseSetting(filename, CODEC_TYPE[1])
for i in range(0, len(codec_types)):
title += codec_types[i] + ', '
if setting != CORES[1]:
cores = ParseSetting(filename, CORES[1])
for i in range(0, len(cores)):
title += cores[i].split('.')[0] + ', '
if setting != CORES[1]:
cores = ParseSetting(filename, CORES[1])
for i in range(0, len(cores)):
title += cores[i].split('.')[0] + ', '
if setting != FRAMERATE[1]:
framerate = ParseSetting(filename, FRAMERATE[1])
for i in range(0, len(framerate)):
title += framerate[i].split('.')[0] + ', '
if setting != FRAMERATE[1]:
framerate = ParseSetting(filename, FRAMERATE[1])
for i in range(0, len(framerate)):
title += framerate[i].split('.')[0] + ', '
if (setting != CODEC_IMPLEMENTATION_NAME[1] and
setting != ENCODER_IMPLEMENTATION_NAME[1]):
enc_names = ParseSetting(filename, ENCODER_IMPLEMENTATION_NAME[1])
for i in range(0, len(enc_names)):
title += enc_names[i] + ', '
if (setting != CODEC_IMPLEMENTATION_NAME[1]
and setting != ENCODER_IMPLEMENTATION_NAME[1]):
enc_names = ParseSetting(filename, ENCODER_IMPLEMENTATION_NAME[1])
for i in range(0, len(enc_names)):
title += enc_names[i] + ', '
if (setting != CODEC_IMPLEMENTATION_NAME[1] and
setting != DECODER_IMPLEMENTATION_NAME[1]):
dec_names = ParseSetting(filename, DECODER_IMPLEMENTATION_NAME[1])
for i in range(0, len(dec_names)):
title += dec_names[i] + ', '
if (setting != CODEC_IMPLEMENTATION_NAME[1]
and setting != DECODER_IMPLEMENTATION_NAME[1]):
dec_names = ParseSetting(filename, DECODER_IMPLEMENTATION_NAME[1])
for i in range(0, len(dec_names)):
title += dec_names[i] + ', '
filenames = ParseSetting(filename, FILENAME[1])
title += filenames[0].split('_')[0]
filenames = ParseSetting(filename, FILENAME[1])
title += filenames[0].split('_')[0]
num_frames = ParseSetting(filename, NUM_FRAMES[1])
for i in range(0, len(num_frames)):
title += ' (' + num_frames[i].split('.')[0] + ')'
num_frames = ParseSetting(filename, NUM_FRAMES[1])
for i in range(0, len(num_frames)):
title += ' (' + num_frames[i].split('.')[0] + ')'
return title
return title
def ToString(input_list):
return ToStringWithoutMetric(input_list, ('', ''))
return ToStringWithoutMetric(input_list, ('', ''))
def ToStringWithoutMetric(input_list, metric):
i = 1
output_str = ""
for m in input_list:
if m != metric:
output_str = output_str + ("%s. %s\n" % (i, m[1]))
i += 1
return output_str
i = 1
output_str = ""
for m in input_list:
if m != metric:
output_str = output_str + ("%s. %s\n" % (i, m[1]))
i += 1
return output_str
def GetIdx(text_list):
return int(raw_input(text_list)) - 1
return int(raw_input(text_list)) - 1
def main():
filename = sys.argv[1]
filename = sys.argv[1]
# Setup.
idx_metric = GetIdx("Choose metric:\n0. All\n%s" % ToString(RESULTS))
if idx_metric == -1:
# Plot all metrics. One subplot for each metric.
# Per subplot: metric vs bitrate (per resolution).
cores = ParseSetting(filename, CORES[1])
setting1 = CORES[1]
setting2 = WIDTH[1]
sub_keys = [cores[0]] * len(Y_METRICS)
y_metrics = Y_METRICS
x_metric = BITRATE[1]
else:
resolutions = ParseSetting(filename, WIDTH[1])
idx = GetIdx("Select metric for x-axis:\n%s" % ToString(X_SETTINGS))
if X_SETTINGS[idx] == BITRATE:
idx = GetIdx("Plot per:\n%s" % ToStringWithoutMetric(SUBPLOT_SETTINGS,
BITRATE))
idx_setting = METRICS_TO_PARSE.index(SUBPLOT_SETTINGS[idx])
# Plot one metric. One subplot for each resolution.
# Per subplot: metric vs bitrate (per setting).
setting1 = WIDTH[1]
setting2 = METRICS_TO_PARSE[idx_setting][1]
sub_keys = resolutions
y_metrics = [RESULTS[idx_metric][1]] * len(sub_keys)
x_metric = BITRATE[1]
# Setup.
idx_metric = GetIdx("Choose metric:\n0. All\n%s" % ToString(RESULTS))
if idx_metric == -1:
# Plot all metrics. One subplot for each metric.
# Per subplot: metric vs bitrate (per resolution).
cores = ParseSetting(filename, CORES[1])
setting1 = CORES[1]
setting2 = WIDTH[1]
sub_keys = [cores[0]] * len(Y_METRICS)
y_metrics = Y_METRICS
x_metric = BITRATE[1]
else:
# Plot one metric. One subplot for each resolution.
# Per subplot: metric vs setting (per bitrate).
setting1 = WIDTH[1]
setting2 = BITRATE[1]
sub_keys = resolutions
y_metrics = [RESULTS[idx_metric][1]] * len(sub_keys)
x_metric = X_SETTINGS[idx][1]
resolutions = ParseSetting(filename, WIDTH[1])
idx = GetIdx("Select metric for x-axis:\n%s" % ToString(X_SETTINGS))
if X_SETTINGS[idx] == BITRATE:
idx = GetIdx("Plot per:\n%s" %
ToStringWithoutMetric(SUBPLOT_SETTINGS, BITRATE))
idx_setting = METRICS_TO_PARSE.index(SUBPLOT_SETTINGS[idx])
# Plot one metric. One subplot for each resolution.
# Per subplot: metric vs bitrate (per setting).
setting1 = WIDTH[1]
setting2 = METRICS_TO_PARSE[idx_setting][1]
sub_keys = resolutions
y_metrics = [RESULTS[idx_metric][1]] * len(sub_keys)
x_metric = BITRATE[1]
else:
# Plot one metric. One subplot for each resolution.
# Per subplot: metric vs setting (per bitrate).
setting1 = WIDTH[1]
setting2 = BITRATE[1]
sub_keys = resolutions
y_metrics = [RESULTS[idx_metric][1]] * len(sub_keys)
x_metric = X_SETTINGS[idx][1]
metrics = ParseMetrics(filename, setting1, setting2)
metrics = ParseMetrics(filename, setting1, setting2)
# Stretch fig size.
figsize = plt.rcParams["figure.figsize"]
figsize[0] *= FIG_SIZE_SCALE_FACTOR_X
figsize[1] *= FIG_SIZE_SCALE_FACTOR_Y
plt.rcParams["figure.figsize"] = figsize
# Stretch fig size.
figsize = plt.rcParams["figure.figsize"]
figsize[0] *= FIG_SIZE_SCALE_FACTOR_X
figsize[1] *= FIG_SIZE_SCALE_FACTOR_Y
plt.rcParams["figure.figsize"] = figsize
PlotFigure(sub_keys, y_metrics, x_metric, metrics,
GetTitle(filename, setting2))
PlotFigure(sub_keys, y_metrics, x_metric, metrics,
GetTitle(filename, setting2))
plt.show()
plt.show()
if __name__ == '__main__':
main()
main()

View File

@ -20,146 +20,145 @@ from presubmit_test_mocks import MockInputApi, MockOutputApi, MockFile, MockChan
class CheckBugEntryFieldTest(unittest.TestCase):
def testCommitMessageBugEntryWithNoError(self):
mock_input_api = MockInputApi()
mock_output_api = MockOutputApi()
mock_input_api.change = MockChange([], ['webrtc:1234'])
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
mock_output_api)
self.assertEqual(0, len(errors))
def testCommitMessageBugEntryWithNoError(self):
mock_input_api = MockInputApi()
mock_output_api = MockOutputApi()
mock_input_api.change = MockChange([], ['webrtc:1234'])
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
mock_output_api)
self.assertEqual(0, len(errors))
def testCommitMessageBugEntryReturnError(self):
mock_input_api = MockInputApi()
mock_output_api = MockOutputApi()
mock_input_api.change = MockChange([], ['webrtc:1234', 'webrtc=4321'])
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
mock_output_api)
self.assertEqual(1, len(errors))
self.assertEqual(('Bogus Bug entry: webrtc=4321. Please specify'
' the issue tracker prefix and the issue number,'
' separated by a colon, e.g. webrtc:123 or'
' chromium:12345.'), str(errors[0]))
def testCommitMessageBugEntryReturnError(self):
mock_input_api = MockInputApi()
mock_output_api = MockOutputApi()
mock_input_api.change = MockChange([], ['webrtc:1234', 'webrtc=4321'])
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
mock_output_api)
self.assertEqual(1, len(errors))
self.assertEqual(('Bogus Bug entry: webrtc=4321. Please specify'
' the issue tracker prefix and the issue number,'
' separated by a colon, e.g. webrtc:123 or'
' chromium:12345.'), str(errors[0]))
def testCommitMessageBugEntryWithoutPrefix(self):
mock_input_api = MockInputApi()
mock_output_api = MockOutputApi()
mock_input_api.change = MockChange([], ['1234'])
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
mock_output_api)
self.assertEqual(1, len(errors))
self.assertEqual(('Bug entry requires issue tracker prefix, '
'e.g. webrtc:1234'), str(errors[0]))
def testCommitMessageBugEntryWithoutPrefix(self):
mock_input_api = MockInputApi()
mock_output_api = MockOutputApi()
mock_input_api.change = MockChange([], ['1234'])
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
mock_output_api)
self.assertEqual(1, len(errors))
self.assertEqual(('Bug entry requires issue tracker prefix, '
'e.g. webrtc:1234'), str(errors[0]))
def testCommitMessageBugEntryIsNone(self):
mock_input_api = MockInputApi()
mock_output_api = MockOutputApi()
mock_input_api.change = MockChange([], ['None'])
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
mock_output_api)
self.assertEqual(0, len(errors))
def testCommitMessageBugEntryIsNone(self):
mock_input_api = MockInputApi()
mock_output_api = MockOutputApi()
mock_input_api.change = MockChange([], ['None'])
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
mock_output_api)
self.assertEqual(0, len(errors))
def testCommitMessageBugEntrySupportInternalBugReference(self):
mock_input_api = MockInputApi()
mock_output_api = MockOutputApi()
mock_input_api.change.BUG = 'b/12345'
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
mock_output_api)
self.assertEqual(0, len(errors))
mock_input_api.change.BUG = 'b/12345, webrtc:1234'
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
mock_output_api)
self.assertEqual(0, len(errors))
def testCommitMessageBugEntrySupportInternalBugReference(self):
mock_input_api = MockInputApi()
mock_output_api = MockOutputApi()
mock_input_api.change.BUG = 'b/12345'
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
mock_output_api)
self.assertEqual(0, len(errors))
mock_input_api.change.BUG = 'b/12345, webrtc:1234'
errors = PRESUBMIT.CheckCommitMessageBugEntry(mock_input_api,
mock_output_api)
self.assertEqual(0, len(errors))
class CheckNewlineAtTheEndOfProtoFilesTest(unittest.TestCase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
self.proto_file_path = os.path.join(self.tmp_dir, 'foo.proto')
self.input_api = MockInputApi()
self.output_api = MockOutputApi()
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
self.proto_file_path = os.path.join(self.tmp_dir, 'foo.proto')
self.input_api = MockInputApi()
self.output_api = MockOutputApi()
def tearDown(self):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
def tearDown(self):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
def testErrorIfProtoFileDoesNotEndWithNewline(self):
self._GenerateProtoWithoutNewlineAtTheEnd()
self.input_api.files = [MockFile(self.proto_file_path)]
errors = PRESUBMIT.CheckNewlineAtTheEndOfProtoFiles(
self.input_api, self.output_api, lambda x: True)
self.assertEqual(1, len(errors))
self.assertEqual(
'File %s must end with exactly one newline.' %
self.proto_file_path, str(errors[0]))
def testErrorIfProtoFileDoesNotEndWithNewline(self):
self._GenerateProtoWithoutNewlineAtTheEnd()
self.input_api.files = [MockFile(self.proto_file_path)]
errors = PRESUBMIT.CheckNewlineAtTheEndOfProtoFiles(self.input_api,
self.output_api,
lambda x: True)
self.assertEqual(1, len(errors))
self.assertEqual(
'File %s must end with exactly one newline.' % self.proto_file_path,
str(errors[0]))
def testNoErrorIfProtoFileEndsWithNewline(self):
self._GenerateProtoWithNewlineAtTheEnd()
self.input_api.files = [MockFile(self.proto_file_path)]
errors = PRESUBMIT.CheckNewlineAtTheEndOfProtoFiles(
self.input_api, self.output_api, lambda x: True)
self.assertEqual(0, len(errors))
def testNoErrorIfProtoFileEndsWithNewline(self):
self._GenerateProtoWithNewlineAtTheEnd()
self.input_api.files = [MockFile(self.proto_file_path)]
errors = PRESUBMIT.CheckNewlineAtTheEndOfProtoFiles(self.input_api,
self.output_api,
lambda x: True)
self.assertEqual(0, len(errors))
def _GenerateProtoWithNewlineAtTheEnd(self):
with open(self.proto_file_path, 'w') as f:
f.write(textwrap.dedent("""
def _GenerateProtoWithNewlineAtTheEnd(self):
with open(self.proto_file_path, 'w') as f:
f.write(
textwrap.dedent("""
syntax = "proto2";
option optimize_for = LITE_RUNTIME;
package webrtc.audioproc;
"""))
def _GenerateProtoWithoutNewlineAtTheEnd(self):
with open(self.proto_file_path, 'w') as f:
f.write(textwrap.dedent("""
def _GenerateProtoWithoutNewlineAtTheEnd(self):
with open(self.proto_file_path, 'w') as f:
f.write(
textwrap.dedent("""
syntax = "proto2";
option optimize_for = LITE_RUNTIME;
package webrtc.audioproc;"""))
class CheckNoMixingSourcesTest(unittest.TestCase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
self.file_path = os.path.join(self.tmp_dir, 'BUILD.gn')
self.input_api = MockInputApi()
self.output_api = MockOutputApi()
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
self.file_path = os.path.join(self.tmp_dir, 'BUILD.gn')
self.input_api = MockInputApi()
self.output_api = MockOutputApi()
def tearDown(self):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
def tearDown(self):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
def testErrorIfCAndCppAreMixed(self):
self._AssertNumberOfErrorsWithSources(1, ['foo.c', 'bar.cc', 'bar.h'])
def testErrorIfCAndCppAreMixed(self):
self._AssertNumberOfErrorsWithSources(1, ['foo.c', 'bar.cc', 'bar.h'])
def testErrorIfCAndObjCAreMixed(self):
self._AssertNumberOfErrorsWithSources(1, ['foo.c', 'bar.m', 'bar.h'])
def testErrorIfCAndObjCAreMixed(self):
self._AssertNumberOfErrorsWithSources(1, ['foo.c', 'bar.m', 'bar.h'])
def testErrorIfCAndObjCppAreMixed(self):
self._AssertNumberOfErrorsWithSources(1, ['foo.c', 'bar.mm', 'bar.h'])
def testErrorIfCAndObjCppAreMixed(self):
self._AssertNumberOfErrorsWithSources(1, ['foo.c', 'bar.mm', 'bar.h'])
def testErrorIfCppAndObjCAreMixed(self):
self._AssertNumberOfErrorsWithSources(1, ['foo.cc', 'bar.m', 'bar.h'])
def testErrorIfCppAndObjCAreMixed(self):
self._AssertNumberOfErrorsWithSources(1, ['foo.cc', 'bar.m', 'bar.h'])
def testErrorIfCppAndObjCppAreMixed(self):
self._AssertNumberOfErrorsWithSources(1, ['foo.cc', 'bar.mm', 'bar.h'])
def testErrorIfCppAndObjCppAreMixed(self):
self._AssertNumberOfErrorsWithSources(1, ['foo.cc', 'bar.mm', 'bar.h'])
def testNoErrorIfOnlyC(self):
self._AssertNumberOfErrorsWithSources(0, ['foo.c', 'bar.c', 'bar.h'])
def testNoErrorIfOnlyC(self):
self._AssertNumberOfErrorsWithSources(0, ['foo.c', 'bar.c', 'bar.h'])
def testNoErrorIfOnlyCpp(self):
self._AssertNumberOfErrorsWithSources(0, ['foo.cc', 'bar.cc', 'bar.h'])
def testNoErrorIfOnlyCpp(self):
self._AssertNumberOfErrorsWithSources(0, ['foo.cc', 'bar.cc', 'bar.h'])
def testNoErrorIfOnlyObjC(self):
self._AssertNumberOfErrorsWithSources(0, ['foo.m', 'bar.m', 'bar.h'])
def testNoErrorIfOnlyObjC(self):
self._AssertNumberOfErrorsWithSources(0, ['foo.m', 'bar.m', 'bar.h'])
def testNoErrorIfOnlyObjCpp(self):
self._AssertNumberOfErrorsWithSources(0, ['foo.mm', 'bar.mm', 'bar.h'])
def testNoErrorIfOnlyObjCpp(self):
self._AssertNumberOfErrorsWithSources(0, ['foo.mm', 'bar.mm', 'bar.h'])
def testNoErrorIfObjCAndObjCppAreMixed(self):
self._AssertNumberOfErrorsWithSources(0, ['foo.m', 'bar.mm', 'bar.h'])
def testNoErrorIfObjCAndObjCppAreMixed(self):
self._AssertNumberOfErrorsWithSources(0, ['foo.m', 'bar.mm', 'bar.h'])
def testNoErrorIfSourcesAreInExclusiveIfBranches(self):
self._GenerateBuildFile(textwrap.dedent("""
def testNoErrorIfSourcesAreInExclusiveIfBranches(self):
self._GenerateBuildFile(
textwrap.dedent("""
rtc_library("bar_foo") {
if (is_win) {
sources = [
@ -185,14 +184,15 @@ class CheckNoMixingSourcesTest(unittest.TestCase):
}
}
"""))
self.input_api.files = [MockFile(self.file_path)]
errors = PRESUBMIT.CheckNoMixingSources(self.input_api,
[MockFile(self.file_path)],
self.output_api)
self.assertEqual(0, len(errors))
self.input_api.files = [MockFile(self.file_path)]
errors = PRESUBMIT.CheckNoMixingSources(self.input_api,
[MockFile(self.file_path)],
self.output_api)
self.assertEqual(0, len(errors))
def testErrorIfSourcesAreNotInExclusiveIfBranches(self):
self._GenerateBuildFile(textwrap.dedent("""
def testErrorIfSourcesAreNotInExclusiveIfBranches(self):
self._GenerateBuildFile(
textwrap.dedent("""
rtc_library("bar_foo") {
if (is_win) {
sources = [
@ -224,21 +224,23 @@ class CheckNoMixingSourcesTest(unittest.TestCase):
}
}
"""))
self.input_api.files = [MockFile(self.file_path)]
errors = PRESUBMIT.CheckNoMixingSources(self.input_api,
[MockFile(self.file_path)],
self.output_api)
self.assertEqual(1, len(errors))
self.assertTrue('bar.cc' in str(errors[0]))
self.assertTrue('bar.mm' in str(errors[0]))
self.assertTrue('foo.cc' in str(errors[0]))
self.assertTrue('foo.mm' in str(errors[0]))
self.assertTrue('bar.m' in str(errors[0]))
self.assertTrue('bar.c' in str(errors[0]))
self.input_api.files = [MockFile(self.file_path)]
errors = PRESUBMIT.CheckNoMixingSources(self.input_api,
[MockFile(self.file_path)],
self.output_api)
self.assertEqual(1, len(errors))
self.assertTrue('bar.cc' in str(errors[0]))
self.assertTrue('bar.mm' in str(errors[0]))
self.assertTrue('foo.cc' in str(errors[0]))
self.assertTrue('foo.mm' in str(errors[0]))
self.assertTrue('bar.m' in str(errors[0]))
self.assertTrue('bar.c' in str(errors[0]))
def _AssertNumberOfErrorsWithSources(self, number_of_errors, sources):
assert len(sources) == 3, 'This function accepts a list of 3 source files'
self._GenerateBuildFile(textwrap.dedent("""
def _AssertNumberOfErrorsWithSources(self, number_of_errors, sources):
assert len(
sources) == 3, 'This function accepts a list of 3 source files'
self._GenerateBuildFile(
textwrap.dedent("""
rtc_static_library("bar_foo") {
sources = [
"%s",
@ -254,20 +256,20 @@ class CheckNoMixingSourcesTest(unittest.TestCase):
],
}
""" % (tuple(sources) * 2)))
self.input_api.files = [MockFile(self.file_path)]
errors = PRESUBMIT.CheckNoMixingSources(self.input_api,
[MockFile(self.file_path)],
self.output_api)
self.assertEqual(number_of_errors, len(errors))
if number_of_errors == 1:
for source in sources:
if not source.endswith('.h'):
self.assertTrue(source in str(errors[0]))
self.input_api.files = [MockFile(self.file_path)]
errors = PRESUBMIT.CheckNoMixingSources(self.input_api,
[MockFile(self.file_path)],
self.output_api)
self.assertEqual(number_of_errors, len(errors))
if number_of_errors == 1:
for source in sources:
if not source.endswith('.h'):
self.assertTrue(source in str(errors[0]))
def _GenerateBuildFile(self, content):
with open(self.file_path, 'w') as f:
f.write(content)
def _GenerateBuildFile(self, content):
with open(self.file_path, 'w') as f:
f.write(content)
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@ -14,118 +14,125 @@ import re
class MockInputApi(object):
"""Mock class for the InputApi class.
"""Mock class for the InputApi class.
This class can be used for unittests for presubmit by initializing the files
attribute as the list of changed files.
"""
def __init__(self):
self.change = MockChange([], [])
self.files = []
self.presubmit_local_path = os.path.dirname(__file__)
def __init__(self):
self.change = MockChange([], [])
self.files = []
self.presubmit_local_path = os.path.dirname(__file__)
def AffectedSourceFiles(self, file_filter=None):
return self.AffectedFiles(file_filter=file_filter)
def AffectedSourceFiles(self, file_filter=None):
return self.AffectedFiles(file_filter=file_filter)
def AffectedFiles(self, file_filter=None, include_deletes=False):
# pylint: disable=unused-argument
return self.files
def AffectedFiles(self, file_filter=None, include_deletes=False):
# pylint: disable=unused-argument
return self.files
@classmethod
def FilterSourceFile(cls, affected_file, files_to_check=(),
files_to_skip=()):
# pylint: disable=unused-argument
return True
@classmethod
def FilterSourceFile(cls,
affected_file,
files_to_check=(),
files_to_skip=()):
# pylint: disable=unused-argument
return True
def PresubmitLocalPath(self):
return self.presubmit_local_path
def PresubmitLocalPath(self):
return self.presubmit_local_path
def ReadFile(self, affected_file, mode='rU'):
filename = affected_file.AbsoluteLocalPath()
for f in self.files:
if f.LocalPath() == filename:
with open(filename, mode) as f:
return f.read()
# Otherwise, file is not in our mock API.
raise IOError, "No such file or directory: '%s'" % filename
def ReadFile(self, affected_file, mode='rU'):
filename = affected_file.AbsoluteLocalPath()
for f in self.files:
if f.LocalPath() == filename:
with open(filename, mode) as f:
return f.read()
# Otherwise, file is not in our mock API.
raise IOError, "No such file or directory: '%s'" % filename
class MockOutputApi(object):
"""Mock class for the OutputApi class.
"""Mock class for the OutputApi class.
An instance of this class can be passed to presubmit unittests for outputing
various types of results.
"""
class PresubmitResult(object):
def __init__(self, message, items=None, long_text=''):
self.message = message
self.items = items
self.long_text = long_text
class PresubmitResult(object):
def __init__(self, message, items=None, long_text=''):
self.message = message
self.items = items
self.long_text = long_text
def __repr__(self):
return self.message
def __repr__(self):
return self.message
class PresubmitError(PresubmitResult):
def __init__(self, message, items=None, long_text=''):
MockOutputApi.PresubmitResult.__init__(self, message, items, long_text)
self.type = 'error'
class PresubmitError(PresubmitResult):
def __init__(self, message, items=None, long_text=''):
MockOutputApi.PresubmitResult.__init__(self, message, items,
long_text)
self.type = 'error'
class MockChange(object):
"""Mock class for Change class.
"""Mock class for Change class.
This class can be used in presubmit unittests to mock the query of the
current change.
"""
def __init__(self, changed_files, bugs_from_description, tags=None):
self._changed_files = changed_files
self._bugs_from_description = bugs_from_description
self.tags = dict() if not tags else tags
def __init__(self, changed_files, bugs_from_description, tags=None):
self._changed_files = changed_files
self._bugs_from_description = bugs_from_description
self.tags = dict() if not tags else tags
def BugsFromDescription(self):
return self._bugs_from_description
def BugsFromDescription(self):
return self._bugs_from_description
def __getattr__(self, attr):
"""Return tags directly as attributes on the object."""
if not re.match(r"^[A-Z_]*$", attr):
raise AttributeError(self, attr)
return self.tags.get(attr)
def __getattr__(self, attr):
"""Return tags directly as attributes on the object."""
if not re.match(r"^[A-Z_]*$", attr):
raise AttributeError(self, attr)
return self.tags.get(attr)
class MockFile(object):
"""Mock class for the File class.
"""Mock class for the File class.
This class can be used to form the mock list of changed files in
MockInputApi for presubmit unittests.
"""
def __init__(self, local_path, new_contents=None, old_contents=None,
action='A'):
if new_contents is None:
new_contents = ["Data"]
self._local_path = local_path
self._new_contents = new_contents
self._changed_contents = [(i + 1, l) for i, l in enumerate(new_contents)]
self._action = action
self._old_contents = old_contents
def __init__(self,
local_path,
new_contents=None,
old_contents=None,
action='A'):
if new_contents is None:
new_contents = ["Data"]
self._local_path = local_path
self._new_contents = new_contents
self._changed_contents = [(i + 1, l)
for i, l in enumerate(new_contents)]
self._action = action
self._old_contents = old_contents
def Action(self):
return self._action
def Action(self):
return self._action
def ChangedContents(self):
return self._changed_contents
def ChangedContents(self):
return self._changed_contents
def NewContents(self):
return self._new_contents
def NewContents(self):
return self._new_contents
def LocalPath(self):
return self._local_path
def LocalPath(self):
return self._local_path
def AbsoluteLocalPath(self):
return self._local_path
def AbsoluteLocalPath(self):
return self._local_path
def OldContents(self):
return self._old_contents
def OldContents(self):
return self._old_contents

View File

@ -18,7 +18,6 @@ import subprocess
import sys
import tempfile
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
# Chrome browsertests will throw away stderr; avoid that output gets lost.
@ -26,131 +25,154 @@ sys.stderr = sys.stdout
def _ParseArgs():
"""Registers the command-line options."""
usage = 'usage: %prog [options]'
parser = optparse.OptionParser(usage=usage)
"""Registers the command-line options."""
usage = 'usage: %prog [options]'
parser = optparse.OptionParser(usage=usage)
parser.add_option('--label', type='string', default='MY_TEST',
help=('Label of the test, used to identify different '
'tests. Default: %default'))
parser.add_option('--ref_video', type='string',
help='Reference video to compare with (YUV).')
parser.add_option('--test_video', type='string',
help=('Test video to be compared with the reference '
'video (YUV).'))
parser.add_option('--frame_analyzer', type='string',
help='Path to the frame analyzer executable.')
parser.add_option('--aligned_output_file', type='string',
help='Path for output aligned YUV or Y4M file.')
parser.add_option('--vmaf', type='string',
help='Path to VMAF executable.')
parser.add_option('--vmaf_model', type='string',
help='Path to VMAF model.')
parser.add_option('--vmaf_phone_model', action='store_true',
help='Whether to use phone model in VMAF.')
parser.add_option('--yuv_frame_width', type='int', default=640,
help='Width of the YUV file\'s frames. Default: %default')
parser.add_option('--yuv_frame_height', type='int', default=480,
help='Height of the YUV file\'s frames. Default: %default')
parser.add_option('--chartjson_result_file', type='str', default=None,
help='Where to store perf results in chartjson format.')
options, _ = parser.parse_args()
parser.add_option('--label',
type='string',
default='MY_TEST',
help=('Label of the test, used to identify different '
'tests. Default: %default'))
parser.add_option('--ref_video',
type='string',
help='Reference video to compare with (YUV).')
parser.add_option('--test_video',
type='string',
help=('Test video to be compared with the reference '
'video (YUV).'))
parser.add_option('--frame_analyzer',
type='string',
help='Path to the frame analyzer executable.')
parser.add_option('--aligned_output_file',
type='string',
help='Path for output aligned YUV or Y4M file.')
parser.add_option('--vmaf', type='string', help='Path to VMAF executable.')
parser.add_option('--vmaf_model',
type='string',
help='Path to VMAF model.')
parser.add_option('--vmaf_phone_model',
action='store_true',
help='Whether to use phone model in VMAF.')
parser.add_option(
'--yuv_frame_width',
type='int',
default=640,
help='Width of the YUV file\'s frames. Default: %default')
parser.add_option(
'--yuv_frame_height',
type='int',
default=480,
help='Height of the YUV file\'s frames. Default: %default')
parser.add_option('--chartjson_result_file',
type='str',
default=None,
help='Where to store perf results in chartjson format.')
options, _ = parser.parse_args()
if not options.ref_video:
parser.error('You must provide a path to the reference video!')
if not os.path.exists(options.ref_video):
parser.error('Cannot find the reference video at %s' % options.ref_video)
if not options.ref_video:
parser.error('You must provide a path to the reference video!')
if not os.path.exists(options.ref_video):
parser.error('Cannot find the reference video at %s' %
options.ref_video)
if not options.test_video:
parser.error('You must provide a path to the test video!')
if not os.path.exists(options.test_video):
parser.error('Cannot find the test video at %s' % options.test_video)
if not options.test_video:
parser.error('You must provide a path to the test video!')
if not os.path.exists(options.test_video):
parser.error('Cannot find the test video at %s' % options.test_video)
if not options.frame_analyzer:
parser.error('You must provide the path to the frame analyzer executable!')
if not os.path.exists(options.frame_analyzer):
parser.error('Cannot find frame analyzer executable at %s!' %
options.frame_analyzer)
if not options.frame_analyzer:
parser.error(
'You must provide the path to the frame analyzer executable!')
if not os.path.exists(options.frame_analyzer):
parser.error('Cannot find frame analyzer executable at %s!' %
options.frame_analyzer)
if options.vmaf and not options.vmaf_model:
parser.error('You must provide a path to a VMAF model to use VMAF.')
if options.vmaf and not options.vmaf_model:
parser.error('You must provide a path to a VMAF model to use VMAF.')
return options
return options
def _DevNull():
"""On Windows, sometimes the inherited stdin handle from the parent process
"""On Windows, sometimes the inherited stdin handle from the parent process
fails. Workaround this by passing null to stdin to the subprocesses commands.
This function can be used to create the null file handler.
"""
return open(os.devnull, 'r')
return open(os.devnull, 'r')
def _RunFrameAnalyzer(options, yuv_directory=None):
"""Run frame analyzer to compare the videos and print output."""
cmd = [
options.frame_analyzer,
'--label=%s' % options.label,
'--reference_file=%s' % options.ref_video,
'--test_file=%s' % options.test_video,
'--width=%d' % options.yuv_frame_width,
'--height=%d' % options.yuv_frame_height,
]
if options.chartjson_result_file:
cmd.append('--chartjson_result_file=%s' % options.chartjson_result_file)
if options.aligned_output_file:
cmd.append('--aligned_output_file=%s' % options.aligned_output_file)
if yuv_directory:
cmd.append('--yuv_directory=%s' % yuv_directory)
frame_analyzer = subprocess.Popen(cmd, stdin=_DevNull(),
stdout=sys.stdout, stderr=sys.stderr)
frame_analyzer.wait()
if frame_analyzer.returncode != 0:
print('Failed to run frame analyzer.')
return frame_analyzer.returncode
"""Run frame analyzer to compare the videos and print output."""
cmd = [
options.frame_analyzer,
'--label=%s' % options.label,
'--reference_file=%s' % options.ref_video,
'--test_file=%s' % options.test_video,
'--width=%d' % options.yuv_frame_width,
'--height=%d' % options.yuv_frame_height,
]
if options.chartjson_result_file:
cmd.append('--chartjson_result_file=%s' %
options.chartjson_result_file)
if options.aligned_output_file:
cmd.append('--aligned_output_file=%s' % options.aligned_output_file)
if yuv_directory:
cmd.append('--yuv_directory=%s' % yuv_directory)
frame_analyzer = subprocess.Popen(cmd,
stdin=_DevNull(),
stdout=sys.stdout,
stderr=sys.stderr)
frame_analyzer.wait()
if frame_analyzer.returncode != 0:
print('Failed to run frame analyzer.')
return frame_analyzer.returncode
def _RunVmaf(options, yuv_directory, logfile):
""" Run VMAF to compare videos and print output.
""" Run VMAF to compare videos and print output.
The yuv_directory is assumed to have been populated with a reference and test
video in .yuv format, with names according to the label.
"""
cmd = [
options.vmaf,
'yuv420p',
str(options.yuv_frame_width),
str(options.yuv_frame_height),
os.path.join(yuv_directory, "ref.yuv"),
os.path.join(yuv_directory, "test.yuv"),
options.vmaf_model,
'--log',
logfile,
'--log-fmt',
'json',
]
if options.vmaf_phone_model:
cmd.append('--phone-model')
cmd = [
options.vmaf,
'yuv420p',
str(options.yuv_frame_width),
str(options.yuv_frame_height),
os.path.join(yuv_directory, "ref.yuv"),
os.path.join(yuv_directory, "test.yuv"),
options.vmaf_model,
'--log',
logfile,
'--log-fmt',
'json',
]
if options.vmaf_phone_model:
cmd.append('--phone-model')
vmaf = subprocess.Popen(cmd, stdin=_DevNull(),
stdout=sys.stdout, stderr=sys.stderr)
vmaf.wait()
if vmaf.returncode != 0:
print('Failed to run VMAF.')
return 1
vmaf = subprocess.Popen(cmd,
stdin=_DevNull(),
stdout=sys.stdout,
stderr=sys.stderr)
vmaf.wait()
if vmaf.returncode != 0:
print('Failed to run VMAF.')
return 1
# Read per-frame scores from VMAF output and print.
with open(logfile) as f:
vmaf_data = json.load(f)
vmaf_scores = []
for frame in vmaf_data['frames']:
vmaf_scores.append(frame['metrics']['vmaf'])
print('RESULT VMAF: %s=' % options.label, vmaf_scores)
# Read per-frame scores from VMAF output and print.
with open(logfile) as f:
vmaf_data = json.load(f)
vmaf_scores = []
for frame in vmaf_data['frames']:
vmaf_scores.append(frame['metrics']['vmaf'])
print('RESULT VMAF: %s=' % options.label, vmaf_scores)
return 0
return 0
def main():
"""The main function.
"""The main function.
A simple invocation is:
./webrtc/rtc_tools/compare_videos.py
@ -161,27 +183,28 @@ def main():
Running vmaf requires the following arguments:
--vmaf, --vmaf_model, --yuv_frame_width, --yuv_frame_height
"""
options = _ParseArgs()
options = _ParseArgs()
if options.vmaf:
try:
# Directory to save temporary YUV files for VMAF in frame_analyzer.
yuv_directory = tempfile.mkdtemp()
_, vmaf_logfile = tempfile.mkstemp()
if options.vmaf:
try:
# Directory to save temporary YUV files for VMAF in frame_analyzer.
yuv_directory = tempfile.mkdtemp()
_, vmaf_logfile = tempfile.mkstemp()
# Run frame analyzer to compare the videos and print output.
if _RunFrameAnalyzer(options, yuv_directory=yuv_directory) != 0:
return 1
# Run frame analyzer to compare the videos and print output.
if _RunFrameAnalyzer(options, yuv_directory=yuv_directory) != 0:
return 1
# Run VMAF for further video comparison and print output.
return _RunVmaf(options, yuv_directory, vmaf_logfile)
finally:
shutil.rmtree(yuv_directory)
os.remove(vmaf_logfile)
else:
return _RunFrameAnalyzer(options)
# Run VMAF for further video comparison and print output.
return _RunVmaf(options, yuv_directory, vmaf_logfile)
finally:
shutil.rmtree(yuv_directory)
os.remove(vmaf_logfile)
else:
return _RunFrameAnalyzer(options)
return 0
return 0
if __name__ == '__main__':
sys.exit(main())
sys.exit(main())

View File

@ -39,52 +39,59 @@ MICROSECONDS_IN_SECOND = 1e6
def main():
parser = argparse.ArgumentParser(
description='Plots metrics exported from WebRTC perf tests')
parser.add_argument('-m', '--metrics', type=str, nargs='*',
help='Metrics to plot. If nothing specified then will plot all available')
args = parser.parse_args()
parser = argparse.ArgumentParser(
description='Plots metrics exported from WebRTC perf tests')
parser.add_argument(
'-m',
'--metrics',
type=str,
nargs='*',
help=
'Metrics to plot. If nothing specified then will plot all available')
args = parser.parse_args()
metrics_to_plot = set()
if args.metrics:
for metric in args.metrics:
metrics_to_plot.add(metric)
metrics_to_plot = set()
if args.metrics:
for metric in args.metrics:
metrics_to_plot.add(metric)
metrics = []
for line in fileinput.input('-'):
line = line.strip()
if line.startswith(LINE_PREFIX):
line = line.replace(LINE_PREFIX, '')
metrics.append(json.loads(line))
else:
print line
metrics = []
for line in fileinput.input('-'):
line = line.strip()
if line.startswith(LINE_PREFIX):
line = line.replace(LINE_PREFIX, '')
metrics.append(json.loads(line))
else:
print line
for metric in metrics:
if len(metrics_to_plot) > 0 and metric[GRAPH_NAME] not in metrics_to_plot:
continue
for metric in metrics:
if len(metrics_to_plot
) > 0 and metric[GRAPH_NAME] not in metrics_to_plot:
continue
figure = plt.figure()
figure.canvas.set_window_title(metric[TRACE_NAME])
figure = plt.figure()
figure.canvas.set_window_title(metric[TRACE_NAME])
x_values = []
y_values = []
start_x = None
samples = metric['samples']
samples.sort(key=lambda x: x['time'])
for sample in samples:
if start_x is None:
start_x = sample['time']
# Time is us, we want to show it in seconds.
x_values.append((sample['time'] - start_x) / MICROSECONDS_IN_SECOND)
y_values.append(sample['value'])
x_values = []
y_values = []
start_x = None
samples = metric['samples']
samples.sort(key=lambda x: x['time'])
for sample in samples:
if start_x is None:
start_x = sample['time']
# Time is us, we want to show it in seconds.
x_values.append(
(sample['time'] - start_x) / MICROSECONDS_IN_SECOND)
y_values.append(sample['value'])
plt.ylabel('%s (%s)' % (metric[GRAPH_NAME], metric[UNITS]))
plt.xlabel('time (s)')
plt.title(metric[GRAPH_NAME])
plt.plot(x_values, y_values)
plt.ylabel('%s (%s)' % (metric[GRAPH_NAME], metric[UNITS]))
plt.xlabel('time (s)')
plt.title(metric[GRAPH_NAME])
plt.plot(x_values, y_values)
plt.show()
plt.show()
if __name__ == '__main__':
main()
main()

View File

@ -10,21 +10,21 @@
import network_tester_config_pb2
def AddConfig(all_configs,
packet_send_interval_ms,
packet_size,
def AddConfig(all_configs, packet_send_interval_ms, packet_size,
execution_time_ms):
config = all_configs.configs.add()
config.packet_send_interval_ms = packet_send_interval_ms
config.packet_size = packet_size
config.execution_time_ms = execution_time_ms
config = all_configs.configs.add()
config.packet_send_interval_ms = packet_send_interval_ms
config.packet_size = packet_size
config.execution_time_ms = execution_time_ms
def main():
all_configs = network_tester_config_pb2.NetworkTesterAllConfigs()
AddConfig(all_configs, 10, 50, 200)
AddConfig(all_configs, 10, 100, 200)
with open("network_tester_config.dat", 'wb') as f:
f.write(all_configs.SerializeToString())
all_configs = network_tester_config_pb2.NetworkTesterAllConfigs()
AddConfig(all_configs, 10, 50, 200)
AddConfig(all_configs, 10, 100, 200)
with open("network_tester_config.dat", 'wb') as f:
f.write(all_configs.SerializeToString())
if __name__ == "__main__":
main()
main()

View File

@ -20,128 +20,131 @@ import matplotlib.pyplot as plt
import network_tester_packet_pb2
def GetSize(file_to_parse):
data = file_to_parse.read(1)
if data == '':
return 0
return struct.unpack('<b', data)[0]
data = file_to_parse.read(1)
if data == '':
return 0
return struct.unpack('<b', data)[0]
def ParsePacketLog(packet_log_file_to_parse):
packets = []
with open(packet_log_file_to_parse, 'rb') as file_to_parse:
while True:
size = GetSize(file_to_parse)
if size == 0:
break
try:
packet = network_tester_packet_pb2.NetworkTesterPacket()
packet.ParseFromString(file_to_parse.read(size))
packets.append(packet)
except IOError:
break
return packets
packets = []
with open(packet_log_file_to_parse, 'rb') as file_to_parse:
while True:
size = GetSize(file_to_parse)
if size == 0:
break
try:
packet = network_tester_packet_pb2.NetworkTesterPacket()
packet.ParseFromString(file_to_parse.read(size))
packets.append(packet)
except IOError:
break
return packets
def GetTimeAxis(packets):
first_arrival_time = packets[0].arrival_timestamp
return [(packet.arrival_timestamp - first_arrival_time) / 1000000.0
for packet in packets]
first_arrival_time = packets[0].arrival_timestamp
return [(packet.arrival_timestamp - first_arrival_time) / 1000000.0
for packet in packets]
def CreateSendTimeDiffPlot(packets, plot):
first_send_time_diff = (
packets[0].arrival_timestamp - packets[0].send_timestamp)
y = [(packet.arrival_timestamp - packet.send_timestamp) - first_send_time_diff
for packet in packets]
plot.grid(True)
plot.set_title("SendTime difference [us]")
plot.plot(GetTimeAxis(packets), y)
first_send_time_diff = (packets[0].arrival_timestamp -
packets[0].send_timestamp)
y = [(packet.arrival_timestamp - packet.send_timestamp) -
first_send_time_diff for packet in packets]
plot.grid(True)
plot.set_title("SendTime difference [us]")
plot.plot(GetTimeAxis(packets), y)
class MovingAverageBitrate(object):
def __init__(self):
self.packet_window = []
self.window_time = 1000000
self.bytes = 0
self.latest_packet_time = 0
self.send_interval = 0
def __init__(self):
self.packet_window = []
self.window_time = 1000000
self.bytes = 0
self.latest_packet_time = 0
self.send_interval = 0
def RemoveOldPackets(self):
for packet in self.packet_window:
if (self.latest_packet_time - packet.arrival_timestamp >
self.window_time):
self.bytes = self.bytes - packet.packet_size
self.packet_window.remove(packet)
def RemoveOldPackets(self):
for packet in self.packet_window:
if (self.latest_packet_time - packet.arrival_timestamp >
self.window_time):
self.bytes = self.bytes - packet.packet_size
self.packet_window.remove(packet)
def AddPacket(self, packet):
"""This functions returns bits / second"""
self.send_interval = packet.arrival_timestamp - self.latest_packet_time
self.latest_packet_time = packet.arrival_timestamp
self.RemoveOldPackets()
self.packet_window.append(packet)
self.bytes = self.bytes + packet.packet_size
return self.bytes * 8
def AddPacket(self, packet):
"""This functions returns bits / second"""
self.send_interval = packet.arrival_timestamp - self.latest_packet_time
self.latest_packet_time = packet.arrival_timestamp
self.RemoveOldPackets()
self.packet_window.append(packet)
self.bytes = self.bytes + packet.packet_size
return self.bytes * 8
def CreateReceiveBiratePlot(packets, plot):
bitrate = MovingAverageBitrate()
y = [bitrate.AddPacket(packet) for packet in packets]
plot.grid(True)
plot.set_title("Receive birate [bps]")
plot.plot(GetTimeAxis(packets), y)
bitrate = MovingAverageBitrate()
y = [bitrate.AddPacket(packet) for packet in packets]
plot.grid(True)
plot.set_title("Receive birate [bps]")
plot.plot(GetTimeAxis(packets), y)
def CreatePacketlossPlot(packets, plot):
packets_look_up = {}
first_sequence_number = packets[0].sequence_number
last_sequence_number = packets[-1].sequence_number
for packet in packets:
packets_look_up[packet.sequence_number] = packet
y = []
x = []
first_arrival_time = 0
last_arrival_time = 0
last_arrival_time_diff = 0
for sequence_number in range(first_sequence_number, last_sequence_number + 1):
if sequence_number in packets_look_up:
y.append(0)
if first_arrival_time == 0:
first_arrival_time = packets_look_up[sequence_number].arrival_timestamp
x_time = (packets_look_up[sequence_number].arrival_timestamp -
first_arrival_time)
if last_arrival_time != 0:
last_arrival_time_diff = x_time - last_arrival_time
last_arrival_time = x_time
x.append(x_time / 1000000.0)
else:
if last_arrival_time != 0 and last_arrival_time_diff != 0:
x.append((last_arrival_time + last_arrival_time_diff) / 1000000.0)
y.append(1)
plot.grid(True)
plot.set_title("Lost packets [0/1]")
plot.plot(x, y)
packets_look_up = {}
first_sequence_number = packets[0].sequence_number
last_sequence_number = packets[-1].sequence_number
for packet in packets:
packets_look_up[packet.sequence_number] = packet
y = []
x = []
first_arrival_time = 0
last_arrival_time = 0
last_arrival_time_diff = 0
for sequence_number in range(first_sequence_number,
last_sequence_number + 1):
if sequence_number in packets_look_up:
y.append(0)
if first_arrival_time == 0:
first_arrival_time = packets_look_up[
sequence_number].arrival_timestamp
x_time = (packets_look_up[sequence_number].arrival_timestamp -
first_arrival_time)
if last_arrival_time != 0:
last_arrival_time_diff = x_time - last_arrival_time
last_arrival_time = x_time
x.append(x_time / 1000000.0)
else:
if last_arrival_time != 0 and last_arrival_time_diff != 0:
x.append(
(last_arrival_time + last_arrival_time_diff) / 1000000.0)
y.append(1)
plot.grid(True)
plot.set_title("Lost packets [0/1]")
plot.plot(x, y)
def main():
parser = OptionParser()
parser.add_option("-f",
"--packet_log_file",
dest="packet_log_file",
help="packet_log file to parse")
parser = OptionParser()
parser.add_option("-f",
"--packet_log_file",
dest="packet_log_file",
help="packet_log file to parse")
options = parser.parse_args()[0]
options = parser.parse_args()[0]
packets = ParsePacketLog(options.packet_log_file)
f, plots = plt.subplots(3, sharex=True)
plt.xlabel('time [sec]')
CreateSendTimeDiffPlot(packets, plots[0])
CreateReceiveBiratePlot(packets, plots[1])
CreatePacketlossPlot(packets, plots[2])
f.subplots_adjust(hspace=0.3)
plt.show()
packets = ParsePacketLog(options.packet_log_file)
f, plots = plt.subplots(3, sharex=True)
plt.xlabel('time [sec]')
CreateSendTimeDiffPlot(packets, plots[0])
CreateReceiveBiratePlot(packets, plots[1])
CreatePacketlossPlot(packets, plots[2])
f.subplots_adjust(hspace=0.3)
plt.show()
if __name__ == "__main__":
main()
main()

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Utility functions for calculating statistics.
"""
@ -15,18 +14,17 @@ import sys
def CountReordered(sequence_numbers):
"""Returns number of reordered indices.
"""Returns number of reordered indices.
A reordered index is an index `i` for which sequence_numbers[i] >=
sequence_numbers[i + 1]
"""
return sum(1 for (s1, s2) in zip(sequence_numbers,
sequence_numbers[1:]) if
s1 >= s2)
return sum(1 for (s1, s2) in zip(sequence_numbers, sequence_numbers[1:])
if s1 >= s2)
def SsrcNormalizedSizeTable(data_points):
"""Counts proportion of data for every SSRC.
"""Counts proportion of data for every SSRC.
Args:
data_points: list of pb_parse.DataPoint
@ -37,14 +35,14 @@ def SsrcNormalizedSizeTable(data_points):
SSRC `s` to the total size of all packets.
"""
mapping = collections.defaultdict(int)
for point in data_points:
mapping[point.ssrc] += point.size
return NormalizeCounter(mapping)
mapping = collections.defaultdict(int)
for point in data_points:
mapping[point.ssrc] += point.size
return NormalizeCounter(mapping)
def NormalizeCounter(counter):
"""Returns a normalized version of the dictionary `counter`.
"""Returns a normalized version of the dictionary `counter`.
Does not modify `counter`.
@ -52,12 +50,12 @@ def NormalizeCounter(counter):
A new dictionary, in which every value in `counter`
has been divided by the total to sum up to 1.
"""
total = sum(counter.values())
return {key: counter[key] / total for key in counter}
total = sum(counter.values())
return {key: counter[key] / total for key in counter}
def Unwrap(data, mod):
"""Returns `data` unwrapped modulo `mod`. Does not modify data.
"""Returns `data` unwrapped modulo `mod`. Does not modify data.
Adds integer multiples of mod to all elements of data except the
first, such that all pairs of consecutive elements (a, b) satisfy
@ -66,22 +64,22 @@ def Unwrap(data, mod):
E.g. Unwrap([0, 1, 2, 0, 1, 2, 7, 8], 3) -> [0, 1, 2, 3,
4, 5, 4, 5]
"""
lst = data[:]
for i in range(1, len(data)):
lst[i] = lst[i - 1] + (lst[i] - lst[i - 1] +
mod // 2) % mod - (mod // 2)
return lst
lst = data[:]
for i in range(1, len(data)):
lst[i] = lst[i - 1] + (lst[i] - lst[i - 1] + mod // 2) % mod - (mod //
2)
return lst
def SsrcDirections(data_points):
ssrc_is_incoming = {}
for point in data_points:
ssrc_is_incoming[point.ssrc] = point.incoming
return ssrc_is_incoming
ssrc_is_incoming = {}
for point in data_points:
ssrc_is_incoming[point.ssrc] = point.incoming
return ssrc_is_incoming
# Python 2/3-compatible input function
if sys.version_info[0] <= 2:
get_input = raw_input # pylint: disable=invalid-name
get_input = raw_input # pylint: disable=invalid-name
else:
get_input = input # pylint: disable=invalid-name
get_input = input # pylint: disable=invalid-name

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Run the tests with
python misc_test.py
@ -22,51 +21,52 @@ import misc
class TestMisc(unittest.TestCase):
def testUnwrapMod3(self):
data = [0, 1, 2, 0, -1, -2, -3, -4]
unwrapped_3 = misc.Unwrap(data, 3)
self.assertEqual([0, 1, 2, 3, 2, 1, 0, -1], unwrapped_3)
def testUnwrapMod3(self):
data = [0, 1, 2, 0, -1, -2, -3, -4]
unwrapped_3 = misc.Unwrap(data, 3)
self.assertEqual([0, 1, 2, 3, 2, 1, 0, -1], unwrapped_3)
def testUnwrapMod4(self):
data = [0, 1, 2, 0, -1, -2, -3, -4]
unwrapped_4 = misc.Unwrap(data, 4)
self.assertEqual([0, 1, 2, 0, -1, -2, -3, -4], unwrapped_4)
def testUnwrapMod4(self):
data = [0, 1, 2, 0, -1, -2, -3, -4]
unwrapped_4 = misc.Unwrap(data, 4)
self.assertEqual([0, 1, 2, 0, -1, -2, -3, -4], unwrapped_4)
def testDataShouldNotChangeAfterUnwrap(self):
data = [0, 1, 2, 0, -1, -2, -3, -4]
_ = misc.Unwrap(data, 4)
def testDataShouldNotChangeAfterUnwrap(self):
data = [0, 1, 2, 0, -1, -2, -3, -4]
_ = misc.Unwrap(data, 4)
self.assertEqual([0, 1, 2, 0, -1, -2, -3, -4], data)
self.assertEqual([0, 1, 2, 0, -1, -2, -3, -4], data)
def testRandomlyMultiplesOfModAdded(self):
# `unwrap` definition says only multiples of mod are added.
random_data = [random.randint(0, 9) for _ in range(100)]
def testRandomlyMultiplesOfModAdded(self):
# `unwrap` definition says only multiples of mod are added.
random_data = [random.randint(0, 9) for _ in range(100)]
for mod in range(1, 100):
random_data_unwrapped_mod = misc.Unwrap(random_data, mod)
for mod in range(1, 100):
random_data_unwrapped_mod = misc.Unwrap(random_data, mod)
for (old_a, a) in zip(random_data, random_data_unwrapped_mod):
self.assertEqual((old_a - a) % mod, 0)
for (old_a, a) in zip(random_data, random_data_unwrapped_mod):
self.assertEqual((old_a - a) % mod, 0)
def testRandomlyAgainstInequalityDefinition(self):
# Data has to satisfy -mod/2 <= difference < mod/2 for every
# difference between consecutive values after unwrap.
random_data = [random.randint(0, 9) for _ in range(100)]
def testRandomlyAgainstInequalityDefinition(self):
# Data has to satisfy -mod/2 <= difference < mod/2 for every
# difference between consecutive values after unwrap.
random_data = [random.randint(0, 9) for _ in range(100)]
for mod in range(1, 100):
random_data_unwrapped_mod = misc.Unwrap(random_data, mod)
for mod in range(1, 100):
random_data_unwrapped_mod = misc.Unwrap(random_data, mod)
for (a, b) in zip(random_data_unwrapped_mod,
random_data_unwrapped_mod[1:]):
self.assertTrue(-mod / 2 <= b - a < mod / 2)
for (a, b) in zip(random_data_unwrapped_mod,
random_data_unwrapped_mod[1:]):
self.assertTrue(-mod / 2 <= b - a < mod / 2)
def testRandomlyDataShouldNotChangeAfterUnwrap(self):
random_data = [random.randint(0, 9) for _ in range(100)]
random_data_copy = random_data[:]
for mod in range(1, 100):
_ = misc.Unwrap(random_data, mod)
def testRandomlyDataShouldNotChangeAfterUnwrap(self):
random_data = [random.randint(0, 9) for _ in range(100)]
random_data_copy = random_data[:]
for mod in range(1, 100):
_ = misc.Unwrap(random_data, mod)
self.assertEqual(random_data, random_data_copy)
self.assertEqual(random_data, random_data_copy)
if __name__ == "__main__":
unittest.main()
unittest.main()

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Parses protobuf RTC dumps."""
from __future__ import division
@ -14,26 +13,26 @@ import pyproto.logging.rtc_event_log.rtc_event_log_pb2 as rtc_pb
class DataPoint(object):
"""Simple container class for RTP events."""
"""Simple container class for RTP events."""
def __init__(self, rtp_header_str, packet_size,
arrival_timestamp_us, incoming):
"""Builds a data point by parsing an RTP header, size and arrival time.
def __init__(self, rtp_header_str, packet_size, arrival_timestamp_us,
incoming):
"""Builds a data point by parsing an RTP header, size and arrival time.
RTP header structure is defined in RFC 3550 section 5.1.
"""
self.size = packet_size
self.arrival_timestamp_ms = arrival_timestamp_us / 1000
self.incoming = incoming
header = struct.unpack_from("!HHII", rtp_header_str, 0)
(first2header_bytes, self.sequence_number, self.timestamp,
self.ssrc) = header
self.payload_type = first2header_bytes & 0b01111111
self.marker_bit = (first2header_bytes & 0b10000000) >> 7
self.size = packet_size
self.arrival_timestamp_ms = arrival_timestamp_us / 1000
self.incoming = incoming
header = struct.unpack_from("!HHII", rtp_header_str, 0)
(first2header_bytes, self.sequence_number, self.timestamp,
self.ssrc) = header
self.payload_type = first2header_bytes & 0b01111111
self.marker_bit = (first2header_bytes & 0b10000000) >> 7
def ParseProtobuf(file_path):
"""Parses RTC event log from protobuf file.
"""Parses RTC event log from protobuf file.
Args:
file_path: path to protobuf file of RTC event stream
@ -41,12 +40,12 @@ def ParseProtobuf(file_path):
Returns:
all RTP packet events from the event stream as a list of DataPoints
"""
event_stream = rtc_pb.EventStream()
with open(file_path, "rb") as f:
event_stream.ParseFromString(f.read())
event_stream = rtc_pb.EventStream()
with open(file_path, "rb") as f:
event_stream.ParseFromString(f.read())
return [DataPoint(event.rtp_packet.header,
event.rtp_packet.packet_length,
event.timestamp_us, event.rtp_packet.incoming)
for event in event_stream.stream
if event.HasField("rtp_packet")]
return [
DataPoint(event.rtp_packet.header, event.rtp_packet.packet_length,
event.timestamp_us, event.rtp_packet.incoming)
for event in event_stream.stream if event.HasField("rtp_packet")
]

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Displays statistics and plots graphs from RTC protobuf dump."""
from __future__ import division
@ -24,13 +23,13 @@ import pb_parse
class RTPStatistics(object):
"""Has methods for calculating and plotting RTP stream statistics."""
"""Has methods for calculating and plotting RTP stream statistics."""
BANDWIDTH_SMOOTHING_WINDOW_SIZE = 10
PLOT_RESOLUTION_MS = 50
BANDWIDTH_SMOOTHING_WINDOW_SIZE = 10
PLOT_RESOLUTION_MS = 50
def __init__(self, data_points):
"""Initializes object with data_points and computes simple statistics.
def __init__(self, data_points):
"""Initializes object with data_points and computes simple statistics.
Computes percentages of number of packets and packet sizes by
SSRC.
@ -41,238 +40,245 @@ class RTPStatistics(object):
"""
self.data_points = data_points
self.ssrc_frequencies = misc.NormalizeCounter(
collections.Counter([pt.ssrc for pt in self.data_points]))
self.ssrc_size_table = misc.SsrcNormalizedSizeTable(self.data_points)
self.bandwidth_kbps = None
self.smooth_bw_kbps = None
self.data_points = data_points
self.ssrc_frequencies = misc.NormalizeCounter(
collections.Counter([pt.ssrc for pt in self.data_points]))
self.ssrc_size_table = misc.SsrcNormalizedSizeTable(self.data_points)
self.bandwidth_kbps = None
self.smooth_bw_kbps = None
def PrintHeaderStatistics(self):
print("{:>6}{:>14}{:>14}{:>6}{:>6}{:>3}{:>11}".format(
"SeqNo", "TimeStamp", "SendTime", "Size", "PT", "M", "SSRC"))
for point in self.data_points:
print("{:>6}{:>14}{:>14}{:>6}{:>6}{:>3}{:>11}".format(
point.sequence_number, point.timestamp,
int(point.arrival_timestamp_ms), point.size, point.payload_type,
point.marker_bit, "0x{:x}".format(point.ssrc)))
def PrintHeaderStatistics(self):
print("{:>6}{:>14}{:>14}{:>6}{:>6}{:>3}{:>11}".format(
"SeqNo", "TimeStamp", "SendTime", "Size", "PT", "M", "SSRC"))
for point in self.data_points:
print("{:>6}{:>14}{:>14}{:>6}{:>6}{:>3}{:>11}".format(
point.sequence_number, point.timestamp,
int(point.arrival_timestamp_ms), point.size,
point.payload_type, point.marker_bit,
"0x{:x}".format(point.ssrc)))
def PrintSsrcInfo(self, ssrc_id, ssrc):
"""Prints packet and size statistics for a given SSRC.
def PrintSsrcInfo(self, ssrc_id, ssrc):
"""Prints packet and size statistics for a given SSRC.
Args:
ssrc_id: textual identifier of SSRC printed beside statistics for it.
ssrc: SSRC by which to filter data and display statistics
"""
filtered_ssrc = [point for point in self.data_points if point.ssrc
== ssrc]
payloads = misc.NormalizeCounter(
collections.Counter([point.payload_type for point in
filtered_ssrc]))
filtered_ssrc = [
point for point in self.data_points if point.ssrc == ssrc
]
payloads = misc.NormalizeCounter(
collections.Counter(
[point.payload_type for point in filtered_ssrc]))
payload_info = "payload type(s): {}".format(
", ".join(str(payload) for payload in payloads))
print("{} 0x{:x} {}, {:.2f}% packets, {:.2f}% data".format(
ssrc_id, ssrc, payload_info, self.ssrc_frequencies[ssrc] * 100,
self.ssrc_size_table[ssrc] * 100))
print(" packet sizes:")
(bin_counts, bin_bounds) = numpy.histogram([point.size for point in
filtered_ssrc], bins=5,
density=False)
bin_proportions = bin_counts / sum(bin_counts)
print("\n".join([
" {:.1f} - {:.1f}: {:.2f}%".format(bin_bounds[i], bin_bounds[i + 1],
bin_proportions[i] * 100)
for i in range(len(bin_proportions))
]))
payload_info = "payload type(s): {}".format(", ".join(
str(payload) for payload in payloads))
print("{} 0x{:x} {}, {:.2f}% packets, {:.2f}% data".format(
ssrc_id, ssrc, payload_info, self.ssrc_frequencies[ssrc] * 100,
self.ssrc_size_table[ssrc] * 100))
print(" packet sizes:")
(bin_counts,
bin_bounds) = numpy.histogram([point.size for point in filtered_ssrc],
bins=5,
density=False)
bin_proportions = bin_counts / sum(bin_counts)
print("\n".join([
" {:.1f} - {:.1f}: {:.2f}%".format(bin_bounds[i],
bin_bounds[i + 1],
bin_proportions[i] * 100)
for i in range(len(bin_proportions))
]))
def ChooseSsrc(self):
"""Queries user for SSRC."""
def ChooseSsrc(self):
"""Queries user for SSRC."""
if len(self.ssrc_frequencies) == 1:
chosen_ssrc = self.ssrc_frequencies.keys()[0]
self.PrintSsrcInfo("", chosen_ssrc)
return chosen_ssrc
if len(self.ssrc_frequencies) == 1:
chosen_ssrc = self.ssrc_frequencies.keys()[0]
self.PrintSsrcInfo("", chosen_ssrc)
return chosen_ssrc
ssrc_is_incoming = misc.SsrcDirections(self.data_points)
incoming = [ssrc for ssrc in ssrc_is_incoming if ssrc_is_incoming[ssrc]]
outgoing = [ssrc for ssrc in ssrc_is_incoming if not ssrc_is_incoming[ssrc]]
ssrc_is_incoming = misc.SsrcDirections(self.data_points)
incoming = [
ssrc for ssrc in ssrc_is_incoming if ssrc_is_incoming[ssrc]
]
outgoing = [
ssrc for ssrc in ssrc_is_incoming if not ssrc_is_incoming[ssrc]
]
print("\nIncoming:\n")
for (i, ssrc) in enumerate(incoming):
self.PrintSsrcInfo(i, ssrc)
print("\nIncoming:\n")
for (i, ssrc) in enumerate(incoming):
self.PrintSsrcInfo(i, ssrc)
print("\nOutgoing:\n")
for (i, ssrc) in enumerate(outgoing):
self.PrintSsrcInfo(i + len(incoming), ssrc)
print("\nOutgoing:\n")
for (i, ssrc) in enumerate(outgoing):
self.PrintSsrcInfo(i + len(incoming), ssrc)
while True:
chosen_index = int(misc.get_input("choose one> "))
if 0 <= chosen_index < len(self.ssrc_frequencies):
return (incoming + outgoing)[chosen_index]
else:
print("Invalid index!")
while True:
chosen_index = int(misc.get_input("choose one> "))
if 0 <= chosen_index < len(self.ssrc_frequencies):
return (incoming + outgoing)[chosen_index]
else:
print("Invalid index!")
def FilterSsrc(self, chosen_ssrc):
"""Filters and wraps data points.
def FilterSsrc(self, chosen_ssrc):
"""Filters and wraps data points.
Removes data points with `ssrc != chosen_ssrc`. Unwraps sequence
numbers and timestamps for the chosen selection.
"""
self.data_points = [point for point in self.data_points if
point.ssrc == chosen_ssrc]
unwrapped_sequence_numbers = misc.Unwrap(
[point.sequence_number for point in self.data_points], 2**16 - 1)
for (data_point, sequence_number) in zip(self.data_points,
unwrapped_sequence_numbers):
data_point.sequence_number = sequence_number
self.data_points = [
point for point in self.data_points if point.ssrc == chosen_ssrc
]
unwrapped_sequence_numbers = misc.Unwrap(
[point.sequence_number for point in self.data_points], 2**16 - 1)
for (data_point, sequence_number) in zip(self.data_points,
unwrapped_sequence_numbers):
data_point.sequence_number = sequence_number
unwrapped_timestamps = misc.Unwrap([point.timestamp for point in
self.data_points], 2**32 - 1)
unwrapped_timestamps = misc.Unwrap(
[point.timestamp for point in self.data_points], 2**32 - 1)
for (data_point, timestamp) in zip(self.data_points,
unwrapped_timestamps):
data_point.timestamp = timestamp
for (data_point, timestamp) in zip(self.data_points,
unwrapped_timestamps):
data_point.timestamp = timestamp
def PrintSequenceNumberStatistics(self):
seq_no_set = set(point.sequence_number for point in
self.data_points)
missing_sequence_numbers = max(seq_no_set) - min(seq_no_set) + (
1 - len(seq_no_set))
print("Missing sequence numbers: {} out of {} ({:.2f}%)".format(
missing_sequence_numbers,
len(seq_no_set),
100 * missing_sequence_numbers / len(seq_no_set)
))
print("Duplicated packets: {}".format(len(self.data_points) -
len(seq_no_set)))
print("Reordered packets: {}".format(
misc.CountReordered([point.sequence_number for point in
self.data_points])))
def PrintSequenceNumberStatistics(self):
seq_no_set = set(point.sequence_number for point in self.data_points)
missing_sequence_numbers = max(seq_no_set) - min(seq_no_set) + (
1 - len(seq_no_set))
print("Missing sequence numbers: {} out of {} ({:.2f}%)".format(
missing_sequence_numbers, len(seq_no_set),
100 * missing_sequence_numbers / len(seq_no_set)))
print("Duplicated packets: {}".format(
len(self.data_points) - len(seq_no_set)))
print("Reordered packets: {}".format(
misc.CountReordered(
[point.sequence_number for point in self.data_points])))
def EstimateFrequency(self, always_query_sample_rate):
"""Estimates frequency and updates data.
def EstimateFrequency(self, always_query_sample_rate):
"""Estimates frequency and updates data.
Guesses the most probable frequency by looking at changes in
timestamps (RFC 3550 section 5.1), calculates clock drifts and
sending time of packets. Updates `self.data_points` with changes
in delay and send time.
"""
delta_timestamp = (self.data_points[-1].timestamp -
self.data_points[0].timestamp)
delta_arr_timestamp = float((self.data_points[-1].arrival_timestamp_ms -
self.data_points[0].arrival_timestamp_ms))
freq_est = delta_timestamp / delta_arr_timestamp
delta_timestamp = (self.data_points[-1].timestamp -
self.data_points[0].timestamp)
delta_arr_timestamp = float(
(self.data_points[-1].arrival_timestamp_ms -
self.data_points[0].arrival_timestamp_ms))
freq_est = delta_timestamp / delta_arr_timestamp
freq_vec = [8, 16, 32, 48, 90]
freq = None
for f in freq_vec:
if abs((freq_est - f) / f) < 0.05:
freq = f
freq_vec = [8, 16, 32, 48, 90]
freq = None
for f in freq_vec:
if abs((freq_est - f) / f) < 0.05:
freq = f
print("Estimated frequency: {:.3f}kHz".format(freq_est))
if freq is None or always_query_sample_rate:
if not always_query_sample_rate:
print ("Frequency could not be guessed.", end=" ")
freq = int(misc.get_input("Input frequency (in kHz)> "))
else:
print("Guessed frequency: {}kHz".format(freq))
print("Estimated frequency: {:.3f}kHz".format(freq_est))
if freq is None or always_query_sample_rate:
if not always_query_sample_rate:
print("Frequency could not be guessed.", end=" ")
freq = int(misc.get_input("Input frequency (in kHz)> "))
else:
print("Guessed frequency: {}kHz".format(freq))
for point in self.data_points:
point.real_send_time_ms = (point.timestamp -
self.data_points[0].timestamp) / freq
point.delay = point.arrival_timestamp_ms - point.real_send_time_ms
for point in self.data_points:
point.real_send_time_ms = (point.timestamp -
self.data_points[0].timestamp) / freq
point.delay = point.arrival_timestamp_ms - point.real_send_time_ms
def PrintDurationStatistics(self):
"""Prints delay, clock drift and bitrate statistics."""
def PrintDurationStatistics(self):
"""Prints delay, clock drift and bitrate statistics."""
min_delay = min(point.delay for point in self.data_points)
min_delay = min(point.delay for point in self.data_points)
for point in self.data_points:
point.absdelay = point.delay - min_delay
for point in self.data_points:
point.absdelay = point.delay - min_delay
stream_duration_sender = self.data_points[-1].real_send_time_ms / 1000
print("Stream duration at sender: {:.1f} seconds".format(
stream_duration_sender
))
stream_duration_sender = self.data_points[-1].real_send_time_ms / 1000
print("Stream duration at sender: {:.1f} seconds".format(
stream_duration_sender))
arrival_timestamps_ms = [point.arrival_timestamp_ms for point in
self.data_points]
stream_duration_receiver = (max(arrival_timestamps_ms) -
min(arrival_timestamps_ms)) / 1000
print("Stream duration at receiver: {:.1f} seconds".format(
stream_duration_receiver
))
arrival_timestamps_ms = [
point.arrival_timestamp_ms for point in self.data_points
]
stream_duration_receiver = (max(arrival_timestamps_ms) -
min(arrival_timestamps_ms)) / 1000
print("Stream duration at receiver: {:.1f} seconds".format(
stream_duration_receiver))
print("Clock drift: {:.2f}%".format(
100 * (stream_duration_receiver / stream_duration_sender - 1)
))
print("Clock drift: {:.2f}%".format(
100 * (stream_duration_receiver / stream_duration_sender - 1)))
total_size = sum(point.size for point in self.data_points) * 8 / 1000
print("Send average bitrate: {:.2f} kbps".format(
total_size / stream_duration_sender))
total_size = sum(point.size for point in self.data_points) * 8 / 1000
print("Send average bitrate: {:.2f} kbps".format(
total_size / stream_duration_sender))
print("Receive average bitrate: {:.2f} kbps".format(
total_size / stream_duration_receiver))
print("Receive average bitrate: {:.2f} kbps".format(
total_size / stream_duration_receiver))
def RemoveReordered(self):
last = self.data_points[0]
data_points_ordered = [last]
for point in self.data_points[1:]:
if point.sequence_number > last.sequence_number and (
point.real_send_time_ms > last.real_send_time_ms):
data_points_ordered.append(point)
last = point
self.data_points = data_points_ordered
def RemoveReordered(self):
last = self.data_points[0]
data_points_ordered = [last]
for point in self.data_points[1:]:
if point.sequence_number > last.sequence_number and (
point.real_send_time_ms > last.real_send_time_ms):
data_points_ordered.append(point)
last = point
self.data_points = data_points_ordered
def ComputeBandwidth(self):
"""Computes bandwidth averaged over several consecutive packets.
def ComputeBandwidth(self):
"""Computes bandwidth averaged over several consecutive packets.
The number of consecutive packets used in the average is
BANDWIDTH_SMOOTHING_WINDOW_SIZE. Averaging is done with
numpy.correlate.
"""
start_ms = self.data_points[0].real_send_time_ms
stop_ms = self.data_points[-1].real_send_time_ms
(self.bandwidth_kbps, _) = numpy.histogram(
[point.real_send_time_ms for point in self.data_points],
bins=numpy.arange(start_ms, stop_ms,
RTPStatistics.PLOT_RESOLUTION_MS),
weights=[point.size * 8 / RTPStatistics.PLOT_RESOLUTION_MS
for point in self.data_points]
)
correlate_filter = (numpy.ones(
RTPStatistics.BANDWIDTH_SMOOTHING_WINDOW_SIZE) /
RTPStatistics.BANDWIDTH_SMOOTHING_WINDOW_SIZE)
self.smooth_bw_kbps = numpy.correlate(self.bandwidth_kbps, correlate_filter)
start_ms = self.data_points[0].real_send_time_ms
stop_ms = self.data_points[-1].real_send_time_ms
(self.bandwidth_kbps, _) = numpy.histogram(
[point.real_send_time_ms for point in self.data_points],
bins=numpy.arange(start_ms, stop_ms,
RTPStatistics.PLOT_RESOLUTION_MS),
weights=[
point.size * 8 / RTPStatistics.PLOT_RESOLUTION_MS
for point in self.data_points
])
correlate_filter = (
numpy.ones(RTPStatistics.BANDWIDTH_SMOOTHING_WINDOW_SIZE) /
RTPStatistics.BANDWIDTH_SMOOTHING_WINDOW_SIZE)
self.smooth_bw_kbps = numpy.correlate(self.bandwidth_kbps,
correlate_filter)
def PlotStatistics(self):
"""Plots changes in delay and average bandwidth."""
def PlotStatistics(self):
"""Plots changes in delay and average bandwidth."""
start_ms = self.data_points[0].real_send_time_ms
stop_ms = self.data_points[-1].real_send_time_ms
time_axis = numpy.arange(start_ms / 1000, stop_ms / 1000,
RTPStatistics.PLOT_RESOLUTION_MS / 1000)
start_ms = self.data_points[0].real_send_time_ms
stop_ms = self.data_points[-1].real_send_time_ms
time_axis = numpy.arange(start_ms / 1000, stop_ms / 1000,
RTPStatistics.PLOT_RESOLUTION_MS / 1000)
delay = CalculateDelay(start_ms, stop_ms,
RTPStatistics.PLOT_RESOLUTION_MS,
self.data_points)
delay = CalculateDelay(start_ms, stop_ms,
RTPStatistics.PLOT_RESOLUTION_MS,
self.data_points)
plt.figure(1)
plt.plot(time_axis, delay[:len(time_axis)])
plt.xlabel("Send time [s]")
plt.ylabel("Relative transport delay [ms]")
plt.figure(1)
plt.plot(time_axis, delay[:len(time_axis)])
plt.xlabel("Send time [s]")
plt.ylabel("Relative transport delay [ms]")
plt.figure(2)
plt.plot(time_axis[:len(self.smooth_bw_kbps)], self.smooth_bw_kbps)
plt.xlabel("Send time [s]")
plt.ylabel("Bandwidth [kbps]")
plt.figure(2)
plt.plot(time_axis[:len(self.smooth_bw_kbps)], self.smooth_bw_kbps)
plt.xlabel("Send time [s]")
plt.ylabel("Bandwidth [kbps]")
plt.show()
plt.show()
def CalculateDelay(start, stop, step, points):
"""Quantizes the time coordinates for the delay.
"""Quantizes the time coordinates for the delay.
Quantizes points by rounding the timestamps downwards to the nearest
point in the time sequence start, start+step, start+2*step... Takes
@ -280,61 +286,67 @@ def CalculateDelay(start, stop, step, points):
masked array, in which time points with no value are masked.
"""
grouped_delays = [[] for _ in numpy.arange(start, stop + step, step)]
rounded_value_index = lambda x: int((x - start) / step)
for point in points:
grouped_delays[rounded_value_index(point.real_send_time_ms)
].append(point.absdelay)
regularized_delays = [numpy.average(arr) if arr else -1 for arr in
grouped_delays]
return numpy.ma.masked_values(regularized_delays, -1)
grouped_delays = [[] for _ in numpy.arange(start, stop + step, step)]
rounded_value_index = lambda x: int((x - start) / step)
for point in points:
grouped_delays[rounded_value_index(point.real_send_time_ms)].append(
point.absdelay)
regularized_delays = [
numpy.average(arr) if arr else -1 for arr in grouped_delays
]
return numpy.ma.masked_values(regularized_delays, -1)
def main():
usage = "Usage: %prog [options] <filename of rtc event log>"
parser = optparse.OptionParser(usage=usage)
parser.add_option("--dump_header_to_stdout",
default=False, action="store_true",
help="print header info to stdout; similar to rtp_analyze")
parser.add_option("--query_sample_rate",
default=False, action="store_true",
help="always query user for real sample rate")
usage = "Usage: %prog [options] <filename of rtc event log>"
parser = optparse.OptionParser(usage=usage)
parser.add_option(
"--dump_header_to_stdout",
default=False,
action="store_true",
help="print header info to stdout; similar to rtp_analyze")
parser.add_option("--query_sample_rate",
default=False,
action="store_true",
help="always query user for real sample rate")
parser.add_option("--working_directory",
default=None, action="store",
help="directory in which to search for relative paths")
parser.add_option("--working_directory",
default=None,
action="store",
help="directory in which to search for relative paths")
(options, args) = parser.parse_args()
(options, args) = parser.parse_args()
if len(args) < 1:
parser.print_help()
sys.exit(0)
if len(args) < 1:
parser.print_help()
sys.exit(0)
input_file = args[0]
input_file = args[0]
if options.working_directory and not os.path.isabs(input_file):
input_file = os.path.join(options.working_directory, input_file)
if options.working_directory and not os.path.isabs(input_file):
input_file = os.path.join(options.working_directory, input_file)
data_points = pb_parse.ParseProtobuf(input_file)
rtp_stats = RTPStatistics(data_points)
data_points = pb_parse.ParseProtobuf(input_file)
rtp_stats = RTPStatistics(data_points)
if options.dump_header_to_stdout:
print("Printing header info to stdout.", file=sys.stderr)
rtp_stats.PrintHeaderStatistics()
sys.exit(0)
if options.dump_header_to_stdout:
print("Printing header info to stdout.", file=sys.stderr)
rtp_stats.PrintHeaderStatistics()
sys.exit(0)
chosen_ssrc = rtp_stats.ChooseSsrc()
print("Chosen SSRC: 0X{:X}".format(chosen_ssrc))
chosen_ssrc = rtp_stats.ChooseSsrc()
print("Chosen SSRC: 0X{:X}".format(chosen_ssrc))
rtp_stats.FilterSsrc(chosen_ssrc)
rtp_stats.FilterSsrc(chosen_ssrc)
print("Statistics:")
rtp_stats.PrintSequenceNumberStatistics()
rtp_stats.EstimateFrequency(options.query_sample_rate)
rtp_stats.PrintDurationStatistics()
rtp_stats.RemoveReordered()
rtp_stats.ComputeBandwidth()
rtp_stats.PlotStatistics()
print("Statistics:")
rtp_stats.PrintSequenceNumberStatistics()
rtp_stats.EstimateFrequency(options.query_sample_rate)
rtp_stats.PrintDurationStatistics()
rtp_stats.RemoveReordered()
rtp_stats.ComputeBandwidth()
rtp_stats.PlotStatistics()
if __name__ == "__main__":
main()
main()

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Run the tests with
python rtp_analyzer_test.py
@ -19,43 +18,43 @@ import unittest
MISSING_NUMPY = False # pylint: disable=invalid-name
try:
import numpy
import rtp_analyzer
import numpy
import rtp_analyzer
except ImportError:
MISSING_NUMPY = True
MISSING_NUMPY = True
FakePoint = collections.namedtuple("FakePoint",
["real_send_time_ms", "absdelay"])
class TestDelay(unittest.TestCase):
def AssertMaskEqual(self, masked_array, data, mask):
self.assertEqual(list(masked_array.data), data)
def AssertMaskEqual(self, masked_array, data, mask):
self.assertEqual(list(masked_array.data), data)
if isinstance(masked_array.mask, numpy.bool_):
array_mask = masked_array.mask
else:
array_mask = list(masked_array.mask)
self.assertEqual(array_mask, mask)
if isinstance(masked_array.mask, numpy.bool_):
array_mask = masked_array.mask
else:
array_mask = list(masked_array.mask)
self.assertEqual(array_mask, mask)
def testCalculateDelaySimple(self):
points = [FakePoint(0, 0), FakePoint(1, 0)]
mask = rtp_analyzer.CalculateDelay(0, 1, 1, points)
self.AssertMaskEqual(mask, [0, 0], False)
def testCalculateDelaySimple(self):
points = [FakePoint(0, 0), FakePoint(1, 0)]
mask = rtp_analyzer.CalculateDelay(0, 1, 1, points)
self.AssertMaskEqual(mask, [0, 0], False)
def testCalculateDelayMissing(self):
points = [FakePoint(0, 0), FakePoint(2, 0)]
mask = rtp_analyzer.CalculateDelay(0, 2, 1, points)
self.AssertMaskEqual(mask, [0, -1, 0], [False, True, False])
def testCalculateDelayMissing(self):
points = [FakePoint(0, 0), FakePoint(2, 0)]
mask = rtp_analyzer.CalculateDelay(0, 2, 1, points)
self.AssertMaskEqual(mask, [0, -1, 0], [False, True, False])
def testCalculateDelayBorders(self):
points = [FakePoint(0, 0), FakePoint(2, 0)]
mask = rtp_analyzer.CalculateDelay(0, 3, 2, points)
self.AssertMaskEqual(mask, [0, 0, -1], [False, False, True])
def testCalculateDelayBorders(self):
points = [FakePoint(0, 0), FakePoint(2, 0)]
mask = rtp_analyzer.CalculateDelay(0, 3, 2, points)
self.AssertMaskEqual(mask, [0, 0, -1], [False, False, True])
if __name__ == "__main__":
if MISSING_NUMPY:
print "Missing numpy, skipping test."
else:
unittest.main()
if MISSING_NUMPY:
print "Missing numpy, skipping test."
else:
unittest.main()

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Builds the AppRTC collider using the golang toolchain.
The golang toolchain is downloaded by download_apprtc.py. We use that here
@ -24,44 +23,44 @@ import sys
import utils
USAGE_STR = "Usage: {} <apprtc_dir> <go_dir> <output_dir>"
def _ConfigureApprtcServerToDeveloperMode(app_yaml_path):
for line in fileinput.input(app_yaml_path, inplace=True):
# We can't click past these in browser-based tests, so disable them.
line = line.replace('BYPASS_JOIN_CONFIRMATION: false',
'BYPASS_JOIN_CONFIRMATION: true')
sys.stdout.write(line)
for line in fileinput.input(app_yaml_path, inplace=True):
# We can't click past these in browser-based tests, so disable them.
line = line.replace('BYPASS_JOIN_CONFIRMATION: false',
'BYPASS_JOIN_CONFIRMATION: true')
sys.stdout.write(line)
def main(argv):
if len(argv) != 4:
return USAGE_STR.format(argv[0])
if len(argv) != 4:
return USAGE_STR.format(argv[0])
apprtc_dir = os.path.abspath(argv[1])
go_root_dir = os.path.abspath(argv[2])
golang_workspace = os.path.abspath(argv[3])
apprtc_dir = os.path.abspath(argv[1])
go_root_dir = os.path.abspath(argv[2])
golang_workspace = os.path.abspath(argv[3])
app_yaml_path = os.path.join(apprtc_dir, 'out', 'app_engine', 'app.yaml')
_ConfigureApprtcServerToDeveloperMode(app_yaml_path)
app_yaml_path = os.path.join(apprtc_dir, 'out', 'app_engine', 'app.yaml')
_ConfigureApprtcServerToDeveloperMode(app_yaml_path)
utils.RemoveDirectory(golang_workspace)
utils.RemoveDirectory(golang_workspace)
collider_dir = os.path.join(apprtc_dir, 'src', 'collider')
shutil.copytree(collider_dir, os.path.join(golang_workspace, 'src'))
collider_dir = os.path.join(apprtc_dir, 'src', 'collider')
shutil.copytree(collider_dir, os.path.join(golang_workspace, 'src'))
golang_path = os.path.join(go_root_dir, 'bin',
'go' + utils.GetExecutableExtension())
golang_env = os.environ.copy()
golang_env['GOROOT'] = go_root_dir
golang_env['GOPATH'] = golang_workspace
collider_out = os.path.join(golang_workspace,
'collidermain' + utils.GetExecutableExtension())
subprocess.check_call([golang_path, 'build', '-o', collider_out,
'collidermain'], env=golang_env)
golang_path = os.path.join(go_root_dir, 'bin',
'go' + utils.GetExecutableExtension())
golang_env = os.environ.copy()
golang_env['GOROOT'] = go_root_dir
golang_env['GOPATH'] = golang_workspace
collider_out = os.path.join(
golang_workspace, 'collidermain' + utils.GetExecutableExtension())
subprocess.check_call(
[golang_path, 'build', '-o', collider_out, 'collidermain'],
env=golang_env)
if __name__ == '__main__':
sys.exit(main(sys.argv))
sys.exit(main(sys.argv))

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Downloads prebuilt AppRTC and Go from WebRTC storage and unpacks it.
Requires that depot_tools is installed and in the PATH.
@ -21,38 +20,37 @@ import sys
import utils
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
def _GetGoArchivePathForPlatform():
archive_extension = 'zip' if utils.GetPlatform() == 'win' else 'tar.gz'
return os.path.join(utils.GetPlatform(), 'go.%s' % archive_extension)
archive_extension = 'zip' if utils.GetPlatform() == 'win' else 'tar.gz'
return os.path.join(utils.GetPlatform(), 'go.%s' % archive_extension)
def main(argv):
if len(argv) > 2:
return 'Usage: %s [output_dir]' % argv[0]
if len(argv) > 2:
return 'Usage: %s [output_dir]' % argv[0]
output_dir = os.path.abspath(argv[1]) if len(argv) > 1 else None
output_dir = os.path.abspath(argv[1]) if len(argv) > 1 else None
apprtc_zip_path = os.path.join(SCRIPT_DIR, 'prebuilt_apprtc.zip')
if os.path.isfile(apprtc_zip_path + '.sha1'):
utils.DownloadFilesFromGoogleStorage(SCRIPT_DIR, auto_platform=False)
apprtc_zip_path = os.path.join(SCRIPT_DIR, 'prebuilt_apprtc.zip')
if os.path.isfile(apprtc_zip_path + '.sha1'):
utils.DownloadFilesFromGoogleStorage(SCRIPT_DIR, auto_platform=False)
if output_dir is not None:
utils.RemoveDirectory(os.path.join(output_dir, 'apprtc'))
utils.UnpackArchiveTo(apprtc_zip_path, output_dir)
if output_dir is not None:
utils.RemoveDirectory(os.path.join(output_dir, 'apprtc'))
utils.UnpackArchiveTo(apprtc_zip_path, output_dir)
golang_path = os.path.join(SCRIPT_DIR, 'golang')
golang_zip_path = os.path.join(golang_path, _GetGoArchivePathForPlatform())
if os.path.isfile(golang_zip_path + '.sha1'):
utils.DownloadFilesFromGoogleStorage(golang_path)
golang_path = os.path.join(SCRIPT_DIR, 'golang')
golang_zip_path = os.path.join(golang_path, _GetGoArchivePathForPlatform())
if os.path.isfile(golang_zip_path + '.sha1'):
utils.DownloadFilesFromGoogleStorage(golang_path)
if output_dir is not None:
utils.RemoveDirectory(os.path.join(output_dir, 'go'))
utils.UnpackArchiveTo(golang_zip_path, output_dir)
if output_dir is not None:
utils.RemoveDirectory(os.path.join(output_dir, 'go'))
utils.UnpackArchiveTo(golang_zip_path, output_dir)
if __name__ == '__main__':
sys.exit(main(sys.argv))
sys.exit(main(sys.argv))

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""This script sets up AppRTC and its dependencies.
Requires that depot_tools is installed and in the PATH.
@ -19,27 +18,26 @@ import sys
import utils
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
def main(argv):
if len(argv) == 1:
return 'Usage %s <output_dir>' % argv[0]
if len(argv) == 1:
return 'Usage %s <output_dir>' % argv[0]
output_dir = os.path.abspath(argv[1])
output_dir = os.path.abspath(argv[1])
download_apprtc_path = os.path.join(SCRIPT_DIR, 'download_apprtc.py')
utils.RunSubprocessWithRetry([sys.executable, download_apprtc_path,
output_dir])
download_apprtc_path = os.path.join(SCRIPT_DIR, 'download_apprtc.py')
utils.RunSubprocessWithRetry(
[sys.executable, download_apprtc_path, output_dir])
build_apprtc_path = os.path.join(SCRIPT_DIR, 'build_apprtc.py')
apprtc_dir = os.path.join(output_dir, 'apprtc')
go_dir = os.path.join(output_dir, 'go')
collider_dir = os.path.join(output_dir, 'collider')
utils.RunSubprocessWithRetry([sys.executable, build_apprtc_path,
apprtc_dir, go_dir, collider_dir])
build_apprtc_path = os.path.join(SCRIPT_DIR, 'build_apprtc.py')
apprtc_dir = os.path.join(output_dir, 'apprtc')
go_dir = os.path.join(output_dir, 'go')
collider_dir = os.path.join(output_dir, 'collider')
utils.RunSubprocessWithRetry(
[sys.executable, build_apprtc_path, apprtc_dir, go_dir, collider_dir])
if __name__ == '__main__':
sys.exit(main(sys.argv))
sys.exit(main(sys.argv))

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Utilities for all our deps-management stuff."""
from __future__ import absolute_import
@ -23,36 +22,37 @@ import zipfile
def RunSubprocessWithRetry(cmd):
"""Invokes the subprocess and backs off exponentially on fail."""
for i in range(5):
try:
subprocess.check_call(cmd)
return
except subprocess.CalledProcessError as exception:
backoff = pow(2, i)
print('Got %s, retrying in %d seconds...' % (exception, backoff))
time.sleep(backoff)
"""Invokes the subprocess and backs off exponentially on fail."""
for i in range(5):
try:
subprocess.check_call(cmd)
return
except subprocess.CalledProcessError as exception:
backoff = pow(2, i)
print('Got %s, retrying in %d seconds...' % (exception, backoff))
time.sleep(backoff)
print('Giving up.')
raise exception
print('Giving up.')
raise exception
def DownloadFilesFromGoogleStorage(path, auto_platform=True):
print('Downloading files in %s...' % path)
print('Downloading files in %s...' % path)
extension = 'bat' if 'win32' in sys.platform else 'py'
cmd = ['download_from_google_storage.%s' % extension,
'--bucket=chromium-webrtc-resources',
'--directory', path]
if auto_platform:
cmd += ['--auto_platform', '--recursive']
subprocess.check_call(cmd)
extension = 'bat' if 'win32' in sys.platform else 'py'
cmd = [
'download_from_google_storage.%s' % extension,
'--bucket=chromium-webrtc-resources', '--directory', path
]
if auto_platform:
cmd += ['--auto_platform', '--recursive']
subprocess.check_call(cmd)
# Code partially copied from
# https://cs.chromium.org#chromium/build/scripts/common/chromium_utils.py
def RemoveDirectory(*path):
"""Recursively removes a directory, even if it's marked read-only.
"""Recursively removes a directory, even if it's marked read-only.
Remove the directory located at *path, if it exists.
@ -67,62 +67,63 @@ def RemoveDirectory(*path):
bit and try again, so we do that too. It's hand-waving, but sometimes it
works. :/
"""
file_path = os.path.join(*path)
print('Deleting `{}`.'.format(file_path))
if not os.path.exists(file_path):
print('`{}` does not exist.'.format(file_path))
return
file_path = os.path.join(*path)
print('Deleting `{}`.'.format(file_path))
if not os.path.exists(file_path):
print('`{}` does not exist.'.format(file_path))
return
if sys.platform == 'win32':
# Give up and use cmd.exe's rd command.
file_path = os.path.normcase(file_path)
for _ in range(3):
print('RemoveDirectory running %s' % (' '.join(
['cmd.exe', '/c', 'rd', '/q', '/s', file_path])))
if not subprocess.call(['cmd.exe', '/c', 'rd', '/q', '/s', file_path]):
break
print(' Failed')
time.sleep(3)
return
else:
shutil.rmtree(file_path, ignore_errors=True)
if sys.platform == 'win32':
# Give up and use cmd.exe's rd command.
file_path = os.path.normcase(file_path)
for _ in range(3):
print('RemoveDirectory running %s' %
(' '.join(['cmd.exe', '/c', 'rd', '/q', '/s', file_path])))
if not subprocess.call(
['cmd.exe', '/c', 'rd', '/q', '/s', file_path]):
break
print(' Failed')
time.sleep(3)
return
else:
shutil.rmtree(file_path, ignore_errors=True)
def UnpackArchiveTo(archive_path, output_dir):
extension = os.path.splitext(archive_path)[1]
if extension == '.zip':
_UnzipArchiveTo(archive_path, output_dir)
else:
_UntarArchiveTo(archive_path, output_dir)
extension = os.path.splitext(archive_path)[1]
if extension == '.zip':
_UnzipArchiveTo(archive_path, output_dir)
else:
_UntarArchiveTo(archive_path, output_dir)
def _UnzipArchiveTo(archive_path, output_dir):
print('Unzipping {} in {}.'.format(archive_path, output_dir))
zip_file = zipfile.ZipFile(archive_path)
try:
zip_file.extractall(output_dir)
finally:
zip_file.close()
print('Unzipping {} in {}.'.format(archive_path, output_dir))
zip_file = zipfile.ZipFile(archive_path)
try:
zip_file.extractall(output_dir)
finally:
zip_file.close()
def _UntarArchiveTo(archive_path, output_dir):
print('Untarring {} in {}.'.format(archive_path, output_dir))
tar_file = tarfile.open(archive_path, 'r:gz')
try:
tar_file.extractall(output_dir)
finally:
tar_file.close()
print('Untarring {} in {}.'.format(archive_path, output_dir))
tar_file = tarfile.open(archive_path, 'r:gz')
try:
tar_file.extractall(output_dir)
finally:
tar_file.close()
def GetPlatform():
if sys.platform.startswith('win'):
return 'win'
if sys.platform.startswith('linux'):
return 'linux'
if sys.platform.startswith('darwin'):
return 'mac'
raise Exception("Can't run on platform %s." % sys.platform)
if sys.platform.startswith('win'):
return 'win'
if sys.platform.startswith('linux'):
return 'linux'
if sys.platform.startswith('darwin'):
return 'mac'
raise Exception("Can't run on platform %s." % sys.platform)
def GetExecutableExtension():
return '.exe' if GetPlatform() == 'win' else ''
return '.exe' if GetPlatform() == 'win' else ''

View File

@ -6,23 +6,26 @@
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
def CheckChangeOnUpload(input_api, output_api):
results = []
results.extend(CheckPatchFormatted(input_api, output_api))
return results
results = []
results.extend(CheckPatchFormatted(input_api, output_api))
return results
def CheckPatchFormatted(input_api, output_api):
import git_cl
cmd = ['cl', 'format', '--dry-run', input_api.PresubmitLocalPath()]
code, _ = git_cl.RunGitWithCode(cmd, suppress_stderr=True)
if code == 2:
short_path = input_api.basename(input_api.PresubmitLocalPath())
full_path = input_api.os_path.relpath(input_api.PresubmitLocalPath(),
input_api.change.RepositoryRoot())
return [output_api.PresubmitPromptWarning(
'The %s directory requires source formatting. '
'Please run git cl format %s' %
(short_path, full_path))]
# As this is just a warning, ignore all other errors if the user
# happens to have a broken clang-format, doesn't use git, etc etc.
return []
import git_cl
cmd = ['cl', 'format', '--dry-run', input_api.PresubmitLocalPath()]
code, _ = git_cl.RunGitWithCode(cmd, suppress_stderr=True)
if code == 2:
short_path = input_api.basename(input_api.PresubmitLocalPath())
full_path = input_api.os_path.relpath(
input_api.PresubmitLocalPath(), input_api.change.RepositoryRoot())
return [
output_api.PresubmitPromptWarning(
'The %s directory requires source formatting. '
'Please run git cl format %s' % (short_path, full_path))
]
# As this is just a warning, ignore all other errors if the user
# happens to have a broken clang-format, doesn't use git, etc etc.
return []

View File

@ -8,39 +8,43 @@
def _LicenseHeader(input_api):
"""Returns the license header regexp."""
# Accept any year number from 2003 to the current year
current_year = int(input_api.time.strftime('%Y'))
allowed_years = (str(s) for s in reversed(xrange(2003, current_year + 1)))
years_re = '(' + '|'.join(allowed_years) + ')'
license_header = (
r'.*? Copyright( \(c\))? %(year)s The WebRTC [Pp]roject [Aa]uthors\. '
"""Returns the license header regexp."""
# Accept any year number from 2003 to the current year
current_year = int(input_api.time.strftime('%Y'))
allowed_years = (str(s) for s in reversed(xrange(2003, current_year + 1)))
years_re = '(' + '|'.join(allowed_years) + ')'
license_header = (
r'.*? Copyright( \(c\))? %(year)s The WebRTC [Pp]roject [Aa]uthors\. '
r'All [Rr]ights [Rr]eserved\.\n'
r'.*?\n'
r'.*? Use of this source code is governed by a BSD-style license\n'
r'.*? that can be found in the LICENSE file in the root of the source\n'
r'.*? tree\. An additional intellectual property rights grant can be '
r'.*?\n'
r'.*? Use of this source code is governed by a BSD-style license\n'
r'.*? that can be found in the LICENSE file in the root of the source\n'
r'.*? tree\. An additional intellectual property rights grant can be '
r'found\n'
r'.*? in the file PATENTS\. All contributing project authors may\n'
r'.*? be found in the AUTHORS file in the root of the source tree\.\n'
) % {
'year': years_re,
}
return license_header
r'.*? in the file PATENTS\. All contributing project authors may\n'
r'.*? be found in the AUTHORS file in the root of the source tree\.\n'
) % {
'year': years_re,
}
return license_header
def _CommonChecks(input_api, output_api):
"""Checks common to both upload and commit."""
results = []
results.extend(input_api.canned_checks.CheckLicense(
input_api, output_api, _LicenseHeader(input_api)))
return results
"""Checks common to both upload and commit."""
results = []
results.extend(
input_api.canned_checks.CheckLicense(input_api, output_api,
_LicenseHeader(input_api)))
return results
def CheckChangeOnUpload(input_api, output_api):
results = []
results.extend(_CommonChecks(input_api, output_api))
return results
results = []
results.extend(_CommonChecks(input_api, output_api))
return results
def CheckChangeOnCommit(input_api, output_api):
results = []
results.extend(_CommonChecks(input_api, output_api))
return results
results = []
results.extend(_CommonChecks(input_api, output_api))
return results

View File

@ -7,7 +7,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Script to generate libwebrtc.aar for distribution.
The script has to be run from the root src folder.
@ -33,7 +32,6 @@ import sys
import tempfile
import zipfile
SCRIPT_DIR = os.path.dirname(os.path.realpath(sys.argv[0]))
SRC_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, os.pardir, os.pardir))
DEFAULT_ARCHS = ['armeabi-v7a', 'arm64-v8a', 'x86', 'x86_64']
@ -41,8 +39,8 @@ NEEDED_SO_FILES = ['libjingle_peerconnection_so.so']
JAR_FILE = 'lib.java/sdk/android/libwebrtc.jar'
MANIFEST_FILE = 'sdk/android/AndroidManifest.xml'
TARGETS = [
'sdk/android:libwebrtc',
'sdk/android:libjingle_peerconnection_so',
'sdk/android:libwebrtc',
'sdk/android:libjingle_peerconnection_so',
]
sys.path.append(os.path.join(SCRIPT_DIR, '..', 'libs'))
@ -52,183 +50,209 @@ sys.path.append(os.path.join(SRC_DIR, 'build'))
import find_depot_tools
def _ParseArgs():
parser = argparse.ArgumentParser(description='libwebrtc.aar generator.')
parser.add_argument('--build-dir',
help='Build dir. By default will create and use temporary dir.')
parser.add_argument('--output', default='libwebrtc.aar',
help='Output file of the script.')
parser.add_argument('--arch', default=DEFAULT_ARCHS, nargs='*',
help='Architectures to build. Defaults to %(default)s.')
parser.add_argument('--use-goma', action='store_true', default=False,
help='Use goma.')
parser.add_argument('--verbose', action='store_true', default=False,
help='Debug logging.')
parser.add_argument('--extra-gn-args', default=[], nargs='*',
help="""Additional GN arguments to be used during Ninja generation.
parser = argparse.ArgumentParser(description='libwebrtc.aar generator.')
parser.add_argument(
'--build-dir',
help='Build dir. By default will create and use temporary dir.')
parser.add_argument('--output',
default='libwebrtc.aar',
help='Output file of the script.')
parser.add_argument(
'--arch',
default=DEFAULT_ARCHS,
nargs='*',
help='Architectures to build. Defaults to %(default)s.')
parser.add_argument('--use-goma',
action='store_true',
default=False,
help='Use goma.')
parser.add_argument('--verbose',
action='store_true',
default=False,
help='Debug logging.')
parser.add_argument(
'--extra-gn-args',
default=[],
nargs='*',
help="""Additional GN arguments to be used during Ninja generation.
These are passed to gn inside `--args` switch and
applied after any other arguments and will
override any values defined by the script.
Example of building debug aar file:
build_aar.py --extra-gn-args='is_debug=true'""")
parser.add_argument('--extra-ninja-switches', default=[], nargs='*',
help="""Additional Ninja switches to be used during compilation.
parser.add_argument(
'--extra-ninja-switches',
default=[],
nargs='*',
help="""Additional Ninja switches to be used during compilation.
These are applied after any other Ninja switches.
Example of enabling verbose Ninja output:
build_aar.py --extra-ninja-switches='-v'""")
parser.add_argument('--extra-gn-switches', default=[], nargs='*',
help="""Additional GN switches to be used during compilation.
parser.add_argument(
'--extra-gn-switches',
default=[],
nargs='*',
help="""Additional GN switches to be used during compilation.
These are applied after any other GN switches.
Example of enabling verbose GN output:
build_aar.py --extra-gn-switches='-v'""")
return parser.parse_args()
return parser.parse_args()
def _RunGN(args):
cmd = [sys.executable,
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py')]
cmd.extend(args)
logging.debug('Running: %r', cmd)
subprocess.check_call(cmd)
cmd = [
sys.executable,
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py')
]
cmd.extend(args)
logging.debug('Running: %r', cmd)
subprocess.check_call(cmd)
def _RunNinja(output_directory, args):
cmd = [os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'ninja'),
'-C', output_directory]
cmd.extend(args)
logging.debug('Running: %r', cmd)
subprocess.check_call(cmd)
cmd = [
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'ninja'), '-C',
output_directory
]
cmd.extend(args)
logging.debug('Running: %r', cmd)
subprocess.check_call(cmd)
def _EncodeForGN(value):
"""Encodes value as a GN literal."""
if isinstance(value, str):
return '"' + value + '"'
elif isinstance(value, bool):
return repr(value).lower()
else:
return repr(value)
"""Encodes value as a GN literal."""
if isinstance(value, str):
return '"' + value + '"'
elif isinstance(value, bool):
return repr(value).lower()
else:
return repr(value)
def _GetOutputDirectory(build_dir, arch):
"""Returns the GN output directory for the target architecture."""
return os.path.join(build_dir, arch)
"""Returns the GN output directory for the target architecture."""
return os.path.join(build_dir, arch)
def _GetTargetCpu(arch):
"""Returns target_cpu for the GN build with the given architecture."""
if arch in ['armeabi', 'armeabi-v7a']:
return 'arm'
elif arch == 'arm64-v8a':
return 'arm64'
elif arch == 'x86':
return 'x86'
elif arch == 'x86_64':
return 'x64'
else:
raise Exception('Unknown arch: ' + arch)
"""Returns target_cpu for the GN build with the given architecture."""
if arch in ['armeabi', 'armeabi-v7a']:
return 'arm'
elif arch == 'arm64-v8a':
return 'arm64'
elif arch == 'x86':
return 'x86'
elif arch == 'x86_64':
return 'x64'
else:
raise Exception('Unknown arch: ' + arch)
def _GetArmVersion(arch):
"""Returns arm_version for the GN build with the given architecture."""
if arch == 'armeabi':
return 6
elif arch == 'armeabi-v7a':
return 7
elif arch in ['arm64-v8a', 'x86', 'x86_64']:
return None
else:
raise Exception('Unknown arch: ' + arch)
"""Returns arm_version for the GN build with the given architecture."""
if arch == 'armeabi':
return 6
elif arch == 'armeabi-v7a':
return 7
elif arch in ['arm64-v8a', 'x86', 'x86_64']:
return None
else:
raise Exception('Unknown arch: ' + arch)
def Build(build_dir, arch, use_goma, extra_gn_args, extra_gn_switches,
extra_ninja_switches):
"""Generates target architecture using GN and builds it using ninja."""
logging.info('Building: %s', arch)
output_directory = _GetOutputDirectory(build_dir, arch)
gn_args = {
'target_os': 'android',
'is_debug': False,
'is_component_build': False,
'rtc_include_tests': False,
'target_cpu': _GetTargetCpu(arch),
'use_goma': use_goma
}
arm_version = _GetArmVersion(arch)
if arm_version:
gn_args['arm_version'] = arm_version
gn_args_str = '--args=' + ' '.join([
k + '=' + _EncodeForGN(v) for k, v in gn_args.items()] + extra_gn_args)
"""Generates target architecture using GN and builds it using ninja."""
logging.info('Building: %s', arch)
output_directory = _GetOutputDirectory(build_dir, arch)
gn_args = {
'target_os': 'android',
'is_debug': False,
'is_component_build': False,
'rtc_include_tests': False,
'target_cpu': _GetTargetCpu(arch),
'use_goma': use_goma
}
arm_version = _GetArmVersion(arch)
if arm_version:
gn_args['arm_version'] = arm_version
gn_args_str = '--args=' + ' '.join(
[k + '=' + _EncodeForGN(v)
for k, v in gn_args.items()] + extra_gn_args)
gn_args_list = ['gen', output_directory, gn_args_str]
gn_args_list.extend(extra_gn_switches)
_RunGN(gn_args_list)
gn_args_list = ['gen', output_directory, gn_args_str]
gn_args_list.extend(extra_gn_switches)
_RunGN(gn_args_list)
ninja_args = TARGETS[:]
if use_goma:
ninja_args.extend(['-j', '200'])
ninja_args.extend(extra_ninja_switches)
_RunNinja(output_directory, ninja_args)
ninja_args = TARGETS[:]
if use_goma:
ninja_args.extend(['-j', '200'])
ninja_args.extend(extra_ninja_switches)
_RunNinja(output_directory, ninja_args)
def CollectCommon(aar_file, build_dir, arch):
"""Collects architecture independent files into the .aar-archive."""
logging.info('Collecting common files.')
output_directory = _GetOutputDirectory(build_dir, arch)
aar_file.write(MANIFEST_FILE, 'AndroidManifest.xml')
aar_file.write(os.path.join(output_directory, JAR_FILE), 'classes.jar')
"""Collects architecture independent files into the .aar-archive."""
logging.info('Collecting common files.')
output_directory = _GetOutputDirectory(build_dir, arch)
aar_file.write(MANIFEST_FILE, 'AndroidManifest.xml')
aar_file.write(os.path.join(output_directory, JAR_FILE), 'classes.jar')
def Collect(aar_file, build_dir, arch):
"""Collects architecture specific files into the .aar-archive."""
logging.info('Collecting: %s', arch)
output_directory = _GetOutputDirectory(build_dir, arch)
"""Collects architecture specific files into the .aar-archive."""
logging.info('Collecting: %s', arch)
output_directory = _GetOutputDirectory(build_dir, arch)
abi_dir = os.path.join('jni', arch)
for so_file in NEEDED_SO_FILES:
aar_file.write(os.path.join(output_directory, so_file),
os.path.join(abi_dir, so_file))
abi_dir = os.path.join('jni', arch)
for so_file in NEEDED_SO_FILES:
aar_file.write(os.path.join(output_directory, so_file),
os.path.join(abi_dir, so_file))
def GenerateLicenses(output_dir, build_dir, archs):
builder = LicenseBuilder(
[_GetOutputDirectory(build_dir, arch) for arch in archs], TARGETS)
builder.GenerateLicenseText(output_dir)
builder = LicenseBuilder(
[_GetOutputDirectory(build_dir, arch) for arch in archs], TARGETS)
builder.GenerateLicenseText(output_dir)
def BuildAar(archs, output_file, use_goma=False, extra_gn_args=None,
ext_build_dir=None, extra_gn_switches=None,
def BuildAar(archs,
output_file,
use_goma=False,
extra_gn_args=None,
ext_build_dir=None,
extra_gn_switches=None,
extra_ninja_switches=None):
extra_gn_args = extra_gn_args or []
extra_gn_switches = extra_gn_switches or []
extra_ninja_switches = extra_ninja_switches or []
build_dir = ext_build_dir if ext_build_dir else tempfile.mkdtemp()
extra_gn_args = extra_gn_args or []
extra_gn_switches = extra_gn_switches or []
extra_ninja_switches = extra_ninja_switches or []
build_dir = ext_build_dir if ext_build_dir else tempfile.mkdtemp()
for arch in archs:
Build(build_dir, arch, use_goma, extra_gn_args, extra_gn_switches,
extra_ninja_switches)
with zipfile.ZipFile(output_file, 'w') as aar_file:
# Architecture doesn't matter here, arbitrarily using the first one.
CollectCommon(aar_file, build_dir, archs[0])
for arch in archs:
Collect(aar_file, build_dir, arch)
Build(build_dir, arch, use_goma, extra_gn_args, extra_gn_switches,
extra_ninja_switches)
license_dir = os.path.dirname(os.path.realpath(output_file))
GenerateLicenses(license_dir, build_dir, archs)
with zipfile.ZipFile(output_file, 'w') as aar_file:
# Architecture doesn't matter here, arbitrarily using the first one.
CollectCommon(aar_file, build_dir, archs[0])
for arch in archs:
Collect(aar_file, build_dir, arch)
if not ext_build_dir:
shutil.rmtree(build_dir, True)
license_dir = os.path.dirname(os.path.realpath(output_file))
GenerateLicenses(license_dir, build_dir, archs)
if not ext_build_dir:
shutil.rmtree(build_dir, True)
def main():
args = _ParseArgs()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
args = _ParseArgs()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
BuildAar(args.arch, args.output, args.use_goma, args.extra_gn_args,
args.build_dir, args.extra_gn_switches, args.extra_ninja_switches)
BuildAar(args.arch, args.output, args.use_goma, args.extra_gn_args,
args.build_dir, args.extra_gn_switches, args.extra_ninja_switches)
if __name__ == '__main__':
sys.exit(main())
sys.exit(main())

View File

@ -7,7 +7,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Script for publishing WebRTC AAR on Bintray.
Set BINTRAY_USER and BINTRAY_API_KEY environment variables before running
@ -25,7 +24,6 @@ import sys
import tempfile
import time
SCRIPT_DIR = os.path.dirname(os.path.realpath(sys.argv[0]))
CHECKOUT_ROOT = os.path.abspath(os.path.join(SCRIPT_DIR, os.pardir, os.pardir))
@ -36,7 +34,6 @@ import jinja2
sys.path.append(os.path.join(CHECKOUT_ROOT, 'tools_webrtc'))
from android.build_aar import BuildAar
ARCHS = ['armeabi-v7a', 'arm64-v8a', 'x86', 'x86_64']
MAVEN_REPOSITORY = 'https://google.bintray.com/webrtc'
API = 'https://api.bintray.com'
@ -62,230 +59,249 @@ AAR_PROJECT_VERSION_DEPENDENCY = "implementation 'org.webrtc:google-webrtc:%s'"
def _ParseArgs():
parser = argparse.ArgumentParser(description='Releases WebRTC on Bintray.')
parser.add_argument('--use-goma', action='store_true', default=False,
help='Use goma.')
parser.add_argument('--skip-tests', action='store_true', default=False,
help='Skips running the tests.')
parser.add_argument('--publish', action='store_true', default=False,
help='Automatically publishes the library if the tests pass.')
parser.add_argument('--build-dir', default=None,
help='Temporary directory to store the build files. If not specified, '
'a new directory will be created.')
parser.add_argument('--verbose', action='store_true', default=False,
help='Debug logging.')
return parser.parse_args()
parser = argparse.ArgumentParser(description='Releases WebRTC on Bintray.')
parser.add_argument('--use-goma',
action='store_true',
default=False,
help='Use goma.')
parser.add_argument('--skip-tests',
action='store_true',
default=False,
help='Skips running the tests.')
parser.add_argument(
'--publish',
action='store_true',
default=False,
help='Automatically publishes the library if the tests pass.')
parser.add_argument(
'--build-dir',
default=None,
help='Temporary directory to store the build files. If not specified, '
'a new directory will be created.')
parser.add_argument('--verbose',
action='store_true',
default=False,
help='Debug logging.')
return parser.parse_args()
def _GetCommitHash():
commit_hash = subprocess.check_output(
['git', 'rev-parse', 'HEAD'], cwd=CHECKOUT_ROOT).strip()
return commit_hash
commit_hash = subprocess.check_output(['git', 'rev-parse', 'HEAD'],
cwd=CHECKOUT_ROOT).strip()
return commit_hash
def _GetCommitPos():
commit_message = subprocess.check_output(
['git', 'rev-list', '--format=%B', '--max-count=1', 'HEAD'],
cwd=CHECKOUT_ROOT)
commit_pos_match = re.search(
COMMIT_POSITION_REGEX, commit_message, re.MULTILINE)
if not commit_pos_match:
raise Exception('Commit position not found in the commit message: %s'
% commit_message)
return commit_pos_match.group(1)
commit_message = subprocess.check_output(
['git', 'rev-list', '--format=%B', '--max-count=1', 'HEAD'],
cwd=CHECKOUT_ROOT)
commit_pos_match = re.search(COMMIT_POSITION_REGEX, commit_message,
re.MULTILINE)
if not commit_pos_match:
raise Exception('Commit position not found in the commit message: %s' %
commit_message)
return commit_pos_match.group(1)
def _UploadFile(user, password, filename, version, target_file):
# URL is of format:
# <repository_api>/<version>/<group_id>/<artifact_id>/<version>/<target_file>
# Example:
# https://api.bintray.com/content/google/webrtc/google-webrtc/1.0.19742/org/webrtc/google-webrtc/1.0.19742/google-webrtc-1.0.19742.aar
# URL is of format:
# <repository_api>/<version>/<group_id>/<artifact_id>/<version>/<target_file>
# Example:
# https://api.bintray.com/content/google/webrtc/google-webrtc/1.0.19742/org/webrtc/google-webrtc/1.0.19742/google-webrtc-1.0.19742.aar
target_dir = version + '/' + GROUP_ID + '/' + ARTIFACT_ID + '/' + version
target_path = target_dir + '/' + target_file
url = CONTENT_API + '/' + target_path
target_dir = version + '/' + GROUP_ID + '/' + ARTIFACT_ID + '/' + version
target_path = target_dir + '/' + target_file
url = CONTENT_API + '/' + target_path
logging.info('Uploading %s to %s', filename, url)
with open(filename) as fh:
file_data = fh.read()
logging.info('Uploading %s to %s', filename, url)
with open(filename) as fh:
file_data = fh.read()
for attempt in xrange(UPLOAD_TRIES):
try:
response = requests.put(url, data=file_data, auth=(user, password),
timeout=API_TIMEOUT_SECONDS)
break
except requests.exceptions.Timeout as e:
logging.warning('Timeout while uploading: %s', e)
time.sleep(UPLOAD_RETRY_BASE_SLEEP_SECONDS ** attempt)
else:
raise Exception('Failed to upload %s' % filename)
for attempt in xrange(UPLOAD_TRIES):
try:
response = requests.put(url,
data=file_data,
auth=(user, password),
timeout=API_TIMEOUT_SECONDS)
break
except requests.exceptions.Timeout as e:
logging.warning('Timeout while uploading: %s', e)
time.sleep(UPLOAD_RETRY_BASE_SLEEP_SECONDS**attempt)
else:
raise Exception('Failed to upload %s' % filename)
if not response.ok:
raise Exception('Failed to upload %s. Response: %s' % (filename, response))
logging.info('Uploaded %s: %s', filename, response)
if not response.ok:
raise Exception('Failed to upload %s. Response: %s' %
(filename, response))
logging.info('Uploaded %s: %s', filename, response)
def _GeneratePom(target_file, version, commit):
env = jinja2.Environment(
loader=jinja2.PackageLoader('release_aar'),
)
template = env.get_template('pom.jinja')
pom = template.render(version=version, commit=commit)
with open(target_file, 'w') as fh:
fh.write(pom)
env = jinja2.Environment(loader=jinja2.PackageLoader('release_aar'), )
template = env.get_template('pom.jinja')
pom = template.render(version=version, commit=commit)
with open(target_file, 'w') as fh:
fh.write(pom)
def _TestAAR(tmp_dir, username, password, version):
"""Runs AppRTCMobile tests using the AAR. Returns true if the tests pass."""
logging.info('Testing library.')
env = jinja2.Environment(
loader=jinja2.PackageLoader('release_aar'),
)
"""Runs AppRTCMobile tests using the AAR. Returns true if the tests pass."""
logging.info('Testing library.')
env = jinja2.Environment(loader=jinja2.PackageLoader('release_aar'), )
gradle_backup = os.path.join(tmp_dir, 'build.gradle.backup')
app_gradle_backup = os.path.join(tmp_dir, 'app-build.gradle.backup')
gradle_backup = os.path.join(tmp_dir, 'build.gradle.backup')
app_gradle_backup = os.path.join(tmp_dir, 'app-build.gradle.backup')
# Make backup copies of the project files before modifying them.
shutil.copy2(AAR_PROJECT_GRADLE, gradle_backup)
shutil.copy2(AAR_PROJECT_APP_GRADLE, app_gradle_backup)
# Make backup copies of the project files before modifying them.
shutil.copy2(AAR_PROJECT_GRADLE, gradle_backup)
shutil.copy2(AAR_PROJECT_APP_GRADLE, app_gradle_backup)
try:
maven_repository_template = env.get_template('maven-repository.jinja')
maven_repository = maven_repository_template.render(
url=MAVEN_REPOSITORY, username=username, password=password)
# Append Maven repository to build file to download unpublished files.
with open(AAR_PROJECT_GRADLE, 'a') as gradle_file:
gradle_file.write(maven_repository)
# Read app build file.
with open(AAR_PROJECT_APP_GRADLE, 'r') as gradle_app_file:
gradle_app = gradle_app_file.read()
if AAR_PROJECT_DEPENDENCY not in gradle_app:
raise Exception(
'%s not found in the build file.' % AAR_PROJECT_DEPENDENCY)
# Set version to the version to be tested.
target_dependency = AAR_PROJECT_VERSION_DEPENDENCY % version
gradle_app = gradle_app.replace(AAR_PROJECT_DEPENDENCY, target_dependency)
# Write back.
with open(AAR_PROJECT_APP_GRADLE, 'w') as gradle_app_file:
gradle_app_file.write(gradle_app)
# Uninstall any existing version of AppRTCMobile.
logging.info('Uninstalling previous AppRTCMobile versions. It is okay for '
'these commands to fail if AppRTCMobile is not installed.')
subprocess.call([ADB_BIN, 'uninstall', 'org.appspot.apprtc'])
subprocess.call([ADB_BIN, 'uninstall', 'org.appspot.apprtc.test'])
# Run tests.
try:
# First clean the project.
subprocess.check_call([GRADLEW_BIN, 'clean'], cwd=AAR_PROJECT_DIR)
# Then run the tests.
subprocess.check_call([GRADLEW_BIN, 'connectedDebugAndroidTest'],
cwd=AAR_PROJECT_DIR)
except subprocess.CalledProcessError:
logging.exception('Test failure.')
return False # Clean or tests failed
maven_repository_template = env.get_template('maven-repository.jinja')
maven_repository = maven_repository_template.render(
url=MAVEN_REPOSITORY, username=username, password=password)
return True # Tests pass
finally:
# Restore backups.
shutil.copy2(gradle_backup, AAR_PROJECT_GRADLE)
shutil.copy2(app_gradle_backup, AAR_PROJECT_APP_GRADLE)
# Append Maven repository to build file to download unpublished files.
with open(AAR_PROJECT_GRADLE, 'a') as gradle_file:
gradle_file.write(maven_repository)
# Read app build file.
with open(AAR_PROJECT_APP_GRADLE, 'r') as gradle_app_file:
gradle_app = gradle_app_file.read()
if AAR_PROJECT_DEPENDENCY not in gradle_app:
raise Exception('%s not found in the build file.' %
AAR_PROJECT_DEPENDENCY)
# Set version to the version to be tested.
target_dependency = AAR_PROJECT_VERSION_DEPENDENCY % version
gradle_app = gradle_app.replace(AAR_PROJECT_DEPENDENCY,
target_dependency)
# Write back.
with open(AAR_PROJECT_APP_GRADLE, 'w') as gradle_app_file:
gradle_app_file.write(gradle_app)
# Uninstall any existing version of AppRTCMobile.
logging.info(
'Uninstalling previous AppRTCMobile versions. It is okay for '
'these commands to fail if AppRTCMobile is not installed.')
subprocess.call([ADB_BIN, 'uninstall', 'org.appspot.apprtc'])
subprocess.call([ADB_BIN, 'uninstall', 'org.appspot.apprtc.test'])
# Run tests.
try:
# First clean the project.
subprocess.check_call([GRADLEW_BIN, 'clean'], cwd=AAR_PROJECT_DIR)
# Then run the tests.
subprocess.check_call([GRADLEW_BIN, 'connectedDebugAndroidTest'],
cwd=AAR_PROJECT_DIR)
except subprocess.CalledProcessError:
logging.exception('Test failure.')
return False # Clean or tests failed
return True # Tests pass
finally:
# Restore backups.
shutil.copy2(gradle_backup, AAR_PROJECT_GRADLE)
shutil.copy2(app_gradle_backup, AAR_PROJECT_APP_GRADLE)
def _PublishAAR(user, password, version, additional_args):
args = {
'publish_wait_for_secs': 0 # Publish asynchronously.
}
args.update(additional_args)
args = {
'publish_wait_for_secs': 0 # Publish asynchronously.
}
args.update(additional_args)
url = CONTENT_API + '/' + version + '/publish'
response = requests.post(url, data=json.dumps(args), auth=(user, password),
timeout=API_TIMEOUT_SECONDS)
url = CONTENT_API + '/' + version + '/publish'
response = requests.post(url,
data=json.dumps(args),
auth=(user, password),
timeout=API_TIMEOUT_SECONDS)
if not response.ok:
raise Exception('Failed to publish. Response: %s' % response)
if not response.ok:
raise Exception('Failed to publish. Response: %s' % response)
def _DeleteUnpublishedVersion(user, password, version):
url = PACKAGES_API + '/versions/' + version
response = requests.get(url, auth=(user, password),
timeout=API_TIMEOUT_SECONDS)
if not response.ok:
raise Exception('Failed to get version info. Response: %s' % response)
url = PACKAGES_API + '/versions/' + version
response = requests.get(url,
auth=(user, password),
timeout=API_TIMEOUT_SECONDS)
if not response.ok:
raise Exception('Failed to get version info. Response: %s' % response)
version_info = json.loads(response.content)
if version_info['published']:
logging.info('Version has already been published, not deleting.')
return
version_info = json.loads(response.content)
if version_info['published']:
logging.info('Version has already been published, not deleting.')
return
logging.info('Deleting unpublished version.')
response = requests.delete(url, auth=(user, password),
timeout=API_TIMEOUT_SECONDS)
if not response.ok:
raise Exception('Failed to delete version. Response: %s' % response)
logging.info('Deleting unpublished version.')
response = requests.delete(url,
auth=(user, password),
timeout=API_TIMEOUT_SECONDS)
if not response.ok:
raise Exception('Failed to delete version. Response: %s' % response)
def ReleaseAar(use_goma, skip_tests, publish, build_dir):
version = '1.0.' + _GetCommitPos()
commit = _GetCommitHash()
logging.info('Releasing AAR version %s with hash %s', version, commit)
version = '1.0.' + _GetCommitPos()
commit = _GetCommitHash()
logging.info('Releasing AAR version %s with hash %s', version, commit)
user = os.environ.get('BINTRAY_USER', None)
api_key = os.environ.get('BINTRAY_API_KEY', None)
if not user or not api_key:
raise Exception('Environment variables BINTRAY_USER and BINTRAY_API_KEY '
'must be defined.')
user = os.environ.get('BINTRAY_USER', None)
api_key = os.environ.get('BINTRAY_API_KEY', None)
if not user or not api_key:
raise Exception(
'Environment variables BINTRAY_USER and BINTRAY_API_KEY '
'must be defined.')
# If build directory is not specified, create a temporary directory.
use_tmp_dir = not build_dir
if use_tmp_dir:
build_dir = tempfile.mkdtemp()
try:
base_name = ARTIFACT_ID + '-' + version
aar_file = os.path.join(build_dir, base_name + '.aar')
third_party_licenses_file = os.path.join(build_dir, 'LICENSE.md')
pom_file = os.path.join(build_dir, base_name + '.pom')
logging.info('Building at %s', build_dir)
BuildAar(ARCHS, aar_file,
use_goma=use_goma,
ext_build_dir=os.path.join(build_dir, 'aar-build'))
_GeneratePom(pom_file, version, commit)
_UploadFile(user, api_key, aar_file, version, base_name + '.aar')
_UploadFile(user, api_key, third_party_licenses_file, version,
'THIRD_PARTY_LICENSES.md')
_UploadFile(user, api_key, pom_file, version, base_name + '.pom')
tests_pass = skip_tests or _TestAAR(build_dir, user, api_key, version)
if not tests_pass:
logging.info('Discarding library.')
_PublishAAR(user, api_key, version, {'discard': True})
_DeleteUnpublishedVersion(user, api_key, version)
raise Exception('Test failure. Discarded library.')
if publish:
logging.info('Publishing library.')
_PublishAAR(user, api_key, version, {})
else:
logging.info('Note: The library has not not been published automatically.'
' Please do so manually if desired.')
finally:
# If build directory is not specified, create a temporary directory.
use_tmp_dir = not build_dir
if use_tmp_dir:
shutil.rmtree(build_dir, True)
build_dir = tempfile.mkdtemp()
try:
base_name = ARTIFACT_ID + '-' + version
aar_file = os.path.join(build_dir, base_name + '.aar')
third_party_licenses_file = os.path.join(build_dir, 'LICENSE.md')
pom_file = os.path.join(build_dir, base_name + '.pom')
logging.info('Building at %s', build_dir)
BuildAar(ARCHS,
aar_file,
use_goma=use_goma,
ext_build_dir=os.path.join(build_dir, 'aar-build'))
_GeneratePom(pom_file, version, commit)
_UploadFile(user, api_key, aar_file, version, base_name + '.aar')
_UploadFile(user, api_key, third_party_licenses_file, version,
'THIRD_PARTY_LICENSES.md')
_UploadFile(user, api_key, pom_file, version, base_name + '.pom')
tests_pass = skip_tests or _TestAAR(build_dir, user, api_key, version)
if not tests_pass:
logging.info('Discarding library.')
_PublishAAR(user, api_key, version, {'discard': True})
_DeleteUnpublishedVersion(user, api_key, version)
raise Exception('Test failure. Discarded library.')
if publish:
logging.info('Publishing library.')
_PublishAAR(user, api_key, version, {})
else:
logging.info(
'Note: The library has not not been published automatically.'
' Please do so manually if desired.')
finally:
if use_tmp_dir:
shutil.rmtree(build_dir, True)
def main():
args = _ParseArgs()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
ReleaseAar(args.use_goma, args.skip_tests, args.publish, args.build_dir)
args = _ParseArgs()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
ReleaseAar(args.use_goma, args.skip_tests, args.publish, args.build_dir)
if __name__ == '__main__':
sys.exit(main())
sys.exit(main())

File diff suppressed because it is too large Load Diff

View File

@ -14,7 +14,6 @@ import sys
import tempfile
import unittest
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PARENT_DIR = os.path.join(SCRIPT_DIR, os.pardir)
sys.path.append(PARENT_DIR)
@ -27,15 +26,15 @@ from roll_deps import CalculateChangedDeps, FindAddedDeps, \
import mock
TEST_DATA_VARS = {
'chromium_git': 'https://chromium.googlesource.com',
'chromium_revision': '1b9c098a08e40114e44b6c1ec33ddf95c40b901d',
'chromium_git': 'https://chromium.googlesource.com',
'chromium_revision': '1b9c098a08e40114e44b6c1ec33ddf95c40b901d',
}
DEPS_ENTRIES = {
'src/build': 'https://build.com',
'src/third_party/depot_tools': 'https://depottools.com',
'src/testing/gtest': 'https://gtest.com',
'src/testing/gmock': 'https://gmock.com',
'src/build': 'https://build.com',
'src/third_party/depot_tools': 'https://depottools.com',
'src/testing/gtest': 'https://gtest.com',
'src/testing/gmock': 'https://gmock.com',
}
BUILD_OLD_REV = '52f7afeca991d96d68cf0507e20dbdd5b845691f'
@ -47,291 +46,298 @@ NO_CHROMIUM_REVISION_UPDATE = ChromiumRevisionUpdate('cafe', 'cafe')
class TestError(Exception):
pass
pass
class FakeCmd(object):
def __init__(self):
self.expectations = []
def __init__(self):
self.expectations = []
def AddExpectation(self, *args, **kwargs):
returns = kwargs.pop('_returns', None)
ignores = kwargs.pop('_ignores', [])
self.expectations.append((args, kwargs, returns, ignores))
def AddExpectation(self, *args, **kwargs):
returns = kwargs.pop('_returns', None)
ignores = kwargs.pop('_ignores', [])
self.expectations.append((args, kwargs, returns, ignores))
def __call__(self, *args, **kwargs):
if not self.expectations:
raise TestError('Got unexpected\n%s\n%s' % (args, kwargs))
exp_args, exp_kwargs, exp_returns, ignores = self.expectations.pop(0)
for item in ignores:
kwargs.pop(item, None)
if args != exp_args or kwargs != exp_kwargs:
message = 'Expected:\n args: %s\n kwargs: %s\n' % (exp_args, exp_kwargs)
message += 'Got:\n args: %s\n kwargs: %s\n' % (args, kwargs)
raise TestError(message)
return exp_returns
def __call__(self, *args, **kwargs):
if not self.expectations:
raise TestError('Got unexpected\n%s\n%s' % (args, kwargs))
exp_args, exp_kwargs, exp_returns, ignores = self.expectations.pop(0)
for item in ignores:
kwargs.pop(item, None)
if args != exp_args or kwargs != exp_kwargs:
message = 'Expected:\n args: %s\n kwargs: %s\n' % (exp_args,
exp_kwargs)
message += 'Got:\n args: %s\n kwargs: %s\n' % (args, kwargs)
raise TestError(message)
return exp_returns
class NullCmd(object):
"""No-op mock when calls mustn't be checked. """
"""No-op mock when calls mustn't be checked. """
def __call__(self, *args, **kwargs):
# Empty stdout and stderr.
return None, None
def __call__(self, *args, **kwargs):
# Empty stdout and stderr.
return None, None
class TestRollChromiumRevision(unittest.TestCase):
def setUp(self):
self._output_dir = tempfile.mkdtemp()
test_data_dir = os.path.join(SCRIPT_DIR, 'testdata', 'roll_deps')
for test_file in glob.glob(os.path.join(test_data_dir, '*')):
shutil.copy(test_file, self._output_dir)
join = lambda f: os.path.join(self._output_dir, f)
self._webrtc_depsfile = join('DEPS')
self._new_cr_depsfile = join('DEPS.chromium.new')
self._webrtc_depsfile_android = join('DEPS.with_android_deps')
self._new_cr_depsfile_android = join('DEPS.chromium.with_android_deps')
self.fake = FakeCmd()
def setUp(self):
self._output_dir = tempfile.mkdtemp()
test_data_dir = os.path.join(SCRIPT_DIR, 'testdata', 'roll_deps')
for test_file in glob.glob(os.path.join(test_data_dir, '*')):
shutil.copy(test_file, self._output_dir)
join = lambda f: os.path.join(self._output_dir, f)
self._webrtc_depsfile = join('DEPS')
self._new_cr_depsfile = join('DEPS.chromium.new')
self._webrtc_depsfile_android = join('DEPS.with_android_deps')
self._new_cr_depsfile_android = join('DEPS.chromium.with_android_deps')
self.fake = FakeCmd()
def tearDown(self):
shutil.rmtree(self._output_dir, ignore_errors=True)
self.assertEqual(self.fake.expectations, [])
def tearDown(self):
shutil.rmtree(self._output_dir, ignore_errors=True)
self.assertEqual(self.fake.expectations, [])
def testVarLookup(self):
local_scope = {'foo': 'wrong', 'vars': {'foo': 'bar'}}
lookup = roll_deps.VarLookup(local_scope)
self.assertEquals(lookup('foo'), 'bar')
def testVarLookup(self):
local_scope = {'foo': 'wrong', 'vars': {'foo': 'bar'}}
lookup = roll_deps.VarLookup(local_scope)
self.assertEquals(lookup('foo'), 'bar')
def testUpdateDepsFile(self):
new_rev = 'aaaaabbbbbcccccdddddeeeeefffff0000011111'
current_rev = TEST_DATA_VARS['chromium_revision']
def testUpdateDepsFile(self):
new_rev = 'aaaaabbbbbcccccdddddeeeeefffff0000011111'
current_rev = TEST_DATA_VARS['chromium_revision']
with open(self._new_cr_depsfile_android) as deps_file:
new_cr_contents = deps_file.read()
with open(self._new_cr_depsfile_android) as deps_file:
new_cr_contents = deps_file.read()
UpdateDepsFile(self._webrtc_depsfile,
ChromiumRevisionUpdate(current_rev, new_rev),
[],
new_cr_contents)
with open(self._webrtc_depsfile) as deps_file:
deps_contents = deps_file.read()
self.assertTrue(new_rev in deps_contents,
'Failed to find %s in\n%s' % (new_rev, deps_contents))
UpdateDepsFile(self._webrtc_depsfile,
ChromiumRevisionUpdate(current_rev, new_rev), [],
new_cr_contents)
with open(self._webrtc_depsfile) as deps_file:
deps_contents = deps_file.read()
self.assertTrue(
new_rev in deps_contents,
'Failed to find %s in\n%s' % (new_rev, deps_contents))
def _UpdateDepsSetup(self):
with open(self._webrtc_depsfile_android) as deps_file:
webrtc_contents = deps_file.read()
with open(self._new_cr_depsfile_android) as deps_file:
new_cr_contents = deps_file.read()
webrtc_deps = ParseDepsDict(webrtc_contents)
new_cr_deps = ParseDepsDict(new_cr_contents)
def _UpdateDepsSetup(self):
with open(self._webrtc_depsfile_android) as deps_file:
webrtc_contents = deps_file.read()
with open(self._new_cr_depsfile_android) as deps_file:
new_cr_contents = deps_file.read()
webrtc_deps = ParseDepsDict(webrtc_contents)
new_cr_deps = ParseDepsDict(new_cr_contents)
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
with mock.patch('roll_deps._RunCommand', NullCmd()):
UpdateDepsFile(self._webrtc_depsfile_android,
NO_CHROMIUM_REVISION_UPDATE,
changed_deps,
new_cr_contents)
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
with mock.patch('roll_deps._RunCommand', NullCmd()):
UpdateDepsFile(self._webrtc_depsfile_android,
NO_CHROMIUM_REVISION_UPDATE, changed_deps,
new_cr_contents)
with open(self._webrtc_depsfile_android) as deps_file:
updated_contents = deps_file.read()
with open(self._webrtc_depsfile_android) as deps_file:
updated_contents = deps_file.read()
return webrtc_contents, updated_contents
return webrtc_contents, updated_contents
def testUpdateAndroidGeneratedDeps(self):
_, updated_contents = self._UpdateDepsSetup()
def testUpdateAndroidGeneratedDeps(self):
_, updated_contents = self._UpdateDepsSetup()
changed = 'third_party/android_deps/libs/android_arch_core_common'
changed_version = '1.0.0-cr0'
self.assertTrue(changed in updated_contents)
self.assertTrue(changed_version in updated_contents)
changed = 'third_party/android_deps/libs/android_arch_core_common'
changed_version = '1.0.0-cr0'
self.assertTrue(changed in updated_contents)
self.assertTrue(changed_version in updated_contents)
def testAddAndroidGeneratedDeps(self):
webrtc_contents, updated_contents = self._UpdateDepsSetup()
def testAddAndroidGeneratedDeps(self):
webrtc_contents, updated_contents = self._UpdateDepsSetup()
added = 'third_party/android_deps/libs/android_arch_lifecycle_common'
self.assertFalse(added in webrtc_contents)
self.assertTrue(added in updated_contents)
added = 'third_party/android_deps/libs/android_arch_lifecycle_common'
self.assertFalse(added in webrtc_contents)
self.assertTrue(added in updated_contents)
def testRemoveAndroidGeneratedDeps(self):
webrtc_contents, updated_contents = self._UpdateDepsSetup()
def testRemoveAndroidGeneratedDeps(self):
webrtc_contents, updated_contents = self._UpdateDepsSetup()
removed = 'third_party/android_deps/libs/android_arch_lifecycle_runtime'
self.assertTrue(removed in webrtc_contents)
self.assertFalse(removed in updated_contents)
removed = 'third_party/android_deps/libs/android_arch_lifecycle_runtime'
self.assertTrue(removed in webrtc_contents)
self.assertFalse(removed in updated_contents)
def testParseDepsDict(self):
with open(self._webrtc_depsfile) as deps_file:
deps_contents = deps_file.read()
local_scope = ParseDepsDict(deps_contents)
vars_dict = local_scope['vars']
def testParseDepsDict(self):
with open(self._webrtc_depsfile) as deps_file:
deps_contents = deps_file.read()
local_scope = ParseDepsDict(deps_contents)
vars_dict = local_scope['vars']
def AssertVar(variable_name):
self.assertEquals(vars_dict[variable_name], TEST_DATA_VARS[variable_name])
AssertVar('chromium_git')
AssertVar('chromium_revision')
self.assertEquals(len(local_scope['deps']), 3)
self.assertEquals(len(local_scope['deps_os']), 1)
def AssertVar(variable_name):
self.assertEquals(vars_dict[variable_name],
TEST_DATA_VARS[variable_name])
def testGetMatchingDepsEntriesReturnsPathInSimpleCase(self):
entries = GetMatchingDepsEntries(DEPS_ENTRIES, 'src/testing/gtest')
self.assertEquals(len(entries), 1)
self.assertEquals(entries[0], DEPS_ENTRIES['src/testing/gtest'])
AssertVar('chromium_git')
AssertVar('chromium_revision')
self.assertEquals(len(local_scope['deps']), 3)
self.assertEquals(len(local_scope['deps_os']), 1)
def testGetMatchingDepsEntriesHandlesSimilarStartingPaths(self):
entries = GetMatchingDepsEntries(DEPS_ENTRIES, 'src/testing')
self.assertEquals(len(entries), 2)
def testGetMatchingDepsEntriesReturnsPathInSimpleCase(self):
entries = GetMatchingDepsEntries(DEPS_ENTRIES, 'src/testing/gtest')
self.assertEquals(len(entries), 1)
self.assertEquals(entries[0], DEPS_ENTRIES['src/testing/gtest'])
def testGetMatchingDepsEntriesHandlesTwoPathsWithIdenticalFirstParts(self):
entries = GetMatchingDepsEntries(DEPS_ENTRIES, 'src/build')
self.assertEquals(len(entries), 1)
def testGetMatchingDepsEntriesHandlesSimilarStartingPaths(self):
entries = GetMatchingDepsEntries(DEPS_ENTRIES, 'src/testing')
self.assertEquals(len(entries), 2)
def testGetMatchingDepsEntriesHandlesTwoPathsWithIdenticalFirstParts(self):
entries = GetMatchingDepsEntries(DEPS_ENTRIES, 'src/build')
self.assertEquals(len(entries), 1)
def testCalculateChangedDeps(self):
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile)
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile)
with mock.patch('roll_deps._RunCommand', self.fake):
_SetupGitLsRemoteCall(
self.fake, 'https://chromium.googlesource.com/chromium/src/build',
BUILD_NEW_REV)
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
def testCalculateChangedDeps(self):
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile)
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile)
with mock.patch('roll_deps._RunCommand', self.fake):
_SetupGitLsRemoteCall(
self.fake,
'https://chromium.googlesource.com/chromium/src/build',
BUILD_NEW_REV)
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
self.assertEquals(len(changed_deps), 3)
self.assertEquals(changed_deps[0].path, 'src/build')
self.assertEquals(changed_deps[0].current_rev, BUILD_OLD_REV)
self.assertEquals(changed_deps[0].new_rev, BUILD_NEW_REV)
self.assertEquals(len(changed_deps), 3)
self.assertEquals(changed_deps[0].path, 'src/build')
self.assertEquals(changed_deps[0].current_rev, BUILD_OLD_REV)
self.assertEquals(changed_deps[0].new_rev, BUILD_NEW_REV)
self.assertEquals(changed_deps[1].path, 'src/third_party/depot_tools')
self.assertEquals(changed_deps[1].current_rev, DEPOTTOOLS_OLD_REV)
self.assertEquals(changed_deps[1].new_rev, DEPOTTOOLS_NEW_REV)
self.assertEquals(changed_deps[1].path, 'src/third_party/depot_tools')
self.assertEquals(changed_deps[1].current_rev, DEPOTTOOLS_OLD_REV)
self.assertEquals(changed_deps[1].new_rev, DEPOTTOOLS_NEW_REV)
self.assertEquals(changed_deps[2].path, 'src/third_party/xstream')
self.assertEquals(changed_deps[2].package, 'chromium/third_party/xstream')
self.assertEquals(changed_deps[2].current_version, 'version:1.4.8-cr0')
self.assertEquals(changed_deps[2].new_version, 'version:1.10.0-cr0')
self.assertEquals(changed_deps[2].path, 'src/third_party/xstream')
self.assertEquals(changed_deps[2].package,
'chromium/third_party/xstream')
self.assertEquals(changed_deps[2].current_version, 'version:1.4.8-cr0')
self.assertEquals(changed_deps[2].new_version, 'version:1.10.0-cr0')
def testWithDistinctDeps(self):
"""Check CalculateChangedDeps still works when deps are added/removed. """
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
self.assertEquals(len(changed_deps), 1)
self.assertEquals(
changed_deps[0].path,
'src/third_party/android_deps/libs/android_arch_core_common')
self.assertEquals(
changed_deps[0].package,
'chromium/third_party/android_deps/libs/android_arch_core_common')
self.assertEquals(changed_deps[0].current_version, 'version:0.9.0')
self.assertEquals(changed_deps[0].new_version, 'version:1.0.0-cr0')
def testWithDistinctDeps(self):
"""Check CalculateChangedDeps still works when deps are added/removed. """
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
self.assertEquals(len(changed_deps), 1)
self.assertEquals(
changed_deps[0].path,
'src/third_party/android_deps/libs/android_arch_core_common')
self.assertEquals(
changed_deps[0].package,
'chromium/third_party/android_deps/libs/android_arch_core_common')
self.assertEquals(changed_deps[0].current_version, 'version:0.9.0')
self.assertEquals(changed_deps[0].new_version, 'version:1.0.0-cr0')
def testFindAddedDeps(self):
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
added_android_paths, other_paths = FindAddedDeps(webrtc_deps, new_cr_deps)
self.assertEquals(
added_android_paths,
['src/third_party/android_deps/libs/android_arch_lifecycle_common'])
self.assertEquals(other_paths, [])
def testFindAddedDeps(self):
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
added_android_paths, other_paths = FindAddedDeps(
webrtc_deps, new_cr_deps)
self.assertEquals(added_android_paths, [
'src/third_party/android_deps/libs/android_arch_lifecycle_common'
])
self.assertEquals(other_paths, [])
def testFindRemovedDeps(self):
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
removed_android_paths, other_paths = FindRemovedDeps(webrtc_deps,
new_cr_deps)
self.assertEquals(removed_android_paths,
['src/third_party/android_deps/libs/android_arch_lifecycle_runtime'])
self.assertEquals(other_paths, [])
def testFindRemovedDeps(self):
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
removed_android_paths, other_paths = FindRemovedDeps(
webrtc_deps, new_cr_deps)
self.assertEquals(removed_android_paths, [
'src/third_party/android_deps/libs/android_arch_lifecycle_runtime'
])
self.assertEquals(other_paths, [])
def testMissingDepsIsDetected(self):
"""Check an error is reported when deps cannot be automatically removed."""
# The situation at test is the following:
# * A WebRTC DEPS entry is missing from Chromium.
# * The dependency isn't an android_deps (those are supported).
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile)
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
_, other_paths = FindRemovedDeps(webrtc_deps, new_cr_deps)
self.assertEquals(other_paths, ['src/third_party/xstream',
'src/third_party/depot_tools'])
def testMissingDepsIsDetected(self):
"""Check an error is reported when deps cannot be automatically removed."""
# The situation at test is the following:
# * A WebRTC DEPS entry is missing from Chromium.
# * The dependency isn't an android_deps (those are supported).
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile)
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
_, other_paths = FindRemovedDeps(webrtc_deps, new_cr_deps)
self.assertEquals(
other_paths,
['src/third_party/xstream', 'src/third_party/depot_tools'])
def testExpectedDepsIsNotReportedMissing(self):
"""Some deps musn't be seen as missing, even if absent from Chromium."""
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile)
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
removed_android_paths, other_paths = FindRemovedDeps(webrtc_deps,
new_cr_deps)
self.assertTrue('src/build' not in removed_android_paths)
self.assertTrue('src/build' not in other_paths)
def testExpectedDepsIsNotReportedMissing(self):
"""Some deps musn't be seen as missing, even if absent from Chromium."""
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile)
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
removed_android_paths, other_paths = FindRemovedDeps(
webrtc_deps, new_cr_deps)
self.assertTrue('src/build' not in removed_android_paths)
self.assertTrue('src/build' not in other_paths)
def _CommitMessageSetup(self):
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
def _CommitMessageSetup(self):
webrtc_deps = ParseLocalDepsFile(self._webrtc_depsfile_android)
new_cr_deps = ParseLocalDepsFile(self._new_cr_depsfile_android)
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
added_paths, _ = FindAddedDeps(webrtc_deps, new_cr_deps)
removed_paths, _ = FindRemovedDeps(webrtc_deps, new_cr_deps)
changed_deps = CalculateChangedDeps(webrtc_deps, new_cr_deps)
added_paths, _ = FindAddedDeps(webrtc_deps, new_cr_deps)
removed_paths, _ = FindRemovedDeps(webrtc_deps, new_cr_deps)
current_commit_pos = 'cafe'
new_commit_pos = 'f00d'
current_commit_pos = 'cafe'
new_commit_pos = 'f00d'
with mock.patch('roll_deps._RunCommand', self.fake):
# We don't really care, but it's needed to construct the message.
self.fake.AddExpectation(['git', 'config', 'user.email'],
_returns=('nobody@nowhere.no', None),
_ignores=['working_dir'])
with mock.patch('roll_deps._RunCommand', self.fake):
# We don't really care, but it's needed to construct the message.
self.fake.AddExpectation(['git', 'config', 'user.email'],
_returns=('nobody@nowhere.no', None),
_ignores=['working_dir'])
commit_msg = GenerateCommitMessage(
NO_CHROMIUM_REVISION_UPDATE, current_commit_pos, new_commit_pos,
changed_deps, added_paths, removed_paths)
commit_msg = GenerateCommitMessage(NO_CHROMIUM_REVISION_UPDATE,
current_commit_pos,
new_commit_pos, changed_deps,
added_paths, removed_paths)
return [l.strip() for l in commit_msg.split('\n')]
return [l.strip() for l in commit_msg.split('\n')]
def testChangedDepsInCommitMessage(self):
commit_lines = self._CommitMessageSetup()
def testChangedDepsInCommitMessage(self):
commit_lines = self._CommitMessageSetup()
changed = '* src/third_party/android_deps/libs/' \
'android_arch_core_common: version:0.9.0..version:1.0.0-cr0'
self.assertTrue(changed in commit_lines)
# Check it is in adequate section.
changed_line = commit_lines.index(changed)
self.assertTrue('Changed' in commit_lines[changed_line-1])
changed = '* src/third_party/android_deps/libs/' \
'android_arch_core_common: version:0.9.0..version:1.0.0-cr0'
self.assertTrue(changed in commit_lines)
# Check it is in adequate section.
changed_line = commit_lines.index(changed)
self.assertTrue('Changed' in commit_lines[changed_line - 1])
def testAddedDepsInCommitMessage(self):
commit_lines = self._CommitMessageSetup()
def testAddedDepsInCommitMessage(self):
commit_lines = self._CommitMessageSetup()
added = '* src/third_party/android_deps/libs/' \
'android_arch_lifecycle_common'
self.assertTrue(added in commit_lines)
# Check it is in adequate section.
added_line = commit_lines.index(added)
self.assertTrue('Added' in commit_lines[added_line-1])
added = '* src/third_party/android_deps/libs/' \
'android_arch_lifecycle_common'
self.assertTrue(added in commit_lines)
# Check it is in adequate section.
added_line = commit_lines.index(added)
self.assertTrue('Added' in commit_lines[added_line - 1])
def testRemovedDepsInCommitMessage(self):
commit_lines = self._CommitMessageSetup()
def testRemovedDepsInCommitMessage(self):
commit_lines = self._CommitMessageSetup()
removed = '* src/third_party/android_deps/libs/' \
'android_arch_lifecycle_runtime'
self.assertTrue(removed in commit_lines)
# Check it is in adequate section.
removed_line = commit_lines.index(removed)
self.assertTrue('Removed' in commit_lines[removed_line-1])
removed = '* src/third_party/android_deps/libs/' \
'android_arch_lifecycle_runtime'
self.assertTrue(removed in commit_lines)
# Check it is in adequate section.
removed_line = commit_lines.index(removed)
self.assertTrue('Removed' in commit_lines[removed_line - 1])
class TestChooseCQMode(unittest.TestCase):
def testSkip(self):
self.assertEquals(ChooseCQMode(True, 99, 500000, 500100), 0)
def testSkip(self):
self.assertEquals(ChooseCQMode(True, 99, 500000, 500100), 0)
def testDryRun(self):
self.assertEquals(ChooseCQMode(False, 101, 500000, 500100), 1)
def testDryRun(self):
self.assertEquals(ChooseCQMode(False, 101, 500000, 500100), 1)
def testSubmit(self):
self.assertEquals(ChooseCQMode(False, 100, 500000, 500100), 2)
def testSubmit(self):
self.assertEquals(ChooseCQMode(False, 100, 500000, 500100), 2)
def _SetupGitLsRemoteCall(cmd_fake, url, revision):
cmd = ['git', 'ls-remote', url, revision]
cmd_fake.AddExpectation(cmd, _returns=(revision, None))
cmd = ['git', 'ls-remote', url, revision]
cmd_fake.AddExpectation(cmd, _returns=(revision, None))
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Invoke clang-tidy tool.
Usage: clang_tidy.py file.cc [clang-tidy-args...]
@ -25,7 +24,6 @@ import tempfile
from presubmit_checks_lib.build_helpers import GetClangTidyPath, \
GetCompilationCommand
# We enable all checkers by default for investigation purpose.
# This includes clang-analyzer-* checks.
# Individual checkers can be disabled via command line options.
@ -34,63 +32,66 @@ CHECKER_OPTION = '-checks=*'
def Process(filepath, args):
# Build directory is needed to gather compilation flags.
# Create a temporary one (instead of reusing an existing one)
# to keep the CLI simple and unencumbered.
out_dir = tempfile.mkdtemp('clang_tidy')
# Build directory is needed to gather compilation flags.
# Create a temporary one (instead of reusing an existing one)
# to keep the CLI simple and unencumbered.
out_dir = tempfile.mkdtemp('clang_tidy')
try:
gn_args = [] # Use default build.
command = GetCompilationCommand(filepath, gn_args, out_dir)
try:
gn_args = [] # Use default build.
command = GetCompilationCommand(filepath, gn_args, out_dir)
# Remove warning flags. They aren't needed and they cause trouble
# when clang-tidy doesn't match most recent clang.
# Same battle for -f (e.g. -fcomplete-member-pointers).
command = [arg for arg in command if not (arg.startswith('-W') or
arg.startswith('-f'))]
# Remove warning flags. They aren't needed and they cause trouble
# when clang-tidy doesn't match most recent clang.
# Same battle for -f (e.g. -fcomplete-member-pointers).
command = [
arg for arg in command
if not (arg.startswith('-W') or arg.startswith('-f'))
]
# Path from build dir.
rel_path = os.path.relpath(os.path.abspath(filepath), out_dir)
# Path from build dir.
rel_path = os.path.relpath(os.path.abspath(filepath), out_dir)
# Replace clang++ by clang-tidy
command[0:1] = [GetClangTidyPath(),
CHECKER_OPTION,
rel_path] + args + ['--'] # Separator for clang flags.
print "Running: %s" % ' '.join(command)
# Run from build dir so that relative paths are correct.
p = subprocess.Popen(command, cwd=out_dir,
stdout=sys.stdout, stderr=sys.stderr)
p.communicate()
return p.returncode
finally:
shutil.rmtree(out_dir, ignore_errors=True)
# Replace clang++ by clang-tidy
command[0:1] = [GetClangTidyPath(), CHECKER_OPTION, rel_path
] + args + ['--'] # Separator for clang flags.
print "Running: %s" % ' '.join(command)
# Run from build dir so that relative paths are correct.
p = subprocess.Popen(command,
cwd=out_dir,
stdout=sys.stdout,
stderr=sys.stderr)
p.communicate()
return p.returncode
finally:
shutil.rmtree(out_dir, ignore_errors=True)
def ValidateCC(filepath):
"""We can only analyze .cc files. Provide explicit message about that."""
if filepath.endswith('.cc'):
return filepath
msg = ('%s not supported.\n'
'For now, we can only analyze translation units (.cc files).' %
filepath)
raise argparse.ArgumentTypeError(msg)
"""We can only analyze .cc files. Provide explicit message about that."""
if filepath.endswith('.cc'):
return filepath
msg = ('%s not supported.\n'
'For now, we can only analyze translation units (.cc files).' %
filepath)
raise argparse.ArgumentTypeError(msg)
def Main():
description = (
"Run clang-tidy on single cc file.\n"
"Use flags, defines and include paths as in default debug build.\n"
"WARNING, this is a POC version with rough edges.")
parser = argparse.ArgumentParser(description=description)
parser.add_argument('filepath',
help='Specifies the path of the .cc file to analyze.',
type=ValidateCC)
parser.add_argument('args',
nargs=argparse.REMAINDER,
help='Arguments passed to clang-tidy')
parsed_args = parser.parse_args()
return Process(parsed_args.filepath, parsed_args.args)
description = (
"Run clang-tidy on single cc file.\n"
"Use flags, defines and include paths as in default debug build.\n"
"WARNING, this is a POC version with rough edges.")
parser = argparse.ArgumentParser(description=description)
parser.add_argument('filepath',
help='Specifies the path of the .cc file to analyze.',
type=ValidateCC)
parser.add_argument('args',
nargs=argparse.REMAINDER,
help='Arguments passed to clang-tidy')
parsed_args = parser.parse_args()
return Process(parsed_args.filepath, parsed_args.args)
if __name__ == '__main__':
sys.exit(Main())
sys.exit(Main())

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Generates a command-line for coverage.py. Useful for manual coverage runs.
Before running the generated command line, do this:
@ -17,39 +16,32 @@ gn gen out/coverage --args='use_clang_coverage=true is_component_build=false'
import sys
TESTS = [
'video_capture_tests',
'webrtc_nonparallel_tests',
'video_engine_tests',
'tools_unittests',
'test_support_unittests',
'slow_tests',
'system_wrappers_unittests',
'rtc_unittests',
'rtc_stats_unittests',
'rtc_pc_unittests',
'rtc_media_unittests',
'peerconnection_unittests',
'modules_unittests',
'modules_tests',
'low_bandwidth_audio_test',
'common_video_unittests',
'common_audio_unittests',
'audio_decoder_unittests'
'video_capture_tests', 'webrtc_nonparallel_tests', 'video_engine_tests',
'tools_unittests', 'test_support_unittests', 'slow_tests',
'system_wrappers_unittests', 'rtc_unittests', 'rtc_stats_unittests',
'rtc_pc_unittests', 'rtc_media_unittests', 'peerconnection_unittests',
'modules_unittests', 'modules_tests', 'low_bandwidth_audio_test',
'common_video_unittests', 'common_audio_unittests',
'audio_decoder_unittests'
]
def main():
cmd = ([sys.executable, 'tools/code_coverage/coverage.py'] + TESTS +
['-b out/coverage', '-o out/report'] +
['-i=\'.*/out/.*|.*/third_party/.*|.*test.*\''] +
['-c \'out/coverage/%s\'' % t for t in TESTS])
cmd = ([sys.executable, 'tools/code_coverage/coverage.py'] + TESTS +
['-b out/coverage', '-o out/report'] +
['-i=\'.*/out/.*|.*/third_party/.*|.*test.*\''] +
['-c \'out/coverage/%s\'' % t for t in TESTS])
def WithXvfb(binary):
return '-c \'%s testing/xvfb.py %s\'' % (sys.executable, binary)
modules_unittests = 'out/coverage/modules_unittests'
cmd[cmd.index('-c \'%s\'' % modules_unittests)] = WithXvfb(modules_unittests)
def WithXvfb(binary):
return '-c \'%s testing/xvfb.py %s\'' % (sys.executable, binary)
modules_unittests = 'out/coverage/modules_unittests'
cmd[cmd.index('-c \'%s\'' %
modules_unittests)] = WithXvfb(modules_unittests)
print ' '.join(cmd)
return 0
print ' '.join(cmd)
return 0
if __name__ == '__main__':
sys.exit(main())
sys.exit(main())

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Generates command-line instructions to produce one-time iOS coverage using
coverage.py.
@ -53,122 +52,115 @@ import sys
DIRECTORY = 'out/coverage'
TESTS = [
'audio_decoder_unittests',
'common_audio_unittests',
'common_video_unittests',
'modules_tests',
'modules_unittests',
'rtc_media_unittests',
'rtc_pc_unittests',
'rtc_stats_unittests',
'rtc_unittests',
'slow_tests',
'system_wrappers_unittests',
'test_support_unittests',
'tools_unittests',
'video_capture_tests',
'video_engine_tests',
'webrtc_nonparallel_tests',
'audio_decoder_unittests',
'common_audio_unittests',
'common_video_unittests',
'modules_tests',
'modules_unittests',
'rtc_media_unittests',
'rtc_pc_unittests',
'rtc_stats_unittests',
'rtc_unittests',
'slow_tests',
'system_wrappers_unittests',
'test_support_unittests',
'tools_unittests',
'video_capture_tests',
'video_engine_tests',
'webrtc_nonparallel_tests',
]
XC_TESTS = [
'apprtcmobile_tests',
'sdk_framework_unittests',
'sdk_unittests',
'apprtcmobile_tests',
'sdk_framework_unittests',
'sdk_unittests',
]
def FormatIossimTest(test_name, is_xctest=False):
args = ['%s/%s.app' % (DIRECTORY, test_name)]
if is_xctest:
args += ['%s/%s_module.xctest' % (DIRECTORY, test_name)]
args = ['%s/%s.app' % (DIRECTORY, test_name)]
if is_xctest:
args += ['%s/%s_module.xctest' % (DIRECTORY, test_name)]
return '-c \'%s/iossim %s\'' % (DIRECTORY, ' '.join(args))
return '-c \'%s/iossim %s\'' % (DIRECTORY, ' '.join(args))
def GetGNArgs(is_simulator):
target_cpu = 'x64' if is_simulator else 'arm64'
return ([] +
['target_os="ios"'] +
['target_cpu="%s"' % target_cpu] +
['use_clang_coverage=true'] +
['is_component_build=false'] +
['dcheck_always_on=true'])
target_cpu = 'x64' if is_simulator else 'arm64'
return ([] + ['target_os="ios"'] + ['target_cpu="%s"' % target_cpu] +
['use_clang_coverage=true'] + ['is_component_build=false'] +
['dcheck_always_on=true'])
def GenerateIOSSimulatorCommand():
gn_args_string = ' '.join(GetGNArgs(is_simulator=True))
gn_cmd = ['gn', 'gen', DIRECTORY, '--args=\'%s\'' % gn_args_string]
gn_args_string = ' '.join(GetGNArgs(is_simulator=True))
gn_cmd = ['gn', 'gen', DIRECTORY, '--args=\'%s\'' % gn_args_string]
coverage_cmd = (
[sys.executable, 'tools/code_coverage/coverage.py'] +
["%s.app" % t for t in XC_TESTS + TESTS] +
['-b %s' % DIRECTORY, '-o out/report'] +
['-i=\'.*/out/.*|.*/third_party/.*|.*test.*\''] +
[FormatIossimTest(t, is_xctest=True) for t in XC_TESTS] +
[FormatIossimTest(t, is_xctest=False) for t in TESTS]
)
coverage_cmd = ([sys.executable, 'tools/code_coverage/coverage.py'] +
["%s.app" % t for t in XC_TESTS + TESTS] +
['-b %s' % DIRECTORY, '-o out/report'] +
['-i=\'.*/out/.*|.*/third_party/.*|.*test.*\''] +
[FormatIossimTest(t, is_xctest=True) for t in XC_TESTS] +
[FormatIossimTest(t, is_xctest=False) for t in TESTS])
print 'To get code coverage using iOS simulator just run following commands:'
print ''
print ' '.join(gn_cmd)
print ''
print ' '.join(coverage_cmd)
return 0
print 'To get code coverage using iOS simulator just run following commands:'
print ''
print ' '.join(gn_cmd)
print ''
print ' '.join(coverage_cmd)
return 0
def GenerateIOSDeviceCommand():
gn_args_string = ' '.join(GetGNArgs(is_simulator=False))
gn_args_string = ' '.join(GetGNArgs(is_simulator=False))
coverage_report_cmd = (
[sys.executable, 'tools/code_coverage/coverage.py'] +
['%s.app' % t for t in TESTS] +
['-b %s' % DIRECTORY] +
['-o out/report'] +
['-p %s/merged.profdata' % DIRECTORY] +
['-i=\'.*/out/.*|.*/third_party/.*|.*test.*\'']
)
coverage_report_cmd = (
[sys.executable, 'tools/code_coverage/coverage.py'] +
['%s.app' % t for t in TESTS] + ['-b %s' % DIRECTORY] +
['-o out/report'] + ['-p %s/merged.profdata' % DIRECTORY] +
['-i=\'.*/out/.*|.*/third_party/.*|.*test.*\''])
print 'Computing code coverage for real iOS device is a little bit tedious.'
print ''
print 'You will need:'
print ''
print '1. Generate xcode project and open it with Xcode 10+:'
print ' gn gen %s --ide=xcode --args=\'%s\'' % (DIRECTORY, gn_args_string)
print ' open %s/all.xcworkspace' % DIRECTORY
print ''
print '2. Execute these Run targets manually with Xcode Run button and '
print 'manually save generated coverage.profraw file to %s:' % DIRECTORY
print '\n'.join('- %s' % t for t in TESTS)
print ''
print '3. Execute these Test targets manually with Xcode Test button and '
print 'manually save generated coverage.profraw file to %s:' % DIRECTORY
print '\n'.join('- %s' % t for t in XC_TESTS)
print ''
print '4. Merge *.profraw files to *.profdata using llvm-profdata tool:'
print (' build/mac_files/Xcode.app/Contents/Developer/Toolchains/' +
'XcodeDefault.xctoolchain/usr/bin/llvm-profdata merge ' +
'-o %s/merged.profdata ' % DIRECTORY +
'-sparse=true %s/*.profraw' % DIRECTORY)
print ''
print '5. Generate coverage report:'
print ' ' + ' '.join(coverage_report_cmd)
return 0
print 'Computing code coverage for real iOS device is a little bit tedious.'
print ''
print 'You will need:'
print ''
print '1. Generate xcode project and open it with Xcode 10+:'
print ' gn gen %s --ide=xcode --args=\'%s\'' % (DIRECTORY, gn_args_string)
print ' open %s/all.xcworkspace' % DIRECTORY
print ''
print '2. Execute these Run targets manually with Xcode Run button and '
print 'manually save generated coverage.profraw file to %s:' % DIRECTORY
print '\n'.join('- %s' % t for t in TESTS)
print ''
print '3. Execute these Test targets manually with Xcode Test button and '
print 'manually save generated coverage.profraw file to %s:' % DIRECTORY
print '\n'.join('- %s' % t for t in XC_TESTS)
print ''
print '4. Merge *.profraw files to *.profdata using llvm-profdata tool:'
print(' build/mac_files/Xcode.app/Contents/Developer/Toolchains/' +
'XcodeDefault.xctoolchain/usr/bin/llvm-profdata merge ' +
'-o %s/merged.profdata ' % DIRECTORY +
'-sparse=true %s/*.profraw' % DIRECTORY)
print ''
print '5. Generate coverage report:'
print ' ' + ' '.join(coverage_report_cmd)
return 0
def Main():
if len(sys.argv) < 2:
print 'Please specify type of coverage:'
print ' %s simulator' % sys.argv[0]
print ' %s device' % sys.argv[0]
elif sys.argv[1] == 'simulator':
GenerateIOSSimulatorCommand()
elif sys.argv[1] == 'device':
GenerateIOSDeviceCommand()
else:
print 'Unsupported type of coverage'
if len(sys.argv) < 2:
print 'Please specify type of coverage:'
print ' %s simulator' % sys.argv[0]
print ' %s device' % sys.argv[0]
elif sys.argv[1] == 'simulator':
GenerateIOSSimulatorCommand()
elif sys.argv[1] == 'device':
GenerateIOSDeviceCommand()
else:
print 'Unsupported type of coverage'
return 0
return 0
if __name__ == '__main__':
sys.exit(Main())
sys.exit(Main())

View File

@ -8,7 +8,6 @@
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
import psutil
import sys
@ -17,67 +16,68 @@ from matplotlib import pyplot
class CpuSnapshot(object):
def __init__(self, label):
self.label = label
self.samples = []
def __init__(self, label):
self.label = label
self.samples = []
def Capture(self, sample_count):
print ('Capturing %d CPU samples for %s...' %
((sample_count - len(self.samples)), self.label))
while len(self.samples) < sample_count:
self.samples.append(psutil.cpu_percent(1.0, False))
def Capture(self, sample_count):
print('Capturing %d CPU samples for %s...' %
((sample_count - len(self.samples)), self.label))
while len(self.samples) < sample_count:
self.samples.append(psutil.cpu_percent(1.0, False))
def Text(self):
return ('%s: avg=%s, median=%s, min=%s, max=%s' %
(self.label, numpy.average(self.samples),
numpy.median(self.samples),
numpy.min(self.samples), numpy.max(self.samples)))
def Text(self):
return ('%s: avg=%s, median=%s, min=%s, max=%s' %
(self.label, numpy.average(self.samples),
numpy.median(self.samples), numpy.min(
self.samples), numpy.max(self.samples)))
def Max(self):
return numpy.max(self.samples)
def Max(self):
return numpy.max(self.samples)
def GrabCpuSamples(sample_count):
print 'Label for snapshot (enter to quit): '
label = raw_input().strip()
if len(label) == 0:
return None
print 'Label for snapshot (enter to quit): '
label = raw_input().strip()
if len(label) == 0:
return None
snapshot = CpuSnapshot(label)
snapshot.Capture(sample_count)
snapshot = CpuSnapshot(label)
snapshot.Capture(sample_count)
return snapshot
return snapshot
def main():
print 'How many seconds to capture per snapshot (enter for 60)?'
sample_count = raw_input().strip()
if len(sample_count) > 0 and int(sample_count) > 0:
sample_count = int(sample_count)
else:
print 'Defaulting to 60 samples.'
sample_count = 60
print 'How many seconds to capture per snapshot (enter for 60)?'
sample_count = raw_input().strip()
if len(sample_count) > 0 and int(sample_count) > 0:
sample_count = int(sample_count)
else:
print 'Defaulting to 60 samples.'
sample_count = 60
snapshots = []
while True:
snapshot = GrabCpuSamples(sample_count)
if snapshot is None:
break
snapshots.append(snapshot)
snapshots = []
while True:
snapshot = GrabCpuSamples(sample_count)
if snapshot is None:
break
snapshots.append(snapshot)
if len(snapshots) == 0:
print 'no samples captured'
return -1
if len(snapshots) == 0:
print 'no samples captured'
return -1
pyplot.title('CPU usage')
pyplot.title('CPU usage')
for s in snapshots:
pyplot.plot(s.samples, label=s.Text(), linewidth=2)
for s in snapshots:
pyplot.plot(s.samples, label=s.Text(), linewidth=2)
pyplot.legend()
pyplot.legend()
pyplot.show()
return 0
pyplot.show()
return 0
if __name__ == '__main__':
sys.exit(main())
sys.exit(main())

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Downloads precompiled tools.
These are checked into the repository as SHA-1 hashes (see *.sha1 files in
@ -17,12 +16,10 @@ so please download and compile these tools manually if this script fails.
import os
import sys
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SRC_DIR = os.path.abspath(os.path.join(SCRIPT_DIR, os.pardir))
sys.path.append(os.path.join(SRC_DIR, 'build'))
import find_depot_tools
find_depot_tools.add_depot_tools_to_path()
import gclient_utils
@ -30,32 +27,34 @@ import subprocess2
def main(directories):
if not directories:
directories = [SCRIPT_DIR]
if not directories:
directories = [SCRIPT_DIR]
for path in directories:
cmd = [
sys.executable,
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH,
'download_from_google_storage.py'),
'--directory',
'--num_threads=10',
'--bucket', 'chrome-webrtc-resources',
'--auto_platform',
'--recursive',
path,
]
print 'Downloading precompiled tools...'
for path in directories:
cmd = [
sys.executable,
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH,
'download_from_google_storage.py'),
'--directory',
'--num_threads=10',
'--bucket',
'chrome-webrtc-resources',
'--auto_platform',
'--recursive',
path,
]
print 'Downloading precompiled tools...'
# Perform download similar to how gclient hooks execute.
try:
gclient_utils.CheckCallAndFilter(
cmd, cwd=SRC_DIR, always_show_header=True)
except (gclient_utils.Error, subprocess2.CalledProcessError) as e:
print 'Error: %s' % str(e)
return 2
return 0
# Perform download similar to how gclient hooks execute.
try:
gclient_utils.CheckCallAndFilter(cmd,
cwd=SRC_DIR,
always_show_header=True)
except (gclient_utils.Error, subprocess2.CalledProcessError) as e:
print 'Error: %s' % str(e)
return 2
return 0
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))
sys.exit(main(sys.argv[1:]))

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Checks if a virtual webcam is running and starts it if not.
Returns a non-zero return code if the webcam could not be started.
@ -32,74 +31,73 @@ import psutil # pylint: disable=F0401
import subprocess
import sys
WEBCAM_WIN = ('schtasks', '/run', '/tn', 'ManyCam')
WEBCAM_MAC = ('open', '/Applications/ManyCam/ManyCam.app')
def IsWebCamRunning():
if sys.platform == 'win32':
process_name = 'ManyCam.exe'
elif sys.platform.startswith('darwin'):
process_name = 'ManyCam'
elif sys.platform.startswith('linux'):
# TODO(bugs.webrtc.org/9636): Currently a no-op on Linux: sw webcams no
# longer in use.
print 'Virtual webcam: no-op on Linux'
return True
else:
raise Exception('Unsupported platform: %s' % sys.platform)
for p in psutil.process_iter():
try:
if process_name == p.name:
print 'Found a running virtual webcam (%s with PID %s)' % (p.name,
p.pid)
if sys.platform == 'win32':
process_name = 'ManyCam.exe'
elif sys.platform.startswith('darwin'):
process_name = 'ManyCam'
elif sys.platform.startswith('linux'):
# TODO(bugs.webrtc.org/9636): Currently a no-op on Linux: sw webcams no
# longer in use.
print 'Virtual webcam: no-op on Linux'
return True
except psutil.AccessDenied:
pass # This is normal if we query sys processes, etc.
return False
else:
raise Exception('Unsupported platform: %s' % sys.platform)
for p in psutil.process_iter():
try:
if process_name == p.name:
print 'Found a running virtual webcam (%s with PID %s)' % (
p.name, p.pid)
return True
except psutil.AccessDenied:
pass # This is normal if we query sys processes, etc.
return False
def StartWebCam():
try:
if sys.platform == 'win32':
subprocess.check_call(WEBCAM_WIN)
print 'Successfully launched virtual webcam.'
elif sys.platform.startswith('darwin'):
subprocess.check_call(WEBCAM_MAC)
print 'Successfully launched virtual webcam.'
elif sys.platform.startswith('linux'):
# TODO(bugs.webrtc.org/9636): Currently a no-op on Linux: sw webcams no
# longer in use.
print 'Not implemented on Linux'
try:
if sys.platform == 'win32':
subprocess.check_call(WEBCAM_WIN)
print 'Successfully launched virtual webcam.'
elif sys.platform.startswith('darwin'):
subprocess.check_call(WEBCAM_MAC)
print 'Successfully launched virtual webcam.'
elif sys.platform.startswith('linux'):
# TODO(bugs.webrtc.org/9636): Currently a no-op on Linux: sw webcams no
# longer in use.
print 'Not implemented on Linux'
except Exception as e:
print 'Failed to launch virtual webcam: %s' % e
return False
except Exception as e:
print 'Failed to launch virtual webcam: %s' % e
return False
return True
return True
def _ForcePythonInterpreter(cmd):
"""Returns the fixed command line to call the right python executable."""
out = cmd[:]
if out[0] == 'python':
out[0] = sys.executable
elif out[0].endswith('.py'):
out.insert(0, sys.executable)
return out
"""Returns the fixed command line to call the right python executable."""
out = cmd[:]
if out[0] == 'python':
out[0] = sys.executable
elif out[0].endswith('.py'):
out.insert(0, sys.executable)
return out
def Main(argv):
if not IsWebCamRunning():
if not StartWebCam():
return 1
if not IsWebCamRunning():
if not StartWebCam():
return 1
if argv:
return subprocess.call(_ForcePythonInterpreter(argv))
else:
return 0
if argv:
return subprocess.call(_ForcePythonInterpreter(argv))
else:
return 0
if __name__ == '__main__':
sys.exit(Main(sys.argv[1:]))
sys.exit(Main(sys.argv[1:]))

View File

@ -55,7 +55,6 @@ import subprocess
import sys
import tempfile
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
SRC_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, os.pardir))
sys.path.append(os.path.join(SRC_DIR, 'build'))
@ -63,39 +62,40 @@ import find_depot_tools
def _ParseArgs():
desc = 'Generates a GN executable targeting the host machine.'
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--executable_name',
required=True,
help='Name of the executable to build')
args = parser.parse_args()
return args
desc = 'Generates a GN executable targeting the host machine.'
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--executable_name',
required=True,
help='Name of the executable to build')
args = parser.parse_args()
return args
@contextmanager
def HostBuildDir():
temp_dir = tempfile.mkdtemp()
try:
yield temp_dir
finally:
shutil.rmtree(temp_dir)
temp_dir = tempfile.mkdtemp()
try:
yield temp_dir
finally:
shutil.rmtree(temp_dir)
def _RunCommand(argv, cwd=SRC_DIR, **kwargs):
with open(os.devnull, 'w') as devnull:
subprocess.check_call(argv, cwd=cwd, stdout=devnull, **kwargs)
with open(os.devnull, 'w') as devnull:
subprocess.check_call(argv, cwd=cwd, stdout=devnull, **kwargs)
def DepotToolPath(*args):
return os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, *args)
return os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, *args)
if __name__ == '__main__':
ARGS = _ParseArgs()
EXECUTABLE_TO_BUILD = ARGS.executable_name
EXECUTABLE_FINAL_NAME = ARGS.executable_name + '_host'
with HostBuildDir() as build_dir:
_RunCommand([sys.executable, DepotToolPath('gn.py'), 'gen', build_dir])
_RunCommand([DepotToolPath('ninja'), '-C', build_dir, EXECUTABLE_TO_BUILD])
shutil.copy(os.path.join(build_dir, EXECUTABLE_TO_BUILD),
EXECUTABLE_FINAL_NAME)
ARGS = _ParseArgs()
EXECUTABLE_TO_BUILD = ARGS.executable_name
EXECUTABLE_FINAL_NAME = ARGS.executable_name + '_host'
with HostBuildDir() as build_dir:
_RunCommand([sys.executable, DepotToolPath('gn.py'), 'gen', build_dir])
_RunCommand(
[DepotToolPath('ninja'), '-C', build_dir, EXECUTABLE_TO_BUILD])
shutil.copy(os.path.join(build_dir, EXECUTABLE_TO_BUILD),
EXECUTABLE_FINAL_NAME)

View File

@ -15,30 +15,32 @@ import sys
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--isolated-script-test-perf-output')
args, unrecognized_args = parser.parse_known_args()
parser = argparse.ArgumentParser()
parser.add_argument('--isolated-script-test-perf-output')
args, unrecognized_args = parser.parse_known_args()
test_command = _ForcePythonInterpreter(unrecognized_args)
if args.isolated_script_test_perf_output:
test_command += ['--isolated_script_test_perf_output=' +
args.isolated_script_test_perf_output]
logging.info('Running %r', test_command)
test_command = _ForcePythonInterpreter(unrecognized_args)
if args.isolated_script_test_perf_output:
test_command += [
'--isolated_script_test_perf_output=' +
args.isolated_script_test_perf_output
]
logging.info('Running %r', test_command)
return subprocess.call(test_command)
return subprocess.call(test_command)
def _ForcePythonInterpreter(cmd):
"""Returns the fixed command line to call the right python executable."""
out = cmd[:]
if out[0] == 'python':
out[0] = sys.executable
elif out[0].endswith('.py'):
out.insert(0, sys.executable)
return out
"""Returns the fixed command line to call the right python executable."""
out = cmd[:]
if out[0] == 'python':
out[0] = sys.executable
elif out[0].endswith('.py'):
out.insert(0, sys.executable)
return out
if __name__ == '__main__':
# pylint: disable=W0101
logging.basicConfig(level=logging.INFO)
sys.exit(main())
# pylint: disable=W0101
logging.basicConfig(level=logging.INFO)
sys.exit(main())

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""
This file emits the list of reasons why a particular build needs to be clobbered
(or a list of 'landmines').
@ -20,49 +19,48 @@ CHECKOUT_ROOT = os.path.abspath(os.path.join(SCRIPT_DIR, os.pardir))
sys.path.insert(0, os.path.join(CHECKOUT_ROOT, 'build'))
import landmine_utils
host_os = landmine_utils.host_os # pylint: disable=invalid-name
def print_landmines(): # pylint: disable=invalid-name
"""
"""
ALL LANDMINES ARE EMITTED FROM HERE.
"""
# DO NOT add landmines as part of a regular CL. Landmines are a last-effort
# bandaid fix if a CL that got landed has a build dependency bug and all bots
# need to be cleaned up. If you're writing a new CL that causes build
# dependency problems, fix the dependency problems instead of adding a
# landmine.
# See the Chromium version in src/build/get_landmines.py for usage examples.
print 'Clobber to remove out/{Debug,Release}/args.gn (webrtc:5070)'
if host_os() == 'win':
print 'Clobber to resolve some issues with corrupt .pdb files on bots.'
print 'Clobber due to corrupt .pdb files (after #14623)'
print 'Clobber due to Win 64-bit Debug linking error (crbug.com/668961)'
print ('Clobber due to Win Clang Debug linking errors in '
'https://codereview.webrtc.org/2786603002')
print ('Clobber due to Win Debug linking errors in '
'https://codereview.webrtc.org/2832063003/')
print 'Clobber win x86 bots (issues with isolated files).'
if host_os() == 'mac':
print 'Clobber due to iOS compile errors (crbug.com/694721)'
print 'Clobber to unblock https://codereview.webrtc.org/2709573003'
print ('Clobber to fix https://codereview.webrtc.org/2709573003 after '
'landing')
print ('Clobber to fix https://codereview.webrtc.org/2767383005 before'
'landing (changing rtc_executable -> rtc_test on iOS)')
print ('Clobber to fix https://codereview.webrtc.org/2767383005 before'
'landing (changing rtc_executable -> rtc_test on iOS)')
print 'Another landmine for low_bandwidth_audio_test (webrtc:7430)'
print 'Clobber to change neteq_rtpplay type to executable'
print 'Clobber to remove .xctest files.'
print 'Clobber to remove .xctest files (take 2).'
# DO NOT add landmines as part of a regular CL. Landmines are a last-effort
# bandaid fix if a CL that got landed has a build dependency bug and all bots
# need to be cleaned up. If you're writing a new CL that causes build
# dependency problems, fix the dependency problems instead of adding a
# landmine.
# See the Chromium version in src/build/get_landmines.py for usage examples.
print 'Clobber to remove out/{Debug,Release}/args.gn (webrtc:5070)'
if host_os() == 'win':
print 'Clobber to resolve some issues with corrupt .pdb files on bots.'
print 'Clobber due to corrupt .pdb files (after #14623)'
print 'Clobber due to Win 64-bit Debug linking error (crbug.com/668961)'
print('Clobber due to Win Clang Debug linking errors in '
'https://codereview.webrtc.org/2786603002')
print('Clobber due to Win Debug linking errors in '
'https://codereview.webrtc.org/2832063003/')
print 'Clobber win x86 bots (issues with isolated files).'
if host_os() == 'mac':
print 'Clobber due to iOS compile errors (crbug.com/694721)'
print 'Clobber to unblock https://codereview.webrtc.org/2709573003'
print('Clobber to fix https://codereview.webrtc.org/2709573003 after '
'landing')
print('Clobber to fix https://codereview.webrtc.org/2767383005 before'
'landing (changing rtc_executable -> rtc_test on iOS)')
print('Clobber to fix https://codereview.webrtc.org/2767383005 before'
'landing (changing rtc_executable -> rtc_test on iOS)')
print 'Another landmine for low_bandwidth_audio_test (webrtc:7430)'
print 'Clobber to change neteq_rtpplay type to executable'
print 'Clobber to remove .xctest files.'
print 'Clobber to remove .xctest files (take 2).'
def main():
print_landmines()
return 0
print_landmines()
return 0
if __name__ == '__main__':
sys.exit(main())
sys.exit(main())

View File

@ -7,7 +7,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""
This tool tries to fix (some) errors reported by `gn gen --check` or
`gn check`.
@ -31,72 +30,78 @@ from collections import defaultdict
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
CHROMIUM_DIRS = ['base', 'build', 'buildtools',
'testing', 'third_party', 'tools']
CHROMIUM_DIRS = [
'base', 'build', 'buildtools', 'testing', 'third_party', 'tools'
]
TARGET_RE = re.compile(
r'(?P<indentation_level>\s*)\w*\("(?P<target_name>\w*)"\) {$')
class TemporaryDirectory(object):
def __init__(self):
self._closed = False
self._name = None
self._name = tempfile.mkdtemp()
def __init__(self):
self._closed = False
self._name = None
self._name = tempfile.mkdtemp()
def __enter__(self):
return self._name
def __enter__(self):
return self._name
def __exit__(self, exc, value, _tb):
if self._name and not self._closed:
shutil.rmtree(self._name)
self._closed = True
def __exit__(self, exc, value, _tb):
if self._name and not self._closed:
shutil.rmtree(self._name)
self._closed = True
def Run(cmd):
print 'Running:', ' '.join(cmd)
sub = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return sub.communicate()
print 'Running:', ' '.join(cmd)
sub = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return sub.communicate()
def FixErrors(filename, missing_deps, deleted_sources):
with open(filename) as f:
lines = f.readlines()
with open(filename) as f:
lines = f.readlines()
fixed_file = ''
indentation_level = None
for line in lines:
match = TARGET_RE.match(line)
if match:
target = match.group('target_name')
if target in missing_deps:
indentation_level = match.group('indentation_level')
elif indentation_level is not None:
match = re.match(indentation_level + '}$', line)
if match:
line = ('deps = [\n' +
''.join(' "' + dep + '",\n' for dep in missing_deps[target]) +
']\n') + line
indentation_level = None
elif line.strip().startswith('deps'):
is_empty_deps = line.strip() == 'deps = []'
line = 'deps = [\n' if is_empty_deps else line
line += ''.join(' "' + dep + '",\n' for dep in missing_deps[target])
line += ']\n' if is_empty_deps else ''
indentation_level = None
fixed_file = ''
indentation_level = None
for line in lines:
match = TARGET_RE.match(line)
if match:
target = match.group('target_name')
if target in missing_deps:
indentation_level = match.group('indentation_level')
elif indentation_level is not None:
match = re.match(indentation_level + '}$', line)
if match:
line = ('deps = [\n' + ''.join(' "' + dep + '",\n'
for dep in missing_deps[target])
+ ']\n') + line
indentation_level = None
elif line.strip().startswith('deps'):
is_empty_deps = line.strip() == 'deps = []'
line = 'deps = [\n' if is_empty_deps else line
line += ''.join(' "' + dep + '",\n'
for dep in missing_deps[target])
line += ']\n' if is_empty_deps else ''
indentation_level = None
if line.strip() not in deleted_sources:
fixed_file += line
if line.strip() not in deleted_sources:
fixed_file += line
with open(filename, 'w') as f:
f.write(fixed_file)
with open(filename, 'w') as f:
f.write(fixed_file)
Run(['gn', 'format', filename])
Run(['gn', 'format', filename])
def FirstNonEmpty(iterable):
"""Return first item which evaluates to True, or fallback to None."""
return next((x for x in iterable if x), None)
"""Return first item which evaluates to True, or fallback to None."""
return next((x for x in iterable if x), None)
def Rebase(base_path, dependency_path, dependency):
"""Adapt paths so they work both in stand-alone WebRTC and Chromium tree.
"""Adapt paths so they work both in stand-alone WebRTC and Chromium tree.
To cope with varying top-level directory (WebRTC VS Chromium), we use:
* relative paths for WebRTC modules.
@ -113,77 +118,82 @@ def Rebase(base_path, dependency_path, dependency):
Full target path (E.g. '../rtc_base/time:timestamp_extrapolator').
"""
root = FirstNonEmpty(dependency_path.split('/'))
if root in CHROMIUM_DIRS:
# Chromium paths must remain absolute. E.g. //third_party//abseil-cpp...
rebased = dependency_path
else:
base_path = base_path.split(os.path.sep)
dependency_path = dependency_path.split(os.path.sep)
root = FirstNonEmpty(dependency_path.split('/'))
if root in CHROMIUM_DIRS:
# Chromium paths must remain absolute. E.g. //third_party//abseil-cpp...
rebased = dependency_path
else:
base_path = base_path.split(os.path.sep)
dependency_path = dependency_path.split(os.path.sep)
first_difference = None
shortest_length = min(len(dependency_path), len(base_path))
for i in range(shortest_length):
if dependency_path[i] != base_path[i]:
first_difference = i
break
first_difference = None
shortest_length = min(len(dependency_path), len(base_path))
for i in range(shortest_length):
if dependency_path[i] != base_path[i]:
first_difference = i
break
first_difference = first_difference or shortest_length
base_path = base_path[first_difference:]
dependency_path = dependency_path[first_difference:]
rebased = os.path.sep.join((['..'] * len(base_path)) + dependency_path)
return rebased + ':' + dependency
first_difference = first_difference or shortest_length
base_path = base_path[first_difference:]
dependency_path = dependency_path[first_difference:]
rebased = os.path.sep.join((['..'] * len(base_path)) + dependency_path)
return rebased + ':' + dependency
def main():
deleted_sources = set()
errors_by_file = defaultdict(lambda: defaultdict(set))
deleted_sources = set()
errors_by_file = defaultdict(lambda: defaultdict(set))
with TemporaryDirectory() as tmp_dir:
mb_script_path = os.path.join(SCRIPT_DIR, 'mb', 'mb.py')
mb_config_file_path = os.path.join(SCRIPT_DIR, 'mb', 'mb_config.pyl')
mb_gen_command = ([
mb_script_path, 'gen',
tmp_dir,
'--config-file', mb_config_file_path,
] + sys.argv[1:])
with TemporaryDirectory() as tmp_dir:
mb_script_path = os.path.join(SCRIPT_DIR, 'mb', 'mb.py')
mb_config_file_path = os.path.join(SCRIPT_DIR, 'mb', 'mb_config.pyl')
mb_gen_command = ([
mb_script_path,
'gen',
tmp_dir,
'--config-file',
mb_config_file_path,
] + sys.argv[1:])
mb_output = Run(mb_gen_command)
errors = mb_output[0].split('ERROR')[1:]
mb_output = Run(mb_gen_command)
errors = mb_output[0].split('ERROR')[1:]
if mb_output[1]:
print mb_output[1]
return 1
if mb_output[1]:
print mb_output[1]
return 1
for error in errors:
error = error.splitlines()
target_msg = 'The target:'
if target_msg not in error:
target_msg = 'It is not in any dependency of'
if target_msg not in error:
print '\n'.join(error)
continue
index = error.index(target_msg) + 1
path, target = error[index].strip().split(':')
if error[index+1] in ('is including a file from the target:',
'The include file is in the target(s):'):
dep = error[index+2].strip()
dep_path, dep = dep.split(':')
dep = Rebase(path, dep_path, dep)
# Replacing /target:target with /target
dep = re.sub(r'/(\w+):(\1)$', r'/\1', dep)
path = os.path.join(path[2:], 'BUILD.gn')
errors_by_file[path][target].add(dep)
elif error[index+1] == 'has a source file:':
deleted_file = '"' + os.path.basename(error[index+2].strip()) + '",'
deleted_sources.add(deleted_file)
else:
print '\n'.join(error)
continue
for error in errors:
error = error.splitlines()
target_msg = 'The target:'
if target_msg not in error:
target_msg = 'It is not in any dependency of'
if target_msg not in error:
print '\n'.join(error)
continue
index = error.index(target_msg) + 1
path, target = error[index].strip().split(':')
if error[index + 1] in ('is including a file from the target:',
'The include file is in the target(s):'):
dep = error[index + 2].strip()
dep_path, dep = dep.split(':')
dep = Rebase(path, dep_path, dep)
# Replacing /target:target with /target
dep = re.sub(r'/(\w+):(\1)$', r'/\1', dep)
path = os.path.join(path[2:], 'BUILD.gn')
errors_by_file[path][target].add(dep)
elif error[index + 1] == 'has a source file:':
deleted_file = '"' + os.path.basename(
error[index + 2].strip()) + '",'
deleted_sources.add(deleted_file)
else:
print '\n'.join(error)
continue
for path, missing_deps in errors_by_file.items():
FixErrors(path, missing_deps, deleted_sources)
for path, missing_deps in errors_by_file.items():
FixErrors(path, missing_deps, deleted_sources)
return 0
return 0
if __name__ == '__main__':
sys.exit(main())
sys.exit(main())

View File

@ -75,165 +75,174 @@ import shutil
import subprocess
import sys
Args = collections.namedtuple('Args',
['gtest_parallel_args', 'test_env', 'output_dir',
'test_artifacts_dir'])
Args = collections.namedtuple(
'Args',
['gtest_parallel_args', 'test_env', 'output_dir', 'test_artifacts_dir'])
def _CatFiles(file_list, output_file):
with open(output_file, 'w') as output_file:
for filename in file_list:
with open(filename) as input_file:
output_file.write(input_file.read())
os.remove(filename)
with open(output_file, 'w') as output_file:
for filename in file_list:
with open(filename) as input_file:
output_file.write(input_file.read())
os.remove(filename)
def _ParseWorkersOption(workers):
"""Interpret Nx syntax as N * cpu_count. Int value is left as is."""
base = float(workers.rstrip('x'))
if workers.endswith('x'):
result = int(base * multiprocessing.cpu_count())
else:
result = int(base)
return max(result, 1) # Sanitize when using e.g. '0.5x'.
"""Interpret Nx syntax as N * cpu_count. Int value is left as is."""
base = float(workers.rstrip('x'))
if workers.endswith('x'):
result = int(base * multiprocessing.cpu_count())
else:
result = int(base)
return max(result, 1) # Sanitize when using e.g. '0.5x'.
class ReconstructibleArgumentGroup(object):
"""An argument group that can be converted back into a command line.
"""An argument group that can be converted back into a command line.
This acts like ArgumentParser.add_argument_group, but names of arguments added
to it are also kept in a list, so that parsed options from
ArgumentParser.parse_args can be reconstructed back into a command line (list
of args) based on the list of wanted keys."""
def __init__(self, parser, *args, **kwargs):
self._group = parser.add_argument_group(*args, **kwargs)
self._keys = []
def AddArgument(self, *args, **kwargs):
arg = self._group.add_argument(*args, **kwargs)
self._keys.append(arg.dest)
def __init__(self, parser, *args, **kwargs):
self._group = parser.add_argument_group(*args, **kwargs)
self._keys = []
def RemakeCommandLine(self, options):
result = []
for key in self._keys:
value = getattr(options, key)
if value is True:
result.append('--%s' % key)
elif value is not None:
result.append('--%s=%s' % (key, value))
return result
def AddArgument(self, *args, **kwargs):
arg = self._group.add_argument(*args, **kwargs)
self._keys.append(arg.dest)
def RemakeCommandLine(self, options):
result = []
for key in self._keys:
value = getattr(options, key)
if value is True:
result.append('--%s' % key)
elif value is not None:
result.append('--%s=%s' % (key, value))
return result
def ParseArgs(argv=None):
parser = argparse.ArgumentParser(argv)
parser = argparse.ArgumentParser(argv)
gtest_group = ReconstructibleArgumentGroup(parser,
'Arguments to gtest-parallel')
# These options will be passed unchanged to gtest-parallel.
gtest_group.AddArgument('-d', '--output_dir')
gtest_group.AddArgument('-r', '--repeat')
gtest_group.AddArgument('--retry_failed')
gtest_group.AddArgument('--gtest_color')
gtest_group.AddArgument('--gtest_filter')
gtest_group.AddArgument('--gtest_also_run_disabled_tests',
action='store_true', default=None)
gtest_group.AddArgument('--timeout')
gtest_group = ReconstructibleArgumentGroup(parser,
'Arguments to gtest-parallel')
# These options will be passed unchanged to gtest-parallel.
gtest_group.AddArgument('-d', '--output_dir')
gtest_group.AddArgument('-r', '--repeat')
gtest_group.AddArgument('--retry_failed')
gtest_group.AddArgument('--gtest_color')
gtest_group.AddArgument('--gtest_filter')
gtest_group.AddArgument('--gtest_also_run_disabled_tests',
action='store_true',
default=None)
gtest_group.AddArgument('--timeout')
# Syntax 'Nx' will be interpreted as N * number of cpu cores.
gtest_group.AddArgument('-w', '--workers', type=_ParseWorkersOption)
# Syntax 'Nx' will be interpreted as N * number of cpu cores.
gtest_group.AddArgument('-w', '--workers', type=_ParseWorkersOption)
# Needed when the test wants to store test artifacts, because it doesn't know
# what will be the swarming output dir.
parser.add_argument('--store-test-artifacts', action='store_true')
# Needed when the test wants to store test artifacts, because it doesn't know
# what will be the swarming output dir.
parser.add_argument('--store-test-artifacts', action='store_true')
# No-sandbox is a Chromium-specific flag, ignore it.
# TODO(oprypin): Remove (bugs.webrtc.org/8115)
parser.add_argument('--no-sandbox', action='store_true',
help=argparse.SUPPRESS)
# No-sandbox is a Chromium-specific flag, ignore it.
# TODO(oprypin): Remove (bugs.webrtc.org/8115)
parser.add_argument('--no-sandbox',
action='store_true',
help=argparse.SUPPRESS)
parser.add_argument('executable')
parser.add_argument('executable_args', nargs='*')
parser.add_argument('executable')
parser.add_argument('executable_args', nargs='*')
options, unrecognized_args = parser.parse_known_args(argv)
options, unrecognized_args = parser.parse_known_args(argv)
args_to_pass = []
for arg in unrecognized_args:
if arg.startswith('--isolated-script-test-perf-output'):
arg_split = arg.split('=')
assert len(arg_split) == 2, 'You must use the = syntax for this flag.'
args_to_pass.append('--isolated_script_test_perf_output=' + arg_split[1])
args_to_pass = []
for arg in unrecognized_args:
if arg.startswith('--isolated-script-test-perf-output'):
arg_split = arg.split('=')
assert len(
arg_split) == 2, 'You must use the = syntax for this flag.'
args_to_pass.append('--isolated_script_test_perf_output=' +
arg_split[1])
else:
args_to_pass.append(arg)
executable_args = options.executable_args + args_to_pass
if options.store_test_artifacts:
assert options.output_dir, (
'--output_dir must be specified for storing test artifacts.')
test_artifacts_dir = os.path.join(options.output_dir, 'test_artifacts')
executable_args.insert(0,
'--test_artifacts_dir=%s' % test_artifacts_dir)
else:
args_to_pass.append(arg)
test_artifacts_dir = None
executable_args = options.executable_args + args_to_pass
gtest_parallel_args = gtest_group.RemakeCommandLine(options)
if options.store_test_artifacts:
assert options.output_dir, (
'--output_dir must be specified for storing test artifacts.')
test_artifacts_dir = os.path.join(options.output_dir, 'test_artifacts')
# GTEST_SHARD_INDEX and GTEST_TOTAL_SHARDS must be removed from the
# environment. Otherwise it will be picked up by the binary, causing a bug
# where only tests in the first shard are executed.
test_env = os.environ.copy()
gtest_shard_index = test_env.pop('GTEST_SHARD_INDEX', '0')
gtest_total_shards = test_env.pop('GTEST_TOTAL_SHARDS', '1')
executable_args.insert(0, '--test_artifacts_dir=%s' % test_artifacts_dir)
else:
test_artifacts_dir = None
gtest_parallel_args.insert(0, '--shard_index=%s' % gtest_shard_index)
gtest_parallel_args.insert(1, '--shard_count=%s' % gtest_total_shards)
gtest_parallel_args = gtest_group.RemakeCommandLine(options)
gtest_parallel_args.append(options.executable)
if executable_args:
gtest_parallel_args += ['--'] + executable_args
# GTEST_SHARD_INDEX and GTEST_TOTAL_SHARDS must be removed from the
# environment. Otherwise it will be picked up by the binary, causing a bug
# where only tests in the first shard are executed.
test_env = os.environ.copy()
gtest_shard_index = test_env.pop('GTEST_SHARD_INDEX', '0')
gtest_total_shards = test_env.pop('GTEST_TOTAL_SHARDS', '1')
gtest_parallel_args.insert(0, '--shard_index=%s' % gtest_shard_index)
gtest_parallel_args.insert(1, '--shard_count=%s' % gtest_total_shards)
gtest_parallel_args.append(options.executable)
if executable_args:
gtest_parallel_args += ['--'] + executable_args
return Args(gtest_parallel_args, test_env, options.output_dir,
test_artifacts_dir)
return Args(gtest_parallel_args, test_env, options.output_dir,
test_artifacts_dir)
def main():
webrtc_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
gtest_parallel_path = os.path.join(
webrtc_root, 'third_party', 'gtest-parallel', 'gtest-parallel')
webrtc_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
gtest_parallel_path = os.path.join(webrtc_root, 'third_party',
'gtest-parallel', 'gtest-parallel')
gtest_parallel_args, test_env, output_dir, test_artifacts_dir = ParseArgs()
gtest_parallel_args, test_env, output_dir, test_artifacts_dir = ParseArgs()
command = [
sys.executable,
gtest_parallel_path,
] + gtest_parallel_args
command = [
sys.executable,
gtest_parallel_path,
] + gtest_parallel_args
if output_dir and not os.path.isdir(output_dir):
os.makedirs(output_dir)
if test_artifacts_dir and not os.path.isdir(test_artifacts_dir):
os.makedirs(test_artifacts_dir)
if output_dir and not os.path.isdir(output_dir):
os.makedirs(output_dir)
if test_artifacts_dir and not os.path.isdir(test_artifacts_dir):
os.makedirs(test_artifacts_dir)
print 'gtest-parallel-wrapper: Executing command %s' % ' '.join(command)
sys.stdout.flush()
print 'gtest-parallel-wrapper: Executing command %s' % ' '.join(command)
sys.stdout.flush()
exit_code = subprocess.call(command, env=test_env, cwd=os.getcwd())
exit_code = subprocess.call(command, env=test_env, cwd=os.getcwd())
if output_dir:
for test_status in 'passed', 'failed', 'interrupted':
logs_dir = os.path.join(output_dir, 'gtest-parallel-logs', test_status)
if not os.path.isdir(logs_dir):
continue
logs = [os.path.join(logs_dir, log) for log in os.listdir(logs_dir)]
log_file = os.path.join(output_dir, '%s-tests.log' % test_status)
_CatFiles(logs, log_file)
os.rmdir(logs_dir)
if output_dir:
for test_status in 'passed', 'failed', 'interrupted':
logs_dir = os.path.join(output_dir, 'gtest-parallel-logs',
test_status)
if not os.path.isdir(logs_dir):
continue
logs = [
os.path.join(logs_dir, log) for log in os.listdir(logs_dir)
]
log_file = os.path.join(output_dir, '%s-tests.log' % test_status)
_CatFiles(logs, log_file)
os.rmdir(logs_dir)
if test_artifacts_dir:
shutil.make_archive(test_artifacts_dir, 'zip', test_artifacts_dir)
shutil.rmtree(test_artifacts_dir)
if test_artifacts_dir:
shutil.make_archive(test_artifacts_dir, 'zip', test_artifacts_dir)
shutil.rmtree(test_artifacts_dir)
return exit_code
return exit_code
if __name__ == '__main__':
sys.exit(main())
sys.exit(main())

View File

@ -21,149 +21,152 @@ gtest_parallel_wrapper = __import__('gtest-parallel-wrapper')
@contextmanager
def TemporaryDirectory():
tmp_dir = tempfile.mkdtemp()
yield tmp_dir
os.rmdir(tmp_dir)
tmp_dir = tempfile.mkdtemp()
yield tmp_dir
os.rmdir(tmp_dir)
class GtestParallelWrapperHelpersTest(unittest.TestCase):
def testGetWorkersAsIs(self):
# pylint: disable=protected-access
self.assertEqual(gtest_parallel_wrapper._ParseWorkersOption('12'), 12)
def testGetWorkersAsIs(self):
# pylint: disable=protected-access
self.assertEqual(gtest_parallel_wrapper._ParseWorkersOption('12'), 12)
def testGetTwiceWorkers(self):
expected = 2 * multiprocessing.cpu_count()
# pylint: disable=protected-access
self.assertEqual(gtest_parallel_wrapper._ParseWorkersOption('2x'),
expected)
def testGetTwiceWorkers(self):
expected = 2 * multiprocessing.cpu_count()
# pylint: disable=protected-access
self.assertEqual(gtest_parallel_wrapper._ParseWorkersOption('2x'), expected)
def testGetHalfWorkers(self):
expected = max(multiprocessing.cpu_count() // 2, 1)
# pylint: disable=protected-access
self.assertEqual(
gtest_parallel_wrapper._ParseWorkersOption('0.5x'), expected)
def testGetHalfWorkers(self):
expected = max(multiprocessing.cpu_count() // 2, 1)
# pylint: disable=protected-access
self.assertEqual(gtest_parallel_wrapper._ParseWorkersOption('0.5x'),
expected)
class GtestParallelWrapperTest(unittest.TestCase):
@classmethod
def _Expected(cls, gtest_parallel_args):
return ['--shard_index=0', '--shard_count=1'] + gtest_parallel_args
@classmethod
def _Expected(cls, gtest_parallel_args):
return ['--shard_index=0', '--shard_count=1'] + gtest_parallel_args
def testOverwrite(self):
result = gtest_parallel_wrapper.ParseArgs(
['--timeout=123', 'exec', '--timeout', '124'])
expected = self._Expected(['--timeout=124', 'exec'])
self.assertEqual(result.gtest_parallel_args, expected)
def testOverwrite(self):
result = gtest_parallel_wrapper.ParseArgs(
['--timeout=123', 'exec', '--timeout', '124'])
expected = self._Expected(['--timeout=124', 'exec'])
self.assertEqual(result.gtest_parallel_args, expected)
def testMixing(self):
result = gtest_parallel_wrapper.ParseArgs([
'--timeout=123', '--param1', 'exec', '--param2', '--timeout', '124'
])
expected = self._Expected(
['--timeout=124', 'exec', '--', '--param1', '--param2'])
self.assertEqual(result.gtest_parallel_args, expected)
def testMixing(self):
result = gtest_parallel_wrapper.ParseArgs(
['--timeout=123', '--param1', 'exec', '--param2', '--timeout', '124'])
expected = self._Expected(
['--timeout=124', 'exec', '--', '--param1', '--param2'])
self.assertEqual(result.gtest_parallel_args, expected)
def testMixingPositional(self):
result = gtest_parallel_wrapper.ParseArgs([
'--timeout=123', 'exec', '--foo1', 'bar1', '--timeout', '124',
'--foo2', 'bar2'
])
expected = self._Expected([
'--timeout=124', 'exec', '--', '--foo1', 'bar1', '--foo2', 'bar2'
])
self.assertEqual(result.gtest_parallel_args, expected)
def testMixingPositional(self):
result = gtest_parallel_wrapper.ParseArgs([
'--timeout=123', 'exec', '--foo1', 'bar1', '--timeout', '124', '--foo2',
'bar2'
])
expected = self._Expected(
['--timeout=124', 'exec', '--', '--foo1', 'bar1', '--foo2', 'bar2'])
self.assertEqual(result.gtest_parallel_args, expected)
def testDoubleDash1(self):
result = gtest_parallel_wrapper.ParseArgs(
['--timeout', '123', 'exec', '--', '--timeout', '124'])
expected = self._Expected(
['--timeout=123', 'exec', '--', '--timeout', '124'])
self.assertEqual(result.gtest_parallel_args, expected)
def testDoubleDash1(self):
result = gtest_parallel_wrapper.ParseArgs(
['--timeout', '123', 'exec', '--', '--timeout', '124'])
expected = self._Expected(
['--timeout=123', 'exec', '--', '--timeout', '124'])
self.assertEqual(result.gtest_parallel_args, expected)
def testDoubleDash2(self):
result = gtest_parallel_wrapper.ParseArgs(
['--timeout=123', '--', 'exec', '--timeout=124'])
expected = self._Expected(
['--timeout=123', 'exec', '--', '--timeout=124'])
self.assertEqual(result.gtest_parallel_args, expected)
def testDoubleDash2(self):
result = gtest_parallel_wrapper.ParseArgs(
['--timeout=123', '--', 'exec', '--timeout=124'])
expected = self._Expected(['--timeout=123', 'exec', '--', '--timeout=124'])
self.assertEqual(result.gtest_parallel_args, expected)
def testArtifacts(self):
with TemporaryDirectory() as tmp_dir:
output_dir = os.path.join(tmp_dir, 'foo')
result = gtest_parallel_wrapper.ParseArgs(
['exec', '--store-test-artifacts', '--output_dir', output_dir])
exp_artifacts_dir = os.path.join(output_dir, 'test_artifacts')
exp = self._Expected([
'--output_dir=' + output_dir, 'exec', '--',
'--test_artifacts_dir=' + exp_artifacts_dir
])
self.assertEqual(result.gtest_parallel_args, exp)
self.assertEqual(result.output_dir, output_dir)
self.assertEqual(result.test_artifacts_dir, exp_artifacts_dir)
def testArtifacts(self):
with TemporaryDirectory() as tmp_dir:
output_dir = os.path.join(tmp_dir, 'foo')
result = gtest_parallel_wrapper.ParseArgs(
['exec', '--store-test-artifacts', '--output_dir', output_dir])
exp_artifacts_dir = os.path.join(output_dir, 'test_artifacts')
exp = self._Expected([
'--output_dir=' + output_dir, 'exec', '--',
'--test_artifacts_dir=' + exp_artifacts_dir
])
self.assertEqual(result.gtest_parallel_args, exp)
self.assertEqual(result.output_dir, output_dir)
self.assertEqual(result.test_artifacts_dir, exp_artifacts_dir)
def testNoDirsSpecified(self):
result = gtest_parallel_wrapper.ParseArgs(['exec'])
self.assertEqual(result.output_dir, None)
self.assertEqual(result.test_artifacts_dir, None)
def testNoDirsSpecified(self):
result = gtest_parallel_wrapper.ParseArgs(['exec'])
self.assertEqual(result.output_dir, None)
self.assertEqual(result.test_artifacts_dir, None)
def testOutputDirSpecified(self):
result = gtest_parallel_wrapper.ParseArgs(
['exec', '--output_dir', '/tmp/foo'])
self.assertEqual(result.output_dir, '/tmp/foo')
self.assertEqual(result.test_artifacts_dir, None)
def testOutputDirSpecified(self):
result = gtest_parallel_wrapper.ParseArgs(
['exec', '--output_dir', '/tmp/foo'])
self.assertEqual(result.output_dir, '/tmp/foo')
self.assertEqual(result.test_artifacts_dir, None)
def testShortArg(self):
result = gtest_parallel_wrapper.ParseArgs(['-d', '/tmp/foo', 'exec'])
expected = self._Expected(['--output_dir=/tmp/foo', 'exec'])
self.assertEqual(result.gtest_parallel_args, expected)
self.assertEqual(result.output_dir, '/tmp/foo')
def testShortArg(self):
result = gtest_parallel_wrapper.ParseArgs(['-d', '/tmp/foo', 'exec'])
expected = self._Expected(['--output_dir=/tmp/foo', 'exec'])
self.assertEqual(result.gtest_parallel_args, expected)
self.assertEqual(result.output_dir, '/tmp/foo')
def testBoolArg(self):
result = gtest_parallel_wrapper.ParseArgs(
['--gtest_also_run_disabled_tests', 'exec'])
expected = self._Expected(['--gtest_also_run_disabled_tests', 'exec'])
self.assertEqual(result.gtest_parallel_args, expected)
def testBoolArg(self):
result = gtest_parallel_wrapper.ParseArgs(
['--gtest_also_run_disabled_tests', 'exec'])
expected = self._Expected(['--gtest_also_run_disabled_tests', 'exec'])
self.assertEqual(result.gtest_parallel_args, expected)
def testNoArgs(self):
result = gtest_parallel_wrapper.ParseArgs(['exec'])
expected = self._Expected(['exec'])
self.assertEqual(result.gtest_parallel_args, expected)
def testNoArgs(self):
result = gtest_parallel_wrapper.ParseArgs(['exec'])
expected = self._Expected(['exec'])
self.assertEqual(result.gtest_parallel_args, expected)
def testDocExample(self):
with TemporaryDirectory() as tmp_dir:
output_dir = os.path.join(tmp_dir, 'foo')
result = gtest_parallel_wrapper.ParseArgs([
'some_test', '--some_flag=some_value', '--another_flag',
'--output_dir=' + output_dir, '--store-test-artifacts',
'--isolated-script-test-perf-output=SOME_OTHER_DIR',
'--foo=bar', '--baz'
])
expected_artifacts_dir = os.path.join(output_dir, 'test_artifacts')
expected = self._Expected([
'--output_dir=' + output_dir, 'some_test', '--',
'--test_artifacts_dir=' + expected_artifacts_dir,
'--some_flag=some_value', '--another_flag',
'--isolated_script_test_perf_output=SOME_OTHER_DIR',
'--foo=bar', '--baz'
])
self.assertEqual(result.gtest_parallel_args, expected)
def testDocExample(self):
with TemporaryDirectory() as tmp_dir:
output_dir = os.path.join(tmp_dir, 'foo')
result = gtest_parallel_wrapper.ParseArgs([
'some_test', '--some_flag=some_value', '--another_flag',
'--output_dir=' + output_dir, '--store-test-artifacts',
'--isolated-script-test-perf-output=SOME_OTHER_DIR', '--foo=bar',
'--baz'
])
expected_artifacts_dir = os.path.join(output_dir, 'test_artifacts')
expected = self._Expected([
'--output_dir=' + output_dir,
'some_test', '--', '--test_artifacts_dir=' + expected_artifacts_dir,
'--some_flag=some_value', '--another_flag',
'--isolated_script_test_perf_output=SOME_OTHER_DIR', '--foo=bar',
'--baz'
])
self.assertEqual(result.gtest_parallel_args, expected)
def testStandardWorkers(self):
"""Check integer value is passed as-is."""
result = gtest_parallel_wrapper.ParseArgs(['--workers', '17', 'exec'])
expected = self._Expected(['--workers=17', 'exec'])
self.assertEqual(result.gtest_parallel_args, expected)
def testStandardWorkers(self):
"""Check integer value is passed as-is."""
result = gtest_parallel_wrapper.ParseArgs(['--workers', '17', 'exec'])
expected = self._Expected(['--workers=17', 'exec'])
self.assertEqual(result.gtest_parallel_args, expected)
def testTwoWorkersPerCpuCore(self):
result = gtest_parallel_wrapper.ParseArgs(['--workers', '2x', 'exec'])
workers = 2 * multiprocessing.cpu_count()
expected = self._Expected(['--workers=%s' % workers, 'exec'])
self.assertEqual(result.gtest_parallel_args, expected)
def testTwoWorkersPerCpuCore(self):
result = gtest_parallel_wrapper.ParseArgs(['--workers', '2x', 'exec'])
workers = 2 * multiprocessing.cpu_count()
expected = self._Expected(['--workers=%s' % workers, 'exec'])
self.assertEqual(result.gtest_parallel_args, expected)
def testUseHalfTheCpuCores(self):
result = gtest_parallel_wrapper.ParseArgs(['--workers', '0.5x', 'exec'])
workers = max(multiprocessing.cpu_count() // 2, 1)
expected = self._Expected(['--workers=%s' % workers, 'exec'])
self.assertEqual(result.gtest_parallel_args, expected)
def testUseHalfTheCpuCores(self):
result = gtest_parallel_wrapper.ParseArgs(
['--workers', '0.5x', 'exec'])
workers = max(multiprocessing.cpu_count() // 2, 1)
expected = self._Expected(['--workers=%s' % workers, 'exec'])
self.assertEqual(result.gtest_parallel_args, expected)
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@ -7,7 +7,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""WebRTC iOS FAT libraries build script.
Each architecture is compiled separately before being merged together.
By default, the library is created in out_ios_libs/. (Change with -o.)
@ -21,7 +20,6 @@ import shutil
import subprocess
import sys
os.environ['PATH'] = '/usr/libexec' + os.pathsep + os.environ['PATH']
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
@ -41,198 +39,235 @@ from generate_licenses import LicenseBuilder
def _ParseArgs():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--build_config', default='release',
choices=['debug', 'release'],
help='The build config. Can be "debug" or "release". '
'Defaults to "release".')
parser.add_argument('--arch', nargs='+', default=DEFAULT_ARCHS,
choices=ENABLED_ARCHS,
help='Architectures to build. Defaults to %(default)s.')
parser.add_argument('-c', '--clean', action='store_true', default=False,
help='Removes the previously generated build output, if any.')
parser.add_argument('-p', '--purify', action='store_true', default=False,
help='Purifies the previously generated build output by '
'removing the temporary results used when (re)building.')
parser.add_argument('-o', '--output-dir', default=SDK_OUTPUT_DIR,
help='Specifies a directory to output the build artifacts to. '
'If specified together with -c, deletes the dir.')
parser.add_argument('-r', '--revision', type=int, default=0,
help='Specifies a revision number to embed if building the framework.')
parser.add_argument('-e', '--bitcode', action='store_true', default=False,
help='Compile with bitcode.')
parser.add_argument('--verbose', action='store_true', default=False,
help='Debug logging.')
parser.add_argument('--use-goma', action='store_true', default=False,
help='Use goma to build.')
parser.add_argument('--extra-gn-args', default=[], nargs='*',
help='Additional GN args to be used during Ninja generation.')
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument('--build_config',
default='release',
choices=['debug', 'release'],
help='The build config. Can be "debug" or "release". '
'Defaults to "release".')
parser.add_argument(
'--arch',
nargs='+',
default=DEFAULT_ARCHS,
choices=ENABLED_ARCHS,
help='Architectures to build. Defaults to %(default)s.')
parser.add_argument(
'-c',
'--clean',
action='store_true',
default=False,
help='Removes the previously generated build output, if any.')
parser.add_argument(
'-p',
'--purify',
action='store_true',
default=False,
help='Purifies the previously generated build output by '
'removing the temporary results used when (re)building.')
parser.add_argument(
'-o',
'--output-dir',
default=SDK_OUTPUT_DIR,
help='Specifies a directory to output the build artifacts to. '
'If specified together with -c, deletes the dir.')
parser.add_argument(
'-r',
'--revision',
type=int,
default=0,
help='Specifies a revision number to embed if building the framework.')
parser.add_argument('-e',
'--bitcode',
action='store_true',
default=False,
help='Compile with bitcode.')
parser.add_argument('--verbose',
action='store_true',
default=False,
help='Debug logging.')
parser.add_argument('--use-goma',
action='store_true',
default=False,
help='Use goma to build.')
parser.add_argument(
'--extra-gn-args',
default=[],
nargs='*',
help='Additional GN args to be used during Ninja generation.')
return parser.parse_args()
return parser.parse_args()
def _RunCommand(cmd):
logging.debug('Running: %r', cmd)
subprocess.check_call(cmd, cwd=SRC_DIR)
logging.debug('Running: %r', cmd)
subprocess.check_call(cmd, cwd=SRC_DIR)
def _CleanArtifacts(output_dir):
if os.path.isdir(output_dir):
logging.info('Deleting %s', output_dir)
shutil.rmtree(output_dir)
if os.path.isdir(output_dir):
logging.info('Deleting %s', output_dir)
shutil.rmtree(output_dir)
def _CleanTemporary(output_dir, architectures):
if os.path.isdir(output_dir):
logging.info('Removing temporary build files.')
for arch in architectures:
arch_lib_path = os.path.join(output_dir, arch + '_libs')
if os.path.isdir(arch_lib_path):
shutil.rmtree(arch_lib_path)
if os.path.isdir(output_dir):
logging.info('Removing temporary build files.')
for arch in architectures:
arch_lib_path = os.path.join(output_dir, arch + '_libs')
if os.path.isdir(arch_lib_path):
shutil.rmtree(arch_lib_path)
def BuildWebRTC(output_dir, target_arch, flavor, gn_target_name,
ios_deployment_target, libvpx_build_vp9, use_bitcode,
use_goma, extra_gn_args):
output_dir = os.path.join(output_dir, target_arch + '_libs')
gn_args = ['target_os="ios"', 'ios_enable_code_signing=false',
'use_xcode_clang=true', 'is_component_build=false']
ios_deployment_target, libvpx_build_vp9, use_bitcode, use_goma,
extra_gn_args):
output_dir = os.path.join(output_dir, target_arch + '_libs')
gn_args = [
'target_os="ios"', 'ios_enable_code_signing=false',
'use_xcode_clang=true', 'is_component_build=false'
]
# Add flavor option.
if flavor == 'debug':
gn_args.append('is_debug=true')
elif flavor == 'release':
gn_args.append('is_debug=false')
else:
raise ValueError('Unexpected flavor type: %s' % flavor)
# Add flavor option.
if flavor == 'debug':
gn_args.append('is_debug=true')
elif flavor == 'release':
gn_args.append('is_debug=false')
else:
raise ValueError('Unexpected flavor type: %s' % flavor)
gn_args.append('target_cpu="%s"' % target_arch)
gn_args.append('target_cpu="%s"' % target_arch)
gn_args.append('ios_deployment_target="%s"' % ios_deployment_target)
gn_args.append('ios_deployment_target="%s"' % ios_deployment_target)
gn_args.append('rtc_libvpx_build_vp9=' +
('true' if libvpx_build_vp9 else 'false'))
gn_args.append('rtc_libvpx_build_vp9=' +
('true' if libvpx_build_vp9 else 'false'))
gn_args.append('enable_ios_bitcode=' +
('true' if use_bitcode else 'false'))
gn_args.append('use_goma=' + ('true' if use_goma else 'false'))
gn_args.append('enable_ios_bitcode=' +
('true' if use_bitcode else 'false'))
gn_args.append('use_goma=' + ('true' if use_goma else 'false'))
args_string = ' '.join(gn_args + extra_gn_args)
logging.info('Building WebRTC with args: %s', args_string)
args_string = ' '.join(gn_args + extra_gn_args)
logging.info('Building WebRTC with args: %s', args_string)
cmd = [
sys.executable,
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py'),
'gen',
output_dir,
'--args=' + args_string,
]
_RunCommand(cmd)
logging.info('Building target: %s', gn_target_name)
cmd = [
sys.executable,
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py'),
'gen',
output_dir,
'--args=' + args_string,
]
_RunCommand(cmd)
logging.info('Building target: %s', gn_target_name)
cmd = [
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'ninja'),
'-C',
output_dir,
gn_target_name,
]
if use_goma:
cmd.extend(['-j', '200'])
_RunCommand(cmd)
cmd = [
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'ninja'),
'-C',
output_dir,
gn_target_name,
]
if use_goma:
cmd.extend(['-j', '200'])
_RunCommand(cmd)
def main():
args = _ParseArgs()
args = _ParseArgs()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
if args.clean:
_CleanArtifacts(args.output_dir)
return 0
if args.clean:
_CleanArtifacts(args.output_dir)
return 0
architectures = list(args.arch)
gn_args = args.extra_gn_args
architectures = list(args.arch)
gn_args = args.extra_gn_args
if args.purify:
_CleanTemporary(args.output_dir, architectures)
return 0
if args.purify:
_CleanTemporary(args.output_dir, architectures)
return 0
gn_target_name = 'framework_objc'
if not args.bitcode:
gn_args.append('enable_dsyms=true')
gn_args.append('enable_stripping=true')
gn_target_name = 'framework_objc'
if not args.bitcode:
gn_args.append('enable_dsyms=true')
gn_args.append('enable_stripping=true')
# Build all architectures.
for arch in architectures:
BuildWebRTC(args.output_dir, arch, args.build_config, gn_target_name,
IOS_DEPLOYMENT_TARGET, LIBVPX_BUILD_VP9, args.bitcode,
args.use_goma, gn_args)
# Build all architectures.
for arch in architectures:
BuildWebRTC(args.output_dir, arch, args.build_config, gn_target_name,
IOS_DEPLOYMENT_TARGET, LIBVPX_BUILD_VP9, args.bitcode,
args.use_goma, gn_args)
# Create FAT archive.
lib_paths = [
os.path.join(args.output_dir, arch + '_libs') for arch in architectures
]
# Create FAT archive.
lib_paths = [os.path.join(args.output_dir, arch + '_libs')
for arch in architectures]
# Combine the slices.
dylib_path = os.path.join(SDK_FRAMEWORK_NAME, 'WebRTC')
# Dylibs will be combined, all other files are the same across archs.
# Use distutils instead of shutil to support merging folders.
distutils.dir_util.copy_tree(
os.path.join(lib_paths[0], SDK_FRAMEWORK_NAME),
os.path.join(args.output_dir, SDK_FRAMEWORK_NAME))
logging.info('Merging framework slices.')
dylib_paths = [os.path.join(path, dylib_path) for path in lib_paths]
out_dylib_path = os.path.join(args.output_dir, dylib_path)
try:
os.remove(out_dylib_path)
except OSError:
pass
cmd = ['lipo'] + dylib_paths + ['-create', '-output', out_dylib_path]
_RunCommand(cmd)
# Merge the dSYM slices.
lib_dsym_dir_path = os.path.join(lib_paths[0], 'WebRTC.dSYM')
if os.path.isdir(lib_dsym_dir_path):
distutils.dir_util.copy_tree(lib_dsym_dir_path,
os.path.join(args.output_dir, 'WebRTC.dSYM'))
logging.info('Merging dSYM slices.')
dsym_path = os.path.join('WebRTC.dSYM', 'Contents', 'Resources', 'DWARF',
'WebRTC')
lib_dsym_paths = [os.path.join(path, dsym_path) for path in lib_paths]
out_dsym_path = os.path.join(args.output_dir, dsym_path)
try:
os.remove(out_dsym_path)
except OSError:
pass
cmd = ['lipo'] + lib_dsym_paths + ['-create', '-output', out_dsym_path]
_RunCommand(cmd)
# Generate the license file.
ninja_dirs = [os.path.join(args.output_dir, arch + '_libs')
for arch in architectures]
gn_target_full_name = '//sdk:' + gn_target_name
builder = LicenseBuilder(ninja_dirs, [gn_target_full_name])
builder.GenerateLicenseText(
# Combine the slices.
dylib_path = os.path.join(SDK_FRAMEWORK_NAME, 'WebRTC')
# Dylibs will be combined, all other files are the same across archs.
# Use distutils instead of shutil to support merging folders.
distutils.dir_util.copy_tree(
os.path.join(lib_paths[0], SDK_FRAMEWORK_NAME),
os.path.join(args.output_dir, SDK_FRAMEWORK_NAME))
# Modify the version number.
# Format should be <Branch cut MXX>.<Hotfix #>.<Rev #>.
# e.g. 55.0.14986 means branch cut 55, no hotfixes, and revision 14986.
infoplist_path = os.path.join(args.output_dir, SDK_FRAMEWORK_NAME,
'Info.plist')
cmd = ['PlistBuddy', '-c',
'Print :CFBundleShortVersionString', infoplist_path]
major_minor = subprocess.check_output(cmd).strip()
version_number = '%s.%s' % (major_minor, args.revision)
logging.info('Substituting revision number: %s', version_number)
cmd = ['PlistBuddy', '-c',
'Set :CFBundleVersion ' + version_number, infoplist_path]
logging.info('Merging framework slices.')
dylib_paths = [os.path.join(path, dylib_path) for path in lib_paths]
out_dylib_path = os.path.join(args.output_dir, dylib_path)
try:
os.remove(out_dylib_path)
except OSError:
pass
cmd = ['lipo'] + dylib_paths + ['-create', '-output', out_dylib_path]
_RunCommand(cmd)
_RunCommand(['plutil', '-convert', 'binary1', infoplist_path])
logging.info('Done.')
return 0
# Merge the dSYM slices.
lib_dsym_dir_path = os.path.join(lib_paths[0], 'WebRTC.dSYM')
if os.path.isdir(lib_dsym_dir_path):
distutils.dir_util.copy_tree(
lib_dsym_dir_path, os.path.join(args.output_dir, 'WebRTC.dSYM'))
logging.info('Merging dSYM slices.')
dsym_path = os.path.join('WebRTC.dSYM', 'Contents', 'Resources',
'DWARF', 'WebRTC')
lib_dsym_paths = [os.path.join(path, dsym_path) for path in lib_paths]
out_dsym_path = os.path.join(args.output_dir, dsym_path)
try:
os.remove(out_dsym_path)
except OSError:
pass
cmd = ['lipo'] + lib_dsym_paths + ['-create', '-output', out_dsym_path]
_RunCommand(cmd)
# Generate the license file.
ninja_dirs = [
os.path.join(args.output_dir, arch + '_libs')
for arch in architectures
]
gn_target_full_name = '//sdk:' + gn_target_name
builder = LicenseBuilder(ninja_dirs, [gn_target_full_name])
builder.GenerateLicenseText(
os.path.join(args.output_dir, SDK_FRAMEWORK_NAME))
# Modify the version number.
# Format should be <Branch cut MXX>.<Hotfix #>.<Rev #>.
# e.g. 55.0.14986 means branch cut 55, no hotfixes, and revision 14986.
infoplist_path = os.path.join(args.output_dir, SDK_FRAMEWORK_NAME,
'Info.plist')
cmd = [
'PlistBuddy', '-c', 'Print :CFBundleShortVersionString',
infoplist_path
]
major_minor = subprocess.check_output(cmd).strip()
version_number = '%s.%s' % (major_minor, args.revision)
logging.info('Substituting revision number: %s', version_number)
cmd = [
'PlistBuddy', '-c', 'Set :CFBundleVersion ' + version_number,
infoplist_path
]
_RunCommand(cmd)
_RunCommand(['plutil', '-convert', 'binary1', infoplist_path])
logging.info('Done.')
return 0
if __name__ == '__main__':
sys.exit(main())
sys.exit(main())

View File

@ -9,24 +9,24 @@
import argparse
import sys
def GenerateModulemap():
parser = argparse.ArgumentParser(description='Generate modulemap')
parser.add_argument("-o", "--out", type=str, help="Output file.")
parser.add_argument("-n", "--name", type=str, help="Name of binary.")
parser = argparse.ArgumentParser(description='Generate modulemap')
parser.add_argument("-o", "--out", type=str, help="Output file.")
parser.add_argument("-n", "--name", type=str, help="Name of binary.")
args = parser.parse_args()
args = parser.parse_args()
with open(args.out, "w") as outfile:
module_template = 'framework module %s {\n' \
' umbrella header "%s.h"\n' \
'\n' \
' export *\n' \
' module * { export * }\n' \
'}\n' % (args.name, args.name)
outfile.write(module_template)
return 0
with open(args.out, "w") as outfile:
module_template = 'framework module %s {\n' \
' umbrella header "%s.h"\n' \
'\n' \
' export *\n' \
' module * { export * }\n' \
'}\n' % (args.name, args.name)
outfile.write(module_template)
return 0
if __name__ == '__main__':
sys.exit(GenerateModulemap())
sys.exit(GenerateModulemap())

View File

@ -14,15 +14,20 @@ import textwrap
def GenerateUmbrellaHeader():
parser = argparse.ArgumentParser(description='Generate umbrella header')
parser.add_argument("-o", "--out", type=str, help="Output file.")
parser.add_argument("-s", "--sources", default=[], type=str, nargs='+',
help="Headers to include.")
parser = argparse.ArgumentParser(description='Generate umbrella header')
parser.add_argument("-o", "--out", type=str, help="Output file.")
parser.add_argument("-s",
"--sources",
default=[],
type=str,
nargs='+',
help="Headers to include.")
args = parser.parse_args()
args = parser.parse_args()
with open(args.out, "w") as outfile:
outfile.write(textwrap.dedent("""\
with open(args.out, "w") as outfile:
outfile.write(
textwrap.dedent("""\
/*
* Copyright %d The WebRTC project authors. All Rights Reserved.
*
@ -33,11 +38,11 @@ def GenerateUmbrellaHeader():
* be found in the AUTHORS file in the root of the source tree.
*/\n\n""" % datetime.datetime.now().year))
for s in args.sources:
outfile.write("#import <WebRTC/{}>\n".format(os.path.basename(s)))
for s in args.sources:
outfile.write("#import <WebRTC/{}>\n".format(os.path.basename(s)))
return 0
return 0
if __name__ == '__main__':
sys.exit(GenerateUmbrellaHeader())
sys.exit(GenerateUmbrellaHeader())

View File

@ -7,7 +7,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Script for merging generated iOS libraries."""
import sys
@ -22,7 +21,7 @@ VALID_ARCHS = ['arm_libs', 'arm64_libs', 'ia32_libs', 'x64_libs']
def MergeLibs(lib_base_dir):
"""Merges generated iOS libraries for different archs.
"""Merges generated iOS libraries for different archs.
Uses libtool to generate FAT archive files for each generated library.
@ -33,92 +32,96 @@ def MergeLibs(lib_base_dir):
Returns:
Exit code of libtool.
"""
output_dir_name = 'fat_libs'
archs = [arch for arch in os.listdir(lib_base_dir)
if arch in VALID_ARCHS]
# For each arch, find (library name, libary path) for arch. We will merge
# all libraries with the same name.
libs = {}
for lib_dir in [os.path.join(lib_base_dir, arch) for arch in VALID_ARCHS]:
if not os.path.exists(lib_dir):
continue
for dirpath, _, filenames in os.walk(lib_dir):
for filename in filenames:
if not filename.endswith('.a'):
continue
entry = libs.get(filename, [])
entry.append(os.path.join(dirpath, filename))
libs[filename] = entry
orphaned_libs = {}
valid_libs = {}
for library, paths in libs.items():
if len(paths) < len(archs):
orphaned_libs[library] = paths
else:
valid_libs[library] = paths
for library, paths in orphaned_libs.items():
components = library[:-2].split('_')[:-1]
found = False
# Find directly matching parent libs by stripping suffix.
while components and not found:
parent_library = '_'.join(components) + '.a'
if parent_library in valid_libs:
valid_libs[parent_library].extend(paths)
found = True
break
components = components[:-1]
# Find next best match by finding parent libs with the same prefix.
if not found:
base_prefix = library[:-2].split('_')[0]
for valid_lib, valid_paths in valid_libs.items():
if valid_lib[:len(base_prefix)] == base_prefix:
valid_paths.extend(paths)
found = True
break
assert found
output_dir_name = 'fat_libs'
archs = [arch for arch in os.listdir(lib_base_dir) if arch in VALID_ARCHS]
# For each arch, find (library name, libary path) for arch. We will merge
# all libraries with the same name.
libs = {}
for lib_dir in [os.path.join(lib_base_dir, arch) for arch in VALID_ARCHS]:
if not os.path.exists(lib_dir):
continue
for dirpath, _, filenames in os.walk(lib_dir):
for filename in filenames:
if not filename.endswith('.a'):
continue
entry = libs.get(filename, [])
entry.append(os.path.join(dirpath, filename))
libs[filename] = entry
orphaned_libs = {}
valid_libs = {}
for library, paths in libs.items():
if len(paths) < len(archs):
orphaned_libs[library] = paths
else:
valid_libs[library] = paths
for library, paths in orphaned_libs.items():
components = library[:-2].split('_')[:-1]
found = False
# Find directly matching parent libs by stripping suffix.
while components and not found:
parent_library = '_'.join(components) + '.a'
if parent_library in valid_libs:
valid_libs[parent_library].extend(paths)
found = True
break
components = components[:-1]
# Find next best match by finding parent libs with the same prefix.
if not found:
base_prefix = library[:-2].split('_')[0]
for valid_lib, valid_paths in valid_libs.items():
if valid_lib[:len(base_prefix)] == base_prefix:
valid_paths.extend(paths)
found = True
break
assert found
# Create output directory.
output_dir_path = os.path.join(lib_base_dir, output_dir_name)
if not os.path.exists(output_dir_path):
os.mkdir(output_dir_path)
# Create output directory.
output_dir_path = os.path.join(lib_base_dir, output_dir_name)
if not os.path.exists(output_dir_path):
os.mkdir(output_dir_path)
# Use this so libtool merged binaries are always the same.
env = os.environ.copy()
env['ZERO_AR_DATE'] = '1'
# Use this so libtool merged binaries are always the same.
env = os.environ.copy()
env['ZERO_AR_DATE'] = '1'
# Ignore certain errors.
libtool_re = re.compile(r'^.*libtool:.*file: .* has no symbols$')
# Ignore certain errors.
libtool_re = re.compile(r'^.*libtool:.*file: .* has no symbols$')
# Merge libraries using libtool.
libtool_returncode = 0
for library, paths in valid_libs.items():
cmd_list = ['libtool', '-static', '-v', '-o',
os.path.join(output_dir_path, library)] + paths
libtoolout = subprocess.Popen(cmd_list, stderr=subprocess.PIPE, env=env)
_, err = libtoolout.communicate()
for line in err.splitlines():
if not libtool_re.match(line):
print >>sys.stderr, line
# Unconditionally touch the output .a file on the command line if present
# and the command succeeded. A bit hacky.
libtool_returncode = libtoolout.returncode
if not libtool_returncode:
for i in range(len(cmd_list) - 1):
if cmd_list[i] == '-o' and cmd_list[i+1].endswith('.a'):
os.utime(cmd_list[i+1], None)
break
return libtool_returncode
# Merge libraries using libtool.
libtool_returncode = 0
for library, paths in valid_libs.items():
cmd_list = [
'libtool', '-static', '-v', '-o',
os.path.join(output_dir_path, library)
] + paths
libtoolout = subprocess.Popen(cmd_list,
stderr=subprocess.PIPE,
env=env)
_, err = libtoolout.communicate()
for line in err.splitlines():
if not libtool_re.match(line):
print >> sys.stderr, line
# Unconditionally touch the output .a file on the command line if present
# and the command succeeded. A bit hacky.
libtool_returncode = libtoolout.returncode
if not libtool_returncode:
for i in range(len(cmd_list) - 1):
if cmd_list[i] == '-o' and cmd_list[i + 1].endswith('.a'):
os.utime(cmd_list[i + 1], None)
break
return libtool_returncode
def Main():
parser_description = 'Merge WebRTC libraries.'
parser = argparse.ArgumentParser(description=parser_description)
parser.add_argument('lib_base_dir',
help='Directory with built libraries. ',
type=str)
args = parser.parse_args()
lib_base_dir = args.lib_base_dir
MergeLibs(lib_base_dir)
parser_description = 'Merge WebRTC libraries.'
parser = argparse.ArgumentParser(description=parser_description)
parser.add_argument('lib_base_dir',
help='Directory with built libraries. ',
type=str)
args = parser.parse_args()
lib_base_dir = args.lib_base_dir
MergeLibs(lib_base_dir)
if __name__ == '__main__':
sys.exit(Main())
sys.exit(Main())

View File

@ -36,12 +36,16 @@ LIB_TO_LICENSES_DICT = {
'abseil-cpp': ['third_party/abseil-cpp/LICENSE'],
'android_ndk': ['third_party/android_ndk/NOTICE'],
'android_sdk': ['third_party/android_sdk/LICENSE'],
'auto': ['third_party/android_deps/libs/'
'com_google_auto_service_auto_service/LICENSE'],
'auto': [
'third_party/android_deps/libs/'
'com_google_auto_service_auto_service/LICENSE'
],
'bazel': ['third_party/bazel/LICENSE'],
'boringssl': ['third_party/boringssl/src/LICENSE'],
'errorprone': ['third_party/android_deps/libs/'
'com_google_errorprone_error_prone_core/LICENSE'],
'errorprone': [
'third_party/android_deps/libs/'
'com_google_errorprone_error_prone_core/LICENSE'
],
'fiat': ['third_party/boringssl/src/third_party/fiat/LICENSE'],
'guava': ['third_party/guava/LICENSE'],
'ijar': ['third_party/ijar/LICENSE'],
@ -95,11 +99,11 @@ LIB_REGEX_TO_LICENSES_DICT = {
def FindSrcDirPath():
"""Returns the abs path to the src/ dir of the project."""
src_dir = os.path.dirname(os.path.abspath(__file__))
while os.path.basename(src_dir) != 'src':
src_dir = os.path.normpath(os.path.join(src_dir, os.pardir))
return src_dir
"""Returns the abs path to the src/ dir of the project."""
src_dir = os.path.dirname(os.path.abspath(__file__))
while os.path.basename(src_dir) != 'src':
src_dir = os.path.normpath(os.path.join(src_dir, os.pardir))
return src_dir
SCRIPT_DIR = os.path.dirname(os.path.realpath(sys.argv[0]))
@ -113,29 +117,28 @@ THIRD_PARTY_LIB_REGEX_TEMPLATE = r'^.*/third_party/%s$'
class LicenseBuilder(object):
def __init__(self,
buildfile_dirs,
targets,
lib_to_licenses_dict=None,
lib_regex_to_licenses_dict=None):
if lib_to_licenses_dict is None:
lib_to_licenses_dict = LIB_TO_LICENSES_DICT
def __init__(self,
buildfile_dirs,
targets,
lib_to_licenses_dict=None,
lib_regex_to_licenses_dict=None):
if lib_to_licenses_dict is None:
lib_to_licenses_dict = LIB_TO_LICENSES_DICT
if lib_regex_to_licenses_dict is None:
lib_regex_to_licenses_dict = LIB_REGEX_TO_LICENSES_DICT
if lib_regex_to_licenses_dict is None:
lib_regex_to_licenses_dict = LIB_REGEX_TO_LICENSES_DICT
self.buildfile_dirs = buildfile_dirs
self.targets = targets
self.lib_to_licenses_dict = lib_to_licenses_dict
self.lib_regex_to_licenses_dict = lib_regex_to_licenses_dict
self.buildfile_dirs = buildfile_dirs
self.targets = targets
self.lib_to_licenses_dict = lib_to_licenses_dict
self.lib_regex_to_licenses_dict = lib_regex_to_licenses_dict
self.common_licenses_dict = self.lib_to_licenses_dict.copy()
self.common_licenses_dict.update(self.lib_regex_to_licenses_dict)
self.common_licenses_dict = self.lib_to_licenses_dict.copy()
self.common_licenses_dict.update(self.lib_regex_to_licenses_dict)
@staticmethod
def _ParseLibraryName(dep):
"""Returns library name after third_party
@staticmethod
def _ParseLibraryName(dep):
"""Returns library name after third_party
Input one of:
//a/b/third_party/libname:c
@ -144,11 +147,11 @@ class LicenseBuilder(object):
Outputs libname or None if this is not a third_party dependency.
"""
groups = re.match(THIRD_PARTY_LIB_SIMPLE_NAME_REGEX, dep)
return groups.group(1) if groups else None
groups = re.match(THIRD_PARTY_LIB_SIMPLE_NAME_REGEX, dep)
return groups.group(1) if groups else None
def _ParseLibrary(self, dep):
"""Returns library simple or regex name that matches `dep` after third_party
def _ParseLibrary(self, dep):
"""Returns library simple or regex name that matches `dep` after third_party
This method matches `dep` dependency against simple names in
LIB_TO_LICENSES_DICT and regular expression names in
@ -156,104 +159,109 @@ class LicenseBuilder(object):
Outputs matched dict key or None if this is not a third_party dependency.
"""
libname = LicenseBuilder._ParseLibraryName(dep)
libname = LicenseBuilder._ParseLibraryName(dep)
for lib_regex in self.lib_regex_to_licenses_dict:
if re.match(THIRD_PARTY_LIB_REGEX_TEMPLATE % lib_regex, dep):
return lib_regex
for lib_regex in self.lib_regex_to_licenses_dict:
if re.match(THIRD_PARTY_LIB_REGEX_TEMPLATE % lib_regex, dep):
return lib_regex
return libname
return libname
@staticmethod
def _RunGN(buildfile_dir, target):
cmd = [
sys.executable,
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py'),
'desc',
'--all',
'--format=json',
os.path.abspath(buildfile_dir),
target,
]
logging.debug('Running: %r', cmd)
output_json = subprocess.check_output(cmd, cwd=WEBRTC_ROOT)
logging.debug('Output: %s', output_json)
return output_json
@staticmethod
def _RunGN(buildfile_dir, target):
cmd = [
sys.executable,
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py'),
'desc',
'--all',
'--format=json',
os.path.abspath(buildfile_dir),
target,
]
logging.debug('Running: %r', cmd)
output_json = subprocess.check_output(cmd, cwd=WEBRTC_ROOT)
logging.debug('Output: %s', output_json)
return output_json
def _GetThirdPartyLibraries(self, buildfile_dir, target):
output = json.loads(LicenseBuilder._RunGN(buildfile_dir, target))
libraries = set()
for described_target in output.values():
third_party_libs = (
self._ParseLibrary(dep) for dep in described_target['deps'])
libraries |= set(lib for lib in third_party_libs if lib)
return libraries
def _GetThirdPartyLibraries(self, buildfile_dir, target):
output = json.loads(LicenseBuilder._RunGN(buildfile_dir, target))
libraries = set()
for described_target in output.values():
third_party_libs = (self._ParseLibrary(dep)
for dep in described_target['deps'])
libraries |= set(lib for lib in third_party_libs if lib)
return libraries
def GenerateLicenseText(self, output_dir):
# Get a list of third_party libs from gn. For fat libraries we must consider
# all architectures, hence the multiple buildfile directories.
third_party_libs = set()
for buildfile in self.buildfile_dirs:
for target in self.targets:
third_party_libs |= self._GetThirdPartyLibraries(buildfile, target)
assert len(third_party_libs) > 0
def GenerateLicenseText(self, output_dir):
# Get a list of third_party libs from gn. For fat libraries we must consider
# all architectures, hence the multiple buildfile directories.
third_party_libs = set()
for buildfile in self.buildfile_dirs:
for target in self.targets:
third_party_libs |= self._GetThirdPartyLibraries(
buildfile, target)
assert len(third_party_libs) > 0
missing_licenses = third_party_libs - set(self.common_licenses_dict.keys())
if missing_licenses:
error_msg = 'Missing licenses for following third_party targets: %s' % \
', '.join(missing_licenses)
logging.error(error_msg)
raise Exception(error_msg)
missing_licenses = third_party_libs - set(
self.common_licenses_dict.keys())
if missing_licenses:
error_msg = 'Missing licenses for following third_party targets: %s' % \
', '.join(missing_licenses)
logging.error(error_msg)
raise Exception(error_msg)
# Put webrtc at the front of the list.
license_libs = sorted(third_party_libs)
license_libs.insert(0, 'webrtc')
# Put webrtc at the front of the list.
license_libs = sorted(third_party_libs)
license_libs.insert(0, 'webrtc')
logging.info('List of licenses: %s', ', '.join(license_libs))
logging.info('List of licenses: %s', ', '.join(license_libs))
# Generate markdown.
output_license_file = open(os.path.join(output_dir, 'LICENSE.md'), 'w+')
for license_lib in license_libs:
if len(self.common_licenses_dict[license_lib]) == 0:
logging.info('Skipping compile time or internal dependency: %s',
license_lib)
continue # Compile time dependency
# Generate markdown.
output_license_file = open(os.path.join(output_dir, 'LICENSE.md'),
'w+')
for license_lib in license_libs:
if len(self.common_licenses_dict[license_lib]) == 0:
logging.info(
'Skipping compile time or internal dependency: %s',
license_lib)
continue # Compile time dependency
output_license_file.write('# %s\n' % license_lib)
output_license_file.write('```\n')
for path in self.common_licenses_dict[license_lib]:
license_path = os.path.join(WEBRTC_ROOT, path)
with open(license_path, 'r') as license_file:
license_text = cgi.escape(license_file.read(), quote=True)
output_license_file.write(license_text)
output_license_file.write('\n')
output_license_file.write('```\n\n')
output_license_file.write('# %s\n' % license_lib)
output_license_file.write('```\n')
for path in self.common_licenses_dict[license_lib]:
license_path = os.path.join(WEBRTC_ROOT, path)
with open(license_path, 'r') as license_file:
license_text = cgi.escape(license_file.read(), quote=True)
output_license_file.write(license_text)
output_license_file.write('\n')
output_license_file.write('```\n\n')
output_license_file.close()
output_license_file.close()
def main():
parser = argparse.ArgumentParser(description='Generate WebRTC LICENSE.md')
parser.add_argument(
'--verbose', action='store_true', default=False, help='Debug logging.')
parser.add_argument(
'--target',
required=True,
action='append',
default=[],
help='Name of the GN target to generate a license for')
parser.add_argument('output_dir', help='Directory to output LICENSE.md to.')
parser.add_argument(
'buildfile_dirs',
nargs='+',
help='Directories containing gn generated ninja files')
args = parser.parse_args()
parser = argparse.ArgumentParser(description='Generate WebRTC LICENSE.md')
parser.add_argument('--verbose',
action='store_true',
default=False,
help='Debug logging.')
parser.add_argument('--target',
required=True,
action='append',
default=[],
help='Name of the GN target to generate a license for')
parser.add_argument('output_dir',
help='Directory to output LICENSE.md to.')
parser.add_argument('buildfile_dirs',
nargs='+',
help='Directories containing gn generated ninja files')
args = parser.parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
builder = LicenseBuilder(args.buildfile_dirs, args.target)
builder.GenerateLicenseText(args.output_dir)
builder = LicenseBuilder(args.buildfile_dirs, args.target)
builder.GenerateLicenseText(args.output_dir)
if __name__ == '__main__':
sys.exit(main())
sys.exit(main())

View File

@ -16,10 +16,9 @@ from generate_licenses import LicenseBuilder
class TestLicenseBuilder(unittest.TestCase):
@staticmethod
def _FakeRunGN(buildfile_dir, target):
return """
@staticmethod
def _FakeRunGN(buildfile_dir, target):
return """
{
"target1": {
"deps": [
@ -32,91 +31,93 @@ class TestLicenseBuilder(unittest.TestCase):
}
"""
def testParseLibraryName(self):
self.assertEquals(
LicenseBuilder._ParseLibraryName('//a/b/third_party/libname1:c'),
'libname1')
self.assertEquals(
LicenseBuilder._ParseLibraryName('//a/b/third_party/libname2:c(d)'),
'libname2')
self.assertEquals(
LicenseBuilder._ParseLibraryName('//a/b/third_party/libname3/c:d(e)'),
'libname3')
self.assertEquals(
LicenseBuilder._ParseLibraryName('//a/b/not_third_party/c'), None)
def testParseLibraryName(self):
self.assertEquals(
LicenseBuilder._ParseLibraryName('//a/b/third_party/libname1:c'),
'libname1')
self.assertEquals(
LicenseBuilder._ParseLibraryName(
'//a/b/third_party/libname2:c(d)'), 'libname2')
self.assertEquals(
LicenseBuilder._ParseLibraryName(
'//a/b/third_party/libname3/c:d(e)'), 'libname3')
self.assertEquals(
LicenseBuilder._ParseLibraryName('//a/b/not_third_party/c'), None)
def testParseLibrarySimpleMatch(self):
builder = LicenseBuilder([], [], {}, {})
self.assertEquals(
builder._ParseLibrary('//a/b/third_party/libname:c'), 'libname')
def testParseLibrarySimpleMatch(self):
builder = LicenseBuilder([], [], {}, {})
self.assertEquals(builder._ParseLibrary('//a/b/third_party/libname:c'),
'libname')
def testParseLibraryRegExNoMatchFallbacksToDefaultLibname(self):
lib_dict = {
'libname:foo.*': ['path/to/LICENSE'],
}
builder = LicenseBuilder([], [], lib_dict, {})
self.assertEquals(
builder._ParseLibrary('//a/b/third_party/libname:bar_java'), 'libname')
def testParseLibraryRegExNoMatchFallbacksToDefaultLibname(self):
lib_dict = {
'libname:foo.*': ['path/to/LICENSE'],
}
builder = LicenseBuilder([], [], lib_dict, {})
self.assertEquals(
builder._ParseLibrary('//a/b/third_party/libname:bar_java'),
'libname')
def testParseLibraryRegExMatch(self):
lib_regex_dict = {
'libname:foo.*': ['path/to/LICENSE'],
}
builder = LicenseBuilder([], [], {}, lib_regex_dict)
self.assertEquals(
builder._ParseLibrary('//a/b/third_party/libname:foo_bar_java'),
'libname:foo.*')
def testParseLibraryRegExMatch(self):
lib_regex_dict = {
'libname:foo.*': ['path/to/LICENSE'],
}
builder = LicenseBuilder([], [], {}, lib_regex_dict)
self.assertEquals(
builder._ParseLibrary('//a/b/third_party/libname:foo_bar_java'),
'libname:foo.*')
def testParseLibraryRegExMatchWithSubDirectory(self):
lib_regex_dict = {
'libname/foo:bar.*': ['path/to/LICENSE'],
}
builder = LicenseBuilder([], [], {}, lib_regex_dict)
self.assertEquals(
builder._ParseLibrary('//a/b/third_party/libname/foo:bar_java'),
'libname/foo:bar.*')
def testParseLibraryRegExMatchWithSubDirectory(self):
lib_regex_dict = {
'libname/foo:bar.*': ['path/to/LICENSE'],
}
builder = LicenseBuilder([], [], {}, lib_regex_dict)
self.assertEquals(
builder._ParseLibrary('//a/b/third_party/libname/foo:bar_java'),
'libname/foo:bar.*')
def testParseLibraryRegExMatchWithStarInside(self):
lib_regex_dict = {
'libname/foo.*bar.*': ['path/to/LICENSE'],
}
builder = LicenseBuilder([], [], {}, lib_regex_dict)
self.assertEquals(
builder._ParseLibrary('//a/b/third_party/libname/fooHAHA:bar_java'),
'libname/foo.*bar.*')
def testParseLibraryRegExMatchWithStarInside(self):
lib_regex_dict = {
'libname/foo.*bar.*': ['path/to/LICENSE'],
}
builder = LicenseBuilder([], [], {}, lib_regex_dict)
self.assertEquals(
builder._ParseLibrary(
'//a/b/third_party/libname/fooHAHA:bar_java'),
'libname/foo.*bar.*')
@mock.patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN)
def testGetThirdPartyLibrariesWithoutRegex(self):
builder = LicenseBuilder([], [], {}, {})
self.assertEquals(
builder._GetThirdPartyLibraries('out/arm', 'target1'),
set(['libname1', 'libname2', 'libname3']))
@mock.patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN)
def testGetThirdPartyLibrariesWithoutRegex(self):
builder = LicenseBuilder([], [], {}, {})
self.assertEquals(
builder._GetThirdPartyLibraries('out/arm', 'target1'),
set(['libname1', 'libname2', 'libname3']))
@mock.patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN)
def testGetThirdPartyLibrariesWithRegex(self):
lib_regex_dict = {
'libname2:c.*': ['path/to/LICENSE'],
}
builder = LicenseBuilder([], [], {}, lib_regex_dict)
self.assertEquals(
builder._GetThirdPartyLibraries('out/arm', 'target1'),
set(['libname1', 'libname2:c.*', 'libname3']))
@mock.patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN)
def testGetThirdPartyLibrariesWithRegex(self):
lib_regex_dict = {
'libname2:c.*': ['path/to/LICENSE'],
}
builder = LicenseBuilder([], [], {}, lib_regex_dict)
self.assertEquals(
builder._GetThirdPartyLibraries('out/arm', 'target1'),
set(['libname1', 'libname2:c.*', 'libname3']))
@mock.patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN)
def testGenerateLicenseTextFailIfUnknownLibrary(self):
lib_dict = {
'simple_library': ['path/to/LICENSE'],
}
builder = LicenseBuilder(['dummy_dir'], ['dummy_target'], lib_dict, {})
@mock.patch('generate_licenses.LicenseBuilder._RunGN', _FakeRunGN)
def testGenerateLicenseTextFailIfUnknownLibrary(self):
lib_dict = {
'simple_library': ['path/to/LICENSE'],
}
builder = LicenseBuilder(['dummy_dir'], ['dummy_target'], lib_dict, {})
with self.assertRaises(Exception) as context:
builder.GenerateLicenseText('dummy/dir')
with self.assertRaises(Exception) as context:
builder.GenerateLicenseText('dummy/dir')
self.assertEquals(
context.exception.message,
'Missing licenses for following third_party targets: '
'libname1, libname2, libname3')
self.assertEquals(
context.exception.message,
'Missing licenses for following third_party targets: '
'libname1, libname2, libname3')
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@ -6,31 +6,31 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Configuration class for network emulation."""
class ConnectionConfig(object):
"""Configuration containing the characteristics of a network connection."""
"""Configuration containing the characteristics of a network connection."""
def __init__(self, num, name, receive_bw_kbps, send_bw_kbps, delay_ms,
packet_loss_percent, queue_slots):
self.num = num
self.name = name
self.receive_bw_kbps = receive_bw_kbps
self.send_bw_kbps = send_bw_kbps
self.delay_ms = delay_ms
self.packet_loss_percent = packet_loss_percent
self.queue_slots = queue_slots
def __init__(self, num, name, receive_bw_kbps, send_bw_kbps, delay_ms,
packet_loss_percent, queue_slots):
self.num = num
self.name = name
self.receive_bw_kbps = receive_bw_kbps
self.send_bw_kbps = send_bw_kbps
self.delay_ms = delay_ms
self.packet_loss_percent = packet_loss_percent
self.queue_slots = queue_slots
def __str__(self):
"""String representing the configuration.
def __str__(self):
"""String representing the configuration.
Returns:
A string formatted and padded like this example:
12 Name 375 kbps 375 kbps 10 145 ms 0.1 %
"""
left_aligned_name = self.name.ljust(24, ' ')
return '%2s %24s %5s kbps %5s kbps %4s %5s ms %3s %%' % (
self.num, left_aligned_name, self.receive_bw_kbps, self.send_bw_kbps,
self.queue_slots, self.delay_ms, self.packet_loss_percent)
left_aligned_name = self.name.ljust(24, ' ')
return '%2s %24s %5s kbps %5s kbps %4s %5s ms %3s %%' % (
self.num, left_aligned_name, self.receive_bw_kbps,
self.send_bw_kbps, self.queue_slots, self.delay_ms,
self.packet_loss_percent)

View File

@ -6,10 +6,8 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Script for constraining traffic on the local machine."""
import logging
import optparse
import socket
@ -18,7 +16,6 @@ import sys
import config
import network_emulator
_DEFAULT_LOG_LEVEL = logging.INFO
# Default port range to apply network constraints on.
@ -41,7 +38,7 @@ _PRESETS = [
config.ConnectionConfig(12, 'Wifi, Average Case', 40000, 33000, 1, 0, 100),
config.ConnectionConfig(13, 'Wifi, Good', 45000, 40000, 1, 0, 100),
config.ConnectionConfig(14, 'Wifi, Lossy', 40000, 33000, 1, 0, 100),
]
]
_PRESETS_DICT = dict((p.num, p) for p in _PRESETS)
_DEFAULT_PRESET_ID = 2
@ -49,147 +46,170 @@ _DEFAULT_PRESET = _PRESETS_DICT[_DEFAULT_PRESET_ID]
class NonStrippingEpilogOptionParser(optparse.OptionParser):
"""Custom parser to let us show the epilog without weird line breaking."""
"""Custom parser to let us show the epilog without weird line breaking."""
def format_epilog(self, formatter):
return self.epilog
def format_epilog(self, formatter):
return self.epilog
def _GetExternalIp():
"""Finds out the machine's external IP by connecting to google.com."""
external_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
external_socket.connect(('google.com', 80))
return external_socket.getsockname()[0]
"""Finds out the machine's external IP by connecting to google.com."""
external_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
external_socket.connect(('google.com', 80))
return external_socket.getsockname()[0]
def _ParseArgs():
"""Define and parse the command-line arguments."""
presets_string = '\n'.join(str(p) for p in _PRESETS)
parser = NonStrippingEpilogOptionParser(epilog=(
'\nAvailable presets:\n'
' Bandwidth (kbps) Packet\n'
'ID Name Receive Send Queue Delay loss \n'
'-- ---- --------- -------- ----- ------- ------\n'
'%s\n' % presets_string))
parser.add_option('-p', '--preset', type='int', default=_DEFAULT_PRESET_ID,
help=('ConnectionConfig configuration, specified by ID. '
'Default: %default'))
parser.add_option('-r', '--receive-bw', type='int',
default=_DEFAULT_PRESET.receive_bw_kbps,
help=('Receive bandwidth in kilobit/s. Default: %default'))
parser.add_option('-s', '--send-bw', type='int',
default=_DEFAULT_PRESET.send_bw_kbps,
help=('Send bandwidth in kilobit/s. Default: %default'))
parser.add_option('-d', '--delay', type='int',
default=_DEFAULT_PRESET.delay_ms,
help=('Delay in ms. Default: %default'))
parser.add_option('-l', '--packet-loss', type='float',
default=_DEFAULT_PRESET.packet_loss_percent,
help=('Packet loss in %. Default: %default'))
parser.add_option('-q', '--queue', type='int',
default=_DEFAULT_PRESET.queue_slots,
help=('Queue size as number of slots. Default: %default'))
parser.add_option('--port-range', default='%s,%s' % _DEFAULT_PORT_RANGE,
help=('Range of ports for constrained network. Specify as '
'two comma separated integers. Default: %default'))
parser.add_option('--target-ip', default=None,
help=('The interface IP address to apply the rules for. '
'Default: the external facing interface IP address.'))
parser.add_option('-v', '--verbose', action='store_true', default=False,
help=('Turn on verbose output. Will print all \'ipfw\' '
'commands that are executed.'))
"""Define and parse the command-line arguments."""
presets_string = '\n'.join(str(p) for p in _PRESETS)
parser = NonStrippingEpilogOptionParser(epilog=(
'\nAvailable presets:\n'
' Bandwidth (kbps) Packet\n'
'ID Name Receive Send Queue Delay loss \n'
'-- ---- --------- -------- ----- ------- ------\n'
'%s\n' % presets_string))
parser.add_option('-p',
'--preset',
type='int',
default=_DEFAULT_PRESET_ID,
help=('ConnectionConfig configuration, specified by ID. '
'Default: %default'))
parser.add_option(
'-r',
'--receive-bw',
type='int',
default=_DEFAULT_PRESET.receive_bw_kbps,
help=('Receive bandwidth in kilobit/s. Default: %default'))
parser.add_option('-s',
'--send-bw',
type='int',
default=_DEFAULT_PRESET.send_bw_kbps,
help=('Send bandwidth in kilobit/s. Default: %default'))
parser.add_option('-d',
'--delay',
type='int',
default=_DEFAULT_PRESET.delay_ms,
help=('Delay in ms. Default: %default'))
parser.add_option('-l',
'--packet-loss',
type='float',
default=_DEFAULT_PRESET.packet_loss_percent,
help=('Packet loss in %. Default: %default'))
parser.add_option(
'-q',
'--queue',
type='int',
default=_DEFAULT_PRESET.queue_slots,
help=('Queue size as number of slots. Default: %default'))
parser.add_option(
'--port-range',
default='%s,%s' % _DEFAULT_PORT_RANGE,
help=('Range of ports for constrained network. Specify as '
'two comma separated integers. Default: %default'))
parser.add_option(
'--target-ip',
default=None,
help=('The interface IP address to apply the rules for. '
'Default: the external facing interface IP address.'))
parser.add_option('-v',
'--verbose',
action='store_true',
default=False,
help=('Turn on verbose output. Will print all \'ipfw\' '
'commands that are executed.'))
options = parser.parse_args()[0]
options = parser.parse_args()[0]
# Find preset by ID, if specified.
if options.preset and not _PRESETS_DICT.has_key(options.preset):
parser.error('Invalid preset: %s' % options.preset)
# Find preset by ID, if specified.
if options.preset and not _PRESETS_DICT.has_key(options.preset):
parser.error('Invalid preset: %s' % options.preset)
# Simple validation of the IP address, if supplied.
if options.target_ip:
# Simple validation of the IP address, if supplied.
if options.target_ip:
try:
socket.inet_aton(options.target_ip)
except socket.error:
parser.error('Invalid IP address specified: %s' %
options.target_ip)
# Convert port range into the desired tuple format.
try:
socket.inet_aton(options.target_ip)
except socket.error:
parser.error('Invalid IP address specified: %s' % options.target_ip)
if isinstance(options.port_range, str):
options.port_range = tuple(
int(port) for port in options.port_range.split(','))
if len(options.port_range) != 2:
parser.error(
'Invalid port range specified, please specify two '
'integers separated by a comma.')
except ValueError:
parser.error('Invalid port range specified.')
# Convert port range into the desired tuple format.
try:
if isinstance(options.port_range, str):
options.port_range = tuple(int(port) for port in
options.port_range.split(','))
if len(options.port_range) != 2:
parser.error('Invalid port range specified, please specify two '
'integers separated by a comma.')
except ValueError:
parser.error('Invalid port range specified.')
_InitLogging(options.verbose)
return options
_InitLogging(options.verbose)
return options
def _InitLogging(verbose):
"""Setup logging."""
log_level = _DEFAULT_LOG_LEVEL
if verbose:
log_level = logging.DEBUG
logging.basicConfig(level=log_level, format='%(message)s')
"""Setup logging."""
log_level = _DEFAULT_LOG_LEVEL
if verbose:
log_level = logging.DEBUG
logging.basicConfig(level=log_level, format='%(message)s')
def main():
options = _ParseArgs()
options = _ParseArgs()
# Build a configuration object. Override any preset configuration settings if
# a value of a setting was also given as a flag.
connection_config = _PRESETS_DICT[options.preset]
if options.receive_bw is not _DEFAULT_PRESET.receive_bw_kbps:
connection_config.receive_bw_kbps = options.receive_bw
if options.send_bw is not _DEFAULT_PRESET.send_bw_kbps:
connection_config.send_bw_kbps = options.send_bw
if options.delay is not _DEFAULT_PRESET.delay_ms:
connection_config.delay_ms = options.delay
if options.packet_loss is not _DEFAULT_PRESET.packet_loss_percent:
connection_config.packet_loss_percent = options.packet_loss
if options.queue is not _DEFAULT_PRESET.queue_slots:
connection_config.queue_slots = options.queue
emulator = network_emulator.NetworkEmulator(connection_config,
options.port_range)
try:
emulator.CheckPermissions()
except network_emulator.NetworkEmulatorError as e:
logging.error('Error: %s\n\nCause: %s', e.fail_msg, e.error)
return -1
# Build a configuration object. Override any preset configuration settings if
# a value of a setting was also given as a flag.
connection_config = _PRESETS_DICT[options.preset]
if options.receive_bw is not _DEFAULT_PRESET.receive_bw_kbps:
connection_config.receive_bw_kbps = options.receive_bw
if options.send_bw is not _DEFAULT_PRESET.send_bw_kbps:
connection_config.send_bw_kbps = options.send_bw
if options.delay is not _DEFAULT_PRESET.delay_ms:
connection_config.delay_ms = options.delay
if options.packet_loss is not _DEFAULT_PRESET.packet_loss_percent:
connection_config.packet_loss_percent = options.packet_loss
if options.queue is not _DEFAULT_PRESET.queue_slots:
connection_config.queue_slots = options.queue
emulator = network_emulator.NetworkEmulator(connection_config,
options.port_range)
try:
emulator.CheckPermissions()
except network_emulator.NetworkEmulatorError as e:
logging.error('Error: %s\n\nCause: %s', e.fail_msg, e.error)
return -1
if not options.target_ip:
external_ip = _GetExternalIp()
else:
external_ip = options.target_ip
if not options.target_ip:
external_ip = _GetExternalIp()
else:
external_ip = options.target_ip
logging.info('Constraining traffic to/from IP: %s', external_ip)
try:
emulator.Emulate(external_ip)
logging.info(
'Started network emulation with the following configuration:\n'
' Receive bandwidth: %s kbps (%s kB/s)\n'
' Send bandwidth : %s kbps (%s kB/s)\n'
' Delay : %s ms\n'
' Packet loss : %s %%\n'
' Queue slots : %s', connection_config.receive_bw_kbps,
connection_config.receive_bw_kbps / 8,
connection_config.send_bw_kbps, connection_config.send_bw_kbps / 8,
connection_config.delay_ms, connection_config.packet_loss_percent,
connection_config.queue_slots)
logging.info('Affected traffic: IP traffic on ports %s-%s',
options.port_range[0], options.port_range[1])
raw_input('Press Enter to abort Network Emulation...')
logging.info('Flushing all Dummynet rules...')
network_emulator.Cleanup()
logging.info('Completed Network Emulation.')
return 0
except network_emulator.NetworkEmulatorError as e:
logging.error('Error: %s\n\nCause: %s', e.fail_msg, e.error)
return -2
logging.info('Constraining traffic to/from IP: %s', external_ip)
try:
emulator.Emulate(external_ip)
logging.info('Started network emulation with the following configuration:\n'
' Receive bandwidth: %s kbps (%s kB/s)\n'
' Send bandwidth : %s kbps (%s kB/s)\n'
' Delay : %s ms\n'
' Packet loss : %s %%\n'
' Queue slots : %s',
connection_config.receive_bw_kbps,
connection_config.receive_bw_kbps/8,
connection_config.send_bw_kbps,
connection_config.send_bw_kbps/8,
connection_config.delay_ms,
connection_config.packet_loss_percent,
connection_config.queue_slots)
logging.info('Affected traffic: IP traffic on ports %s-%s',
options.port_range[0], options.port_range[1])
raw_input('Press Enter to abort Network Emulation...')
logging.info('Flushing all Dummynet rules...')
network_emulator.Cleanup()
logging.info('Completed Network Emulation.')
return 0
except network_emulator.NetworkEmulatorError as e:
logging.error('Error: %s\n\nCause: %s', e.fail_msg, e.error)
return -2
if __name__ == '__main__':
sys.exit(main())
sys.exit(main())

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Script for constraining traffic on the local machine."""
import ctypes
@ -17,7 +16,7 @@ import sys
class NetworkEmulatorError(BaseException):
"""Exception raised for errors in the network emulator.
"""Exception raised for errors in the network emulator.
Attributes:
fail_msg: User defined error message.
@ -27,81 +26,88 @@ class NetworkEmulatorError(BaseException):
stderr: Error output of running the command.
"""
def __init__(self, fail_msg, cmd=None, returncode=None, output=None,
error=None):
BaseException.__init__(self, fail_msg)
self.fail_msg = fail_msg
self.cmd = cmd
self.returncode = returncode
self.output = output
self.error = error
def __init__(self,
fail_msg,
cmd=None,
returncode=None,
output=None,
error=None):
BaseException.__init__(self, fail_msg)
self.fail_msg = fail_msg
self.cmd = cmd
self.returncode = returncode
self.output = output
self.error = error
class NetworkEmulator(object):
"""A network emulator that can constrain the network using Dummynet."""
"""A network emulator that can constrain the network using Dummynet."""
def __init__(self, connection_config, port_range):
"""Constructor.
def __init__(self, connection_config, port_range):
"""Constructor.
Args:
connection_config: A config.ConnectionConfig object containing the
characteristics for the connection to be emulation.
port_range: Tuple containing two integers defining the port range.
"""
self._pipe_counter = 0
self._rule_counter = 0
self._port_range = port_range
self._connection_config = connection_config
self._pipe_counter = 0
self._rule_counter = 0
self._port_range = port_range
self._connection_config = connection_config
def Emulate(self, target_ip):
"""Starts a network emulation by setting up Dummynet rules.
def Emulate(self, target_ip):
"""Starts a network emulation by setting up Dummynet rules.
Args:
target_ip: The IP address of the interface that shall be that have the
network constraints applied to it.
"""
receive_pipe_id = self._CreateDummynetPipe(
self._connection_config.receive_bw_kbps,
self._connection_config.delay_ms,
self._connection_config.packet_loss_percent,
self._connection_config.queue_slots)
logging.debug('Created receive pipe: %s', receive_pipe_id)
send_pipe_id = self._CreateDummynetPipe(
self._connection_config.send_bw_kbps,
self._connection_config.delay_ms,
self._connection_config.packet_loss_percent,
self._connection_config.queue_slots)
logging.debug('Created send pipe: %s', send_pipe_id)
receive_pipe_id = self._CreateDummynetPipe(
self._connection_config.receive_bw_kbps,
self._connection_config.delay_ms,
self._connection_config.packet_loss_percent,
self._connection_config.queue_slots)
logging.debug('Created receive pipe: %s', receive_pipe_id)
send_pipe_id = self._CreateDummynetPipe(
self._connection_config.send_bw_kbps,
self._connection_config.delay_ms,
self._connection_config.packet_loss_percent,
self._connection_config.queue_slots)
logging.debug('Created send pipe: %s', send_pipe_id)
# Adding the rules will start the emulation.
incoming_rule_id = self._CreateDummynetRule(receive_pipe_id, 'any',
target_ip, self._port_range)
logging.debug('Created incoming rule: %s', incoming_rule_id)
outgoing_rule_id = self._CreateDummynetRule(send_pipe_id, target_ip,
'any', self._port_range)
logging.debug('Created outgoing rule: %s', outgoing_rule_id)
# Adding the rules will start the emulation.
incoming_rule_id = self._CreateDummynetRule(receive_pipe_id, 'any',
target_ip,
self._port_range)
logging.debug('Created incoming rule: %s', incoming_rule_id)
outgoing_rule_id = self._CreateDummynetRule(send_pipe_id, target_ip,
'any', self._port_range)
logging.debug('Created outgoing rule: %s', outgoing_rule_id)
@staticmethod
def CheckPermissions():
"""Checks if permissions are available to run Dummynet commands.
@staticmethod
def CheckPermissions():
"""Checks if permissions are available to run Dummynet commands.
Raises:
NetworkEmulatorError: If permissions to run Dummynet commands are not
available.
"""
try:
if os.getuid() != 0:
raise NetworkEmulatorError('You must run this script with sudo.')
except AttributeError:
try:
if os.getuid() != 0:
raise NetworkEmulatorError(
'You must run this script with sudo.')
except AttributeError:
# AttributeError will be raised on Windows.
if ctypes.windll.shell32.IsUserAnAdmin() == 0:
raise NetworkEmulatorError('You must run this script with administrator'
' privileges.')
# AttributeError will be raised on Windows.
if ctypes.windll.shell32.IsUserAnAdmin() == 0:
raise NetworkEmulatorError(
'You must run this script with administrator'
' privileges.')
def _CreateDummynetRule(self, pipe_id, from_address, to_address,
port_range):
"""Creates a network emulation rule and returns its ID.
def _CreateDummynetRule(self, pipe_id, from_address, to_address,
port_range):
"""Creates a network emulation rule and returns its ID.
Args:
pipe_id: integer ID of the pipe.
@ -115,18 +121,22 @@ class NetworkEmulator(object):
The ID of the rule, starting at 100. The rule ID increments with 100 for
each rule being added.
"""
self._rule_counter += 100
add_part = ['add', self._rule_counter, 'pipe', pipe_id,
'ip', 'from', from_address, 'to', to_address]
_RunIpfwCommand(add_part + ['src-port', '%s-%s' % port_range],
'Failed to add Dummynet src-port rule.')
_RunIpfwCommand(add_part + ['dst-port', '%s-%s' % port_range],
'Failed to add Dummynet dst-port rule.')
return self._rule_counter
self._rule_counter += 100
add_part = [
'add', self._rule_counter, 'pipe', pipe_id, 'ip', 'from',
from_address, 'to', to_address
]
_RunIpfwCommand(add_part +
['src-port', '%s-%s' % port_range],
'Failed to add Dummynet src-port rule.')
_RunIpfwCommand(add_part +
['dst-port', '%s-%s' % port_range],
'Failed to add Dummynet dst-port rule.')
return self._rule_counter
def _CreateDummynetPipe(self, bandwidth_kbps, delay_ms, packet_loss_percent,
queue_slots):
"""Creates a Dummynet pipe and return its ID.
def _CreateDummynetPipe(self, bandwidth_kbps, delay_ms,
packet_loss_percent, queue_slots):
"""Creates a Dummynet pipe and return its ID.
Args:
bandwidth_kbps: Bandwidth.
@ -136,32 +146,34 @@ class NetworkEmulator(object):
Returns:
The ID of the pipe, starting at 1.
"""
self._pipe_counter += 1
cmd = ['pipe', self._pipe_counter, 'config',
'bw', str(bandwidth_kbps/8) + 'KByte/s',
'delay', '%sms' % delay_ms,
'plr', (packet_loss_percent/100.0),
'queue', queue_slots]
error_message = 'Failed to create Dummynet pipe. '
if sys.platform.startswith('linux'):
error_message += ('Make sure you have loaded the ipfw_mod.ko module to '
'your kernel (sudo insmod /path/to/ipfw_mod.ko).')
_RunIpfwCommand(cmd, error_message)
return self._pipe_counter
self._pipe_counter += 1
cmd = [
'pipe', self._pipe_counter, 'config', 'bw',
str(bandwidth_kbps / 8) + 'KByte/s', 'delay',
'%sms' % delay_ms, 'plr', (packet_loss_percent / 100.0), 'queue',
queue_slots
]
error_message = 'Failed to create Dummynet pipe. '
if sys.platform.startswith('linux'):
error_message += (
'Make sure you have loaded the ipfw_mod.ko module to '
'your kernel (sudo insmod /path/to/ipfw_mod.ko).')
_RunIpfwCommand(cmd, error_message)
return self._pipe_counter
def Cleanup():
"""Stops the network emulation by flushing all Dummynet rules.
"""Stops the network emulation by flushing all Dummynet rules.
Notice that this will flush any rules that may have been created previously
before starting the emulation.
"""
_RunIpfwCommand(['-f', 'flush'],
'Failed to flush Dummynet rules!')
_RunIpfwCommand(['-f', 'pipe', 'flush'],
'Failed to flush Dummynet pipes!')
_RunIpfwCommand(['-f', 'flush'], 'Failed to flush Dummynet rules!')
_RunIpfwCommand(['-f', 'pipe', 'flush'], 'Failed to flush Dummynet pipes!')
def _RunIpfwCommand(command, fail_msg=None):
"""Executes a command and prefixes the appropriate command for
"""Executes a command and prefixes the appropriate command for
Windows or Linux/UNIX.
Args:
@ -172,18 +184,19 @@ def _RunIpfwCommand(command, fail_msg=None):
NetworkEmulatorError: If command fails a message is set by the fail_msg
parameter.
"""
if sys.platform == 'win32':
ipfw_command = ['ipfw.exe']
else:
ipfw_command = ['sudo', '-n', 'ipfw']
if sys.platform == 'win32':
ipfw_command = ['ipfw.exe']
else:
ipfw_command = ['sudo', '-n', 'ipfw']
cmd_list = ipfw_command[:] + [str(x) for x in command]
cmd_string = ' '.join(cmd_list)
logging.debug('Running command: %s', cmd_string)
process = subprocess.Popen(cmd_list, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
output, error = process.communicate()
if process.returncode != 0:
raise NetworkEmulatorError(fail_msg, cmd_string, process.returncode, output,
error)
return output.strip()
cmd_list = ipfw_command[:] + [str(x) for x in command]
cmd_string = ' '.join(cmd_list)
logging.debug('Running command: %s', cmd_string)
process = subprocess.Popen(cmd_list,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
output, error = process.communicate()
if process.returncode != 0:
raise NetworkEmulatorError(fail_msg, cmd_string, process.returncode,
output, error)
return output.strip()

View File

@ -7,7 +7,6 @@
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
import httplib2
import json
import subprocess
@ -20,19 +19,19 @@ from tracing.value.diagnostics import reserved_infos
def _GenerateOauthToken():
args = ['luci-auth', 'token']
p = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if p.wait() == 0:
output = p.stdout.read()
return output.strip()
else:
raise RuntimeError(
'Error generating authentication token.\nStdout: %s\nStderr:%s' %
(p.stdout.read(), p.stderr.read()))
args = ['luci-auth', 'token']
p = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if p.wait() == 0:
output = p.stdout.read()
return output.strip()
else:
raise RuntimeError(
'Error generating authentication token.\nStdout: %s\nStderr:%s' %
(p.stdout.read(), p.stderr.read()))
def _SendHistogramSet(url, histograms, oauth_token):
"""Make a HTTP POST with the given JSON to the Performance Dashboard.
"""Make a HTTP POST with the given JSON to the Performance Dashboard.
Args:
url: URL of Performance Dashboard instance, e.g.
@ -40,83 +39,87 @@ def _SendHistogramSet(url, histograms, oauth_token):
histograms: a histogram set object that contains the data to be sent.
oauth_token: An oauth token to use for authorization.
"""
headers = {'Authorization': 'Bearer %s' % oauth_token}
headers = {'Authorization': 'Bearer %s' % oauth_token}
serialized = json.dumps(_ApplyHacks(histograms.AsDicts()), indent=4)
serialized = json.dumps(_ApplyHacks(histograms.AsDicts()), indent=4)
if url.startswith('http://localhost'):
# The catapult server turns off compression in developer mode.
data = serialized
else:
data = zlib.compress(serialized)
if url.startswith('http://localhost'):
# The catapult server turns off compression in developer mode.
data = serialized
else:
data = zlib.compress(serialized)
print 'Sending %d bytes to %s.' % (len(data), url + '/add_histograms')
print 'Sending %d bytes to %s.' % (len(data), url + '/add_histograms')
http = httplib2.Http()
response, content = http.request(url + '/add_histograms', method='POST',
body=data, headers=headers)
return response, content
http = httplib2.Http()
response, content = http.request(url + '/add_histograms',
method='POST',
body=data,
headers=headers)
return response, content
# TODO(https://crbug.com/1029452): HACKHACK
# Remove once we have doubles in the proto and handle -infinity correctly.
def _ApplyHacks(dicts):
for d in dicts:
if 'running' in d:
def _NoInf(value):
if value == float('inf'):
return histogram.JS_MAX_VALUE
if value == float('-inf'):
return -histogram.JS_MAX_VALUE
return value
d['running'] = [_NoInf(value) for value in d['running']]
for d in dicts:
if 'running' in d:
return dicts
def _NoInf(value):
if value == float('inf'):
return histogram.JS_MAX_VALUE
if value == float('-inf'):
return -histogram.JS_MAX_VALUE
return value
d['running'] = [_NoInf(value) for value in d['running']]
return dicts
def _LoadHistogramSetFromProto(options):
hs = histogram_set.HistogramSet()
with options.input_results_file as f:
hs.ImportProto(f.read())
hs = histogram_set.HistogramSet()
with options.input_results_file as f:
hs.ImportProto(f.read())
return hs
return hs
def _AddBuildInfo(histograms, options):
common_diagnostics = {
reserved_infos.MASTERS: options.perf_dashboard_machine_group,
reserved_infos.BOTS: options.bot,
reserved_infos.POINT_ID: options.commit_position,
reserved_infos.BENCHMARKS: options.test_suite,
reserved_infos.WEBRTC_REVISIONS: str(options.webrtc_git_hash),
reserved_infos.BUILD_URLS: options.build_page_url,
}
common_diagnostics = {
reserved_infos.MASTERS: options.perf_dashboard_machine_group,
reserved_infos.BOTS: options.bot,
reserved_infos.POINT_ID: options.commit_position,
reserved_infos.BENCHMARKS: options.test_suite,
reserved_infos.WEBRTC_REVISIONS: str(options.webrtc_git_hash),
reserved_infos.BUILD_URLS: options.build_page_url,
}
for k, v in common_diagnostics.items():
histograms.AddSharedDiagnosticToAllHistograms(
k.name, generic_set.GenericSet([v]))
for k, v in common_diagnostics.items():
histograms.AddSharedDiagnosticToAllHistograms(
k.name, generic_set.GenericSet([v]))
def _DumpOutput(histograms, output_file):
with output_file:
json.dump(_ApplyHacks(histograms.AsDicts()), output_file, indent=4)
with output_file:
json.dump(_ApplyHacks(histograms.AsDicts()), output_file, indent=4)
def UploadToDashboard(options):
histograms = _LoadHistogramSetFromProto(options)
_AddBuildInfo(histograms, options)
histograms = _LoadHistogramSetFromProto(options)
_AddBuildInfo(histograms, options)
if options.output_json_file:
_DumpOutput(histograms, options.output_json_file)
if options.output_json_file:
_DumpOutput(histograms, options.output_json_file)
oauth_token = _GenerateOauthToken()
response, content = _SendHistogramSet(
options.dashboard_url, histograms, oauth_token)
oauth_token = _GenerateOauthToken()
response, content = _SendHistogramSet(options.dashboard_url, histograms,
oauth_token)
if response.status == 200:
print 'Received 200 from dashboard.'
return 0
else:
print('Upload failed with %d: %s\n\n%s' % (response.status, response.reason,
content))
return 1
if response.status == 200:
print 'Received 200 from dashboard.'
return 0
else:
print('Upload failed with %d: %s\n\n%s' %
(response.status, response.reason, content))
return 1

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Adds build info to perf results and uploads them.
The tests don't know which bot executed the tests or at what revision, so we
@ -24,78 +23,93 @@ import sys
def _CreateParser():
parser = argparse.ArgumentParser()
parser.add_argument('--perf-dashboard-machine-group', required=True,
help='The "master" the bots are grouped under. This '
'string is the group in the the perf dashboard path '
'group/bot/perf_id/metric/subtest.')
parser.add_argument('--bot', required=True,
help='The bot running the test (e.g. '
'webrtc-win-large-tests).')
parser.add_argument('--test-suite', required=True,
help='The key for the test in the dashboard (i.e. what '
'you select in the top-level test suite selector in the '
'dashboard')
parser.add_argument('--webrtc-git-hash', required=True,
help='webrtc.googlesource.com commit hash.')
parser.add_argument('--commit-position', type=int, required=True,
help='Commit pos corresponding to the git hash.')
parser.add_argument('--build-page-url', required=True,
help='URL to the build page for this build.')
parser.add_argument('--dashboard-url', required=True,
help='Which dashboard to use.')
parser.add_argument('--input-results-file', type=argparse.FileType(),
required=True,
help='A JSON file with output from WebRTC tests.')
parser.add_argument('--output-json-file', type=argparse.FileType('w'),
help='Where to write the output (for debugging).')
parser.add_argument('--outdir', required=True,
help='Path to the local out/ dir (usually out/Default)')
return parser
parser = argparse.ArgumentParser()
parser.add_argument('--perf-dashboard-machine-group',
required=True,
help='The "master" the bots are grouped under. This '
'string is the group in the the perf dashboard path '
'group/bot/perf_id/metric/subtest.')
parser.add_argument('--bot',
required=True,
help='The bot running the test (e.g. '
'webrtc-win-large-tests).')
parser.add_argument(
'--test-suite',
required=True,
help='The key for the test in the dashboard (i.e. what '
'you select in the top-level test suite selector in the '
'dashboard')
parser.add_argument('--webrtc-git-hash',
required=True,
help='webrtc.googlesource.com commit hash.')
parser.add_argument('--commit-position',
type=int,
required=True,
help='Commit pos corresponding to the git hash.')
parser.add_argument('--build-page-url',
required=True,
help='URL to the build page for this build.')
parser.add_argument('--dashboard-url',
required=True,
help='Which dashboard to use.')
parser.add_argument('--input-results-file',
type=argparse.FileType(),
required=True,
help='A JSON file with output from WebRTC tests.')
parser.add_argument('--output-json-file',
type=argparse.FileType('w'),
help='Where to write the output (for debugging).')
parser.add_argument(
'--outdir',
required=True,
help='Path to the local out/ dir (usually out/Default)')
return parser
def _ConfigurePythonPath(options):
# We just yank the python scripts we require into the PYTHONPATH. You could
# also imagine a solution where we use for instance protobuf:py_proto_runtime
# to copy catapult and protobuf code to out/. This is the convention in
# Chromium and WebRTC python scripts. We do need to build histogram_pb2
# however, so that's why we add out/ to sys.path below.
#
# It would be better if there was an equivalent to py_binary in GN, but
# there's not.
script_dir = os.path.dirname(os.path.realpath(__file__))
checkout_root = os.path.abspath(
os.path.join(script_dir, os.pardir, os.pardir))
# We just yank the python scripts we require into the PYTHONPATH. You could
# also imagine a solution where we use for instance protobuf:py_proto_runtime
# to copy catapult and protobuf code to out/. This is the convention in
# Chromium and WebRTC python scripts. We do need to build histogram_pb2
# however, so that's why we add out/ to sys.path below.
#
# It would be better if there was an equivalent to py_binary in GN, but
# there's not.
script_dir = os.path.dirname(os.path.realpath(__file__))
checkout_root = os.path.abspath(
os.path.join(script_dir, os.pardir, os.pardir))
sys.path.insert(0, os.path.join(checkout_root, 'third_party', 'catapult',
'tracing'))
sys.path.insert(0, os.path.join(checkout_root, 'third_party', 'protobuf',
'python'))
sys.path.insert(
0, os.path.join(checkout_root, 'third_party', 'catapult', 'tracing'))
sys.path.insert(
0, os.path.join(checkout_root, 'third_party', 'protobuf', 'python'))
# The webrtc_dashboard_upload gn rule will build the protobuf stub for python,
# so put it in the path for this script before we attempt to import it.
histogram_proto_path = os.path.join(
options.outdir, 'pyproto', 'tracing', 'tracing', 'proto')
sys.path.insert(0, histogram_proto_path)
# The webrtc_dashboard_upload gn rule will build the protobuf stub for python,
# so put it in the path for this script before we attempt to import it.
histogram_proto_path = os.path.join(options.outdir, 'pyproto', 'tracing',
'tracing', 'proto')
sys.path.insert(0, histogram_proto_path)
# Fail early in case the proto hasn't been built.
from tracing.proto import histogram_proto
if not histogram_proto.HAS_PROTO:
raise ImportError('Could not find histogram_pb2. You need to build the '
'webrtc_dashboard_upload target before invoking this '
'script. Expected to find '
'histogram_pb2.py in %s.' % histogram_proto_path)
# Fail early in case the proto hasn't been built.
from tracing.proto import histogram_proto
if not histogram_proto.HAS_PROTO:
raise ImportError(
'Could not find histogram_pb2. You need to build the '
'webrtc_dashboard_upload target before invoking this '
'script. Expected to find '
'histogram_pb2.py in %s.' % histogram_proto_path)
def main(args):
parser = _CreateParser()
options = parser.parse_args(args)
parser = _CreateParser()
options = parser.parse_args(args)
_ConfigurePythonPath(options)
_ConfigurePythonPath(options)
import catapult_uploader
import catapult_uploader
return catapult_uploader.UploadToDashboard(options)
return catapult_uploader.UploadToDashboard(options)
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))
sys.exit(main(sys.argv[1:]))

View File

@ -5,7 +5,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""This script helps to invoke gn and ninja
which lie in depot_tools repository."""
@ -19,11 +18,11 @@ import tempfile
def FindSrcDirPath():
"""Returns the abs path to the src/ dir of the project."""
src_dir = os.path.dirname(os.path.abspath(__file__))
while os.path.basename(src_dir) != 'src':
src_dir = os.path.normpath(os.path.join(src_dir, os.pardir))
return src_dir
"""Returns the abs path to the src/ dir of the project."""
src_dir = os.path.dirname(os.path.abspath(__file__))
while os.path.basename(src_dir) != 'src':
src_dir = os.path.normpath(os.path.join(src_dir, os.pardir))
return src_dir
SRC_DIR = FindSrcDirPath()
@ -32,16 +31,16 @@ import find_depot_tools
def RunGnCommand(args, root_dir=None):
"""Runs `gn` with provided args and return error if any."""
try:
command = [
sys.executable,
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py')
] + args
subprocess.check_output(command, cwd=root_dir)
except subprocess.CalledProcessError as err:
return err.output
return None
"""Runs `gn` with provided args and return error if any."""
try:
command = [
sys.executable,
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'gn.py')
] + args
subprocess.check_output(command, cwd=root_dir)
except subprocess.CalledProcessError as err:
return err.output
return None
# GN_ERROR_RE matches the summary of an error output by `gn check`.
@ -51,49 +50,49 @@ GN_ERROR_RE = re.compile(r'^ERROR .+(?:\n.*[^_\n].*$)+', re.MULTILINE)
def RunGnCheck(root_dir=None):
"""Runs `gn gen --check` with default args to detect mismatches between
"""Runs `gn gen --check` with default args to detect mismatches between
#includes and dependencies in the BUILD.gn files, as well as general build
errors.
Returns a list of error summary strings.
"""
out_dir = tempfile.mkdtemp('gn')
try:
error = RunGnCommand(['gen', '--check', out_dir], root_dir)
finally:
shutil.rmtree(out_dir, ignore_errors=True)
return GN_ERROR_RE.findall(error) if error else []
out_dir = tempfile.mkdtemp('gn')
try:
error = RunGnCommand(['gen', '--check', out_dir], root_dir)
finally:
shutil.rmtree(out_dir, ignore_errors=True)
return GN_ERROR_RE.findall(error) if error else []
def RunNinjaCommand(args, root_dir=None):
"""Runs ninja quietly. Any failure (e.g. clang not found) is
"""Runs ninja quietly. Any failure (e.g. clang not found) is
silently discarded, since this is unlikely an error in submitted CL."""
command = [
os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'ninja')
] + args
p = subprocess.Popen(command, cwd=root_dir,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
out, _ = p.communicate()
return out
command = [os.path.join(find_depot_tools.DEPOT_TOOLS_PATH, 'ninja')] + args
p = subprocess.Popen(command,
cwd=root_dir,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
out, _ = p.communicate()
return out
def GetClangTidyPath():
"""POC/WIP! Use the one we have, even it doesn't match clang's version."""
tidy = ('third_party/android_ndk/toolchains/'
'llvm/prebuilt/linux-x86_64/bin/clang-tidy')
return os.path.join(SRC_DIR, tidy)
"""POC/WIP! Use the one we have, even it doesn't match clang's version."""
tidy = ('third_party/android_ndk/toolchains/'
'llvm/prebuilt/linux-x86_64/bin/clang-tidy')
return os.path.join(SRC_DIR, tidy)
def GetCompilationDb(root_dir=None):
"""Run ninja compdb tool to get proper flags, defines and include paths."""
# The compdb tool expect a rule.
commands = json.loads(RunNinjaCommand(['-t', 'compdb', 'cxx'], root_dir))
# Turns 'file' field into a key.
return {v['file']: v for v in commands}
"""Run ninja compdb tool to get proper flags, defines and include paths."""
# The compdb tool expect a rule.
commands = json.loads(RunNinjaCommand(['-t', 'compdb', 'cxx'], root_dir))
# Turns 'file' field into a key.
return {v['file']: v for v in commands}
def GetCompilationCommand(filepath, gn_args, work_dir):
"""Get the whole command used to compile one cc file.
"""Get the whole command used to compile one cc file.
Typically, clang++ with flags, defines and include paths.
Args:
@ -104,31 +103,30 @@ def GetCompilationCommand(filepath, gn_args, work_dir):
Returns:
Command as a list, ready to be consumed by subprocess.Popen.
"""
gn_errors = RunGnCommand(['gen'] + gn_args + [work_dir])
if gn_errors:
raise(RuntimeError(
'FYI, cannot complete check due to gn error:\n%s\n'
'Please open a bug.' % gn_errors))
gn_errors = RunGnCommand(['gen'] + gn_args + [work_dir])
if gn_errors:
raise (RuntimeError('FYI, cannot complete check due to gn error:\n%s\n'
'Please open a bug.' % gn_errors))
# Needed for single file compilation.
commands = GetCompilationDb(work_dir)
# Needed for single file compilation.
commands = GetCompilationDb(work_dir)
# Path as referenced by ninja.
rel_path = os.path.relpath(os.path.abspath(filepath), work_dir)
# Path as referenced by ninja.
rel_path = os.path.relpath(os.path.abspath(filepath), work_dir)
# Gather defines, include path and flags (such as -std=c++11).
try:
compilation_entry = commands[rel_path]
except KeyError:
raise ValueError('%s: Not found in compilation database.\n'
'Please check the path.' % filepath)
command = compilation_entry['command'].split()
# Gather defines, include path and flags (such as -std=c++11).
try:
compilation_entry = commands[rel_path]
except KeyError:
raise ValueError('%s: Not found in compilation database.\n'
'Please check the path.' % filepath)
command = compilation_entry['command'].split()
# Remove troublesome flags. May trigger an error otherwise.
if '-MMD' in command:
command.remove('-MMD')
if '-MF' in command:
index = command.index('-MF')
del command[index:index+2] # Remove filename as well.
# Remove troublesome flags. May trigger an error otherwise.
if '-MMD' in command:
command.remove('-MMD')
if '-MF' in command:
index = command.index('-MF')
del command[index:index + 2] # Remove filename as well.
return command
return command

View File

@ -14,19 +14,20 @@ import unittest
#pylint: disable=relative-import
import build_helpers
TESTDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'testdata')
class GnCheckTest(unittest.TestCase):
def testCircularDependencyError(self):
test_dir = os.path.join(TESTDATA_DIR, 'circular_dependency')
expected_errors = ['ERROR Dependency cycle:\n'
' //:bar ->\n //:foo ->\n //:bar']
self.assertListEqual(expected_errors,
build_helpers.RunGnCheck(test_dir))
def testCircularDependencyError(self):
test_dir = os.path.join(TESTDATA_DIR, 'circular_dependency')
expected_errors = [
'ERROR Dependency cycle:\n'
' //:bar ->\n //:foo ->\n //:bar'
]
self.assertListEqual(expected_errors,
build_helpers.RunGnCheck(test_dir))
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@ -11,12 +11,11 @@ import os
import re
import string
# TARGET_RE matches a GN target, and extracts the target name and the contents.
TARGET_RE = re.compile(r'(?P<indent>\s*)\w+\("(?P<target_name>\w+)"\) {'
r'(?P<target_contents>.*?)'
r'(?P=indent)}',
re.MULTILINE | re.DOTALL)
TARGET_RE = re.compile(
r'(?P<indent>\s*)\w+\("(?P<target_name>\w+)"\) {'
r'(?P<target_contents>.*?)'
r'(?P=indent)}', re.MULTILINE | re.DOTALL)
# SOURCES_RE matches a block of sources inside a GN target.
SOURCES_RE = re.compile(
@ -27,27 +26,27 @@ SOURCE_FILE_RE = re.compile(r'.*\"(?P<source_file>.*)\"')
class NoBuildGnFoundError(Exception):
pass
pass
class WrongFileTypeError(Exception):
pass
pass
def _ReadFile(file_path):
"""Returns the content of file_path in a string.
"""Returns the content of file_path in a string.
Args:
file_path: the path of the file to read.
Returns:
A string with the content of the file.
"""
with open(file_path) as f:
return f.read()
with open(file_path) as f:
return f.read()
def GetBuildGnPathFromFilePath(file_path, file_exists_check, root_dir_path):
"""Returns the BUILD.gn file responsible for file_path.
"""Returns the BUILD.gn file responsible for file_path.
Args:
file_path: the absolute path to the .h file to check.
@ -59,23 +58,23 @@ def GetBuildGnPathFromFilePath(file_path, file_exists_check, root_dir_path):
A string with the absolute path to the BUILD.gn file responsible to include
file_path in a target.
"""
if not file_path.endswith('.h'):
raise WrongFileTypeError(
'File {} is not an header file (.h)'.format(file_path))
candidate_dir = os.path.dirname(file_path)
while candidate_dir.startswith(root_dir_path):
candidate_build_gn_path = os.path.join(candidate_dir, 'BUILD.gn')
if file_exists_check(candidate_build_gn_path):
return candidate_build_gn_path
else:
candidate_dir = os.path.abspath(os.path.join(candidate_dir,
os.pardir))
raise NoBuildGnFoundError(
'No BUILD.gn file found for file: `{}`'.format(file_path))
if not file_path.endswith('.h'):
raise WrongFileTypeError(
'File {} is not an header file (.h)'.format(file_path))
candidate_dir = os.path.dirname(file_path)
while candidate_dir.startswith(root_dir_path):
candidate_build_gn_path = os.path.join(candidate_dir, 'BUILD.gn')
if file_exists_check(candidate_build_gn_path):
return candidate_build_gn_path
else:
candidate_dir = os.path.abspath(
os.path.join(candidate_dir, os.pardir))
raise NoBuildGnFoundError(
'No BUILD.gn file found for file: `{}`'.format(file_path))
def IsHeaderInBuildGn(header_path, build_gn_path):
"""Returns True if the header is listed in the BUILD.gn file.
"""Returns True if the header is listed in the BUILD.gn file.
Args:
header_path: the absolute path to the header to check.
@ -86,15 +85,15 @@ def IsHeaderInBuildGn(header_path, build_gn_path):
at least one GN target in the BUILD.gn file specified by
the argument build_gn_path.
"""
target_abs_path = os.path.dirname(build_gn_path)
build_gn_content = _ReadFile(build_gn_path)
headers_in_build_gn = GetHeadersInBuildGnFileSources(build_gn_content,
target_abs_path)
return header_path in headers_in_build_gn
target_abs_path = os.path.dirname(build_gn_path)
build_gn_content = _ReadFile(build_gn_path)
headers_in_build_gn = GetHeadersInBuildGnFileSources(
build_gn_content, target_abs_path)
return header_path in headers_in_build_gn
def GetHeadersInBuildGnFileSources(file_content, target_abs_path):
"""Returns a set with all the .h files in the file_content.
"""Returns a set with all the .h files in the file_content.
Args:
file_content: a string with the content of the BUILD.gn file.
@ -105,15 +104,15 @@ def GetHeadersInBuildGnFileSources(file_content, target_abs_path):
A set with all the headers (.h file) in the file_content.
The set contains absolute paths.
"""
headers_in_sources = set([])
for target_match in TARGET_RE.finditer(file_content):
target_contents = target_match.group('target_contents')
for sources_match in SOURCES_RE.finditer(target_contents):
sources = sources_match.group('sources')
for source_file_match in SOURCE_FILE_RE.finditer(sources):
source_file = source_file_match.group('source_file')
if source_file.endswith('.h'):
source_file_tokens = string.split(source_file, '/')
headers_in_sources.add(os.path.join(target_abs_path,
*source_file_tokens))
return headers_in_sources
headers_in_sources = set([])
for target_match in TARGET_RE.finditer(file_content):
target_contents = target_match.group('target_contents')
for sources_match in SOURCES_RE.finditer(target_contents):
sources = sources_match.group('sources')
for source_file_match in SOURCE_FILE_RE.finditer(sources):
source_file = source_file_match.group('source_file')
if source_file.endswith('.h'):
source_file_tokens = string.split(source_file, '/')
headers_in_sources.add(
os.path.join(target_abs_path, *source_file_tokens))
return headers_in_sources

View File

@ -16,73 +16,67 @@ import check_orphan_headers
def _GetRootBasedOnPlatform():
if sys.platform.startswith('win'):
return 'C:\\'
else:
return '/'
if sys.platform.startswith('win'):
return 'C:\\'
else:
return '/'
def _GetPath(*path_chunks):
return os.path.join(_GetRootBasedOnPlatform(),
*path_chunks)
return os.path.join(_GetRootBasedOnPlatform(), *path_chunks)
class GetBuildGnPathFromFilePathTest(unittest.TestCase):
def testGetBuildGnFromSameDirectory(self):
file_path = _GetPath('home', 'projects', 'webrtc', 'base', 'foo.h')
expected_build_path = _GetPath('home', 'projects', 'webrtc', 'base',
'BUILD.gn')
file_exists = lambda p: p == _GetPath('home', 'projects', 'webrtc',
'base', 'BUILD.gn')
src_dir_path = _GetPath('home', 'projects', 'webrtc')
self.assertEqual(
expected_build_path,
check_orphan_headers.GetBuildGnPathFromFilePath(
file_path, file_exists, src_dir_path))
def testGetBuildGnFromSameDirectory(self):
file_path = _GetPath('home', 'projects', 'webrtc', 'base', 'foo.h')
expected_build_path = _GetPath('home', 'projects', 'webrtc', 'base',
'BUILD.gn')
file_exists = lambda p: p == _GetPath('home', 'projects', 'webrtc',
'base', 'BUILD.gn')
src_dir_path = _GetPath('home', 'projects', 'webrtc')
self.assertEqual(
expected_build_path,
check_orphan_headers.GetBuildGnPathFromFilePath(file_path,
file_exists,
src_dir_path))
def testGetBuildPathFromParentDirectory(self):
file_path = _GetPath('home', 'projects', 'webrtc', 'base', 'foo.h')
expected_build_path = _GetPath('home', 'projects', 'webrtc',
'BUILD.gn')
file_exists = lambda p: p == _GetPath('home', 'projects', 'webrtc',
'BUILD.gn')
src_dir_path = _GetPath('home', 'projects', 'webrtc')
self.assertEqual(
expected_build_path,
check_orphan_headers.GetBuildGnPathFromFilePath(
file_path, file_exists, src_dir_path))
def testGetBuildPathFromParentDirectory(self):
file_path = _GetPath('home', 'projects', 'webrtc', 'base', 'foo.h')
expected_build_path = _GetPath('home', 'projects', 'webrtc',
'BUILD.gn')
file_exists = lambda p: p == _GetPath('home', 'projects', 'webrtc',
'BUILD.gn')
src_dir_path = _GetPath('home', 'projects', 'webrtc')
self.assertEqual(
expected_build_path,
check_orphan_headers.GetBuildGnPathFromFilePath(file_path,
file_exists,
src_dir_path))
def testExceptionIfNoBuildGnFilesAreFound(self):
with self.assertRaises(check_orphan_headers.NoBuildGnFoundError):
file_path = _GetPath('home', 'projects', 'webrtc', 'base', 'foo.h')
file_exists = lambda p: False
src_dir_path = _GetPath('home', 'projects', 'webrtc')
check_orphan_headers.GetBuildGnPathFromFilePath(
file_path, file_exists, src_dir_path)
def testExceptionIfNoBuildGnFilesAreFound(self):
with self.assertRaises(check_orphan_headers.NoBuildGnFoundError):
file_path = _GetPath('home', 'projects', 'webrtc', 'base', 'foo.h')
file_exists = lambda p: False
src_dir_path = _GetPath('home', 'projects', 'webrtc')
check_orphan_headers.GetBuildGnPathFromFilePath(file_path,
file_exists,
src_dir_path)
def testExceptionIfFilePathIsNotAnHeader(self):
with self.assertRaises(check_orphan_headers.WrongFileTypeError):
file_path = _GetPath('home', 'projects', 'webrtc', 'base', 'foo.cc')
file_exists = lambda p: False
src_dir_path = _GetPath('home', 'projects', 'webrtc')
check_orphan_headers.GetBuildGnPathFromFilePath(file_path,
file_exists,
src_dir_path)
def testExceptionIfFilePathIsNotAnHeader(self):
with self.assertRaises(check_orphan_headers.WrongFileTypeError):
file_path = _GetPath('home', 'projects', 'webrtc', 'base',
'foo.cc')
file_exists = lambda p: False
src_dir_path = _GetPath('home', 'projects', 'webrtc')
check_orphan_headers.GetBuildGnPathFromFilePath(
file_path, file_exists, src_dir_path)
class GetHeadersInBuildGnFileSourcesTest(unittest.TestCase):
def testEmptyFileReturnsEmptySet(self):
self.assertEqual(
set([]),
check_orphan_headers.GetHeadersInBuildGnFileSources('', '/a/b'))
def testEmptyFileReturnsEmptySet(self):
self.assertEqual(
set([]),
check_orphan_headers.GetHeadersInBuildGnFileSources('', '/a/b'))
def testReturnsSetOfHeadersFromFileContent(self):
file_content = """
def testReturnsSetOfHeadersFromFileContent(self):
file_content = """
# Some comments
if (is_android) {
import("//a/b/c.gni")
@ -107,17 +101,17 @@ class GetHeadersInBuildGnFileSourcesTest(unittest.TestCase):
sources = ["baz/foo.h"]
}
"""
target_abs_path = _GetPath('a', 'b')
self.assertEqual(
set([
_GetPath('a', 'b', 'foo.h'),
_GetPath('a', 'b', 'bar.h'),
_GetPath('a', 'b', 'public_foo.h'),
_GetPath('a', 'b', 'baz', 'foo.h'),
]),
check_orphan_headers.GetHeadersInBuildGnFileSources(file_content,
target_abs_path))
target_abs_path = _GetPath('a', 'b')
self.assertEqual(
set([
_GetPath('a', 'b', 'foo.h'),
_GetPath('a', 'b', 'bar.h'),
_GetPath('a', 'b', 'public_foo.h'),
_GetPath('a', 'b', 'baz', 'foo.h'),
]),
check_orphan_headers.GetHeadersInBuildGnFileSources(
file_content, target_abs_path))
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@ -14,12 +14,11 @@ import os
import re
import sys
# TARGET_RE matches a GN target, and extracts the target name and the contents.
TARGET_RE = re.compile(r'(?P<indent>\s*)\w+\("(?P<target_name>\w+)"\) {'
r'(?P<target_contents>.*?)'
r'(?P=indent)}',
re.MULTILINE | re.DOTALL)
TARGET_RE = re.compile(
r'(?P<indent>\s*)\w+\("(?P<target_name>\w+)"\) {'
r'(?P<target_contents>.*?)'
r'(?P=indent)}', re.MULTILINE | re.DOTALL)
# SOURCES_RE matches a block of sources inside a GN target.
SOURCES_RE = re.compile(r'sources \+?= \[(?P<sources>.*?)\]',
@ -31,96 +30,107 @@ ERROR_MESSAGE = ("{build_file_path} in target '{target_name}':\n"
class PackageBoundaryViolation(
collections.namedtuple('PackageBoundaryViolation',
'build_file_path target_name source_file subpackage')):
def __str__(self):
return ERROR_MESSAGE.format(**self._asdict())
collections.namedtuple(
'PackageBoundaryViolation',
'build_file_path target_name source_file subpackage')):
def __str__(self):
return ERROR_MESSAGE.format(**self._asdict())
def _BuildSubpackagesPattern(packages, query):
"""Returns a regular expression that matches source files inside subpackages
"""Returns a regular expression that matches source files inside subpackages
of the given query."""
query += os.path.sep
length = len(query)
pattern = r'\s*"(?P<source_file>(?P<subpackage>'
pattern += '|'.join(re.escape(package[length:].replace(os.path.sep, '/'))
for package in packages if package.startswith(query))
pattern += r')/[\w\./]*)"'
return re.compile(pattern)
query += os.path.sep
length = len(query)
pattern = r'\s*"(?P<source_file>(?P<subpackage>'
pattern += '|'.join(
re.escape(package[length:].replace(os.path.sep, '/'))
for package in packages if package.startswith(query))
pattern += r')/[\w\./]*)"'
return re.compile(pattern)
def _ReadFileAndPrependLines(file_path):
"""Reads the contents of a file."""
with open(file_path) as f:
return "".join(f.readlines())
"""Reads the contents of a file."""
with open(file_path) as f:
return "".join(f.readlines())
def _CheckBuildFile(build_file_path, packages):
"""Iterates over all the targets of the given BUILD.gn file, and verifies that
"""Iterates over all the targets of the given BUILD.gn file, and verifies that
the source files referenced by it don't belong to any of it's subpackages.
Returns an iterator over PackageBoundaryViolations for this package.
"""
package = os.path.dirname(build_file_path)
subpackages_re = _BuildSubpackagesPattern(packages, package)
package = os.path.dirname(build_file_path)
subpackages_re = _BuildSubpackagesPattern(packages, package)
build_file_contents = _ReadFileAndPrependLines(build_file_path)
for target_match in TARGET_RE.finditer(build_file_contents):
target_name = target_match.group('target_name')
target_contents = target_match.group('target_contents')
for sources_match in SOURCES_RE.finditer(target_contents):
sources = sources_match.group('sources')
for subpackages_match in subpackages_re.finditer(sources):
subpackage = subpackages_match.group('subpackage')
source_file = subpackages_match.group('source_file')
if subpackage:
yield PackageBoundaryViolation(build_file_path,
target_name, source_file, subpackage)
build_file_contents = _ReadFileAndPrependLines(build_file_path)
for target_match in TARGET_RE.finditer(build_file_contents):
target_name = target_match.group('target_name')
target_contents = target_match.group('target_contents')
for sources_match in SOURCES_RE.finditer(target_contents):
sources = sources_match.group('sources')
for subpackages_match in subpackages_re.finditer(sources):
subpackage = subpackages_match.group('subpackage')
source_file = subpackages_match.group('source_file')
if subpackage:
yield PackageBoundaryViolation(build_file_path,
target_name, source_file,
subpackage)
def CheckPackageBoundaries(root_dir, build_files=None):
packages = [root for root, _, files in os.walk(root_dir)
if 'BUILD.gn' in files]
packages = [
root for root, _, files in os.walk(root_dir) if 'BUILD.gn' in files
]
if build_files is not None:
if build_files is not None:
for build_file_path in build_files:
assert build_file_path.startswith(root_dir)
else:
build_files = [
os.path.join(package, 'BUILD.gn') for package in packages
]
messages = []
for build_file_path in build_files:
assert build_file_path.startswith(root_dir)
else:
build_files = [os.path.join(package, 'BUILD.gn') for package in packages]
messages = []
for build_file_path in build_files:
messages.extend(_CheckBuildFile(build_file_path, packages))
return messages
messages.extend(_CheckBuildFile(build_file_path, packages))
return messages
def main(argv):
parser = argparse.ArgumentParser(
description='Script that checks package boundary violations in GN '
'build files.')
parser = argparse.ArgumentParser(
description='Script that checks package boundary violations in GN '
'build files.')
parser.add_argument('root_dir', metavar='ROOT_DIR',
help='The root directory that contains all BUILD.gn '
'files to be processed.')
parser.add_argument('build_files', metavar='BUILD_FILE', nargs='*',
help='A list of BUILD.gn files to be processed. If no '
'files are given, all BUILD.gn files under ROOT_DIR '
'will be processed.')
parser.add_argument('--max_messages', type=int, default=None,
help='If set, the maximum number of violations to be '
'displayed.')
parser.add_argument('root_dir',
metavar='ROOT_DIR',
help='The root directory that contains all BUILD.gn '
'files to be processed.')
parser.add_argument('build_files',
metavar='BUILD_FILE',
nargs='*',
help='A list of BUILD.gn files to be processed. If no '
'files are given, all BUILD.gn files under ROOT_DIR '
'will be processed.')
parser.add_argument('--max_messages',
type=int,
default=None,
help='If set, the maximum number of violations to be '
'displayed.')
args = parser.parse_args(argv)
args = parser.parse_args(argv)
messages = CheckPackageBoundaries(args.root_dir, args.build_files)
messages = messages[:args.max_messages]
messages = CheckPackageBoundaries(args.root_dir, args.build_files)
messages = messages[:args.max_messages]
for i, message in enumerate(messages):
if i > 0:
print
print message
for i, message in enumerate(messages):
if i > 0:
print
print message
return bool(messages)
return bool(messages)
if __name__ == '__main__':
sys.exit(main(sys.argv[1:]))
sys.exit(main(sys.argv[1:]))

View File

@ -15,58 +15,60 @@ import unittest
#pylint: disable=relative-import
from check_package_boundaries import CheckPackageBoundaries
MSG_FORMAT = 'ERROR:check_package_boundaries.py: Unexpected %s.'
TESTDATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'testdata')
def ReadPylFile(file_path):
with open(file_path) as f:
return ast.literal_eval(f.read())
with open(file_path) as f:
return ast.literal_eval(f.read())
class UnitTest(unittest.TestCase):
def _RunTest(self, test_dir, check_all_build_files=False):
build_files = [os.path.join(test_dir, 'BUILD.gn')]
if check_all_build_files:
build_files = None
def _RunTest(self, test_dir, check_all_build_files=False):
build_files = [os.path.join(test_dir, 'BUILD.gn')]
if check_all_build_files:
build_files = None
messages = []
for violation in CheckPackageBoundaries(test_dir, build_files):
build_file_path = os.path.relpath(violation.build_file_path, test_dir)
build_file_path = build_file_path.replace(os.path.sep, '/')
messages.append(violation._replace(build_file_path=build_file_path))
messages = []
for violation in CheckPackageBoundaries(test_dir, build_files):
build_file_path = os.path.relpath(violation.build_file_path,
test_dir)
build_file_path = build_file_path.replace(os.path.sep, '/')
messages.append(
violation._replace(build_file_path=build_file_path))
expected_messages = ReadPylFile(os.path.join(test_dir, 'expected.pyl'))
self.assertListEqual(sorted(expected_messages), sorted(messages))
expected_messages = ReadPylFile(os.path.join(test_dir, 'expected.pyl'))
self.assertListEqual(sorted(expected_messages), sorted(messages))
def testNoErrors(self):
self._RunTest(os.path.join(TESTDATA_DIR, 'no_errors'))
def testNoErrors(self):
self._RunTest(os.path.join(TESTDATA_DIR, 'no_errors'))
def testMultipleErrorsSingleTarget(self):
self._RunTest(os.path.join(TESTDATA_DIR, 'multiple_errors_single_target'))
def testMultipleErrorsSingleTarget(self):
self._RunTest(
os.path.join(TESTDATA_DIR, 'multiple_errors_single_target'))
def testMultipleErrorsMultipleTargets(self):
self._RunTest(os.path.join(TESTDATA_DIR,
'multiple_errors_multiple_targets'))
def testMultipleErrorsMultipleTargets(self):
self._RunTest(
os.path.join(TESTDATA_DIR, 'multiple_errors_multiple_targets'))
def testCommonPrefix(self):
self._RunTest(os.path.join(TESTDATA_DIR, 'common_prefix'))
def testCommonPrefix(self):
self._RunTest(os.path.join(TESTDATA_DIR, 'common_prefix'))
def testAllBuildFiles(self):
self._RunTest(os.path.join(TESTDATA_DIR, 'all_build_files'), True)
def testAllBuildFiles(self):
self._RunTest(os.path.join(TESTDATA_DIR, 'all_build_files'), True)
def testSanitizeFilename(self):
# The `dangerous_filename` test case contains a directory with '++' in its
# name. If it's not properly escaped, a regex error would be raised.
self._RunTest(os.path.join(TESTDATA_DIR, 'dangerous_filename'), True)
def testSanitizeFilename(self):
# The `dangerous_filename` test case contains a directory with '++' in its
# name. If it's not properly escaped, a regex error would be raised.
self._RunTest(os.path.join(TESTDATA_DIR, 'dangerous_filename'), True)
def testRelativeFilename(self):
test_dir = os.path.join(TESTDATA_DIR, 'all_build_files')
with self.assertRaises(AssertionError):
CheckPackageBoundaries(test_dir, ["BUILD.gn"])
def testRelativeFilename(self):
test_dir = os.path.join(TESTDATA_DIR, 'all_build_files')
with self.assertRaises(AssertionError):
CheckPackageBoundaries(test_dir, ["BUILD.gn"])
if __name__ == '__main__':
unittest.main()
unittest.main()

View File

@ -6,8 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""This is a tool to transform a crt file into a C/C++ header.
Usage:
@ -41,172 +39,180 @@ _VERBOSE = 'verbose'
def main():
"""The main entrypoint."""
parser = OptionParser('usage %prog FILE')
parser.add_option('-v', '--verbose', dest='verbose', action='store_true')
parser.add_option('-f', '--full_cert', dest='full_cert', action='store_true')
options, args = parser.parse_args()
if len(args) < 1:
parser.error('No crt file specified.')
return
root_dir = _SplitCrt(args[0], options)
_GenCFiles(root_dir, options)
_Cleanup(root_dir)
"""The main entrypoint."""
parser = OptionParser('usage %prog FILE')
parser.add_option('-v', '--verbose', dest='verbose', action='store_true')
parser.add_option('-f',
'--full_cert',
dest='full_cert',
action='store_true')
options, args = parser.parse_args()
if len(args) < 1:
parser.error('No crt file specified.')
return
root_dir = _SplitCrt(args[0], options)
_GenCFiles(root_dir, options)
_Cleanup(root_dir)
def _SplitCrt(source_file, options):
sub_file_blocks = []
label_name = ''
root_dir = os.path.dirname(os.path.abspath(source_file)) + '/'
_PrintOutput(root_dir, options)
f = open(source_file)
for line in f:
if line.startswith('# Label: '):
sub_file_blocks.append(line)
label = re.search(r'\".*\"', line)
temp_label = label.group(0)
end = len(temp_label)-1
label_name = _SafeName(temp_label[1:end])
elif line.startswith('-----END CERTIFICATE-----'):
sub_file_blocks.append(line)
new_file_name = root_dir + _PREFIX + label_name + _EXTENSION
_PrintOutput('Generating: ' + new_file_name, options)
new_file = open(new_file_name, 'w')
for out_line in sub_file_blocks:
new_file.write(out_line)
new_file.close()
sub_file_blocks = []
else:
sub_file_blocks.append(line)
f.close()
return root_dir
sub_file_blocks = []
label_name = ''
root_dir = os.path.dirname(os.path.abspath(source_file)) + '/'
_PrintOutput(root_dir, options)
f = open(source_file)
for line in f:
if line.startswith('# Label: '):
sub_file_blocks.append(line)
label = re.search(r'\".*\"', line)
temp_label = label.group(0)
end = len(temp_label) - 1
label_name = _SafeName(temp_label[1:end])
elif line.startswith('-----END CERTIFICATE-----'):
sub_file_blocks.append(line)
new_file_name = root_dir + _PREFIX + label_name + _EXTENSION
_PrintOutput('Generating: ' + new_file_name, options)
new_file = open(new_file_name, 'w')
for out_line in sub_file_blocks:
new_file.write(out_line)
new_file.close()
sub_file_blocks = []
else:
sub_file_blocks.append(line)
f.close()
return root_dir
def _GenCFiles(root_dir, options):
output_header_file = open(root_dir + _GENERATED_FILE, 'w')
output_header_file.write(_CreateOutputHeader())
if options.full_cert:
subject_name_list = _CreateArraySectionHeader(_SUBJECT_NAME_VARIABLE,
_CHAR_TYPE, options)
public_key_list = _CreateArraySectionHeader(_PUBLIC_KEY_VARIABLE,
_CHAR_TYPE, options)
certificate_list = _CreateArraySectionHeader(_CERTIFICATE_VARIABLE,
_CHAR_TYPE, options)
certificate_size_list = _CreateArraySectionHeader(_CERTIFICATE_SIZE_VARIABLE,
_INT_TYPE, options)
output_header_file = open(root_dir + _GENERATED_FILE, 'w')
output_header_file.write(_CreateOutputHeader())
if options.full_cert:
subject_name_list = _CreateArraySectionHeader(_SUBJECT_NAME_VARIABLE,
_CHAR_TYPE, options)
public_key_list = _CreateArraySectionHeader(_PUBLIC_KEY_VARIABLE,
_CHAR_TYPE, options)
certificate_list = _CreateArraySectionHeader(_CERTIFICATE_VARIABLE,
_CHAR_TYPE, options)
certificate_size_list = _CreateArraySectionHeader(
_CERTIFICATE_SIZE_VARIABLE, _INT_TYPE, options)
for _, _, files in os.walk(root_dir):
for current_file in files:
if current_file.startswith(_PREFIX):
prefix_length = len(_PREFIX)
length = len(current_file) - len(_EXTENSION)
label = current_file[prefix_length:length]
filtered_output, cert_size = _CreateCertSection(root_dir, current_file,
label, options)
output_header_file.write(filtered_output + '\n\n\n')
if options.full_cert:
subject_name_list += _AddLabelToArray(label, _SUBJECT_NAME_ARRAY)
public_key_list += _AddLabelToArray(label, _PUBLIC_KEY_ARRAY)
certificate_list += _AddLabelToArray(label, _CERTIFICATE_ARRAY)
certificate_size_list += (' %s,\n') %(cert_size)
for _, _, files in os.walk(root_dir):
for current_file in files:
if current_file.startswith(_PREFIX):
prefix_length = len(_PREFIX)
length = len(current_file) - len(_EXTENSION)
label = current_file[prefix_length:length]
filtered_output, cert_size = _CreateCertSection(
root_dir, current_file, label, options)
output_header_file.write(filtered_output + '\n\n\n')
if options.full_cert:
subject_name_list += _AddLabelToArray(
label, _SUBJECT_NAME_ARRAY)
public_key_list += _AddLabelToArray(
label, _PUBLIC_KEY_ARRAY)
certificate_list += _AddLabelToArray(label, _CERTIFICATE_ARRAY)
certificate_size_list += (' %s,\n') % (cert_size)
if options.full_cert:
subject_name_list += _CreateArraySectionFooter()
output_header_file.write(subject_name_list)
public_key_list += _CreateArraySectionFooter()
output_header_file.write(public_key_list)
certificate_list += _CreateArraySectionFooter()
output_header_file.write(certificate_list)
certificate_size_list += _CreateArraySectionFooter()
output_header_file.write(certificate_size_list)
output_header_file.write(_CreateOutputFooter())
output_header_file.close()
if options.full_cert:
subject_name_list += _CreateArraySectionFooter()
output_header_file.write(subject_name_list)
public_key_list += _CreateArraySectionFooter()
output_header_file.write(public_key_list)
certificate_list += _CreateArraySectionFooter()
output_header_file.write(certificate_list)
certificate_size_list += _CreateArraySectionFooter()
output_header_file.write(certificate_size_list)
output_header_file.write(_CreateOutputFooter())
output_header_file.close()
def _Cleanup(root_dir):
for f in os.listdir(root_dir):
if f.startswith(_PREFIX):
os.remove(root_dir + f)
for f in os.listdir(root_dir):
if f.startswith(_PREFIX):
os.remove(root_dir + f)
def _CreateCertSection(root_dir, source_file, label, options):
command = 'openssl x509 -in %s%s -noout -C' %(root_dir, source_file)
_PrintOutput(command, options)
output = commands.getstatusoutput(command)[1]
renamed_output = output.replace('unsigned char XXX_',
'const unsigned char ' + label + '_')
filtered_output = ''
cert_block = '^const unsigned char.*?};$'
prog = re.compile(cert_block, re.IGNORECASE | re.MULTILINE | re.DOTALL)
if not options.full_cert:
filtered_output = prog.sub('', renamed_output, count=2)
else:
filtered_output = renamed_output
command = 'openssl x509 -in %s%s -noout -C' % (root_dir, source_file)
_PrintOutput(command, options)
output = commands.getstatusoutput(command)[1]
renamed_output = output.replace('unsigned char XXX_',
'const unsigned char ' + label + '_')
filtered_output = ''
cert_block = '^const unsigned char.*?};$'
prog = re.compile(cert_block, re.IGNORECASE | re.MULTILINE | re.DOTALL)
if not options.full_cert:
filtered_output = prog.sub('', renamed_output, count=2)
else:
filtered_output = renamed_output
cert_size_block = r'\d\d\d+'
prog2 = re.compile(cert_size_block, re.MULTILINE | re.VERBOSE)
result = prog2.findall(renamed_output)
cert_size = result[len(result) - 1]
cert_size_block = r'\d\d\d+'
prog2 = re.compile(cert_size_block, re.MULTILINE | re.VERBOSE)
result = prog2.findall(renamed_output)
cert_size = result[len(result) - 1]
return filtered_output, cert_size
return filtered_output, cert_size
def _CreateOutputHeader():
output = ('/*\n'
' * Copyright 2004 The WebRTC Project Authors. All rights '
'reserved.\n'
' *\n'
' * Use of this source code is governed by a BSD-style license\n'
' * that can be found in the LICENSE file in the root of the '
'source\n'
' * tree. An additional intellectual property rights grant can be '
'found\n'
' * in the file PATENTS. All contributing project authors may\n'
' * be found in the AUTHORS file in the root of the source tree.\n'
' */\n\n'
'#ifndef RTC_BASE_SSL_ROOTS_H_\n'
'#define RTC_BASE_SSL_ROOTS_H_\n\n'
'// This file is the root certificates in C form that are needed to'
' connect to\n// Google.\n\n'
'// It was generated with the following command line:\n'
'// > python tools_webrtc/sslroots/generate_sslroots.py'
'\n// https://pki.goog/roots.pem\n\n'
'// clang-format off\n'
'// Don\'t bother formatting generated code,\n'
'// also it would breaks subject/issuer lines.\n\n')
return output
output = (
'/*\n'
' * Copyright 2004 The WebRTC Project Authors. All rights '
'reserved.\n'
' *\n'
' * Use of this source code is governed by a BSD-style license\n'
' * that can be found in the LICENSE file in the root of the '
'source\n'
' * tree. An additional intellectual property rights grant can be '
'found\n'
' * in the file PATENTS. All contributing project authors may\n'
' * be found in the AUTHORS file in the root of the source tree.\n'
' */\n\n'
'#ifndef RTC_BASE_SSL_ROOTS_H_\n'
'#define RTC_BASE_SSL_ROOTS_H_\n\n'
'// This file is the root certificates in C form that are needed to'
' connect to\n// Google.\n\n'
'// It was generated with the following command line:\n'
'// > python tools_webrtc/sslroots/generate_sslroots.py'
'\n// https://pki.goog/roots.pem\n\n'
'// clang-format off\n'
'// Don\'t bother formatting generated code,\n'
'// also it would breaks subject/issuer lines.\n\n')
return output
def _CreateOutputFooter():
output = ('// clang-format on\n\n'
'#endif // RTC_BASE_SSL_ROOTS_H_\n')
return output
output = ('// clang-format on\n\n' '#endif // RTC_BASE_SSL_ROOTS_H_\n')
return output
def _CreateArraySectionHeader(type_name, type_type, options):
output = ('const %s kSSLCert%sList[] = {\n') %(type_type, type_name)
_PrintOutput(output, options)
return output
output = ('const %s kSSLCert%sList[] = {\n') % (type_type, type_name)
_PrintOutput(output, options)
return output
def _AddLabelToArray(label, type_name):
return ' %s_%s,\n' %(label, type_name)
return ' %s_%s,\n' % (label, type_name)
def _CreateArraySectionFooter():
return '};\n\n'
return '};\n\n'
def _SafeName(original_file_name):
bad_chars = ' -./\\()áéíőú'
replacement_chars = ''
for _ in bad_chars:
replacement_chars += '_'
translation_table = string.maketrans(bad_chars, replacement_chars)
return original_file_name.translate(translation_table)
bad_chars = ' -./\\()áéíőú'
replacement_chars = ''
for _ in bad_chars:
replacement_chars += '_'
translation_table = string.maketrans(bad_chars, replacement_chars)
return original_file_name.translate(translation_table)
def _PrintOutput(output, options):
if options.verbose:
print output
if options.verbose:
print output
if __name__ == '__main__':
main()
main()

View File

@ -53,7 +53,6 @@
#
# * This has only been tested on gPrecise.
import os
import os.path
import shlex
@ -62,25 +61,26 @@ import sys
# Flags from YCM's default config.
_DEFAULT_FLAGS = [
'-DUSE_CLANG_COMPLETER',
'-std=c++11',
'-x',
'c++',
'-DUSE_CLANG_COMPLETER',
'-std=c++11',
'-x',
'c++',
]
_HEADER_ALTERNATES = ('.cc', '.cpp', '.c', '.mm', '.m')
_EXTENSION_FLAGS = {
'.m': ['-x', 'objective-c'],
'.mm': ['-x', 'objective-c++'],
'.m': ['-x', 'objective-c'],
'.mm': ['-x', 'objective-c++'],
}
def PathExists(*args):
return os.path.exists(os.path.join(*args))
return os.path.exists(os.path.join(*args))
def FindWebrtcSrcFromFilename(filename):
"""Searches for the root of the WebRTC checkout.
"""Searches for the root of the WebRTC checkout.
Simply checks parent directories until it finds .gclient and src/.
@ -90,20 +90,20 @@ def FindWebrtcSrcFromFilename(filename):
Returns:
(String) Path of 'src/', or None if unable to find.
"""
curdir = os.path.normpath(os.path.dirname(filename))
while not (os.path.basename(curdir) == 'src'
and PathExists(curdir, 'DEPS')
and (PathExists(curdir, '..', '.gclient')
or PathExists(curdir, '.git'))):
nextdir = os.path.normpath(os.path.join(curdir, '..'))
if nextdir == curdir:
return None
curdir = nextdir
return curdir
curdir = os.path.normpath(os.path.dirname(filename))
while not (os.path.basename(curdir) == 'src'
and PathExists(curdir, 'DEPS') and
(PathExists(curdir, '..', '.gclient')
or PathExists(curdir, '.git'))):
nextdir = os.path.normpath(os.path.join(curdir, '..'))
if nextdir == curdir:
return None
curdir = nextdir
return curdir
def GetDefaultSourceFile(webrtc_root, filename):
"""Returns the default source file to use as an alternative to |filename|.
"""Returns the default source file to use as an alternative to |filename|.
Compile flags used to build the default source file is assumed to be a
close-enough approximation for building |filename|.
@ -115,13 +115,13 @@ def GetDefaultSourceFile(webrtc_root, filename):
Returns:
(String) Absolute path to substitute source file.
"""
if 'test.' in filename:
return os.path.join(webrtc_root, 'base', 'logging_unittest.cc')
return os.path.join(webrtc_root, 'base', 'logging.cc')
if 'test.' in filename:
return os.path.join(webrtc_root, 'base', 'logging_unittest.cc')
return os.path.join(webrtc_root, 'base', 'logging.cc')
def GetNinjaBuildOutputsForSourceFile(out_dir, filename):
"""Returns a list of build outputs for filename.
"""Returns a list of build outputs for filename.
The list is generated by invoking 'ninja -t query' tool to retrieve a list of
inputs and outputs of |filename|. This list is then filtered to only include
@ -135,32 +135,35 @@ def GetNinjaBuildOutputsForSourceFile(out_dir, filename):
(List of Strings) List of target names. Will return [] if |filename| doesn't
yield any .o or .obj outputs.
"""
# Ninja needs the path to the source file relative to the output build
# directory.
rel_filename = os.path.relpath(filename, out_dir)
# Ninja needs the path to the source file relative to the output build
# directory.
rel_filename = os.path.relpath(filename, out_dir)
p = subprocess.Popen(['ninja', '-C', out_dir, '-t', 'query', rel_filename],
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
universal_newlines=True)
stdout, _ = p.communicate()
if p.returncode != 0:
return []
p = subprocess.Popen(['ninja', '-C', out_dir, '-t', 'query', rel_filename],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True)
stdout, _ = p.communicate()
if p.returncode != 0:
return []
# The output looks like:
# ../../relative/path/to/source.cc:
# outputs:
# obj/reative/path/to/target.source.o
# obj/some/other/target2.source.o
# another/target.txt
#
outputs_text = stdout.partition('\n outputs:\n')[2]
output_lines = [line.strip() for line in outputs_text.split('\n')]
return [target for target in output_lines
if target and (target.endswith('.o') or target.endswith('.obj'))]
# The output looks like:
# ../../relative/path/to/source.cc:
# outputs:
# obj/reative/path/to/target.source.o
# obj/some/other/target2.source.o
# another/target.txt
#
outputs_text = stdout.partition('\n outputs:\n')[2]
output_lines = [line.strip() for line in outputs_text.split('\n')]
return [
target for target in output_lines
if target and (target.endswith('.o') or target.endswith('.obj'))
]
def GetClangCommandLineForNinjaOutput(out_dir, build_target):
"""Returns the Clang command line for building |build_target|
"""Returns the Clang command line for building |build_target|
Asks ninja for the list of commands used to build |filename| and returns the
final Clang invocation.
@ -173,24 +176,25 @@ def GetClangCommandLineForNinjaOutput(out_dir, build_target):
(String or None) Clang command line or None if a Clang command line couldn't
be determined.
"""
p = subprocess.Popen(['ninja', '-v', '-C', out_dir,
'-t', 'commands', build_target],
stdout=subprocess.PIPE, universal_newlines=True)
stdout, _ = p.communicate()
if p.returncode != 0:
return None
p = subprocess.Popen(
['ninja', '-v', '-C', out_dir, '-t', 'commands', build_target],
stdout=subprocess.PIPE,
universal_newlines=True)
stdout, _ = p.communicate()
if p.returncode != 0:
return None
# Ninja will return multiple build steps for all dependencies up to
# |build_target|. The build step we want is the last Clang invocation, which
# is expected to be the one that outputs |build_target|.
for line in reversed(stdout.split('\n')):
if 'clang' in line:
return line
return None
# Ninja will return multiple build steps for all dependencies up to
# |build_target|. The build step we want is the last Clang invocation, which
# is expected to be the one that outputs |build_target|.
for line in reversed(stdout.split('\n')):
if 'clang' in line:
return line
return None
def GetClangCommandLineFromNinjaForSource(out_dir, filename):
"""Returns a Clang command line used to build |filename|.
"""Returns a Clang command line used to build |filename|.
The same source file could be built multiple times using different tool
chains. In such cases, this command returns the first Clang invocation. We
@ -206,17 +210,17 @@ def GetClangCommandLineFromNinjaForSource(out_dir, filename):
(String or None): Command line for Clang invocation using |filename| as a
source. Returns None if no such command line could be found.
"""
build_targets = GetNinjaBuildOutputsForSourceFile(out_dir, filename)
for build_target in build_targets:
command_line = GetClangCommandLineForNinjaOutput(out_dir, build_target)
if command_line:
return command_line
return None
build_targets = GetNinjaBuildOutputsForSourceFile(out_dir, filename)
for build_target in build_targets:
command_line = GetClangCommandLineForNinjaOutput(out_dir, build_target)
if command_line:
return command_line
return None
def GetClangOptionsFromCommandLine(clang_commandline, out_dir,
additional_flags):
"""Extracts relevant command line options from |clang_commandline|
"""Extracts relevant command line options from |clang_commandline|
Args:
clang_commandline: (String) Full Clang invocation.
@ -228,46 +232,47 @@ def GetClangOptionsFromCommandLine(clang_commandline, out_dir,
(List of Strings) The list of command line flags for this source file. Can
be empty.
"""
clang_flags = [] + additional_flags
clang_flags = [] + additional_flags
# Parse flags that are important for YCM's purposes.
clang_tokens = shlex.split(clang_commandline)
for flag_index, flag in enumerate(clang_tokens):
if flag.startswith('-I'):
# Relative paths need to be resolved, because they're relative to the
# output dir, not the source.
if flag[2] == '/':
clang_flags.append(flag)
else:
abs_path = os.path.normpath(os.path.join(out_dir, flag[2:]))
clang_flags.append('-I' + abs_path)
elif flag.startswith('-std'):
clang_flags.append(flag)
elif flag.startswith('-') and flag[1] in 'DWFfmO':
if flag == '-Wno-deprecated-register' or flag == '-Wno-header-guard':
# These flags causes libclang (3.3) to crash. Remove it until things
# are fixed.
continue
clang_flags.append(flag)
elif flag == '-isysroot':
# On Mac -isysroot <path> is used to find the system headers.
# Copy over both flags.
if flag_index + 1 < len(clang_tokens):
clang_flags.append(flag)
clang_flags.append(clang_tokens[flag_index + 1])
elif flag.startswith('--sysroot='):
# On Linux we use a sysroot image.
sysroot_path = flag.lstrip('--sysroot=')
if sysroot_path.startswith('/'):
clang_flags.append(flag)
else:
abs_path = os.path.normpath(os.path.join(out_dir, sysroot_path))
clang_flags.append('--sysroot=' + abs_path)
return clang_flags
# Parse flags that are important for YCM's purposes.
clang_tokens = shlex.split(clang_commandline)
for flag_index, flag in enumerate(clang_tokens):
if flag.startswith('-I'):
# Relative paths need to be resolved, because they're relative to the
# output dir, not the source.
if flag[2] == '/':
clang_flags.append(flag)
else:
abs_path = os.path.normpath(os.path.join(out_dir, flag[2:]))
clang_flags.append('-I' + abs_path)
elif flag.startswith('-std'):
clang_flags.append(flag)
elif flag.startswith('-') and flag[1] in 'DWFfmO':
if flag == '-Wno-deprecated-register' or flag == '-Wno-header-guard':
# These flags causes libclang (3.3) to crash. Remove it until things
# are fixed.
continue
clang_flags.append(flag)
elif flag == '-isysroot':
# On Mac -isysroot <path> is used to find the system headers.
# Copy over both flags.
if flag_index + 1 < len(clang_tokens):
clang_flags.append(flag)
clang_flags.append(clang_tokens[flag_index + 1])
elif flag.startswith('--sysroot='):
# On Linux we use a sysroot image.
sysroot_path = flag.lstrip('--sysroot=')
if sysroot_path.startswith('/'):
clang_flags.append(flag)
else:
abs_path = os.path.normpath(os.path.join(
out_dir, sysroot_path))
clang_flags.append('--sysroot=' + abs_path)
return clang_flags
def GetClangOptionsFromNinjaForFilename(webrtc_root, filename):
"""Returns the Clang command line options needed for building |filename|.
"""Returns the Clang command line options needed for building |filename|.
Command line options are based on the command used by ninja for building
|filename|. If |filename| is a .h file, uses its companion .cc or .cpp file.
@ -283,54 +288,55 @@ def GetClangOptionsFromNinjaForFilename(webrtc_root, filename):
(List of Strings) The list of command line flags for this source file. Can
be empty.
"""
if not webrtc_root:
return []
if not webrtc_root:
return []
# Generally, everyone benefits from including WebRTC's src/, because all of
# WebRTC's includes are relative to that.
additional_flags = ['-I' + os.path.join(webrtc_root)]
# Generally, everyone benefits from including WebRTC's src/, because all of
# WebRTC's includes are relative to that.
additional_flags = ['-I' + os.path.join(webrtc_root)]
# Version of Clang used to compile WebRTC can be newer then version of
# libclang that YCM uses for completion. So it's possible that YCM's libclang
# doesn't know about some used warning options, which causes compilation
# warnings (and errors, because of '-Werror');
additional_flags.append('-Wno-unknown-warning-option')
# Version of Clang used to compile WebRTC can be newer then version of
# libclang that YCM uses for completion. So it's possible that YCM's libclang
# doesn't know about some used warning options, which causes compilation
# warnings (and errors, because of '-Werror');
additional_flags.append('-Wno-unknown-warning-option')
sys.path.append(os.path.join(webrtc_root, 'tools', 'vim'))
from ninja_output import GetNinjaOutputDirectory
out_dir = GetNinjaOutputDirectory(webrtc_root)
sys.path.append(os.path.join(webrtc_root, 'tools', 'vim'))
from ninja_output import GetNinjaOutputDirectory
out_dir = GetNinjaOutputDirectory(webrtc_root)
basename, extension = os.path.splitext(filename)
if extension == '.h':
candidates = [basename + ext for ext in _HEADER_ALTERNATES]
else:
candidates = [filename]
basename, extension = os.path.splitext(filename)
if extension == '.h':
candidates = [basename + ext for ext in _HEADER_ALTERNATES]
else:
candidates = [filename]
clang_line = None
buildable_extension = extension
for candidate in candidates:
clang_line = GetClangCommandLineFromNinjaForSource(out_dir, candidate)
if clang_line:
buildable_extension = os.path.splitext(candidate)[1]
break
clang_line = None
buildable_extension = extension
for candidate in candidates:
clang_line = GetClangCommandLineFromNinjaForSource(out_dir, candidate)
if clang_line:
buildable_extension = os.path.splitext(candidate)[1]
break
additional_flags += _EXTENSION_FLAGS.get(buildable_extension, [])
additional_flags += _EXTENSION_FLAGS.get(buildable_extension, [])
if not clang_line:
# If ninja didn't know about filename or it's companion files, then try a
# default build target. It is possible that the file is new, or build.ninja
# is stale.
clang_line = GetClangCommandLineFromNinjaForSource(
out_dir, GetDefaultSourceFile(webrtc_root, filename))
if not clang_line:
# If ninja didn't know about filename or it's companion files, then try a
# default build target. It is possible that the file is new, or build.ninja
# is stale.
clang_line = GetClangCommandLineFromNinjaForSource(
out_dir, GetDefaultSourceFile(webrtc_root, filename))
if not clang_line:
return additional_flags
if not clang_line:
return additional_flags
return GetClangOptionsFromCommandLine(clang_line, out_dir, additional_flags)
return GetClangOptionsFromCommandLine(clang_line, out_dir,
additional_flags)
def FlagsForFile(filename):
"""This is the main entry point for YCM. Its interface is fixed.
"""This is the main entry point for YCM. Its interface is fixed.
Args:
filename: (String) Path to source file being edited.
@ -340,18 +346,16 @@ def FlagsForFile(filename):
'flags': (List of Strings) Command line flags.
'do_cache': (Boolean) True if the result should be cached.
"""
abs_filename = os.path.abspath(filename)
webrtc_root = FindWebrtcSrcFromFilename(abs_filename)
clang_flags = GetClangOptionsFromNinjaForFilename(webrtc_root, abs_filename)
abs_filename = os.path.abspath(filename)
webrtc_root = FindWebrtcSrcFromFilename(abs_filename)
clang_flags = GetClangOptionsFromNinjaForFilename(webrtc_root,
abs_filename)
# If clang_flags could not be determined, then assume that was due to a
# transient failure. Preventing YCM from caching the flags allows us to try to
# determine the flags again.
should_cache_flags_for_file = bool(clang_flags)
# If clang_flags could not be determined, then assume that was due to a
# transient failure. Preventing YCM from caching the flags allows us to try to
# determine the flags again.
should_cache_flags_for_file = bool(clang_flags)
final_flags = _DEFAULT_FLAGS + clang_flags
final_flags = _DEFAULT_FLAGS + clang_flags
return {
'flags': final_flags,
'do_cache': should_cache_flags_for_file
}
return {'flags': final_flags, 'do_cache': should_cache_flags_for_file}

View File

@ -6,7 +6,6 @@
# tree. An additional intellectual property rights grant can be found
# in the file PATENTS. All contributing project authors may
# be found in the AUTHORS file in the root of the source tree.
"""Generate graphs for data generated by loopback tests.
Usage examples:
@ -34,14 +33,14 @@ import numpy
# Fields
DROPPED = 0
INPUT_TIME = 1 # ms (timestamp)
SEND_TIME = 2 # ms (timestamp)
RECV_TIME = 3 # ms (timestamp)
RENDER_TIME = 4 # ms (timestamp)
ENCODED_FRAME_SIZE = 5 # bytes
INPUT_TIME = 1 # ms (timestamp)
SEND_TIME = 2 # ms (timestamp)
RECV_TIME = 3 # ms (timestamp)
RENDER_TIME = 4 # ms (timestamp)
ENCODED_FRAME_SIZE = 5 # bytes
PSNR = 6
SSIM = 7
ENCODE_TIME = 8 # ms (time interval)
ENCODE_TIME = 8 # ms (time interval)
TOTAL_RAW_FIELDS = 9
@ -78,111 +77,116 @@ _FIELDS = [
NAME_TO_ID = {field[1]: field[0] for field in _FIELDS}
ID_TO_TITLE = {field[0]: field[2] for field in _FIELDS}
def FieldArgToId(arg):
if arg == "none":
return None
if arg in NAME_TO_ID:
return NAME_TO_ID[arg]
if arg + "_ms" in NAME_TO_ID:
return NAME_TO_ID[arg + "_ms"]
raise Exception("Unrecognized field name \"{}\"".format(arg))
if arg == "none":
return None
if arg in NAME_TO_ID:
return NAME_TO_ID[arg]
if arg + "_ms" in NAME_TO_ID:
return NAME_TO_ID[arg + "_ms"]
raise Exception("Unrecognized field name \"{}\"".format(arg))
class PlotLine(object):
"""Data for a single graph line."""
"""Data for a single graph line."""
def __init__(self, label, values, flags):
self.label = label
self.values = values
self.flags = flags
def __init__(self, label, values, flags):
self.label = label
self.values = values
self.flags = flags
class Data(object):
"""Object representing one full stack test."""
"""Object representing one full stack test."""
def __init__(self, filename):
self.title = ""
self.length = 0
self.samples = defaultdict(list)
def __init__(self, filename):
self.title = ""
self.length = 0
self.samples = defaultdict(list)
self._ReadSamples(filename)
self._ReadSamples(filename)
def _ReadSamples(self, filename):
"""Reads graph data from the given file."""
f = open(filename)
it = iter(f)
def _ReadSamples(self, filename):
"""Reads graph data from the given file."""
f = open(filename)
it = iter(f)
self.title = it.next().strip()
self.length = int(it.next())
field_names = [name.strip() for name in it.next().split()]
field_ids = [NAME_TO_ID[name] for name in field_names]
self.title = it.next().strip()
self.length = int(it.next())
field_names = [name.strip() for name in it.next().split()]
field_ids = [NAME_TO_ID[name] for name in field_names]
for field_id in field_ids:
self.samples[field_id] = [0.0] * self.length
for field_id in field_ids:
self.samples[field_id] = [0.0] * self.length
for sample_id in xrange(self.length):
for col, value in enumerate(it.next().split()):
self.samples[field_ids[col]][sample_id] = float(value)
for sample_id in xrange(self.length):
for col, value in enumerate(it.next().split()):
self.samples[field_ids[col]][sample_id] = float(value)
self._SubtractFirstInputTime()
self._GenerateAdditionalData()
self._SubtractFirstInputTime()
self._GenerateAdditionalData()
f.close()
f.close()
def _SubtractFirstInputTime(self):
offset = self.samples[INPUT_TIME][0]
for field in [INPUT_TIME, SEND_TIME, RECV_TIME, RENDER_TIME]:
if field in self.samples:
self.samples[field] = [x - offset for x in self.samples[field]]
def _SubtractFirstInputTime(self):
offset = self.samples[INPUT_TIME][0]
for field in [INPUT_TIME, SEND_TIME, RECV_TIME, RENDER_TIME]:
if field in self.samples:
self.samples[field] = [x - offset for x in self.samples[field]]
def _GenerateAdditionalData(self):
"""Calculates sender time, receiver time etc. from the raw data."""
s = self.samples
last_render_time = 0
for field_id in [SENDER_TIME, RECEIVER_TIME, END_TO_END, RENDERED_DELTA]:
s[field_id] = [0] * self.length
def _GenerateAdditionalData(self):
"""Calculates sender time, receiver time etc. from the raw data."""
s = self.samples
last_render_time = 0
for field_id in [
SENDER_TIME, RECEIVER_TIME, END_TO_END, RENDERED_DELTA
]:
s[field_id] = [0] * self.length
for k in range(self.length):
s[SENDER_TIME][k] = s[SEND_TIME][k] - s[INPUT_TIME][k]
for k in range(self.length):
s[SENDER_TIME][k] = s[SEND_TIME][k] - s[INPUT_TIME][k]
decoded_time = s[RENDER_TIME][k]
s[RECEIVER_TIME][k] = decoded_time - s[RECV_TIME][k]
s[END_TO_END][k] = decoded_time - s[INPUT_TIME][k]
if not s[DROPPED][k]:
if k > 0:
s[RENDERED_DELTA][k] = decoded_time - last_render_time
last_render_time = decoded_time
decoded_time = s[RENDER_TIME][k]
s[RECEIVER_TIME][k] = decoded_time - s[RECV_TIME][k]
s[END_TO_END][k] = decoded_time - s[INPUT_TIME][k]
if not s[DROPPED][k]:
if k > 0:
s[RENDERED_DELTA][k] = decoded_time - last_render_time
last_render_time = decoded_time
def _Hide(self, values):
"""
def _Hide(self, values):
"""
Replaces values for dropped frames with None.
These values are then skipped by the Plot() method.
"""
return [None if self.samples[DROPPED][k] else values[k]
for k in range(len(values))]
return [
None if self.samples[DROPPED][k] else values[k]
for k in range(len(values))
]
def AddSamples(self, config, target_lines_list):
"""Creates graph lines from the current data set with given config."""
for field in config.fields:
# field is None means the user wants just to skip the color.
if field is None:
target_lines_list.append(None)
continue
def AddSamples(self, config, target_lines_list):
"""Creates graph lines from the current data set with given config."""
for field in config.fields:
# field is None means the user wants just to skip the color.
if field is None:
target_lines_list.append(None)
continue
field_id = field & FIELD_MASK
values = self.samples[field_id]
field_id = field & FIELD_MASK
values = self.samples[field_id]
if field & HIDE_DROPPED:
values = self._Hide(values)
if field & HIDE_DROPPED:
values = self._Hide(values)
target_lines_list.append(PlotLine(
self.title + " " + ID_TO_TITLE[field_id],
values, field & ~FIELD_MASK))
target_lines_list.append(
PlotLine(self.title + " " + ID_TO_TITLE[field_id], values,
field & ~FIELD_MASK))
def AverageOverCycle(values, length):
"""
"""
Returns the list:
[
avg(values[0], values[length], ...),
@ -194,221 +198,272 @@ def AverageOverCycle(values, length):
Skips None values when calculating the average value.
"""
total = [0.0] * length
count = [0] * length
for k, val in enumerate(values):
if val is not None:
total[k % length] += val
count[k % length] += 1
total = [0.0] * length
count = [0] * length
for k, val in enumerate(values):
if val is not None:
total[k % length] += val
count[k % length] += 1
result = [0.0] * length
for k in range(length):
result[k] = total[k] / count[k] if count[k] else None
return result
result = [0.0] * length
for k in range(length):
result[k] = total[k] / count[k] if count[k] else None
return result
class PlotConfig(object):
"""Object representing a single graph."""
"""Object representing a single graph."""
def __init__(self, fields, data_list, cycle_length=None, frames=None,
offset=0, output_filename=None, title="Graph"):
self.fields = fields
self.data_list = data_list
self.cycle_length = cycle_length
self.frames = frames
self.offset = offset
self.output_filename = output_filename
self.title = title
def __init__(self,
fields,
data_list,
cycle_length=None,
frames=None,
offset=0,
output_filename=None,
title="Graph"):
self.fields = fields
self.data_list = data_list
self.cycle_length = cycle_length
self.frames = frames
self.offset = offset
self.output_filename = output_filename
self.title = title
def Plot(self, ax1):
lines = []
for data in self.data_list:
if not data:
# Add None lines to skip the colors.
lines.extend([None] * len(self.fields))
else:
data.AddSamples(self, lines)
def Plot(self, ax1):
lines = []
for data in self.data_list:
if not data:
# Add None lines to skip the colors.
lines.extend([None] * len(self.fields))
else:
data.AddSamples(self, lines)
def _SliceValues(values):
if self.offset:
values = values[self.offset:]
if self.frames:
values = values[:self.frames]
return values
def _SliceValues(values):
if self.offset:
values = values[self.offset:]
if self.frames:
values = values[:self.frames]
return values
length = None
for line in lines:
if line is None:
continue
length = None
for line in lines:
if line is None:
continue
line.values = _SliceValues(line.values)
if self.cycle_length:
line.values = AverageOverCycle(line.values, self.cycle_length)
line.values = _SliceValues(line.values)
if self.cycle_length:
line.values = AverageOverCycle(line.values, self.cycle_length)
if length is None:
length = len(line.values)
elif length != len(line.values):
raise Exception("All arrays should have the same length!")
if length is None:
length = len(line.values)
elif length != len(line.values):
raise Exception("All arrays should have the same length!")
ax1.set_xlabel("Frame", fontsize="large")
if any(line.flags & RIGHT_Y_AXIS for line in lines if line):
ax2 = ax1.twinx()
ax2.set_xlabel("Frame", fontsize="large")
else:
ax2 = None
ax1.set_xlabel("Frame", fontsize="large")
if any(line.flags & RIGHT_Y_AXIS for line in lines if line):
ax2 = ax1.twinx()
ax2.set_xlabel("Frame", fontsize="large")
else:
ax2 = None
# Have to implement color_cycle manually, due to two scales in a graph.
color_cycle = ["b", "r", "g", "c", "m", "y", "k"]
color_iter = itertools.cycle(color_cycle)
# Have to implement color_cycle manually, due to two scales in a graph.
color_cycle = ["b", "r", "g", "c", "m", "y", "k"]
color_iter = itertools.cycle(color_cycle)
for line in lines:
if not line:
color_iter.next()
continue
for line in lines:
if not line:
color_iter.next()
continue
if self.cycle_length:
x = numpy.array(range(self.cycle_length))
else:
x = numpy.array(range(self.offset, self.offset + len(line.values)))
y = numpy.array(line.values)
ax = ax2 if line.flags & RIGHT_Y_AXIS else ax1
ax.Plot(x, y, "o-", label=line.label, markersize=3.0, linewidth=1.0,
color=color_iter.next())
if self.cycle_length:
x = numpy.array(range(self.cycle_length))
else:
x = numpy.array(
range(self.offset, self.offset + len(line.values)))
y = numpy.array(line.values)
ax = ax2 if line.flags & RIGHT_Y_AXIS else ax1
ax.Plot(x,
y,
"o-",
label=line.label,
markersize=3.0,
linewidth=1.0,
color=color_iter.next())
ax1.grid(True)
if ax2:
ax1.legend(loc="upper left", shadow=True, fontsize="large")
ax2.legend(loc="upper right", shadow=True, fontsize="large")
else:
ax1.legend(loc="best", shadow=True, fontsize="large")
ax1.grid(True)
if ax2:
ax1.legend(loc="upper left", shadow=True, fontsize="large")
ax2.legend(loc="upper right", shadow=True, fontsize="large")
else:
ax1.legend(loc="best", shadow=True, fontsize="large")
def LoadFiles(filenames):
result = []
for filename in filenames:
if filename in LoadFiles.cache:
result.append(LoadFiles.cache[filename])
else:
data = Data(filename)
LoadFiles.cache[filename] = data
result.append(data)
return result
result = []
for filename in filenames:
if filename in LoadFiles.cache:
result.append(LoadFiles.cache[filename])
else:
data = Data(filename)
LoadFiles.cache[filename] = data
result.append(data)
return result
LoadFiles.cache = {}
def GetParser():
class CustomAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if "ordered_args" not in namespace:
namespace.ordered_args = []
namespace.ordered_args.append((self.dest, values))
class CustomAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
if "ordered_args" not in namespace:
namespace.ordered_args = []
namespace.ordered_args.append((self.dest, values))
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument(
"-c", "--cycle_length", nargs=1, action=CustomAction,
type=int, help="Cycle length over which to average the values.")
parser.add_argument(
"-f", "--field", nargs=1, action=CustomAction,
help="Name of the field to show. Use 'none' to skip a color.")
parser.add_argument("-r", "--right", nargs=0, action=CustomAction,
help="Use right Y axis for given field.")
parser.add_argument("-d", "--drop", nargs=0, action=CustomAction,
help="Hide values for dropped frames.")
parser.add_argument("-o", "--offset", nargs=1, action=CustomAction, type=int,
help="Frame offset.")
parser.add_argument("-n", "--next", nargs=0, action=CustomAction,
help="Separator for multiple graphs.")
parser.add_argument(
"--frames", nargs=1, action=CustomAction, type=int,
help="Frame count to show or take into account while averaging.")
parser.add_argument("-t", "--title", nargs=1, action=CustomAction,
help="Title of the graph.")
parser.add_argument(
"-O", "--output_filename", nargs=1, action=CustomAction,
help="Use to save the graph into a file. "
"Otherwise, a window will be shown.")
parser.add_argument(
"files", nargs="+", action=CustomAction,
help="List of text-based files generated by loopback tests.")
return parser
parser.add_argument("-c",
"--cycle_length",
nargs=1,
action=CustomAction,
type=int,
help="Cycle length over which to average the values.")
parser.add_argument(
"-f",
"--field",
nargs=1,
action=CustomAction,
help="Name of the field to show. Use 'none' to skip a color.")
parser.add_argument("-r",
"--right",
nargs=0,
action=CustomAction,
help="Use right Y axis for given field.")
parser.add_argument("-d",
"--drop",
nargs=0,
action=CustomAction,
help="Hide values for dropped frames.")
parser.add_argument("-o",
"--offset",
nargs=1,
action=CustomAction,
type=int,
help="Frame offset.")
parser.add_argument("-n",
"--next",
nargs=0,
action=CustomAction,
help="Separator for multiple graphs.")
parser.add_argument(
"--frames",
nargs=1,
action=CustomAction,
type=int,
help="Frame count to show or take into account while averaging.")
parser.add_argument("-t",
"--title",
nargs=1,
action=CustomAction,
help="Title of the graph.")
parser.add_argument("-O",
"--output_filename",
nargs=1,
action=CustomAction,
help="Use to save the graph into a file. "
"Otherwise, a window will be shown.")
parser.add_argument(
"files",
nargs="+",
action=CustomAction,
help="List of text-based files generated by loopback tests.")
return parser
def _PlotConfigFromArgs(args, graph_num):
# Pylint complains about using kwargs, so have to do it this way.
cycle_length = None
frames = None
offset = 0
output_filename = None
title = "Graph"
# Pylint complains about using kwargs, so have to do it this way.
cycle_length = None
frames = None
offset = 0
output_filename = None
title = "Graph"
fields = []
files = []
mask = 0
for key, values in args:
if key == "cycle_length":
cycle_length = values[0]
elif key == "frames":
frames = values[0]
elif key == "offset":
offset = values[0]
elif key == "output_filename":
output_filename = values[0]
elif key == "title":
title = values[0]
elif key == "drop":
mask |= HIDE_DROPPED
elif key == "right":
mask |= RIGHT_Y_AXIS
elif key == "field":
field_id = FieldArgToId(values[0])
fields.append(field_id | mask if field_id is not None else None)
mask = 0 # Reset mask after the field argument.
elif key == "files":
files.extend(values)
fields = []
files = []
mask = 0
for key, values in args:
if key == "cycle_length":
cycle_length = values[0]
elif key == "frames":
frames = values[0]
elif key == "offset":
offset = values[0]
elif key == "output_filename":
output_filename = values[0]
elif key == "title":
title = values[0]
elif key == "drop":
mask |= HIDE_DROPPED
elif key == "right":
mask |= RIGHT_Y_AXIS
elif key == "field":
field_id = FieldArgToId(values[0])
fields.append(field_id | mask if field_id is not None else None)
mask = 0 # Reset mask after the field argument.
elif key == "files":
files.extend(values)
if not files:
raise Exception("Missing file argument(s) for graph #{}".format(graph_num))
if not fields:
raise Exception("Missing field argument(s) for graph #{}".format(graph_num))
if not files:
raise Exception(
"Missing file argument(s) for graph #{}".format(graph_num))
if not fields:
raise Exception(
"Missing field argument(s) for graph #{}".format(graph_num))
return PlotConfig(fields, LoadFiles(files), cycle_length=cycle_length,
frames=frames, offset=offset, output_filename=output_filename,
title=title)
return PlotConfig(fields,
LoadFiles(files),
cycle_length=cycle_length,
frames=frames,
offset=offset,
output_filename=output_filename,
title=title)
def PlotConfigsFromArgs(args):
"""Generates plot configs for given command line arguments."""
# The way it works:
# First we detect separators -n/--next and split arguments into groups, one
# for each plot. For each group, we partially parse it with
# argparse.ArgumentParser, modified to remember the order of arguments.
# Then we traverse the argument list and fill the PlotConfig.
args = itertools.groupby(args, lambda x: x in ["-n", "--next"])
prep_args = list(list(group) for match, group in args if not match)
"""Generates plot configs for given command line arguments."""
# The way it works:
# First we detect separators -n/--next and split arguments into groups, one
# for each plot. For each group, we partially parse it with
# argparse.ArgumentParser, modified to remember the order of arguments.
# Then we traverse the argument list and fill the PlotConfig.
args = itertools.groupby(args, lambda x: x in ["-n", "--next"])
prep_args = list(list(group) for match, group in args if not match)
parser = GetParser()
plot_configs = []
for index, raw_args in enumerate(prep_args):
graph_args = parser.parse_args(raw_args).ordered_args
plot_configs.append(_PlotConfigFromArgs(graph_args, index))
return plot_configs
parser = GetParser()
plot_configs = []
for index, raw_args in enumerate(prep_args):
graph_args = parser.parse_args(raw_args).ordered_args
plot_configs.append(_PlotConfigFromArgs(graph_args, index))
return plot_configs
def ShowOrSavePlots(plot_configs):
for config in plot_configs:
fig = plt.figure(figsize=(14.0, 10.0))
ax = fig.add_subPlot(1, 1, 1)
for config in plot_configs:
fig = plt.figure(figsize=(14.0, 10.0))
ax = fig.add_subPlot(1, 1, 1)
plt.title(config.title)
config.Plot(ax)
if config.output_filename:
print "Saving to", config.output_filename
fig.savefig(config.output_filename)
plt.close(fig)
plt.title(config.title)
config.Plot(ax)
if config.output_filename:
print "Saving to", config.output_filename
fig.savefig(config.output_filename)
plt.close(fig)
plt.show()
plt.show()
if __name__ == "__main__":
ShowOrSavePlots(PlotConfigsFromArgs(sys.argv[1:]))
ShowOrSavePlots(PlotConfigsFromArgs(sys.argv[1:]))