Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| # Copyright (c) XiMing Xing. All rights reserved. | |
| # Author: XiMing Xing | |
| # Description: | |
| from typing import Callable | |
| from tqdm.auto import tqdm | |
| def tqdm_decorator(func: Callable): | |
| """A decorator function called tqdm_decorator that takes a function as an argument and | |
| returns a new function that wraps the input function with a tqdm progress bar. | |
| Noting: **The input function is assumed to have an object self as its first argument**, which contains a step attribute, | |
| an args attribute with a train_num_steps attribute, and an accelerator attribute with an is_main_process attribute. | |
| Args: | |
| func: tqdm_decorator | |
| Returns: | |
| a new function that wraps the input function with a tqdm progress bar. | |
| """ | |
| def wrapper(*args, **kwargs): | |
| with tqdm(initial=args[0].step, | |
| total=args[0].args.train_num_steps, | |
| disable=not args[0].accelerator.is_main_process) as pbar: | |
| func(*args, **kwargs, pbar=pbar) | |
| return wrapper | |