"""Base class for monitoring diagnostics."""
import logging
import os
import re
import cartopy
import matplotlib.pyplot as plt
import yaml
from iris.analysis import MEAN
from mapgenerator.plotting.timeseries import PlotSeries
from esmvaltool.diag_scripts.shared import ProvenanceLogger, names
logger = logging.getLogger(__name__)
def _replace_tags(paths, variable):
"""Replace tags in the config-developer's file with actual values."""
if isinstance(paths, str):
paths = set((paths.strip('/'), ))
else:
paths = set(path.strip('/') for path in paths)
tlist = set()
for path in paths:
tlist = tlist.union(re.findall(r'{([^}]*)}', path))
if 'sub_experiment' in variable:
new_paths = []
for path in paths:
new_paths.extend(
(re.sub(r'(\b{ensemble}\b)', r'{sub_experiment}-\1', path),
re.sub(r'({ensemble})', r'{sub_experiment}-\1', path)))
tlist.add('sub_experiment')
paths = new_paths
for tag in tlist:
original_tag = tag
tag, _, _ = _get_caps_options(tag)
if tag == 'latestversion': # handled separately later
continue
if tag in variable:
replacewith = variable[tag]
else:
raise ValueError(f"Dataset key '{tag}' must be specified for "
f"{variable}, check your recipe entry")
paths = _replace_tag(paths, original_tag, replacewith)
return paths
def _replace_tag(paths, tag, replacewith):
"""Replace tag by replacewith in paths."""
_, lower, upper = _get_caps_options(tag)
result = []
if isinstance(replacewith, (list, tuple)):
for item in replacewith:
result.extend(_replace_tag(paths, tag, item))
else:
text = _apply_caps(str(replacewith), lower, upper)
result.extend(p.replace('{' + tag + '}', text) for p in paths)
return list(set(result))
def _get_caps_options(tag):
lower = False
upper = False
if tag.endswith('.lower'):
lower = True
tag = tag[0:-6]
elif tag.endswith('.upper'):
upper = True
tag = tag[0:-6]
return tag, lower, upper
def _apply_caps(original, lower, upper):
if lower:
return original.lower()
if upper:
return original.upper()
return original
[docs]
class MonitorBase():
"""Base class for monitoring diagnostic.
It contains the common methods for path creation, provenance
recording, option parsing and to create some common plots.
"""
def __init__(self, config):
self.cfg = config
plot_folder = config.get(
'plot_folder',
'{plot_dir}/../../{dataset}/{exp}/{modeling_realm}/{real_name}',
)
plot_folder = plot_folder.replace('{plot_dir}',
self.cfg[names.PLOT_DIR])
self.plot_folder = os.path.abspath(
os.path.expandvars(os.path.expanduser(plot_folder))
)
self.plot_filename = config.get(
'plot_filename',
'{plot_type}_{real_name}_{dataset}_{mip}_{exp}_{ensemble}')
self.plots = config.get('plots', {})
default_config = os.path.join(os.path.dirname(__file__),
"monitor_config.yml")
cartopy_data_dir = config.get('cartopy_data_dir', None)
if cartopy_data_dir:
cartopy.config['data_dir'] = cartopy_data_dir
with open(config.get('config_file', default_config)) as config_file:
self.config = yaml.safe_load(config_file)
def _add_file_extension(self, filename):
"""Add extension to plot filename."""
return f"{filename}.{self.cfg['output_file_type']}"
def _get_proj_options(self, map_name):
return self.config['maps'][map_name]
def _get_variable_options(self, variable_group, map_name):
options = self.config['variables'].get(
variable_group, self.config['variables']['default'])
if 'default' not in options:
variable_options = options
else:
variable_options = options['default']
if map_name in options:
variable_options = {**variable_options, **options[map_name]}
if 'bounds' in variable_options:
if not isinstance(variable_options['bounds'], str):
variable_options['bounds'] = [
float(n) for n in variable_options['bounds']
]
logger.debug(variable_options)
return variable_options
[docs]
def plot_timeseries(self, cube, var_info, period='', **kwargs):
"""Plot timeseries from a cube.
It also automatically smoothes it for long timeseries of monthly data:
- Between 10 and 70 years long, it also plots the 12-month rolling
average along the raw series
- For more than ten years, it plots the 12-month and 10-years
rolling averages and not the raw series
"""
if 'xlimits' not in kwargs:
kwargs['xlimits'] = 'auto'
length = cube.coord("year").points.max() - cube.coord(
"year").points.min()
filename = self.get_plot_path(f'timeseries{period}', var_info,
add_ext=False)
caption = ("{} of "
f"{var_info[names.LONG_NAME]} of dataset "
f"{var_info[names.DATASET]} (project "
f"{var_info[names.PROJECT]}) from "
f"{var_info[names.START_YEAR]} to "
f"{var_info[names.END_YEAR]}.")
if length < 10 or length * 11 > cube.coord("year").shape[0]:
self.plot_cube(cube, filename, **kwargs)
self.record_plot_provenance(
self._add_file_extension(filename),
var_info,
'timeseries',
period=period,
caption=caption.format("Time series"),
)
elif length < 70:
self.plot_cube(cube, filename, **kwargs)
self.record_plot_provenance(
self._add_file_extension(filename),
var_info,
'timeseries',
period=period,
caption=caption.format("Time series"),
)
# Smoothed time series (12-month running mean)
plt.gca().set_prop_cycle(None)
self.plot_cube(cube.rolling_window('time', MEAN, 12),
f"{filename}_smoothed_12_months",
**kwargs)
self.record_plot_provenance(
self._add_file_extension(f"{filename}_smoothed_12_months"),
var_info,
'timeseries',
period=period,
caption=caption.format(
"Smoothed (12-months running mean) time series"),
)
else:
# Smoothed time series (12-month running mean)
self.plot_cube(cube.rolling_window('time', MEAN, 12),
f"{filename}_smoothed_12_months",
**kwargs)
self.record_plot_provenance(
self._add_file_extension(f"{filename}_smoothed_12_months"),
var_info,
'timeseries',
period=period,
caption=caption.format(
"Smoothed (12-months running mean) time series"),
)
# Smoothed time series (10-year running mean)
self.plot_cube(cube.rolling_window('time', MEAN, 120),
f"{filename}_smoothed_10_years",
**kwargs)
self.record_plot_provenance(
self._add_file_extension(f"{filename}_smoothed_10_years"),
var_info,
'timeseries',
period=period,
caption=caption.format(
"Smoothed (10-years running mean) time series"),
)
[docs]
def record_plot_provenance(self, filename, var_info, plot_type, **kwargs):
"""Write provenance info for a given file."""
with ProvenanceLogger(self.cfg) as provenance_logger:
prov = self.get_provenance_record(
ancestor_files=[var_info['filename']],
plot_type=plot_type,
long_names=[var_info[names.LONG_NAME]],
**kwargs,
)
provenance_logger.log(filename, prov)
[docs]
def plot_cube(self, cube, filename, linestyle='-', **kwargs):
"""Plot a timeseries from a cube.
Supports multiplot layouts for cubes with extra dimensions
`shape_id` or `region`.
"""
plotter = PlotSeries()
plotter.filefmt = self.cfg['output_file_type']
plotter.img_template = filename
region_coords = ('shape_id', 'region')
for region_coord in region_coords:
if cube.coords(region_coord):
if cube.coord(region_coord).shape[0] > 1:
plotter.multiplot_cube(cube, 'time', region_coord,
**kwargs)
return
plotter.plot_cube(cube, 'time', linestyle=linestyle, **kwargs)
[docs]
@staticmethod
def get_provenance_record(ancestor_files, **kwargs):
"""Create provenance record for the diagnostic data and plots."""
record = {
'authors': [
'vegas-regidor_javier',
],
'references': [
'acknow_project',
],
'ancestors': ancestor_files,
**kwargs
}
return record
[docs]
def get_plot_path(self, plot_type, var_info, add_ext=True):
"""Get plot full path from variable info.
Parameters
----------
plot_type: str
Name of the plot
var_info: dict
Variable information from ESMValTool
add_ext: bool, optional (default: True)
Add filename extension from configuration file.
"""
return os.path.join(
self.get_plot_folder(var_info),
self.get_plot_name(plot_type, var_info, add_ext=add_ext),
)
[docs]
def get_plot_folder(self, var_info):
"""Get plot storage folder from variable info.
Parameters
----------
var_info: dict
Variable information from ESMValTool
"""
info = {
'real_name': self._real_name(var_info['variable_group']),
**var_info
}
folder = list(_replace_tags(self.plot_folder, info))[0]
if self.plot_folder.startswith('/'):
folder = '/' + folder
if not os.path.isdir(folder):
os.makedirs(folder, exist_ok=True)
return folder
[docs]
def get_plot_name(self, plot_type, var_info, add_ext=True):
"""Get plot filename from variable info.
Parameters
----------
plot_type: str
Name of the plot
var_info: dict
Variable information from ESMValTool
add_ext: bool, optional (default: True)
Add filename extension from configuration file.
"""
info = {
"plot_type": plot_type,
'real_name': self._real_name(var_info['variable_group']),
**var_info
}
file_name = list(_replace_tags(self.plot_filename, info))[0]
if add_ext:
file_name = self._add_file_extension(file_name)
return file_name
@staticmethod
def _set_rasterized(axes=None):
"""Rasterize all artists and collection of axes if desired."""
if axes is None:
axes = plt.gca()
if not isinstance(axes, list):
axes = [axes]
for single_axes in axes:
for artist in single_axes.artists:
artist.set_rasterized(True)
for collection in single_axes.collections:
collection.set_rasterized(True)
@staticmethod
def _real_name(variable_group):
for subfix in ('Ymean', 'Ysum', 'mean', 'sum'):
if variable_group.endswith(subfix):
variable_group = variable_group.replace(subfix, '')
return variable_group