Spaces:
Running
Running
| """ | |
| TensorFlow SavedModel predictor base class. | |
| Loads and runs inference on TensorFlow SavedModel format. | |
| """ | |
| import logging | |
| import numpy as np | |
| try: | |
| import tensorflow as tf | |
| # Limit GPU memory growth | |
| gpus = tf.config.experimental.list_physical_devices('GPU') | |
| if gpus: | |
| for gpu in gpus: | |
| tf.config.experimental.set_memory_growth(gpu, True) | |
| except ImportError: | |
| tf = None | |
| logging.warning('TensorFlow not available') | |
| class TensorFlowPredictor: | |
| """Base class for TensorFlow SavedModel predictors.""" | |
| def __init__(self, model_path, device='gpu'): | |
| if tf is None: | |
| raise ImportError('TensorFlow is required for this predictor') | |
| self.device = device | |
| self.model = self._load_model(model_path) | |
| logging.info('TensorFlow SavedModel loaded: %s', model_path) | |
| def _load_model(self, model_path): | |
| """Load SavedModel from directory.""" | |
| return tf.saved_model.load(model_path) | |
| def preprocess(self, images): | |
| """ | |
| Preprocess images before inference. | |
| Override in subclass. | |
| """ | |
| raise NotImplementedError | |
| def postprocess(self, outputs): | |
| """ | |
| Postprocess model outputs. | |
| Override in subclass. | |
| """ | |
| raise NotImplementedError | |
| def predict(self, streams, **kwargs): | |
| """ | |
| Run prediction on input streams. | |
| Override in subclass. | |
| """ | |
| raise NotImplementedError | |
| class KerasPredictor: | |
| """Base class for Keras model predictors (for .h5 or SavedModel).""" | |
| def __init__(self, model_path, device='gpu'): | |
| if tf is None: | |
| raise ImportError('TensorFlow is required for this predictor') | |
| self.device = device | |
| self.model = self._load_model(model_path) | |
| logging.info('Keras model loaded: %s', model_path) | |
| def _load_model(self, model_path): | |
| """Load Keras model.""" | |
| return tf.keras.models.load_model(model_path, compile=False) | |
| def preprocess(self, images): | |
| raise NotImplementedError | |
| def postprocess(self, outputs): | |
| raise NotImplementedError | |
| def predict(self, streams, **kwargs): | |
| raise NotImplementedError | |