Spaces:
Configuration error
Configuration error
| # EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction | |
| # Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han | |
| # International Conference on Computer Vision (ICCV), 2023 | |
| import numpy as np | |
| import torch | |
| __all__ = [ | |
| "torch_randint", | |
| "torch_random", | |
| "torch_shuffle", | |
| "torch_uniform", | |
| "torch_random_choices", | |
| ] | |
| def torch_randint( | |
| low: int, high: int, generator: torch.Generator or None = None | |
| ) -> int: | |
| """uniform: [low, high)""" | |
| if low == high: | |
| return low | |
| else: | |
| assert low < high | |
| return int(torch.randint(low=low, high=high, generator=generator, size=(1,))) | |
| def torch_random(generator: torch.Generator or None = None) -> float: | |
| """uniform distribution on the interval [0, 1)""" | |
| return float(torch.rand(1, generator=generator)) | |
| def torch_shuffle( | |
| src_list: list[any], generator: torch.Generator or None = None | |
| ) -> list[any]: | |
| rand_indexes = torch.randperm(len(src_list), generator=generator).tolist() | |
| return [src_list[i] for i in rand_indexes] | |
| def torch_uniform( | |
| low: float, high: float, generator: torch.Generator or None = None | |
| ) -> float: | |
| """uniform distribution on the interval [low, high)""" | |
| rand_val = torch_random(generator) | |
| return (high - low) * rand_val + low | |
| def torch_random_choices( | |
| src_list: list[any], | |
| generator: torch.Generator or None = None, | |
| k=1, | |
| weight_list: list[float] or None = None, | |
| ) -> any or list: | |
| if weight_list is None: | |
| rand_idx = torch.randint( | |
| low=0, high=len(src_list), generator=generator, size=(k,) | |
| ) | |
| out_list = [src_list[i] for i in rand_idx] | |
| else: | |
| assert len(weight_list) == len(src_list) | |
| accumulate_weight_list = np.cumsum(weight_list) | |
| out_list = [] | |
| for _ in range(k): | |
| val = torch_uniform(0, accumulate_weight_list[-1], generator) | |
| active_id = 0 | |
| for i, weight_val in enumerate(accumulate_weight_list): | |
| active_id = i | |
| if weight_val > val: | |
| break | |
| out_list.append(src_list[active_id]) | |
| return out_list[0] if k == 1 else out_list | |