""" |
Utilities for working with the local dataset cache. |
This file is adapted from `AllenNLP <https://github.com/allenai/allennlp>`_. |
and `huggingface <https://github.com/huggingface>`_. |
""" |
import fnmatch |
import json |
import logging |
import os |
import shutil |
import tarfile |
import tempfile |
from functools import partial, wraps |
from hashlib import sha256 |
from io import open |
try: |
from torch.hub import _get_torch_home |
torch_cache_home = _get_torch_home() |
except ImportError: |
torch_cache_home = os.path.expanduser( |
os.getenv( |
"TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch") |
) |
) |
default_cache_path = os.path.join(torch_cache_home, "pytorch_fairseq") |
try: |
from urllib.parse import urlparse |
except ImportError: |
from urlparse import urlparse |
try: |
from pathlib import Path |
PYTORCH_FAIRSEQ_CACHE = Path(os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path)) |
except (AttributeError, ImportError): |
PYTORCH_FAIRSEQ_CACHE = os.getenv("PYTORCH_FAIRSEQ_CACHE", default_cache_path) |
CONFIG_NAME = "config.json" |
WEIGHTS_NAME = "pytorch_model.bin" |
logger = logging.getLogger(__name__) |
def load_archive_file(archive_file): |
try: |
resolved_archive_file = cached_path(archive_file, cache_dir=None) |
except EnvironmentError: |
logger.info( |
"Archive name '{}' was not found in archive name list. " |
"We assumed '{}' was a path or URL but couldn't find any file " |
"associated to this path or URL.".format( |
archive_file, |
archive_file, |
) |
) |
return None |
if resolved_archive_file == archive_file: |
logger.info("loading archive file {}".format(archive_file)) |
else: |
logger.info( |
"loading archive file {} from cache at {}".format( |
archive_file, resolved_archive_file |
) |
) |
tempdir = None |
if not os.path.isdir(resolved_archive_file): |
tempdir = tempfile.mkdtemp() |
logger.info( |
"extracting archive file {} to temp dir {}".format( |
resolved_archive_file, tempdir |
) |
) |
ext = os.path.splitext(archive_file)[1][1:] |
with tarfile.open(resolved_archive_file, "r:" + ext) as archive: |
top_dir = os.path.commonprefix(archive.getnames()) |
archive.extractall(tempdir) |
os.remove(resolved_archive_file) |
shutil.move(os.path.join(tempdir, top_dir), resolved_archive_file) |
shutil.rmtree(tempdir) |
return resolved_archive_file |
def url_to_filename(url, etag=None): |
""" |
Convert `url` into a hashed filename in a repeatable way. |
If `etag` is specified, append its hash to the URL's, delimited |
by a period. |
""" |
url_bytes = url.encode("utf-8") |
url_hash = sha256(url_bytes) |
filename = url_hash.hexdigest() |
if etag: |
etag_bytes = etag.encode("utf-8") |
etag_hash = sha256(etag_bytes) |
filename += "." + etag_hash.hexdigest() |
return filename |
def filename_to_url(filename, cache_dir=None): |
""" |
Return the url and etag (which may be ``None``) stored for `filename`. |
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist. |
""" |
if cache_dir is None: |
if isinstance(cache_dir, Path): |
cache_dir = str(cache_dir) |
cache_path = os.path.join(cache_dir, filename) |
if not os.path.exists(cache_path): |
raise EnvironmentError("file {} not found".format(cache_path)) |
meta_path = cache_path + ".json" |
if not os.path.exists(meta_path): |
raise EnvironmentError("file {} not found".format(meta_path)) |
with open(meta_path, encoding="utf-8") as meta_file: |
metadata = json.load(meta_file) |
url = metadata["url"] |
etag = metadata["etag"] |
return url, etag |
def cached_path_from_pm(url_or_filename): |
""" |
Tries to cache the specified URL using PathManager class. |
Returns the cached path if success otherwise failure. |
""" |
try: |
from fairseq.file_io import PathManager |
local_path = PathManager.get_local_path(url_or_filename) |
return local_path |
except Exception: |
return None |
def cached_path(url_or_filename, cache_dir=None): |
""" |
Given something that might be a URL (or might be a local path), |
determine which. If it's a URL, download the file and cache it, and |
return the path to the cached file. If it's already a local path, |
make sure the file exists and then return the path. |
""" |
if cache_dir is None: |
if isinstance(url_or_filename, Path): |
url_or_filename = str(url_or_filename) |
if isinstance(cache_dir, Path): |
cache_dir = str(cache_dir) |
parsed = urlparse(url_or_filename) |
if parsed.scheme in ("http", "https", "s3"): |
return get_from_cache(url_or_filename, cache_dir) |
elif os.path.exists(url_or_filename): |
return url_or_filename |
elif parsed.scheme == "": |
raise EnvironmentError("file {} not found".format(url_or_filename)) |
else: |
cached_path = cached_path_from_pm(url_or_filename) |
if cached_path: |
return cached_path |
raise ValueError( |
"unable to parse {} as a URL or as a local path".format(url_or_filename) |
) |
def split_s3_path(url): |
"""Split a full s3 path into the bucket name and path.""" |
parsed = urlparse(url) |
if not parsed.netloc or not parsed.path: |
raise ValueError("bad s3 path {}".format(url)) |
bucket_name = parsed.netloc |
s3_path = parsed.path |
if s3_path.startswith("/"): |
s3_path = s3_path[1:] |
return bucket_name, s3_path |
def s3_request(func): |
""" |
Wrapper function for s3 requests in order to create more helpful error |
messages. |
""" |
@wraps(func) |
def wrapper(url, *args, **kwargs): |
from botocore.exceptions import ClientError |
try: |
return func(url, *args, **kwargs) |
except ClientError as exc: |
if int(exc.response["Error"]["Code"]) == 404: |
raise EnvironmentError("file {} not found".format(url)) |
else: |
raise |
return wrapper |
@s3_request |
def s3_etag(url): |
"""Check ETag on S3 object.""" |
import boto3 |
s3_resource = boto3.resource("s3") |
bucket_name, s3_path = split_s3_path(url) |
s3_object = s3_resource.Object(bucket_name, s3_path) |
return s3_object.e_tag |
@s3_request |
def s3_get(url, temp_file): |
"""Pull a file directly from S3.""" |
import boto3 |
s3_resource = boto3.resource("s3") |
bucket_name, s3_path = split_s3_path(url) |
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file) |
def request_wrap_timeout(func, url): |
import requests |
for attempt, timeout in enumerate([10, 20, 40, 60, 60]): |
try: |
return func(timeout=timeout) |
except requests.exceptions.Timeout as e: |
logger.warning( |
"Request for %s timed-out (attempt %d). Retrying with a timeout of %d secs", |
url, |
attempt, |
timeout, |
exc_info=e, |
) |
continue |
raise RuntimeError(f"Unable to fetch file {url}") |
def http_get(url, temp_file): |
import requests |
from tqdm import tqdm |
req = request_wrap_timeout(partial(requests.get, url, stream=True), url) |
content_length = req.headers.get("Content-Length") |
total = int(content_length) if content_length is not None else None |
progress = tqdm(unit="B", total=total) |
for chunk in req.iter_content(chunk_size=1024): |
if chunk: |
progress.update(len(chunk)) |
temp_file.write(chunk) |
progress.close() |
def get_from_cache(url, cache_dir=None): |
""" |
Given a URL, look for the corresponding dataset in the local cache. |
If it's not there, download it. Then return the path to the cached file. |
""" |
if cache_dir is None: |
if isinstance(cache_dir, Path): |
cache_dir = str(cache_dir) |
if not os.path.exists(cache_dir): |
os.makedirs(cache_dir) |
if url.startswith("s3://"): |
etag = s3_etag(url) |
else: |
try: |
import requests |
response = request_wrap_timeout( |
partial(requests.head, url, allow_redirects=True), url |
) |
if response.status_code != 200: |
etag = None |
else: |
etag = response.headers.get("ETag") |
except RuntimeError: |
etag = None |
filename = url_to_filename(url, etag) |
cache_path = os.path.join(cache_dir, filename) |
if not os.path.exists(cache_path) and etag is None: |
matching_files = fnmatch.filter(os.listdir(cache_dir), filename + ".*") |
matching_files = list(filter(lambda s: not s.endswith(".json"), matching_files)) |
if matching_files: |
cache_path = os.path.join(cache_dir, matching_files[-1]) |
if not os.path.exists(cache_path): |
with tempfile.NamedTemporaryFile() as temp_file: |
logger.info("%s not found in cache, downloading to %s", url, temp_file.name) |
if url.startswith("s3://"): |
s3_get(url, temp_file) |
else: |
http_get(url, temp_file) |
temp_file.flush() |
temp_file.seek(0) |
logger.info("copying %s to cache at %s", temp_file.name, cache_path) |
with open(cache_path, "wb") as cache_file: |
shutil.copyfileobj(temp_file, cache_file) |
logger.info("creating metadata file for %s", cache_path) |
meta = {"url": url, "etag": etag} |
meta_path = cache_path + ".json" |
with open(meta_path, "w") as meta_file: |
output_string = json.dumps(meta) |
meta_file.write(output_string) |
logger.info("removing temp file %s", temp_file.name) |
return cache_path |
def read_set_from_file(filename): |
""" |
Extract a de-duped collection (set) of text from a file. |
Expected file format is one item per line. |
""" |
collection = set() |
with open(filename, "r", encoding="utf-8") as file_: |
for line in file_: |
collection.add(line.rstrip()) |
return collection |
def get_file_extension(path, dot=True, lower=True): |
ext = os.path.splitext(path)[1] |
ext = ext if dot else ext[1:] |
return ext.lower() if lower else ext |