starry / backend /python-services /predictors /tensorflow_predictor.py
k-l-lambda's picture
feat: add Python ML services (CPU mode) with model download
2b7aae2
"""
TensorFlow SavedModel predictor base class.
Loads and runs inference on TensorFlow SavedModel format.
"""
import logging
import numpy as np
try:
import tensorflow as tf
# Limit GPU memory growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except ImportError:
tf = None
logging.warning('TensorFlow not available')
class TensorFlowPredictor:
"""Base class for TensorFlow SavedModel predictors."""
def __init__(self, model_path, device='gpu'):
if tf is None:
raise ImportError('TensorFlow is required for this predictor')
self.device = device
self.model = self._load_model(model_path)
logging.info('TensorFlow SavedModel loaded: %s', model_path)
def _load_model(self, model_path):
"""Load SavedModel from directory."""
return tf.saved_model.load(model_path)
def preprocess(self, images):
"""
Preprocess images before inference.
Override in subclass.
"""
raise NotImplementedError
def postprocess(self, outputs):
"""
Postprocess model outputs.
Override in subclass.
"""
raise NotImplementedError
def predict(self, streams, **kwargs):
"""
Run prediction on input streams.
Override in subclass.
"""
raise NotImplementedError
class KerasPredictor:
"""Base class for Keras model predictors (for .h5 or SavedModel)."""
def __init__(self, model_path, device='gpu'):
if tf is None:
raise ImportError('TensorFlow is required for this predictor')
self.device = device
self.model = self._load_model(model_path)
logging.info('Keras model loaded: %s', model_path)
def _load_model(self, model_path):
"""Load Keras model."""
return tf.keras.models.load_model(model_path, compile=False)
def preprocess(self, images):
raise NotImplementedError
def postprocess(self, outputs):
raise NotImplementedError
def predict(self, streams, **kwargs):
raise NotImplementedError