Supporting a New Model
Learn how to add a new model.
Supporting a new model
Let’s see what the code for the AutoMPG regressor module looks like.
Press + to interact
from typing import TYPE_CHECKING, Dict, Tuplefrom sklearn.linear_model import LinearRegressionfrom sklearn.metrics import mean_squared_error, r2_scorefrom joblib import load, dumpif TYPE_CHECKING:import loggingimport pandas as pdfrom omegaconf import DictConfigfrom ml_pipeline.mixins.reporting_mixin import ReportingMixinfrom ml_pipeline.mixins.training_mixin import TrainingMixinfrom ml_pipeline.model import Modelclass AutoMPGRegressor(TrainingMixin, Model, ReportingMixin):def __init__(self,model_params: "DictConfig",training_params: "DictConfig",artifact_dir: str,logger: "logging.Logger" = None,) -> None:self.model = LinearRegression(**model_params)self.training_params = training_paramsself.artifact_dir = artifact_dirself.logger = loggerdef load(self, model_path: str) -> None:self.model = load(model_path)def _encode_train_data(self, X: "pd.DataFrame" = None, y: "pd.Series" = None) -> Tuple["pd.DataFrame", "pd.Series"]:# in this example, we don't do any encodingreturn X, ydef _encode_test_data(self, X: "pd.DataFrame" = None, y: "pd.Series" = None) -> Tuple["pd.DataFrame", "pd.Series"]:# in this example, we don't do any encodingreturn X, ydef _compute_metrics(self, y_true: "pd.Series", y_pred: "pd.Series") -> Dict:self.metrics = {}self.metrics["mean_squared_error"] = mean_squared_error(y_true, y_pred)self.metrics["r2_score"] = r2_score(y_true, y_pred)def create_report(self) -> None:self.save_metrics()def save(self) -> None:filename = f"{self.artifact_dir}/model.joblib"dump(self.model, filename)self.logger.debug(f"Saved {filename}.")def predict(self, X: "pd.DataFrame") -> int:return self.model.predict(X)
If we compare this to the iris classifier module, we see similarities, mainly because, like IrisClassifier
, this class derives from TrainingMixin
, Model
, and ReportingMixin
. We also see several differences:
We see that this model uses the
LinearRegression
module fromscikit-learn
instead of theLogisticRegression
module we used for iris classification in line 26.The iris classification is a regression problem, so it required related metrics, such as accuracy and confusion matrix. However, this project computes regression-related metrics, such as mean squared error and the coefficient of determination. ...