# mypy: allow-untyped-defs import enum from typing import Any, Callable, overload import torch from torch.distributed.algorithms.join import Joinable, JoinHook from torch.optim import Optimizer class _ZeROJoinHook(JoinHook): zero: Any = ... def __init__(self, zero: Any) -> None: ... def main_hook(self) -> None: ... class _DDPBucketAssignment: bucket_index: int parameters: list[torch.Tensor] offset: int device: torch.device tensor: torch.Tensor | None class _OverlapStatus(enum.IntEnum): UNINITIALIZED: int = ... DDP_HAS_REBUILT_BUCKETS: int = ... INITIALIZED: int = ... class _OverlapInfo: status: Any = ... params_per_bucket: Any = ... params_per_rank: Any = ... offsets: Any = ... broadcast_handles: Any = ... bucket_index_to_future: Any = ... bucket_index_to_bucket: Any = ... bucket_indices_seen: Any = ... assigned_ranks_per_bucket: list[set[int]] = ... total_size: int = ... shard_buckets: bool = ... def __init__(self) -> None: ... def wait_for_broadcasts(self) -> None: ... def clear_per_iter_info(self) -> None: ... class ZeroRedundancyOptimizer(Optimizer, Joinable): functional_optim_map: Any = ... initialized: bool = ... process_group: Any = ... world_size: int = ... rank: int = ... global_rank: int = ... parameters_as_bucket_view: bool = ... optim: Any = ... _device_to_device_index: dict[torch.device, int] = ... _overlap_with_ddp: bool = ... _overlap_info: _OverlapInfo = ... _buckets: list[list[torch.Tensor]] = ... _bucket_assignments_per_rank: list[dict[int, _DDPBucketAssignment]] = ... def __init__( self, params: Any, optimizer_class: type[Optimizer], process_group: Any | None = ..., parameters_as_bucket_view: bool = ..., overlap_with_ddp: bool = ..., **defaults: Any, ) -> None: ... def add_param_group(self, param_group: dict[str, Any]) -> None: ... def consolidate_state_dict(self, to: int = ...) -> None: ... @overload def step(self, closure: None = ..., **kwargs: Any) -> None: ... @overload def step(self, closure: Callable[[], float], **kwargs: Any) -> float: ... def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... def state_dict(self) -> dict[str, Any]: ... def _local_step( self, gradients: list[torch.Tensor | None] | None = None, closure: Callable[[], float] | None = None, **kwargs: Any, ) -> float | None: ... def _get_assigned_rank(self, bucket_index: int) -> int: ... def _init_zero_for_overlap(self) -> None: ... def join_hook(self, **kwargs): ... @property def join_device(self) -> torch.device: ... def join_process_group(self) -> Any: ...