Source code for src.serialize

"""
Serialize and deserialize trained model pipelines.
"""
import logging
import os

import joblib
import sklearn.pipeline

from src import load_data

logger = logging.getLogger(__name__)


[docs]def save_pipeline(pipeline: sklearn.pipeline.Pipeline, save_path: str) -> None: """ Serialize a fitted model pipeline. Args: pipeline (:obj:`sklearn.pipeline.Pipeline`): Fitted model pipeline save_path (str): Where to save the pipeline Returns: None """ # Saving to S3 requires some extra care compared to a local directory, # so first save to a local directory and then upload in a separate step. # Put the local copy in the same place it would have gone inside S3. if save_path.startswith("s3://"): _, s3path = load_data.parse_s3(save_path) local_path = s3path joblib.dump(pipeline, local_path) logger.debug("Saved a copy of the model to %s", local_path) load_data.upload_file_to_s3(local_path=local_path, s3path=save_path) else: joblib.dump(pipeline, save_path) logger.info("Saved model to %s", save_path)
[docs]def load_pipeline(load_path: str) -> sklearn.pipeline.Pipeline: """ Deserialize a fitted model pipeline. Args: load_path (str): Path to joblib-saved pipeline Returns: Fitted :obj:`sklearn.pipeline.Pipeline` object """ # Download from S3 if a local copy does not already exist # This helps improve inference speed by reducing unnecessary # I/O and network calls if load_path.startswith("s3://"): _, s3path = load_data.parse_s3(load_path) local_path = s3path if not os.path.exists(local_path): load_data.download_file_from_s3(local_path=local_path, s3path=load_path) logger.debug("Downloaded a copy of the model to %s", local_path) else: logger.debug("Using existing local copy of model at %s", local_path) pipeline = joblib.load(local_path) else: pipeline = joblib.load(load_path) logger.info("Loaded model pipeline from %s", load_path) return pipeline