Spaces:
Running
on
Zero
Running
on
Zero
| # | |
| # For licensing see accompanying LICENSE file. | |
| # Copyright (C) 2025 Apple Inc. All Rights Reserved. | |
| # | |
| """ | |
| Training utilities for STARFlow. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.distributed | |
| import torch.distributed.checkpoint as dcp | |
| from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy, CPUOffloadPolicy | |
| from torch.distributed._tensor import DeviceMesh | |
| from torch.distributed.device_mesh import init_device_mesh | |
| import datetime | |
| import math | |
| import os | |
| import random | |
| import numpy as np | |
| import contextlib | |
| import typing as t | |
| from typing import Any, Dict, List, Union, Optional | |
| from collections import defaultdict, OrderedDict | |
| from fnmatch import fnmatch | |
| # ==== Learning Rate Schedule ==== | |
| class CosineLRSchedule(torch.nn.Module): | |
| counter: torch.Tensor | |
| def __init__(self, optimizer, warmup_steps: int, total_steps: int, min_lr: float, max_lr: float): | |
| super().__init__() | |
| self.register_buffer('counter', torch.zeros(())) | |
| self.warmup_steps = warmup_steps | |
| self.total_steps = total_steps | |
| self.optimizer = optimizer | |
| self.min_lr = min_lr | |
| self.start_lr = min(min_lr, 1e-6) | |
| self.max_lr = max_lr | |
| self.set_lr(min_lr) | |
| def set_lr(self, lr: float) -> float: | |
| if self.min_lr <= lr <= self.max_lr: | |
| for pg in self.optimizer.param_groups: | |
| pg['lr'] = lr | |
| return pg['lr'] | |
| def step(self) -> float: | |
| with torch.no_grad(): | |
| counter = self.counter.add_(1).item() | |
| if self.counter <= self.warmup_steps: | |
| new_lr = self.start_lr + counter / self.warmup_steps * (self.max_lr - self.start_lr) | |
| return self.set_lr(new_lr) | |
| t = (counter - self.warmup_steps) / (self.total_steps - self.warmup_steps) | |
| new_lr = self.min_lr + 0.5 * (1 + math.cos(math.pi * t)) * (self.max_lr - self.min_lr) | |
| return self.set_lr(new_lr) | |
| # ==== Distributed Training ==== | |
| class Distributed: | |
| timeout: float = 72000 | |
| def __init__(self): | |
| if os.environ.get('MASTER_PORT'): # When running with torchrun | |
| self.rank = int(os.environ['RANK']) | |
| self.local_rank = int(os.environ['LOCAL_RANK']) | |
| self.world_size = int(os.environ['WORLD_SIZE']) | |
| self.distributed = True | |
| torch.distributed.init_process_group( | |
| backend='nccl', | |
| init_method='env://', | |
| world_size=self.world_size, | |
| timeout=datetime.timedelta(seconds=self.timeout), | |
| rank=self.rank, | |
| ) | |
| else: # When running with python for debugging | |
| self.rank, self.local_rank, self.world_size = 0, 0, 1 | |
| self.distributed = False | |
| # Only set CUDA device if CUDA is available | |
| if torch.cuda.is_available(): | |
| torch.cuda.set_device(self.local_rank) | |
| self.barrier() | |
| def barrier(self) -> None: | |
| if self.distributed: | |
| torch.distributed.barrier() | |
| def gather_concat(self, x: torch.Tensor) -> torch.Tensor: | |
| if not self.distributed: | |
| return x | |
| x_list = [torch.empty_like(x) for _ in range(self.world_size)] | |
| torch.distributed.all_gather(x_list, x) | |
| return torch.cat(x_list) | |
| def reduce(self, x): | |
| if not self.distributed: | |
| return x | |
| torch.distributed.all_reduce(x, op=torch.distributed.ReduceOp.SUM) | |
| return x | |
| def __del__(self): | |
| if self.distributed: | |
| torch.distributed.destroy_process_group() | |
| def get_local_rank() -> int: | |
| if os.environ.get('MASTER_PORT'): # When running with torchrun | |
| return int(os.environ['LOCAL_RANK']) | |
| return 0 | |
| def get_device_mesh(dp_size: int, tp_size: int = 1) -> DeviceMesh: | |
| """Create DeviceMesh based on tensor and data parallelism configuration.""" | |
| # by default, I will use TP=1 for simplicity | |
| mesh_shape = (dp_size, tp_size) | |
| names = ("dp", "tp") | |
| return init_device_mesh("cuda", mesh_shape=mesh_shape, mesh_dim_names=names) | |
| def wrap_matching_layers( | |
| model: nn.Module, | |
| layer_patterns: t.List[str], | |
| wrapper_fn: t.Callable[[nn.Module], nn.Module], | |
| ): | |
| """ | |
| Recursively wraps submodules in the order they appear in layer_patterns. | |
| For each pattern (in order), we do a pass over the model and wrap matches. | |
| """ | |
| def _wrap_single_pattern(mod: nn.Module, pattern: str): | |
| """ | |
| Recurse over mod, wrapping submodules that match `pattern`. | |
| We do a post-order traversal so children get wrapped before the parent. | |
| """ | |
| for child_name, child_module in list(mod.named_children()): | |
| # Wrap grandchildren first. | |
| _wrap_single_pattern(child_module, pattern) | |
| # Check if the child's class name matches the pattern. | |
| if fnmatch(child_module.__class__.__name__, pattern): | |
| # Replace the child in the parent. | |
| wrapped = wrapper_fn(child_module) | |
| setattr(mod, child_name, wrapped) | |
| # We do a pass for each pattern in order | |
| for pattern in layer_patterns: | |
| _wrap_single_pattern(model, pattern) | |
| def parallelize_model(args, model: nn.Module, dist: Distributed, device='cuda', block_names=['AttentionBlock']) -> nn.Module: | |
| if not getattr(args, "fsdp", False): # use standard DDP | |
| model = model.to(device=device) | |
| if dist.distributed: | |
| print(f"Using DDP") | |
| model_ddp = torch.nn.parallel.DistributedDataParallel(model, device_ids=[dist.local_rank]) | |
| else: | |
| model_ddp = model # compatible with DDP | |
| return model, model_ddp | |
| # Instantiate mixed precision policy from config | |
| mp_policy = MixedPrecisionPolicy( | |
| param_dtype=torch.bfloat16, | |
| reduce_dtype=torch.bfloat16, | |
| output_dtype=torch.bfloat16, | |
| cast_forward_inputs=True | |
| ) | |
| print(f"Using FSDP2 with: {mp_policy}") | |
| # Apply FSDP wrapping based on specified parallel dimensions | |
| dp_mesh = get_device_mesh(dist.world_size)["dp"] | |
| # Configure core FSDP parameters | |
| fsdp_config = {"mp_policy": mp_policy, "mesh": dp_mesh, "reshard_after_forward": True} | |
| # Wrap specified layer patterns with FSDP | |
| wrap_matching_layers(model, block_names, lambda m: fully_shard(m, **fsdp_config)) | |
| # Then wrap full model (remaining modules are captured with this) | |
| model = fully_shard(model, **fsdp_config) | |
| model = model.to(device=device) | |
| return model, model # for compatibility with DDP | |
| def save_model(args, dist, model, model_ckpt_file): | |
| states = model.state_dict() | |
| if not getattr(args, "fsdp", False): # save DDP checkpoints | |
| if dist.local_rank == 0: | |
| torch.save(states, model_ckpt_file) | |
| else: # save FSDP checkpoints | |
| dcp.save(states, checkpoint_id=str(model_ckpt_file)) | |
| def save_optimizer(args, dist, optimizer, lr_schedule, opt_ckpt_file): | |
| optim_states, lr_states = optimizer.state_dict(), lr_schedule.state_dict() | |
| if not getattr(args, "fsdp", False): # save DDP checkpoints | |
| if dist.local_rank == 0: | |
| torch.save({"optimizer": optim_states, "lr_schedule": lr_states}, opt_ckpt_file) | |
| else: | |
| filename = str(opt_ckpt_file) | |
| dcp.save(optim_states, checkpoint_id=f"{filename}/optimizer") | |
| torch.save(lr_states, f"{filename}/lr_schedule.bin") # lr_schedule is not fsdp | |
| def _fsdp2_no_sync(module, sync): | |
| # v2 APIs | |
| module.set_requires_gradient_sync(sync, recurse=True) | |
| try: | |
| yield | |
| finally: | |
| module.set_requires_gradient_sync(True, recurse=True) | |
| def sync_ctx(model, sync=True): | |
| if hasattr(model, 'set_requires_gradient_sync'): | |
| return _fsdp2_no_sync(model, sync) | |
| elif not sync and hasattr(model, 'no_sync'): | |
| return model.no_sync() | |
| return contextlib.nullcontext() | |
| # ==== Utility Functions ==== | |
| def set_random_seed(seed: int) -> None: | |
| """Set random seed for reproducibility.""" | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| torch.cuda.manual_seed(seed) | |
| torch.cuda.manual_seed_all(seed) |