Source code for esmvaltool.diag_scripts.mlr.models.linear_base
"""Base class for linear Machine Learning Regression models."""
import logging
import os
import matplotlib.pyplot as plt
import numpy as np
from esmvaltool.diag_scripts import mlr
from esmvaltool.diag_scripts.mlr.models import MLRModel
logger = logging.getLogger(os.path.basename(__file__))
[docs]
class LinearModel(MLRModel):
"""Base class for linear Machine Learning models."""
_CLF_TYPE = None
[docs]
def plot_coefs(self, filename=None):
"""Plot linear coefficients of models.
Note
----
The features plotted here are not necessarily the real input features,
but the ones after preprocessing.
Parameters
----------
filename : str, optional (default: 'coefs')
Name of the plot file.
"""
if not self._is_ready_for_plotting():
return
logger.info("Plotting linear coefficients")
if filename is None:
filename = 'coefs'
(_, axes) = plt.subplots()
# Plot
coefs = self._clf.coef_
sorted_idx = np.argsort(coefs)
pos = np.arange(sorted_idx.shape[0]) + 0.5
axes.barh(pos, coefs[sorted_idx], align='center')
# Plot appearance
axes.tick_params(axis='y', which='minor', left=False, right=False)
axes.tick_params(axis='y', which='major', left=True, right=False)
y_tick_labels = self.features_after_preprocessing[sorted_idx]
title = f"Linear coefficients ({self._cfg['mlr_model_name']})"
axes.set_title(title)
axes.set_yticks(pos)
axes.set_yticklabels(y_tick_labels)
axes.set_xlim(-np.max(np.abs(axes.get_xlim())),
np.max(np.abs(axes.get_xlim())))
axes.axvline(0.0, color='k')
# Save plot
new_filename = filename + '.' + self._cfg['output_file_type']
plot_path = os.path.join(self._cfg['mlr_plot_dir'], new_filename)
plt.savefig(plot_path, **self._cfg['savefig_kwargs'])
logger.info("Wrote %s", plot_path)
plt.close()
# Save provenance
cube = mlr.get_1d_cube(
y_tick_labels,
coefs[sorted_idx],
x_kwargs={'var_name': 'feature',
'long_name': 'Feature name',
'units': 'no unit'},
y_kwargs={'var_name': 'coef',
'long_name': '(Normalized) Linear Coefficients',
'units': '1',
'attributes': {'project': '', 'dataset': ''}},
)
self._write_plot_provenance(
cube, plot_path, ancestors=self.get_ancestors(prediction_names=[]),
caption=title + '.', plot_types=['bar'])
[docs]
def plot_feature_importance(self, filename=None, color_coded=True):
"""Plot feature importance given by linear coefficients.
Note
----
The features plotted here are not necessarily the real input features,
but the ones after preprocessing.
Parameters
----------
filename : str, optional (default: 'feature_importance')
Name of the plot file.
color_coded : bool, optional (default: True)
If ``True``, mark positive (linear) correlations with red bars and
negative (linear) correlations with blue bars. If ``False``, all
bars are blue.
"""
if not self._is_ready_for_plotting():
return
# Get plot path
if filename is None:
filename = 'feature_importance'
new_filename = filename + '.' + self._cfg['output_file_type']
plot_path = os.path.join(self._cfg['mlr_plot_dir'], new_filename)
# Get feature importance dictionary and colors for bars
coefs = self._clf.coef_
feature_importances = np.abs(coefs) / np.sum(np.abs(coefs))
feature_importance_dict = dict(zip(self.features_after_preprocessing,
feature_importances))
colors = self._get_colors_for_features(color_coded=color_coded)
# Plot
self._plot_feature_importance(feature_importance_dict, colors,
plot_path)