Spaces:
Runtime error
Runtime error
import argparse | |
import re | |
from pathlib import Path | |
import nbformat | |
import nbconvert | |
from traitlets.config import Config | |
# Notebooks that are excluded from the CI tests | |
EXCLUDED_NOTEBOOKS = ["data-preparation-ct-scan.ipynb", "pytorch-monai-training.ipynb"] | |
DEVICE_WIDGET = "device = widgets.Dropdown(" | |
def disable_gradio_debug(nb, notebook_path): | |
found = False | |
for cell in nb["cells"]: | |
if "gradio" in cell["source"] and "debug" in cell["source"]: | |
found = True | |
cell["source"] = cell["source"].replace("debug=True", "debug=False") | |
if found: | |
print(f"Disabled gradio debug mode for {notebook_path}") | |
return nb | |
def disable_skip_ext(nb, notebook_path, test_device=""): | |
found = False | |
skip_for_device = None if test_device else False | |
for cell in nb["cells"]: | |
if test_device is not None and skip_for_device is None: | |
if ( | |
'skip_for_device = "{}" in device.value'.format(test_device) in cell["source"] | |
and "to_quantize = widgets.Checkbox(value=not skip_for_device" in cell["source"] | |
): | |
skip_for_device = True | |
if "%%skip" in cell["source"]: | |
found = True | |
if not skip_for_device: | |
cell["source"] = re.sub(r"%%skip.*.\n", "\n", cell["source"]) | |
else: | |
cell["source"] = '"""\n' + cell["source"] + '\n"""' | |
if found: | |
print(f"Disabled skip extension mode for {notebook_path}") | |
return nb | |
def remove_ov_install(cell): | |
updated_lines = [] | |
def has_additional_deps(str_part): | |
if "%pip" in str_part: | |
return False | |
if "install" in str_part: | |
return False | |
if str_part.startswith("-"): | |
return False | |
if str_part.startswith("https://"): | |
return False | |
return True | |
lines = cell["source"].split("\n") | |
for line in lines: | |
if "openvino" in line: | |
updated_line_content = [] | |
empty = True | |
package_found = False | |
for part in line.split(" "): | |
if "openvino-dev" in part: | |
package_found = True | |
continue | |
if "openvino-nightly" in part: | |
package_found = True | |
continue | |
if "openvino-tokenizers" in part: | |
package_found = True | |
continue | |
if "openvino>" in part or "openvino=" in part or "openvino" == part: | |
package_found = True | |
continue | |
if empty: | |
empty = not has_additional_deps(part) | |
updated_line_content.append(part) | |
if package_found: | |
if not empty: | |
updated_line = " ".join(updated_line_content) | |
if line.startswith(" "): | |
for token in line: | |
if token != " ": | |
break | |
# keep indention | |
updated_line = " " + updated_line | |
updated_lines.append(updated_line + "\n# " + line) | |
else: | |
updated_lines.append(line) | |
else: | |
updated_lines.append(line) | |
cell["source"] = "\n".join(updated_lines) | |
def patch_notebooks(notebooks_dir, test_device="", skip_ov_install=False): | |
""" | |
Patch notebooks in notebooks directory with replacement values | |
found in notebook metadata to speed up test execution. | |
This function is specific for the OpenVINO notebooks | |
Github Actions CI. | |
For example: change nr of epochs from 15 to 1 in | |
tensorflow-training-openvino-nncf.ipynb by adding | |
{"test_replace": {"epochs = 15": "epochs = 1"} to the cell | |
metadata of the cell that contains `epochs = 15` | |
:param notebooks_dir: Directory that contains the notebook subdirectories. | |
For example: openvino_notebooks/notebooks | |
""" | |
nb_convert_config = Config() | |
nb_convert_config.NotebookExporter.preprocessors = ["nbconvert.preprocessors.ClearOutputPreprocessor"] | |
output_remover = nbconvert.NotebookExporter(nb_convert_config) | |
for notebookfile in Path(notebooks_dir).glob("**/*.ipynb"): | |
if not str(notebookfile.name).startswith("test_") and notebookfile.name not in EXCLUDED_NOTEBOOKS: | |
nb = nbformat.read(notebookfile, as_version=nbformat.NO_CONVERT) | |
found = False | |
device_found = False | |
for cell in nb["cells"]: | |
if skip_ov_install and "%pip" in cell["source"]: | |
remove_ov_install(cell) | |
if test_device and DEVICE_WIDGET in cell["source"]: | |
device_found = True | |
cell["source"] = re.sub(r"value=.*,", f"value='{test_device.upper()}',", cell["source"]) | |
cell["source"] = re.sub( | |
r"options=", | |
f"options=['{test_device.upper()}'] + ", | |
cell["source"], | |
) | |
print(f"Replaced testing device to {test_device}") | |
replace_dict = cell.get("metadata", {}).get("test_replace") | |
if replace_dict is not None: | |
found = True | |
for source_value, target_value in replace_dict.items(): | |
if source_value not in cell["source"]: | |
raise ValueError(f"Processing {notebookfile} failed: {source_value} does not exist in cell") | |
cell["source"] = cell["source"].replace(source_value, target_value) | |
cell["source"] = "# Modified for testing\n" + cell["source"] | |
print(f"Processed {notebookfile}: {source_value} -> {target_value}") | |
if test_device and not device_found: | |
print(f"No device replacement found for {notebookfile}") | |
if not found: | |
print(f"No replacements found for {notebookfile}") | |
disable_gradio_debug(nb, notebookfile) | |
disable_skip_ext(nb, notebookfile, args.test_device) | |
nb_without_out, _ = output_remover.from_notebook_node(nb) | |
with notebookfile.with_name(f"test_{notebookfile.name}").open("w", encoding="utf-8") as out_file: | |
out_file.write(nb_without_out) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser("Notebook patcher") | |
parser.add_argument("notebooks_dir", default=".") | |
parser.add_argument("-td", "--test_device", default="") | |
parser.add_argument("--skip_ov_install", action="store_true") | |
args = parser.parse_args() | |
if not Path(args.notebooks_dir).is_dir(): | |
raise ValueError(f"'{args.notebooks_dir}' is not an existing directory") | |
patch_notebooks(args.notebooks_dir, args.test_device, args.skip_ov_install) | |