Source code for esmvaltool.diag_scripts.monitor.monitor_base

"""Base class for monitoring diagnostics."""

import logging
import os
import re

import cartopy
import matplotlib.pyplot as plt
import yaml
from iris.analysis import MEAN
from mapgenerator.plotting.timeseries import PlotSeries

from esmvaltool.diag_scripts.shared import ProvenanceLogger, names

logger = logging.getLogger(__name__)


def _replace_tags(paths, variable):
    """Replace tags in the config-developer's file with actual values."""
    if isinstance(paths, str):
        paths = set((paths.strip("/"),))
    else:
        paths = set(path.strip("/") for path in paths)
    tlist = set()
    for path in paths:
        tlist = tlist.union(re.findall(r"{([^}]*)}", path))
    if "sub_experiment" in variable:
        new_paths = []
        for path in paths:
            new_paths.extend(
                (
                    re.sub(r"(\b{ensemble}\b)", r"{sub_experiment}-\1", path),
                    re.sub(r"({ensemble})", r"{sub_experiment}-\1", path),
                )
            )
            tlist.add("sub_experiment")
        paths = new_paths

    for tag in tlist:
        original_tag = tag
        tag, _, _ = _get_caps_options(tag)

        if tag == "latestversion":  # handled separately later
            continue
        if tag in variable:
            replacewith = variable[tag]
        else:
            raise ValueError(
                f"Dataset key '{tag}' must be specified for "
                f"{variable}, check your recipe entry"
            )
        paths = _replace_tag(paths, original_tag, replacewith)
    return paths


def _replace_tag(paths, tag, replacewith):
    """Replace tag by replacewith in paths."""
    _, lower, upper = _get_caps_options(tag)
    result = []
    if isinstance(replacewith, (list, tuple)):
        for item in replacewith:
            result.extend(_replace_tag(paths, tag, item))
    else:
        text = _apply_caps(str(replacewith), lower, upper)
        result.extend(p.replace("{" + tag + "}", text) for p in paths)
    return list(set(result))


def _get_caps_options(tag):
    lower = False
    upper = False
    if tag.endswith(".lower"):
        lower = True
        tag = tag[0:-6]
    elif tag.endswith(".upper"):
        upper = True
        tag = tag[0:-6]
    return tag, lower, upper


def _apply_caps(original, lower, upper):
    if lower:
        return original.lower()
    if upper:
        return original.upper()
    return original


