Source code for esmvaltool.diag_scripts.monitor.monitor_base

"""Base class for monitoring diagnostics."""

import logging
import os
import re

import cartopy
import matplotlib.pyplot as plt
import yaml
from iris.analysis import MEAN
from mapgenerator.plotting.timeseries import PlotSeries

from esmvaltool.diag_scripts.shared import ProvenanceLogger, names

logger = logging.getLogger(__name__)


def _replace_tags(paths, variable):
    """Replace tags in the config-developer's file with actual values."""
    if isinstance(paths, str):
        paths = set((paths.strip('/'), ))
    else:
        paths = set(path.strip('/') for path in paths)
    tlist = set()
    for path in paths:
        tlist = tlist.union(re.findall(r'{([^}]*)}', path))
    if 'sub_experiment' in variable:
        new_paths = []
        for path in paths:
            new_paths.extend(
                (re.sub(r'(\b{ensemble}\b)', r'{sub_experiment}-\1', path),
                 re.sub(r'({ensemble})', r'{sub_experiment}-\1', path)))
            tlist.add('sub_experiment')
        paths = new_paths

    for tag in tlist:
        original_tag = tag
        tag, _, _ = _get_caps_options(tag)

        if tag == 'latestversion':  # handled separately later
            continue
        if tag in variable:
            replacewith = variable[tag]
        else:
            raise ValueError(f"Dataset key '{tag}' must be specified for "
                             f"{variable}, check your recipe entry")
        paths = _replace_tag(paths, original_tag, replacewith)
    return paths


def _replace_tag(paths, tag, replacewith):
    """Replace tag by replacewith in paths."""
    _, lower, upper = _get_caps_options(tag)
    result = []
    if isinstance(replacewith, (list, tuple)):
        for item in replacewith:
            result.extend(_replace_tag(paths, tag, item))
    else:
        text = _apply_caps(str(replacewith), lower, upper)
        result.extend(p.replace('{' + tag + '}', text) for p in paths)
    return list(set(result))


def _get_caps_options(tag):
    lower = False
    upper = False
    if tag.endswith('.lower'):
        lower = True
        tag = tag[0:-6]
    elif tag.endswith('.upper'):
        upper = True
        tag = tag[0:-6]
    return tag, lower, upper


def _apply_caps(original, lower, upper):
    if lower:
        return original.lower()
    if upper:
        return original.upper()
    return original


