""" TensorFlow SavedModel predictor base class. Loads and runs inference on TensorFlow SavedModel format. """ import logging import numpy as np try: import tensorflow as tf # Limit GPU memory growth gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) except ImportError: tf = None logging.warning('TensorFlow not available') class TensorFlowPredictor: """Base class for TensorFlow SavedModel predictors.""" def __init__(self, model_path, device='gpu'): if tf is None: raise ImportError('TensorFlow is required for this predictor') self.device = device self.model = self._load_model(model_path) logging.info('TensorFlow SavedModel loaded: %s', model_path) def _load_model(self, model_path): """Load SavedModel from directory.""" return tf.saved_model.load(model_path) def preprocess(self, images): """ Preprocess images before inference. Override in subclass. """ raise NotImplementedError def postprocess(self, outputs): """ Postprocess model outputs. Override in subclass. """ raise NotImplementedError def predict(self, streams, **kwargs): """ Run prediction on input streams. Override in subclass. """ raise NotImplementedError class KerasPredictor: """Base class for Keras model predictors (for .h5 or SavedModel).""" def __init__(self, model_path, device='gpu'): if tf is None: raise ImportError('TensorFlow is required for this predictor') self.device = device self.model = self._load_model(model_path) logging.info('Keras model loaded: %s', model_path) def _load_model(self, model_path): """Load Keras model.""" return tf.keras.models.load_model(model_path, compile=False) def preprocess(self, images): raise NotImplementedError def postprocess(self, outputs): raise NotImplementedError def predict(self, streams, **kwargs): raise NotImplementedError