[docs] class MonitorBase: """Base class for monitoring diagnostic. It contains the common methods for path creation, provenance recording, option parsing and to create some common plots. """ def __init__(self, config): self.cfg = config plot_folder = config.get( "plot_folder", "{plot_dir}/../../{dataset}/{exp}/{modeling_realm}/{real_name}", ) plot_folder = plot_folder.replace( "{plot_dir}", self.cfg[names.PLOT_DIR] ) self.plot_folder = os.path.abspath( os.path.expandvars(os.path.expanduser(plot_folder)) ) self.plot_filename = config.get( "plot_filename", "{plot_type}_{real_name}_{dataset}_{mip}_{exp}_{ensemble}", ) self.plots = config.get("plots", {}) default_config = os.path.join( os.path.dirname(__file__), "monitor_config.yml" ) cartopy_data_dir = config.get("cartopy_data_dir", None) if cartopy_data_dir: cartopy.config["data_dir"] = cartopy_data_dir with open(config.get("config_file", default_config)) as config_file: self.config = yaml.safe_load(config_file) def _add_file_extension(self, filename): """Add extension to plot filename.""" return f"{filename}.{self.cfg['output_file_type']}" def _get_proj_options(self, map_name): return self.config["maps"][map_name] def _get_variable_options(self, variable_group, map_name): options = self.config["variables"].get( variable_group, self.config["variables"]["default"] ) if "default" not in options: variable_options = options else: variable_options = options["default"] if map_name in options: variable_options = {**variable_options, **options[map_name]} if "bounds" in variable_options: if not isinstance(variable_options["bounds"], str): variable_options["bounds"] = [ float(n) for n in variable_options["bounds"] ] logger.debug(variable_options) return variable_options
[docs] def plot_timeseries(self, cube, var_info, period="", **kwargs): """Plot timeseries from a cube. It also automatically smoothes it for long timeseries of monthly data: - Between 10 and 70 years long, it also plots the 12-month rolling average along the raw series - For more than ten years, it plots the 12-month and 10-years rolling averages and not the raw series """ if "xlimits" not in kwargs: kwargs["xlimits"] = "auto" length = ( cube.coord("year").points.max() - cube.coord("year").points.min() ) filename = self.get_plot_path( f"timeseries{period}", var_info, add_ext=False ) caption = ( "{} of " f"{var_info[names.LONG_NAME]} of dataset " f"{var_info[names.DATASET]} (project " f"{var_info[names.PROJECT]}) from " f"{var_info[names.START_YEAR]} to " f"{var_info[names.END_YEAR]}." ) if length < 10 or length * 11 > cube.coord("year").shape[0]: self.plot_cube(cube, filename, **kwargs) self.record_plot_provenance( self._add_file_extension(filename), var_info, "timeseries", period=period, caption=caption.format("Time series"), ) elif length < 70: self.plot_cube(cube, filename, **kwargs) self.record_plot_provenance( self._add_file_extension(filename), var_info, "timeseries", period=period, caption=caption.format("Time series"), ) # Smoothed time series (12-month running mean) plt.gca().set_prop_cycle(None) self.plot_cube( cube.rolling_window("time", MEAN, 12), f"{filename}_smoothed_12_months", **kwargs, ) self.record_plot_provenance( self._add_file_extension(f"{filename}_smoothed_12_months"), var_info, "timeseries", period=period, caption=caption.format( "Smoothed (12-months running mean) time series" ), ) else: # Smoothed time series (12-month running mean) self.plot_cube( cube.rolling_window("time", MEAN, 12), f"{filename}_smoothed_12_months", **kwargs, ) self.record_plot_provenance( self._add_file_extension(f"{filename}_smoothed_12_months"), var_info, "timeseries", period=period, caption=caption.format( "Smoothed (12-months running mean) time series" ), ) # Smoothed time series (10-year running mean) self.plot_cube( cube.rolling_window("time", MEAN, 120), f"{filename}_smoothed_10_years", **kwargs, ) self.record_plot_provenance( self._add_file_extension(f"{filename}_smoothed_10_years"), var_info, "timeseries", period=period, caption=caption.format( "Smoothed (10-years running mean) time series" ), )
[docs] def record_plot_provenance(self, filename, var_info, plot_type, **kwargs): """Write provenance info for a given file.""" with ProvenanceLogger(self.cfg) as provenance_logger: prov = self.get_provenance_record( ancestor_files=[var_info["filename"]], plot_type=plot_type, long_names=[var_info[names.LONG_NAME]], **kwargs, ) provenance_logger.log(filename, prov)
[docs] def plot_cube(self, cube, filename, linestyle="-", **kwargs): """Plot a timeseries from a cube. Supports multiplot layouts for cubes with extra dimensions `shape_id` or `region`. """ plotter = PlotSeries() plotter.filefmt = self.cfg["output_file_type"] plotter.img_template = filename region_coords = ("shape_id", "region") for region_coord in region_coords: if cube.coords(region_coord): if cube.coord(region_coord).shape[0] > 1: plotter.multiplot_cube( cube, "time", region_coord, **kwargs ) return plotter.plot_cube(cube, "time", linestyle=linestyle, **kwargs)
[docs] @staticmethod def get_provenance_record(ancestor_files, **kwargs): """Create provenance record for the diagnostic data and plots.""" record = { "authors": [ "vegas-regidor_javier", ], "references": [ "acknow_project", ], "ancestors": ancestor_files, **kwargs, } return record
[docs] def get_plot_path(self, plot_type, var_info, add_ext=True): """Get plot full path from variable info. Parameters ---------- plot_type: str Name of the plot var_info: dict Variable information from ESMValTool add_ext: bool, optional (default: True) Add filename extension from configuration file. """ return os.path.join( self.get_plot_folder(var_info), self.get_plot_name(plot_type, var_info, add_ext=add_ext), )
[docs] def get_plot_folder(self, var_info): """Get plot storage folder from variable info. Parameters ---------- var_info: dict Variable information from ESMValTool """ info = { "real_name": self._real_name(var_info["variable_group"]), **var_info, } folder = list(_replace_tags(self.plot_folder, info))[0] if self.plot_folder.startswith("/"): folder = "/" + folder if not os.path.isdir(folder): os.makedirs(folder, exist_ok=True) return folder
[docs] def get_plot_name(self, plot_type, var_info, add_ext=True): """Get plot filename from variable info. Parameters ---------- plot_type: str Name of the plot var_info: dict Variable information from ESMValTool add_ext: bool, optional (default: True) Add filename extension from configuration file. """ info = { "plot_type": plot_type, "real_name": self._real_name(var_info["variable_group"]), **var_info, } file_name = list(_replace_tags(self.plot_filename, info))[0] if add_ext: file_name = self._add_file_extension(file_name) return file_name
@staticmethod def _set_rasterized(axes=None): """Rasterize all artists and collection of axes if desired.""" if axes is None: axes = plt.gca() if not isinstance(axes, list): axes = [axes] for single_axes in axes: for artist in single_axes.artists: artist.set_rasterized(True) for collection in single_axes.collections: collection.set_rasterized(True) @staticmethod def _real_name(variable_group): for subfix in ("Ymean", "Ysum", "mean", "sum"): if variable_group.endswith(subfix): variable_group = variable_group.replace(subfix, "") return variable_group