Source code for esmvaltool.diag_scripts.monitor.monitor_base

"""Base class for monitoring diagnostics."""

import logging
import os
import re
from typing import Optional

import cartopy
import matplotlib.pyplot as plt
import yaml
from iris.analysis import MEAN
from mapgenerator.plotting.timeseries import PlotSeries
from matplotlib.axes import Axes

from esmvaltool.diag_scripts.shared import ProvenanceLogger, names

logger = logging.getLogger(__name__)


def _replace_tags(paths, variable):
    """Replace tags in the config-developer's file with actual values."""
    if isinstance(paths, str):
        paths = set((paths.strip("/"),))
    else:
        paths = set(path.strip("/") for path in paths)
    tlist = set()
    for path in paths:
        tlist = tlist.union(re.findall(r"{([^}]*)}", path))
    if "sub_experiment" in variable:
        new_paths = []
        for path in paths:
            new_paths.extend(
                (
                    re.sub(r"(\b{ensemble}\b)", r"{sub_experiment}-\1", path),
                    re.sub(r"({ensemble})", r"{sub_experiment}-\1", path),
                )
            )
            tlist.add("sub_experiment")
        paths = new_paths

    for tag in tlist:
        original_tag = tag
        tag, _, _ = _get_caps_options(tag)

        if tag == "latestversion":  # handled separately later
            continue
        if tag in variable:
            replacewith = variable[tag]
        else:
            raise ValueError(
                f"Dataset key '{tag}' must be specified for "
                f"{variable}, check your recipe entry"
            )
        paths = _replace_tag(paths, original_tag, replacewith)
    return paths


def _replace_tag(paths, tag, replacewith):
    """Replace tag by replacewith in paths."""
    _, lower, upper = _get_caps_options(tag)
    result = []
    if isinstance(replacewith, (list, tuple)):
        for item in replacewith:
            result.extend(_replace_tag(paths, tag, item))
    else:
        text = _apply_caps(str(replacewith), lower, upper)
        result.extend(p.replace("{" + tag + "}", text) for p in paths)
    return list(set(result))


def _get_caps_options(tag):
    lower = False
    upper = False
    if tag.endswith(".lower"):
        lower = True
        tag = tag[0:-6]
    elif tag.endswith(".upper"):
        upper = True
        tag = tag[0:-6]
    return tag, lower, upper


def _apply_caps(original, lower, upper):
    if lower:
        return original.lower()
    if upper:
        return original.upper()
    return original