[docs] class MonitorBase(): """Base class for monitoring diagnostic. It contains the common methods for path creation, provenance recording, option parsing and to create some common plots. """ def __init__(self, config): self.cfg = config plot_folder = config.get( 'plot_folder', '{plot_dir}/../../{dataset}/{exp}/{modeling_realm}/{real_name}', ) plot_folder = plot_folder.replace('{plot_dir}', self.cfg[names.PLOT_DIR]) self.plot_folder = os.path.abspath( os.path.expandvars(os.path.expanduser(plot_folder)) ) self.plot_filename = config.get( 'plot_filename', '{plot_type}_{real_name}_{dataset}_{mip}_{exp}_{ensemble}') self.plots = config.get('plots', {}) default_config = os.path.join(os.path.dirname(__file__), "monitor_config.yml") cartopy_data_dir = config.get('cartopy_data_dir', None) if cartopy_data_dir: cartopy.config['data_dir'] = cartopy_data_dir with open(config.get('config_file', default_config)) as config_file: self.config = yaml.safe_load(config_file) def _add_file_extension(self, filename): """Add extension to plot filename.""" return f"{filename}.{self.cfg['output_file_type']}" def _get_proj_options(self, map_name): return self.config['maps'][map_name] def _get_variable_options(self, variable_group, map_name): options = self.config['variables'].get( variable_group, self.config['variables']['default']) if 'default' not in options: variable_options = options else: variable_options = options['default'] if map_name in options: variable_options = {**variable_options, **options[map_name]} if 'bounds' in variable_options: if not isinstance(variable_options['bounds'], str): variable_options['bounds'] = [ float(n) for n in variable_options['bounds'] ] logger.debug(variable_options) return variable_options
[docs] def plot_timeseries(self, cube, var_info, period='', **kwargs): """Plot timeseries from a cube. It also automatically smoothes it for long timeseries of monthly data: - Between 10 and 70 years long, it also plots the 12-month rolling average along the raw series - For more than ten years, it plots the 12-month and 10-years rolling averages and not the raw series """ if 'xlimits' not in kwargs: kwargs['xlimits'] = 'auto' length = cube.coord("year").points.max() - cube.coord( "year").points.min() filename = self.get_plot_path(f'timeseries{period}', var_info, add_ext=False) caption = ("{} of " f"{var_info[names.LONG_NAME]} of dataset " f"{var_info[names.DATASET]} (project " f"{var_info[names.PROJECT]}) from " f"{var_info[names.START_YEAR]} to " f"{var_info[names.END_YEAR]}.") if length < 10 or length * 11 > cube.coord("year").shape[0]: self.plot_cube(cube, filename, **kwargs) self.record_plot_provenance( self._add_file_extension(filename), var_info, 'timeseries', period=period, caption=caption.format("Time series"), ) elif length < 70: self.plot_cube(cube, filename, **kwargs) self.record_plot_provenance( self._add_file_extension(filename), var_info, 'timeseries', period=period, caption=caption.format("Time series"), ) # Smoothed time series (12-month running mean) plt.gca().set_prop_cycle(None) self.plot_cube(cube.rolling_window('time', MEAN, 12), f"{filename}_smoothed_12_months", **kwargs) self.record_plot_provenance( self._add_file_extension(f"{filename}_smoothed_12_months"), var_info, 'timeseries', period=period, caption=caption.format( "Smoothed (12-months running mean) time series"), ) else: # Smoothed time series (12-month running mean) self.plot_cube(cube.rolling_window('time', MEAN, 12), f"{filename}_smoothed_12_months", **kwargs) self.record_plot_provenance( self._add_file_extension(f"{filename}_smoothed_12_months"), var_info, 'timeseries', period=period, caption=caption.format( "Smoothed (12-months running mean) time series"), ) # Smoothed time series (10-year running mean) self.plot_cube(cube.rolling_window('time', MEAN, 120), f"{filename}_smoothed_10_years", **kwargs) self.record_plot_provenance( self._add_file_extension(f"{filename}_smoothed_10_years"), var_info, 'timeseries', period=period, caption=caption.format( "Smoothed (10-years running mean) time series"), )
[docs] def record_plot_provenance(self, filename, var_info, plot_type, **kwargs): """Write provenance info for a given file.""" with ProvenanceLogger(self.cfg) as provenance_logger: prov = self.get_provenance_record( ancestor_files=[var_info['filename']], plot_type=plot_type, long_names=[var_info[names.LONG_NAME]], **kwargs, ) provenance_logger.log(filename, prov)
[docs] def plot_cube(self, cube, filename, linestyle='-', **kwargs): """Plot a timeseries from a cube. Supports multiplot layouts for cubes with extra dimensions `shape_id` or `region`. """ plotter = PlotSeries() plotter.filefmt = self.cfg['output_file_type'] plotter.img_template = filename region_coords = ('shape_id', 'region') for region_coord in region_coords: if cube.coords(region_coord): if cube.coord(region_coord).shape[0] > 1: plotter.multiplot_cube(cube, 'time', region_coord, **kwargs) return plotter.plot_cube(cube, 'time', linestyle=linestyle, **kwargs)
[docs] @staticmethod def get_provenance_record(ancestor_files, **kwargs): """Create provenance record for the diagnostic data and plots.""" record = { 'authors': [ 'vegas-regidor_javier', ], 'references': [ 'acknow_project', ], 'ancestors': ancestor_files, **kwargs } return record
[docs] def get_plot_path(self, plot_type, var_info, add_ext=True): """Get plot full path from variable info. Parameters ---------- plot_type: str Name of the plot var_info: dict Variable information from ESMValTool add_ext: bool, optional (default: True) Add filename extension from configuration file. """ return os.path.join( self.get_plot_folder(var_info), self.get_plot_name(plot_type, var_info, add_ext=add_ext), )
[docs] def get_plot_folder(self, var_info): """Get plot storage folder from variable info. Parameters ---------- var_info: dict Variable information from ESMValTool """ info = { 'real_name': self._real_name(var_info['variable_group']), **var_info } folder = list(_replace_tags(self.plot_folder, info))[0] if self.plot_folder.startswith('/'): folder = '/' + folder if not os.path.isdir(folder): os.makedirs(folder, exist_ok=True) return folder
[docs] def get_plot_name(self, plot_type, var_info, add_ext=True): """Get plot filename from variable info. Parameters ---------- plot_type: str Name of the plot var_info: dict Variable information from ESMValTool add_ext: bool, optional (default: True) Add filename extension from configuration file. """ info = { "plot_type": plot_type, 'real_name': self._real_name(var_info['variable_group']), **var_info } file_name = list(_replace_tags(self.plot_filename, info))[0] if add_ext: file_name = self._add_file_extension(file_name) return file_name
@staticmethod def _set_rasterized(axes=None): """Rasterize all artists and collection of axes if desired.""" if axes is None: axes = plt.gca() if not isinstance(axes, list): axes = [axes] for single_axes in axes: for artist in single_axes.artists: artist.set_rasterized(True) for collection in single_axes.collections: collection.set_rasterized(True) @staticmethod def _real_name(variable_group): for subfix in ('Ymean', 'Ysum', 'mean', 'sum'): if variable_group.endswith(subfix): variable_group = variable_group.replace(subfix, '') return variable_group