Source code for esmvaltool.diag_scripts.mlr.models.gpr_sklearn
"""Gaussian Process Regression model (using :mod:`sklearn`).
Use ``mlr_model_type: gpr_sklearn`` to use this MLR model in the recipe.
"""
# pylint: disable=arguments-differ
import logging
import os
from sklearn.gaussian_process import GaussianProcessRegressor
from esmvaltool.diag_scripts.mlr.models import MLRModel
logger = logging.getLogger(os.path.basename(__file__))
[docs]
class AdvancedGaussianProcessRegressor(GaussianProcessRegressor):
"""Expand :class:`sklearn.gaussian_process.GaussianProcessRegressor`."""
[docs]
def predict(self, x_data, return_var=False, return_cov=False):
"""Expand :meth:`predict` to accept ``return_var``."""
pred = super().predict(
x_data, return_std=return_var, return_cov=return_cov
)
if return_var:
return (pred[0], pred[1] ** 2)
return pred
[docs]
@MLRModel.register_mlr_model("gpr_sklearn")
class SklearnGPRModel(MLRModel):
"""Gaussian Process Regression model (:mod:`sklearn` implementation)."""
_CLF_TYPE = AdvancedGaussianProcessRegressor
[docs]
def print_kernel_info(self):
"""Print information of the fitted kernel of the GPR model."""
self._check_fit_status("Printing kernel")
kernel = self._clf.steps[-1][1].regressor_.kernel_
logger.info("Fitted kernel: %s", kernel)
logger.info("All fitted log-hyperparameters:")
for idx, hyper_param in enumerate(kernel.hyperparameters):
logger.info("%s: %s", hyper_param, kernel.theta[idx])