[docs] class MonitorBase: """Base class for monitoring diagnostic. It contains the common methods for path creation, provenance recording, option parsing and to create some common plots. """ def __init__(self, config): self.cfg = config plot_folder = config.get( "plot_folder", "{plot_dir}/../../{dataset}/{exp}/{modeling_realm}/{real_name}", ) plot_folder = plot_folder.replace( "{plot_dir}", self.cfg[names.PLOT_DIR] ) self.plot_folder = os.path.abspath( os.path.expandvars(os.path.expanduser(plot_folder)) ) self.plot_filename = config.get( "plot_filename", "{plot_type}_{real_name}_{dataset}_{mip}_{exp}_{ensemble}", ) self.plots = config.get("plots", {}) default_config = os.path.join( os.path.dirname(__file__), "monitor_config.yml" ) cartopy_data_dir = config.get("cartopy_data_dir", None) if cartopy_data_dir: cartopy.config["data_dir"] = cartopy_data_dir with open(config.get("config_file", default_config)) as config_file: self.config = yaml.safe_load(config_file) def _add_file_extension(self, filename): """Add extension to plot filename.""" return f"{filename}.{self.cfg['output_file_type']}" def _get_proj_options(self, map_name): return self.config["maps"][map_name] def _get_variable_options(self, variable_group, map_name): options = self.config["variables"].get( variable_group, self.config["variables"]["default"] ) if "default" not in options: variable_options = options else: variable_options = options["default"] if map_name in options: variable_options = {**variable_options, **options[map_name]} if "bounds" in variable_options: if not isinstance(variable_options["bounds"], str): variable_options["bounds"] = [ float(n) for n in variable_options["bounds"] ] logger.debug(variable_options) return variable_options
[docs] def plot_timeseries(self, cube, var_info, period="", **kwargs): """Plot timeseries from a cube. It also automatically smoothes it for long timeseries of monthly data: - Between 10 and 70 years long, it also plots the 12-month rolling average along the raw series - For more than ten years, it plots the 12-month and 10-years rolling averages and not the raw series """ if "xlimits" not in kwargs: kwargs["xlimits"] = "auto" length = ( cube.coord("year").points.max() - cube.coord("year").points.min() ) filename = self.get_plot_path( f"timeseries{period}", var_info, add_ext=False ) caption = ( "{} of " f"{var_info[names.LONG_NAME]} of dataset " f"{var_info[names.DATASET]} (project " f"{var_info[names.PROJECT]}) from " f"{var_info[names.START_YEAR]} to " f"{var_info[names.END_YEAR]}." ) if length < 10 or length * 11 > cube.coord("year").shape[0]: self.plot_cube(cube, filename, **kwargs) self.record_plot_provenance( self._add_file_extension(filename), var_info, "timeseries", period=period, caption=caption.format("Time series"), ) elif length < 70: self.plot_cube(cube, filename, **kwargs) self.record_plot_provenance( self._add_file_extension(filename), var_info, "timeseries", period=period, caption=caption.format("Time series"), ) # Smoothed time series (12-month running mean) plt.gca().set_prop_cycle(None) self.plot_cube( cube.rolling_window("time", MEAN, 12), f"{filename}_smoothed_12_months", **kwargs, ) self.record_plot_provenance( self._add_file_extension(f"{filename}_smoothed_12_months"), var_info, "timeseries", period=period, caption=caption.format( "Smoothed (12-months running mean) time series" ), ) else: # Smoothed time series (12-month running mean) self.plot_cube( cube.rolling_window("time", MEAN, 12), f"{filename}_smoothed_12_months", **kwargs, ) self.record_plot_provenance( self._add_file_extension(f"{filename}_smoothed_12_months"), var_info, "timeseries", period=period, caption=caption.format( "Smoothed (12-months running mean) time series" ), ) # Smoothed time series (10-year running mean) self.plot_cube( cube.rolling_window("time", MEAN, 120), f"{filename}_smoothed_10_years", **kwargs, ) self.record_plot_provenance( self._add_file_extension(f"{filename}_smoothed_10_years"), var_info, "timeseries", period=period, caption=caption.format( "Smoothed (10-years running mean) time series" ), )
[docs] def record_plot_provenance(self, filename, var_info, plot_type, **kwargs): """Write provenance info for a given file.""" with ProvenanceLogger(self.cfg) as provenance_logger: prov = self.get_provenance_record( ancestor_files=[var_info["filename"]], plot_type=plot_type, long_names=[var_info[names.LONG_NAME]], **kwargs, ) provenance_logger.log(filename, prov)
[docs] def plot_cube(self, cube, filename, linestyle="-", **kwargs): """Plot a timeseries from a cube. Supports multiplot layouts for cubes with extra dimensions `shape_id` or `region`. """ plotter = PlotSeries() plotter.filefmt = self.cfg["output_file_type"] plotter.img_template = filename region_coords = ("shape_id", "region") for region_coord in region_coords: if cube.coords(region_coord): if cube.coord(region_coord).shape[0] > 1: plotter.multiplot_cube( cube, "time", region_coord, **kwargs ) return plotter.plot_cube(cube, "time", linestyle=linestyle, **kwargs)
[docs] @staticmethod def get_provenance_record(ancestor_files, **kwargs): """Create provenance record for the diagnostic data and plots.""" record = { "authors": [ "vegas-regidor_javier", ], "references": [ "acknow_project", ], "ancestors": ancestor_files, **kwargs, } return record
[docs] def get_plot_path(self, plot_type, var_info, add_ext=True): """Get plot full path from variable info. Parameters ---------- plot_type: str Name of the plot var_info: dict Variable information from ESMValTool add_ext: bool, optional (default: True) Add filename extension from configuration file. """ return os.path.join( self.get_plot_folder(var_info), self.get_plot_name(plot_type, var_info, add_ext=add_ext), )
[docs] def get_plot_folder(self, var_info): """Get plot storage folder from variable info. Parameters ---------- var_info: dict Variable information from ESMValTool """ info = { "real_name": self._real_name(var_info["variable_group"]), **var_info, } folder = list(_replace_tags(self.plot_folder, info))[0] if self.plot_folder.startswith("/"): folder = "/" + folder if not os.path.isdir(folder): os.makedirs(folder, exist_ok=True) return folder
[docs] def get_plot_name(self, plot_type, var_info, add_ext=True): """Get plot filename from variable info. Parameters ---------- plot_type: str Name of the plot var_info: dict Variable information from ESMValTool add_ext: bool, optional (default: True) Add filename extension from configuration file. """ info = { "plot_type": plot_type, "real_name": self._real_name(var_info["variable_group"]), **var_info, } file_name = list(_replace_tags(self.plot_filename, info))[0] if add_ext: file_name = self._add_file_extension(file_name) return file_name
@staticmethod def _set_rasterized(axes: Optional[Axes | list[Axes]] = None) -> None: """Rasterize all artists and collection of axes if desired.""" if axes is None: axes = plt.gca() if not isinstance(axes, list): axes = [axes] for single_axes in axes: for artist in single_axes.artists: artist.set_rasterized(True) for collection in single_axes.collections: collection.set_rasterized(True) @staticmethod def _real_name(variable_group): for subfix in ("Ymean", "Ysum", "mean", "sum"): if variable_group.endswith(subfix): variable_group = variable_group.replace(subfix, "") return variable_group