"""Module for downloading files from ESGF."""
import concurrent.futures
import contextlib
import datetime
import functools
import hashlib
import itertools
import logging
import os
import random
import re
import shutil
from pathlib import Path
from statistics import median
from tempfile import NamedTemporaryFile
from urllib.parse import urlparse
import requests
import yaml
from humanfriendly import format_size, format_timespan
from esmvalcore.typing import Facets
from ..local import LocalFile
from .facets import DATASET_MAP, FACETS
logger = logging.getLogger(__name__)
TIMEOUT = 5 * 60
"""Timeout (in seconds) for downloads."""
HOSTS_FILE = Path.home() / ".esmvaltool" / "cache" / "esgf-hosts.yml"
SIZE = "size (bytes)"
DURATION = "duration (s)"
SPEED = "speed (MB/s)"
class DownloadError(Exception):
"""An error occurred while downloading."""
def compute_speed(size, duration):
"""Compute download speed in MB/s."""
if duration != 0:
speed = size / duration / 10**6
else:
speed = 0
return speed
def load_speeds():
"""Load average download speeds from HOSTS_FILE."""
try:
content = HOSTS_FILE.read_text(encoding="utf-8")
except FileNotFoundError:
content = "{}"
speeds = yaml.safe_load(content)
return speeds
def log_speed(url, size, duration):
"""Write the downloaded file size and duration to HOSTS_FILE."""
speeds = load_speeds()
host = urlparse(url).hostname
size += speeds.get(host, {}).get(SIZE, 0)
duration += speeds.get(host, {}).get(DURATION, 0)
speed = compute_speed(size, duration)
speeds[host] = {
SIZE: size,
DURATION: round(duration),
SPEED: round(speed, 1),
"error": False,
}
with atomic_write(HOSTS_FILE) as file:
yaml.safe_dump(speeds, file)
def log_error(url):
"""Write the hosts that errored to HOSTS_FILE."""
speeds = load_speeds()
host = urlparse(url).hostname
entry = speeds.get(host, {SIZE: 0, DURATION: 0, SPEED: 0})
entry["error"] = True
speeds[host] = entry
with atomic_write(HOSTS_FILE) as file:
yaml.safe_dump(speeds, file)
@contextlib.contextmanager
def atomic_write(filename):
"""Write a file without the risk of interfering with other processes."""
filename.parent.mkdir(parents=True, exist_ok=True)
with NamedTemporaryFile(prefix=f"{filename}.") as file:
tmp_file = file.name
with open(tmp_file, "w", encoding="utf-8") as file:
yield file
shutil.move(tmp_file, filename)
def get_preferred_hosts():
"""Get a list of preferred hosts.
The list will be sorted by download speed. Hosts that recentely
returned an error will be at the end.
"""
speeds = load_speeds()
if not speeds:
return []
# Compute speeds from size and duration
for entry in speeds.values():
entry[SPEED] = compute_speed(entry[SIZE], entry[DURATION])
# Hosts from which no data has been downloaded yet get median speed; if no
# host with non-zero entries is found assign a value of 0.0
speeds_list = [speeds[h][SPEED] for h in speeds if speeds[h][SPEED] != 0.0]
if not speeds_list:
median_speed = 0.0
else:
median_speed = median(speeds_list)
for host in speeds:
if speeds[host][SIZE] == 0:
speeds[host][SPEED] = median_speed
# Sort hosts by download speed
hosts = sorted(speeds, key=lambda h: speeds[h][SPEED], reverse=True)
# Figure out which hosts recently returned an error
mtime = HOSTS_FILE.stat().st_mtime
now = datetime.datetime.now().timestamp()
age = now - mtime
if age > 60 * 60:
# Ignore errors older than an hour
errored = []
else:
errored = [h for h in speeds if speeds[h]["error"]]
# Move hosts with an error to the end of the list
for host in errored:
if host in hosts:
hosts.pop(hosts.index(host))
hosts.append(host)
return hosts
def sort_hosts(urls):
"""Sort a list of URLs by preferred hosts.
Parameters
----------
urls : :obj:`list` of :obj:`str`
List of all available URLs.
Returns
-------
:obj:`list` of :obj:`str`
The list of URLs, with URLs from a preferred hosts first.
"""
urls = list(urls)
hosts = [urlparse(url).hostname for url in urls]
preferred_hosts = get_preferred_hosts()
for host in preferred_hosts:
if host in hosts:
# Move host and corresponding URL to the beginning of the list,
# but after any unknown hosts so these will get used too.
idx = hosts.index(host)
hosts.append(hosts.pop(idx))
urls.append(urls.pop(idx))
return urls
[docs]
@functools.total_ordering
class ESGFFile:
"""File on the ESGF.
This is the object returned by :func:`esmvalcore.esgf.find_files`.
Attributes
----------
dataset : str
The name of the dataset that the file is part of.
facets : dict[str,str]
Facets describing the file.
name : str
The name of the file.
size : int
The size of the file in bytes.
urls : list[str]
The URLs where the file can be downloaded.
"""
def __init__(self, results):
results = list(results)
self.name = str(Path(results[0].filename).with_suffix(".nc"))
self.size = results[0].size
self.dataset = self._get_dataset_id(results)
self.facets = self._get_facets(results)
self.urls = []
self._checksums = []
for result in results:
self.urls.append(result.download_url)
self._checksums.append((result.checksum_type, result.checksum))
@classmethod
def _from_results(cls, results, facets):
"""Return a list of files from a pyesgf.search.results.ResultSet."""
def same_file(result):
# Remove the hostname from the dataset_id
dataset = result.json["dataset_id"].split("|")[0]
# Ignore the extension (some files are called .nc_0, .nc_1)
filename = Path(result.filename).stem
# Ignore case
return (dataset.lower(), filename.lower())
files = []
results = sorted(results, key=same_file)
for _, file_results in itertools.groupby(results, key=same_file):
file = cls(file_results)
# Filter out files containing the wrong variable, e.g. for
# cmip5.output1.ICHEC.EC-EARTH.historical
# .mon.atmos.Amon.r1i1p1.v20121115
variable = file.name.split("_")[0]
if "variable" not in facets or facets["variable"] == variable:
files.append(file)
else:
logger.debug(
"Ignoring file(s) %s containing wrong variable '%s' in"
" found in search for variable '%s'",
file.urls,
variable,
facets.get("variable", facets.get("variable_id", "?")),
)
return files
def _get_facets(self, results):
"""Read the facets.
This works by first reading the facets from the json response of
the first search result. Next, an alternative set of facets is
read from the `dataset_id` and filename and used to correct any
wrong facets values.
"""
project = results[0].json["project"][0]
# Read the facets from the metadata
facets = {
our_facet: results[0].json[their_facet]
for our_facet, their_facet in FACETS[project].items()
if their_facet in results[0].json
}
facets = {
facet: value[0]
if isinstance(value, list) and len(value) == 1
else value
for facet, value in facets.items()
}
facets["project"] = project
if "dataset" in facets:
reverse_dataset_map = {
v: k for k, v in DATASET_MAP.get(project, {}).items()
}
facets["dataset"] = reverse_dataset_map.get(
facets["dataset"], facets["dataset"]
)
# Update the facets with information from the dataset_id and filename
more_reliable_facets = self._get_facets_from_dataset_id(results)
for facet, value in more_reliable_facets.items():
if facet not in facets or facets[facet] != value:
logger.debug(
"Correcting facet '%s' from '%s' to '%s' for %s.%s",
facet,
facets.get(facet),
value,
self.dataset,
self.name,
)
facets[facet] = value
return facets
@staticmethod
def _get_facets_from_dataset_id(results) -> Facets:
"""Read the facets from the `dataset_id`."""
# This reads the facets from the dataset_id because the facets
# provided by ESGF are unreliable.
#
# Example dataset_id_template_ values:
# CMIP3: '%(project)s.%(institute)s.%(model)s.%(experiment)s.
# %(time_frequency)s.%(realm)s.%(ensemble)s.%(variable)s'
# CMIP5: 'cmip5.%(product)s.%(valid_institute)s.%(model)s.
# %(experiment)s.%(time_frequency)s.%(realm)s.%(cmor_table)s.
# %(ensemble)s'
# CMIP6: '%(mip_era)s.%(activity_drs)s.%(institution_id)s.
# %(source_id)s.%(experiment_id)s.%(member_id)s.%(table_id)s.
# %(variable_id)s.%(grid_label)s'
# CORDEX: 'cordex.%(product)s.%(domain)s.%(institute)s.
# %(driving_model)s.%(experiment)s.%(ensemble)s.%(rcm_name)s.
# %(rcm_version)s.%(time_frequency)s.%(variable)s'
# obs4MIPs: '%(project)s.%(institute)s.%(source_id)s.%(realm)s.
# %(time_frequency)s'
project = results[0].json["project"][0]
# Read the keys from `dataset_id_template_` and translate to our keys
template = results[0].json["dataset_id_template_"][0]
keys = re.findall(r"%\((.*?)\)s", template)
reverse_facet_map = {v: k for k, v in FACETS[project].items()}
reverse_facet_map["realm"] = "modeling_realm"
reverse_facet_map["mip_era"] = "project" # CMIP6 oddity
reverse_facet_map["variable_id"] = "short_name" # CMIP6 oddity
reverse_facet_map["valid_institute"] = "institute" # CMIP5 oddity
keys = [reverse_facet_map.get(k, k) for k in keys]
keys.append("version")
if keys[0] == "project":
# The project is sometimes hardcoded all lowercase in the template
keys = keys[1:]
# Read values from dataset_id
# Pick the first dataset_id if there are differences in case
dataset_id = sorted(
r.json["dataset_id"].split("|")[0] for r in results
)[0]
values = dataset_id.split(".")[1:]
facets = {}
if len(keys) == len(values):
for idx, key in enumerate(keys):
facets[key] = values[idx]
else:
logger.debug(
"Wrong dataset_id_template_ %s or facet values containing '.' "
"for dataset %s",
template,
dataset_id,
)
facets["version"] = dataset_id.split(".")[-1]
# The dataset_id does not contain the short_name for all projects,
# so get it from the filename:
facets["short_name"] = results[0].json["title"].split("_")[0]
return facets
@staticmethod
def _get_dataset_id(results):
"""Simplify dataset_id so it is always composed of the same facets."""
# Pick the first dataset_id if there are differences in case
dataset_id = sorted(
r.json["dataset_id"].split("|")[0] for r in results
)[0]
project = results[0].json["project"][0]
if project != "obs4MIPs":
return dataset_id
# Simplify the obs4MIPs dataset_id so it contains only facets that are
# present for all datasets.
version = dataset_id.rsplit(".", 1)[1]
dataset_key = FACETS[project]["dataset"]
dataset_name = results[0].json[dataset_key][0]
dataset_name = DATASET_MAP[project].get(dataset_name, dataset_name)
return f"{project}.{dataset_name}.{version}"
def _get_relative_path(self) -> Path:
"""Get the subdirectories."""
if self.facets["project"] == "obs4MIPs":
# Avoid errors due to a to a `.` in the dataset name
facets = ["project", "dataset", "version"]
path = Path(*[self.facets[f] for f in facets])
else:
path = Path(*self.dataset.split("."))
return path / self.name
def __repr__(self):
"""Represent the file as a string."""
hosts = [urlparse(u).hostname for u in self.urls]
return f"ESGFFile:{self._get_relative_path()} on hosts {hosts}"
def __eq__(self, other):
"""Compare `self` to `other`."""
return isinstance(other, self.__class__) and (
self.dataset,
self.name,
) == (other.dataset, other.name)
def __lt__(self, other):
"""Compare `self` to `other`."""
return (self.dataset, self.name) < (other.dataset, other.name)
def __hash__(self):
"""Compute a unique hash value."""
return hash((self.dataset, self.name))
[docs]
def local_file(self, dest_folder):
"""Return the path to the local file after download.
Arguments
---------
dest_folder: Path
The destination folder.
Returns
-------
LocalFile
The path where the file will be located after download.
"""
file = LocalFile(dest_folder, self._get_relative_path())
file.facets = self.facets
return file
[docs]
def download(self, dest_folder):
"""Download the file.
Arguments
---------
dest_folder: Path
The destination folder.
Raises
------
DownloadError:
Raised if downloading the file failed.
Returns
-------
LocalFile
The path where the file will be located after download.
"""
local_file = self.local_file(dest_folder)
if local_file.exists():
logger.debug("Skipping download of existing file %s", local_file)
return local_file
os.makedirs(local_file.parent, exist_ok=True)
errors = {}
for url in sort_hosts(self.urls):
try:
self._download(local_file, url)
except (
DownloadError,
requests.exceptions.RequestException,
) as error:
logger.debug(
"Not able to download %s. Error message: %s", url, error
)
errors[url] = error
log_error(url)
else:
break
if not local_file.exists():
raise DownloadError(
f"Failed to download file {local_file}, errors:"
"\n" + "\n".join(f"{url}: {errors[url]}" for url in errors)
)
return local_file
@staticmethod
def _tmp_local_file(local_file):
"""Return the path to a temporary local file for downloading to."""
with NamedTemporaryFile(prefix=f"{local_file}.") as tmp_file:
return Path(tmp_file.name)
def _download(self, local_file, url):
"""Download file from a single url."""
idx = self.urls.index(url)
checksum_type, checksum = self._checksums[idx]
if checksum_type is None:
hasher = None
else:
hasher = hashlib.new(checksum_type)
tmp_file = self._tmp_local_file(local_file)
logger.debug("Downloading %s to %s", url, tmp_file)
start_time = datetime.datetime.now()
response = requests.get(url, stream=True, timeout=TIMEOUT)
response.raise_for_status()
with tmp_file.open("wb") as file:
# Specify chunk_size to avoid
# https://github.com/psf/requests/issues/5536
megabyte = 2**20
for chunk in response.iter_content(chunk_size=megabyte):
if hasher is not None:
hasher.update(chunk)
file.write(chunk)
duration = datetime.datetime.now() - start_time
if hasher is None:
logger.warning(
"No checksum available, unable to check data"
" integrity for %s, ",
url,
)
else:
local_checksum = hasher.hexdigest()
if local_checksum != checksum:
raise DownloadError(
f"Wrong {checksum_type} checksum for file {tmp_file},"
f" downloaded from {url}: expected {checksum}, but got"
f" {local_checksum}. Try downloading the file again."
)
shutil.move(tmp_file, local_file)
log_speed(url, self.size, duration.total_seconds())
logger.info(
"Downloaded %s (%s) in %s (%s/s) from %s",
local_file,
format_size(self.size),
format_timespan(duration.total_seconds()),
format_size(self.size / duration.total_seconds()),
urlparse(url).hostname,
)
def get_download_message(files):
"""Create a log message describing what will be downloaded."""
total_size = 0
lines = []
for file in files:
total_size += file.size
lines.append(f"{format_size(file.size)}\t{file}")
lines.insert(0, "Will download the following files:")
lines.insert(0, f"Will download {format_size(total_size)}")
lines.append(f"Downloading {format_size(total_size)}..")
return "\n".join(lines)
[docs]
def download(files, dest_folder, n_jobs=4):
"""Download multiple ESGFFiles in parallel.
Arguments
---------
files: list of :obj:`ESGFFile`
The files to download.
dest_folder: Path
The destination folder.
n_jobs: int
The number of files to download in parallel.
Raises
------
DownloadError:
Raised if one or more files failed to download.
"""
files = [
file
for file in files
if isinstance(file, ESGFFile)
and not file.local_file(dest_folder).exists()
]
if not files:
logger.debug(
"All required data is available locally, not downloading anything."
)
return
files = sorted(files)
logger.info(get_download_message(files))
def _download(file: ESGFFile):
"""Download file to dest_folder."""
file.download(dest_folder)
total_size = 0
start_time = datetime.datetime.now()
errors = []
random.shuffle(files)
with concurrent.futures.ThreadPoolExecutor(max_workers=n_jobs) as executor:
future_to_file = {
executor.submit(_download, file): file for file in files
}
for future in concurrent.futures.as_completed(future_to_file):
file = future_to_file[future]
try:
future.result()
except DownloadError as error:
logger.error(
"Failed to download %s, error message %s", file, error
)
errors.append(error)
else:
total_size += file.size
duration = datetime.datetime.now() - start_time
logger.info(
"Downloaded %s in %s (%s/s)",
format_size(total_size),
format_timespan(duration.total_seconds()),
format_size(total_size / duration.total_seconds()),
)
if errors:
msg = "Failed to download the following files:\n" + "\n".join(
sorted(str(error) for error in errors)
)
raise DownloadError(msg)
logger.info("Successfully downloaded all requested files.")