Spaces:
Sleeping
Sleeping
| # YOLOv5 π by Ultralytics, AGPL-3.0 license | |
| """Callback utils.""" | |
| import threading | |
| class Callbacks: | |
| """" Handles all registered callbacks for YOLOv5 Hooks.""" | |
| def __init__(self): | |
| # Define the available callbacks | |
| self._callbacks = { | |
| "on_pretrain_routine_start": [], | |
| "on_pretrain_routine_end": [], | |
| "on_train_start": [], | |
| "on_train_epoch_start": [], | |
| "on_train_batch_start": [], | |
| "optimizer_step": [], | |
| "on_before_zero_grad": [], | |
| "on_train_batch_end": [], | |
| "on_train_epoch_end": [], | |
| "on_val_start": [], | |
| "on_val_batch_start": [], | |
| "on_val_image_end": [], | |
| "on_val_batch_end": [], | |
| "on_val_end": [], | |
| "on_fit_epoch_end": [], # fit = train + val | |
| "on_model_save": [], | |
| "on_train_end": [], | |
| "on_params_update": [], | |
| "teardown": [], | |
| } | |
| self.stop_training = False # set True to interrupt training | |
| def register_action(self, hook, name="", callback=None): | |
| """ | |
| Register a new action to a callback hook. | |
| Args: | |
| hook: The callback hook name to register the action to | |
| name: The name of the action for later reference | |
| callback: The callback to fire | |
| """ | |
| assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" | |
| assert callable(callback), f"callback '{callback}' is not callable" | |
| self._callbacks[hook].append({"name": name, "callback": callback}) | |
| def get_registered_actions(self, hook=None): | |
| """ | |
| " Returns all the registered actions by callback hook. | |
| Args: | |
| hook: The name of the hook to check, defaults to all | |
| """ | |
| return self._callbacks[hook] if hook else self._callbacks | |
| def run(self, hook, *args, thread=False, **kwargs): | |
| """ | |
| Loop through the registered actions and fire all callbacks on main thread. | |
| Args: | |
| hook: The name of the hook to check, defaults to all | |
| args: Arguments to receive from YOLOv5 | |
| thread: (boolean) Run callbacks in daemon thread | |
| kwargs: Keyword Arguments to receive from YOLOv5 | |
| """ | |
| assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}" | |
| for logger in self._callbacks[hook]: | |
| if thread: | |
| threading.Thread(target=logger["callback"], args=args, kwargs=kwargs, daemon=True).start() | |
| else: | |
| logger["callback"](*args, **kwargs) | |