Spaces:
Runtime error
Runtime error
| # Copyright 2024 MIT Han Lab | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # SPDX-License-Identifier: Apache-2.0 | |
| from typing import Any, Optional, Union | |
| import numpy as np | |
| import torch | |
| __all__ = [ | |
| "torch_randint", | |
| "torch_random", | |
| "torch_shuffle", | |
| "torch_uniform", | |
| "torch_random_choices", | |
| ] | |
| def torch_randint(low: int, high: int, generator: Optional[torch.Generator] = None) -> int: | |
| """uniform: [low, high)""" | |
| if low == high: | |
| return low | |
| else: | |
| assert low < high | |
| return int(torch.randint(low=low, high=high, generator=generator, size=(1,))) | |
| def torch_random(generator: Optional[torch.Generator] = None) -> float: | |
| """uniform distribution on the interval [0, 1)""" | |
| return float(torch.rand(1, generator=generator)) | |
| def torch_shuffle(src_list: list[Any], generator: Optional[torch.Generator] = None) -> list[Any]: | |
| rand_indexes = torch.randperm(len(src_list), generator=generator).tolist() | |
| return [src_list[i] for i in rand_indexes] | |
| def torch_uniform(low: float, high: float, generator: Optional[torch.Generator] = None) -> float: | |
| """uniform distribution on the interval [low, high)""" | |
| rand_val = torch_random(generator) | |
| return (high - low) * rand_val + low | |
| def torch_random_choices( | |
| src_list: list[Any], | |
| generator: Optional[torch.Generator] = None, | |
| k=1, | |
| weight_list: Optional[list[float]] = None, | |
| ) -> Union[Any, list]: | |
| if weight_list is None: | |
| rand_idx = torch.randint(low=0, high=len(src_list), generator=generator, size=(k,)) | |
| out_list = [src_list[i] for i in rand_idx] | |
| else: | |
| assert len(weight_list) == len(src_list) | |
| accumulate_weight_list = np.cumsum(weight_list) | |
| out_list = [] | |
| for _ in range(k): | |
| val = torch_uniform(0, accumulate_weight_list[-1], generator) | |
| active_id = 0 | |
| for i, weight_val in enumerate(accumulate_weight_list): | |
| active_id = i | |
| if weight_val > val: | |
| break | |
| out_list.append(src_list[active_id]) | |
| return out_list[0] if k == 1 else out_list | |