1

I've built a scikit-learn pipeline which uses an LSTM Keras model (wrapped in a keras.wrappers.scikit_learn.KerasClassifier) as the last pipeline step. Once the pipeline finishes training I save the whole pipeline to disk (see below). I'm having trouble loading the pipeline back into memory and then making predictions. The scikit-learn pipeline and Keras models don't seem to play well together at the moment which is making things tricky. Does anyone have experience with this?

tensorflow: 2.3.1 keras: 2.4.3 scikit-learn: 0.23.2

Code:

import pandas as pd
from model_lstm.config import config
import joblib
import keras
from keras.wrappers.scikit_learn import KerasClassifier
from model_lstm.utils import data_management as dm

def save_fitted_pipeline(pipeline):
    model_path = config.TRAINED_MODEL_DIR / config.TRAINED_MODEL_FILE
    pipeline_path = config.TRAINED_MODEL_DIR / config.TRAINED_PIPELINE_FILE
    pipeline.named_steps["lstm_model"].model.save(model_path)
    pipeline.named_steps["lstm_model"].model = None
    joblib.dump(pipeline, pipeline_path)

def load_fitted_pipeline():
    model_path = config.TRAINED_MODEL_DIR / config.TRAINED_MODEL_FILE
    pipeline_path = config.TRAINED_MODEL_DIR / config.TRAINED_PIPELINE_FILE
    pipeline = joblib.load(pipeline_path)
    model_func = lambda: keras.models.load_model(model_path)
    wrapped_model = KerasClassifier(build_fn=model_func)
    pipeline.named_steps["lstm_model"] = wrapped_model
    pipeline.named_steps["lstm_model"].model = keras.models.load_model(model_path)
    return pipeline

def predict():
    lstm_pipeline = load_fitted_pipeline()
    data_path = config.DATA_DIR / config.TRAINING_DATA_FILE
    X_train, y_train = dm.load_data(data_path)
    pred = lstm_pipeline.predict(X_train)

Current error:

../model_lstm/predict.py:8: in predict
    pred = lstm_pipeline.predict(X_train)
../../../../anaconda/envs/sa_model_lstm/lib/python3.7/site-packages/sklearn/utils/metaestimators.py:119: in <lambda>
    out = lambda *args, **kwargs: self.fn(obj, *args, **kwargs)
../../../../anaconda/envs/sa_model_lstm/lib/python3.7/site-packages/sklearn/pipeline.py:408: in predict
    return self.steps[-1][-1].predict(Xt, **predict_params)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <tensorflow.python.keras.wrappers.scikit_learn.KerasClassifier object at 0x1a4256f690>
x = array([[   0,    0,    0, ...,  125,  309,  310],
       [   0,    0,    0, ...,   19,    3,  312],
       [   0,    0...076],
       [   0,    0,    0, ...,    2, 1077,   13],
       [   0,    0,    0, ..., 1080,  160, 1081]], dtype=int32)
kwargs = {'batch_size': 128, 'verbose': 1}

    def predict(self, x, **kwargs):
      """Returns the class predictions for the given test data.
    
      Arguments:
          x: array-like, shape `(n_samples, n_features)`
              Test samples where `n_samples` is the number of samples
              and `n_features` is the number of features.
          **kwargs: dictionary arguments
              Legal arguments are the arguments
              of `Sequential.predict_classes`.
    
      Returns:
          preds: array-like, shape `(n_samples,)`
              Class predictions.
      """
      kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs)
>     classes = self.model.predict_classes(x, **kwargs)
E     AttributeError: 'Functional' object has no attribute 'predict_classes'

../../../../anaconda/envs/sa_model_lstm/lib/python3.7/site-packages/tensorflow/python/keras/wrappers/scikit_learn.py:241: AttributeError
Swazy
  • 298
  • 3
  • 9

1 Answers1

1

This is my question. For those that are interested, I've managed to figure this out for myself. Turns out the the issue was actually nothing to do with writing then reading the pipeline to disk. The keras.wrappers.scikit_learn.KerasClassifier wrapper only seems to work correctly when your Keras model is an instance of Sequential and not Model as it was in my case. I converted my model to Sequential and everything worked fine. In fact the save and load logic became rather simpler than I have shown in the code above.

Swazy
  • 298
  • 3
  • 9