Source code for esmvaltool.diag_scripts.mlr.models.gpr_sklearn

"""Gaussian Process Regression model (using :mod:`sklearn`).

Use ``mlr_model_type: gpr_sklearn`` to use this MLR model in the recipe.

"""

# pylint: disable=arguments-differ

import logging
import os

from sklearn.gaussian_process import GaussianProcessRegressor

from esmvaltool.diag_scripts.mlr.models import MLRModel

logger = logging.getLogger(os.path.basename(__file__))


[docs]class AdvancedGaussianProcessRegressor(GaussianProcessRegressor): """Expand :class:`sklearn.gaussian_process.GaussianProcessRegressor`."""
[docs] def predict(self, x_data, return_var=False, return_cov=False): """Expand :meth:`predict` to accept ``return_var``.""" pred = super().predict(x_data, return_std=return_var, return_cov=return_cov) if return_var: return (pred[0], pred[1]**2) return pred
[docs]@MLRModel.register_mlr_model('gpr_sklearn') class SklearnGPRModel(MLRModel): """Gaussian Process Regression model (:mod:`sklearn` implementation).""" _CLF_TYPE = AdvancedGaussianProcessRegressor
[docs] def print_kernel_info(self): """Print information of the fitted kernel of the GPR model.""" self._check_fit_status('Printing kernel') kernel = self._clf.steps[-1][1].regressor_.kernel_ logger.info("Fitted kernel: %s", kernel) logger.info("All fitted log-hyperparameters:") for (idx, hyper_param) in enumerate(kernel.hyperparameters): logger.info("%s: %s", hyper_param, kernel.theta[idx])