"""Base class for monitoring diagnostics."""
import logging
import os
import re
import cartopy
import matplotlib.pyplot as plt
import yaml
from iris.analysis import MEAN
from mapgenerator.plotting.timeseries import PlotSeries
from esmvaltool.diag_scripts.shared import ProvenanceLogger, names
logger = logging.getLogger(__name__)
def _replace_tags(paths, variable):
"""Replace tags in the config-developer's file with actual values."""
if isinstance(paths, str):
paths = set((paths.strip("/"),))
else:
paths = set(path.strip("/") for path in paths)
tlist = set()
for path in paths:
tlist = tlist.union(re.findall(r"{([^}]*)}", path))
if "sub_experiment" in variable:
new_paths = []
for path in paths:
new_paths.extend(
(
re.sub(r"(\b{ensemble}\b)", r"{sub_experiment}-\1", path),
re.sub(r"({ensemble})", r"{sub_experiment}-\1", path),
)
)
tlist.add("sub_experiment")
paths = new_paths
for tag in tlist:
original_tag = tag
tag, _, _ = _get_caps_options(tag)
if tag == "latestversion": # handled separately later
continue
if tag in variable:
replacewith = variable[tag]
else:
raise ValueError(
f"Dataset key '{tag}' must be specified for "
f"{variable}, check your recipe entry"
)
paths = _replace_tag(paths, original_tag, replacewith)
return paths
def _replace_tag(paths, tag, replacewith):
"""Replace tag by replacewith in paths."""
_, lower, upper = _get_caps_options(tag)
result = []
if isinstance(replacewith, (list, tuple)):
for item in replacewith:
result.extend(_replace_tag(paths, tag, item))
else:
text = _apply_caps(str(replacewith), lower, upper)
result.extend(p.replace("{" + tag + "}", text) for p in paths)
return list(set(result))
def _get_caps_options(tag):
lower = False
upper = False
if tag.endswith(".lower"):
lower = True
tag = tag[0:-6]
elif tag.endswith(".upper"):
upper = True
tag = tag[0:-6]
return tag, lower, upper
def _apply_caps(original, lower, upper):
if lower:
return original.lower()
if upper:
return original.upper()
return original
[docs]
class MonitorBase:
"""Base class for monitoring diagnostic.
It contains the common methods for path creation, provenance
recording, option parsing and to create some common plots.
"""
def __init__(self, config):
self.cfg = config
plot_folder = config.get(
"plot_folder",
"{plot_dir}/../../{dataset}/{exp}/{modeling_realm}/{real_name}",
)
plot_folder = plot_folder.replace(
"{plot_dir}", self.cfg[names.PLOT_DIR]
)
self.plot_folder = os.path.abspath(
os.path.expandvars(os.path.expanduser(plot_folder))
)
self.plot_filename = config.get(
"plot_filename",
"{plot_type}_{real_name}_{dataset}_{mip}_{exp}_{ensemble}",
)
self.plots = config.get("plots", {})
default_config = os.path.join(
os.path.dirname(__file__), "monitor_config.yml"
)
cartopy_data_dir = config.get("cartopy_data_dir", None)
if cartopy_data_dir:
cartopy.config["data_dir"] = cartopy_data_dir
with open(config.get("config_file", default_config)) as config_file:
self.config = yaml.safe_load(config_file)
def _add_file_extension(self, filename):
"""Add extension to plot filename."""
return f"{filename}.{self.cfg['output_file_type']}"
def _get_proj_options(self, map_name):
return self.config["maps"][map_name]
def _get_variable_options(self, variable_group, map_name):
options = self.config["variables"].get(
variable_group, self.config["variables"]["default"]
)
if "default" not in options:
variable_options = options
else:
variable_options = options["default"]
if map_name in options:
variable_options = {**variable_options, **options[map_name]}
if "bounds" in variable_options:
if not isinstance(variable_options["bounds"], str):
variable_options["bounds"] = [
float(n) for n in variable_options["bounds"]
]
logger.debug(variable_options)
return variable_options
[docs]
def plot_timeseries(self, cube, var_info, period="", **kwargs):
"""Plot timeseries from a cube.
It also automatically smoothes it for long timeseries of monthly data:
- Between 10 and 70 years long, it also plots the 12-month rolling
average along the raw series
- For more than ten years, it plots the 12-month and 10-years
rolling averages and not the raw series
"""
if "xlimits" not in kwargs:
kwargs["xlimits"] = "auto"
length = (
cube.coord("year").points.max() - cube.coord("year").points.min()
)
filename = self.get_plot_path(
f"timeseries{period}", var_info, add_ext=False
)
caption = (
"{} of "
f"{var_info[names.LONG_NAME]} of dataset "
f"{var_info[names.DATASET]} (project "
f"{var_info[names.PROJECT]}) from "
f"{var_info[names.START_YEAR]} to "
f"{var_info[names.END_YEAR]}."
)
if length < 10 or length * 11 > cube.coord("year").shape[0]:
self.plot_cube(cube, filename, **kwargs)
self.record_plot_provenance(
self._add_file_extension(filename),
var_info,
"timeseries",
period=period,
caption=caption.format("Time series"),
)
elif length < 70:
self.plot_cube(cube, filename, **kwargs)
self.record_plot_provenance(
self._add_file_extension(filename),
var_info,
"timeseries",
period=period,
caption=caption.format("Time series"),
)
# Smoothed time series (12-month running mean)
plt.gca().set_prop_cycle(None)
self.plot_cube(
cube.rolling_window("time", MEAN, 12),
f"{filename}_smoothed_12_months",
**kwargs,
)
self.record_plot_provenance(
self._add_file_extension(f"{filename}_smoothed_12_months"),
var_info,
"timeseries",
period=period,
caption=caption.format(
"Smoothed (12-months running mean) time series"
),
)
else:
# Smoothed time series (12-month running mean)
self.plot_cube(
cube.rolling_window("time", MEAN, 12),
f"{filename}_smoothed_12_months",
**kwargs,
)
self.record_plot_provenance(
self._add_file_extension(f"{filename}_smoothed_12_months"),
var_info,
"timeseries",
period=period,
caption=caption.format(
"Smoothed (12-months running mean) time series"
),
)
# Smoothed time series (10-year running mean)
self.plot_cube(
cube.rolling_window("time", MEAN, 120),
f"{filename}_smoothed_10_years",
**kwargs,
)
self.record_plot_provenance(
self._add_file_extension(f"{filename}_smoothed_10_years"),
var_info,
"timeseries",
period=period,
caption=caption.format(
"Smoothed (10-years running mean) time series"
),
)
[docs]
def record_plot_provenance(self, filename, var_info, plot_type, **kwargs):
"""Write provenance info for a given file."""
with ProvenanceLogger(self.cfg) as provenance_logger:
prov = self.get_provenance_record(
ancestor_files=[var_info["filename"]],
plot_type=plot_type,
long_names=[var_info[names.LONG_NAME]],
**kwargs,
)
provenance_logger.log(filename, prov)
[docs]
def plot_cube(self, cube, filename, linestyle="-", **kwargs):
"""Plot a timeseries from a cube.
Supports multiplot layouts for cubes with extra dimensions
`shape_id` or `region`.
"""
plotter = PlotSeries()
plotter.filefmt = self.cfg["output_file_type"]
plotter.img_template = filename
region_coords = ("shape_id", "region")
for region_coord in region_coords:
if cube.coords(region_coord):
if cube.coord(region_coord).shape[0] > 1:
plotter.multiplot_cube(
cube, "time", region_coord, **kwargs
)
return
plotter.plot_cube(cube, "time", linestyle=linestyle, **kwargs)
[docs]
@staticmethod
def get_provenance_record(ancestor_files, **kwargs):
"""Create provenance record for the diagnostic data and plots."""
record = {
"authors": [
"vegas-regidor_javier",
],
"references": [
"acknow_project",
],
"ancestors": ancestor_files,
**kwargs,
}
return record
[docs]
def get_plot_path(self, plot_type, var_info, add_ext=True):
"""Get plot full path from variable info.
Parameters
----------
plot_type: str
Name of the plot
var_info: dict
Variable information from ESMValTool
add_ext: bool, optional (default: True)
Add filename extension from configuration file.
"""
return os.path.join(
self.get_plot_folder(var_info),
self.get_plot_name(plot_type, var_info, add_ext=add_ext),
)
[docs]
def get_plot_folder(self, var_info):
"""Get plot storage folder from variable info.
Parameters
----------
var_info: dict
Variable information from ESMValTool
"""
info = {
"real_name": self._real_name(var_info["variable_group"]),
**var_info,
}
folder = list(_replace_tags(self.plot_folder, info))[0]
if self.plot_folder.startswith("/"):
folder = "/" + folder
if not os.path.isdir(folder):
os.makedirs(folder, exist_ok=True)
return folder
[docs]
def get_plot_name(self, plot_type, var_info, add_ext=True):
"""Get plot filename from variable info.
Parameters
----------
plot_type: str
Name of the plot
var_info: dict
Variable information from ESMValTool
add_ext: bool, optional (default: True)
Add filename extension from configuration file.
"""
info = {
"plot_type": plot_type,
"real_name": self._real_name(var_info["variable_group"]),
**var_info,
}
file_name = list(_replace_tags(self.plot_filename, info))[0]
if add_ext:
file_name = self._add_file_extension(file_name)
return file_name
@staticmethod
def _set_rasterized(axes=None):
"""Rasterize all artists and collection of axes if desired."""
if axes is None:
axes = plt.gca()
if not isinstance(axes, list):
axes = [axes]
for single_axes in axes:
for artist in single_axes.artists:
artist.set_rasterized(True)
for collection in single_axes.collections:
collection.set_rasterized(True)
@staticmethod
def _real_name(variable_group):
for subfix in ("Ymean", "Ysum", "mean", "sum"):
if variable_group.endswith(subfix):
variable_group = variable_group.replace(subfix, "")
return variable_group