Supporting a New Model

Learn how to add a new model.

Supporting a new model

Let’s see what the code for the AutoMPG regressor module looks like.

Press + to interact
from typing import TYPE_CHECKING, Dict, Tuple
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score
from joblib import load, dump
if TYPE_CHECKING:
import logging
import pandas as pd
from omegaconf import DictConfig
from ml_pipeline.mixins.reporting_mixin import ReportingMixin
from ml_pipeline.mixins.training_mixin import TrainingMixin
from ml_pipeline.model import Model
class AutoMPGRegressor(TrainingMixin, Model, ReportingMixin):
def __init__(
self,
model_params: "DictConfig",
training_params: "DictConfig",
artifact_dir: str,
logger: "logging.Logger" = None,
) -> None:
self.model = LinearRegression(**model_params)
self.training_params = training_params
self.artifact_dir = artifact_dir
self.logger = logger
def load(self, model_path: str) -> None:
self.model = load(model_path)
def _encode_train_data(
self, X: "pd.DataFrame" = None, y: "pd.Series" = None
) -> Tuple["pd.DataFrame", "pd.Series"]:
# in this example, we don't do any encoding
return X, y
def _encode_test_data(
self, X: "pd.DataFrame" = None, y: "pd.Series" = None
) -> Tuple["pd.DataFrame", "pd.Series"]:
# in this example, we don't do any encoding
return X, y
def _compute_metrics(
self, y_true: "pd.Series", y_pred: "pd.Series"
) -> Dict:
self.metrics = {}
self.metrics["mean_squared_error"] = mean_squared_error(y_true, y_pred)
self.metrics["r2_score"] = r2_score(y_true, y_pred)
def create_report(self) -> None:
self.save_metrics()
def save(self) -> None:
filename = f"{self.artifact_dir}/model.joblib"
dump(self.model, filename)
self.logger.debug(f"Saved {filename}.")
def predict(self, X: "pd.DataFrame") -> int:
return self.model.predict(X)

If we compare this to the iris classifier module, we see similarities, mainly because, like IrisClassifier, this class derives from TrainingMixin, Model, and ReportingMixin. We also see several differences:

  • We see that this model uses the LinearRegression module from scikit-learn instead of the LogisticRegression module we used for iris classification in line 26.

  • The iris classification is a regression problem, so it required related metrics, such as accuracy and confusion matrix. However, this project computes regression-related metrics, such as mean squared error and the coefficient of determination. ...