Spaces:
Sleeping
Sleeping
| # | |
| # For licensing see accompanying LICENSE file. | |
| # Copyright (C) 2025 Apple Inc. All Rights Reserved. | |
| # | |
| import io | |
| import os | |
| import csv | |
| import json | |
| import random | |
| import torch | |
| import numpy as np | |
| import math | |
| import time | |
| import contextlib | |
| from typing import Optional, Union | |
| from PIL import Image | |
| from collections import defaultdict | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| from torch.utils.data import default_collate, get_worker_info | |
| import tarfile | |
| import tqdm | |
| import gc | |
| import threading | |
| import psutil | |
| import tempfile | |
| # Optional import for video processing (not available on macOS ARM) | |
| try: | |
| import decord | |
| from decord import VideoReader | |
| DECORD_AVAILABLE = True | |
| except ImportError: | |
| DECORD_AVAILABLE = False | |
| print("⚠️ decord not available. Video processing will be disabled.") | |
| import concurrent.futures | |
| from concurrent.futures import ThreadPoolExecutor, TimeoutError | |
| from misc import print, xprint | |
| from misc.condition_utils import get_camera_condition, get_point_condition, get_wind_condition | |
| # Lazy initialization of multiprocessing manager (only when needed, not at import time) | |
| # This avoids issues on macOS which uses 'spawn' instead of 'fork' | |
| _manager = None | |
| def get_manager(): | |
| """Get or create the multiprocessing manager lazily.""" | |
| global _manager | |
| if _manager is None: | |
| try: | |
| # Only create manager when actually needed (not at import time) | |
| # This avoids RuntimeError on macOS with spawn method | |
| _manager = torch.multiprocessing.Manager() | |
| except (RuntimeError, EOFError) as e: | |
| # If manager creation fails (e.g., on macOS with spawn), return None | |
| # The code already handles None manager gracefully | |
| print(f"⚠️ Could not create multiprocessing manager: {e}") | |
| print(" Continuing without multiprocessing manager (may affect some features)") | |
| _manager = False # Use False to indicate attempted but failed | |
| return _manager if _manager is not False else None | |
| # For backward compatibility, but will be None until get_manager() is called | |
| manager = None | |
| # ==== helpers ==== # | |
| def ram_temp_file(data, suffix=".mp4"): | |
| available_ram = psutil.virtual_memory().available | |
| video_size = len(data) | |
| # Use RAM if available, otherwise fall back to disk | |
| if video_size < available_ram - (500 * 1024 * 1024): | |
| temp_dir = "/dev/shm" # RAM disk | |
| else: | |
| temp_dir = None # Default system temp (disk) | |
| with tempfile.NamedTemporaryFile(dir=temp_dir, suffix=suffix, delete=True) as temp_file: | |
| temp_file.write(data) | |
| temp_file.flush() | |
| yield temp_file.name | |
| def _nearest_multiple(x: float, base: int = 8) -> int: | |
| """Round x to the nearest multiple of `base`.""" | |
| return int(round(x / base)) * base | |
| def aspect_ratio_to_image_size(target_size, R, multiple=8): | |
| if R is None: | |
| return target_size, target_size | |
| if isinstance(R, str): | |
| rw, rh = map(int, R.split(':')) | |
| R = rw / rh | |
| area = target_size ** 2 | |
| out_h = _nearest_multiple(math.sqrt(area / R), multiple) | |
| out_w = _nearest_multiple(math.sqrt(area * R), multiple) | |
| return out_h, out_w | |
| def read_tsv(filename): | |
| # Open the TSV file for reading | |
| with open(filename, 'r', newline='') as tsvfile: | |
| reader = csv.reader(tsvfile, delimiter='\t') | |
| rows = [] | |
| while True: | |
| try: | |
| r = next(reader) | |
| rows.append(r) | |
| except csv.Error as e: | |
| print(f'{e}') | |
| except StopIteration: | |
| break | |
| return rows | |
| def sample_clip( | |
| video_path: str, | |
| num_frames: int = 8, | |
| out_fps: Optional[float] = None, # ← pass an fps here | |
| ): | |
| if not DECORD_AVAILABLE: | |
| raise ImportError("decord is required for video processing but is not available. Install with: pip install decord (Note: not available on macOS ARM)") | |
| vr = VideoReader(video_path) | |
| src_fps = vr.get_avg_fps() # native fps | |
| total = len(vr) | |
| if out_fps is None or out_fps >= src_fps: | |
| step = 1 # keep native rate or up-sample later | |
| else: | |
| target_duration = (num_frames - 1) / out_fps # duration in seconds | |
| frame_span = target_duration * src_fps # frames needed for this duration | |
| step = max(frame_span / (num_frames - 1), 1) | |
| max_start = total - step * (num_frames - 1) | |
| if max_start <= 1: # video too short for requested clip | |
| indices = np.linspace(0, total - 1, num_frames, dtype=int) | |
| return vr.get_batch(indices.tolist()), indices | |
| max_start = int(np.floor(max_start - 1)) | |
| start = random.randint(0, max_start) if max_start > 0 else 0 | |
| idxs = [int(np.round(start + i * step)) for i in range(num_frames)] | |
| return vr.get_batch(idxs), idxs | |
| class InfiniteDataLoader(torch.utils.data.DataLoader): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| # Initialize an iterator over the dataset. | |
| self.dataset_iterator = super().__iter__() | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| try: | |
| batch = next(self.dataset_iterator) | |
| except StopIteration: | |
| # Dataset exhausted, use a new fresh iterator. | |
| print('Another Loop over the dataset', flush=True) | |
| self.dataset_iterator = super().__iter__() | |
| batch = next(self.dataset_iterator) | |
| return batch | |
| class DataLoaderWrapper(InfiniteDataLoader): | |
| def __iter__(self): | |
| return IterWrapper(super().__iter__()) | |
| class IterWrapper: | |
| def __init__(self, obj): | |
| self.obj = obj | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| return self.next() | |
| def next(self): | |
| return next(self.obj) | |
| # ==== Dataset Implementation, Load your own data ==== # | |
| class ImageTarDataset(Dataset): | |
| def __init__(self, dataset_tsv, image_size, temporal_size=None, rank=0, world_size=1, | |
| use_image_bucket=False, multiple=8, no_flip=False, edit=False): | |
| all_lines = [] | |
| # get all data lines | |
| self.buckets = {} | |
| self.weights = {} | |
| self.image_buckets = defaultdict(lambda: 0) | |
| self.image_buckets['1:1'] = 0 # default bucket | |
| skipped = 0 | |
| for line in tqdm.tqdm(read_tsv(dataset_tsv)[1:]): | |
| tsv_file = line[0] | |
| bucket = line[1] if len(line) > 1 else 'mlx' | |
| caption = line[2] if len(line) > 2 else 'caption' | |
| weights = float(line[3] if len(line) > 3 else "1") | |
| all_data = read_tsv(tsv_file) | |
| all_maps = {all_data[0][i]: i for i in range(len(all_data[0]))} | |
| self.weights[all_data[1][0]] = weights | |
| for line in all_data[1:]: | |
| try: | |
| if 'width' in all_maps: # filter too small images | |
| width, height = int(line[all_maps['width']]), int(line[all_maps['height']]) | |
| if width * height < (image_size * image_size) / 2: # if image is smaller than half size of the target size | |
| skipped += 1; continue | |
| if caption != 'folder': # input caption has higher priority | |
| captions = caption.split('|')[0].split(':') | |
| operation = caption.split('|')[1] if len(caption.split('|')) > 1 else "none" | |
| caption_line = ([line[all_maps[c]] for c in captions], operation) | |
| else: | |
| caption_line = (line[all_maps['file']].split('/')[-2], "none") # use folder name as caption | |
| items = {'tar': line[all_maps['tar']], 'file': line[all_maps['file']], 'caption': caption_line, | |
| 'image_bucket': line[all_maps['image_bucket']] if 'image_bucket' in all_maps else "1:1"} | |
| if "camera_file" in all_maps: # dl3dv data | |
| items["camera_file"] = line[all_maps["camera_file"]] | |
| if "force_caption" in all_maps: # force dataset | |
| items["force_caption"] = line[all_maps["force_caption"]] | |
| if "wind_speed" in all_maps: # wind force | |
| items["wind_speed"] = line[all_maps["wind_speed"]] | |
| items["wind_angle"] = line[all_maps["wind_angle"]] | |
| elif "force" in all_maps: # point-wise | |
| items["force"] = line[all_maps["force"]] | |
| items["angle"] = line[all_maps["angle"]] | |
| items["coordx"] = line[all_maps["coordx"]] | |
| items["coordy"] = line[all_maps["coordy"]] | |
| if edit: | |
| if line[all_maps['visual_file']] != 'none': continue # TODO: for now, we only support one image, no visual clue | |
| items['edit_instruction'] = line[all_maps['edit_instruction']] | |
| items['edited_file'] = line[all_maps['edited_file']] | |
| all_lines.append(items) | |
| except Exception as e: | |
| skipped += 1; continue | |
| image_bucket = all_lines[-1]['image_bucket'] | |
| self.image_buckets[image_bucket] += 1 | |
| if all_lines[-1]['tar'] not in self.buckets: | |
| self.buckets[all_lines[-1]['tar']] = bucket | |
| if "force_caption" in all_lines[0]: | |
| wind_forces = [l["wind_speed"] for l in all_lines] if "wind_speed" in all_lines[0] else [l["force"] for l in all_lines] | |
| self.min_wind_force = min(wind_forces) | |
| self.max_wind_force = max(wind_forces) | |
| self.use_image_bucket = use_image_bucket | |
| self.all_lines = all_lines[rank:][::world_size] # all lines is sorted by tar file | |
| self.num_samples_per_rank = None | |
| self.image_size = image_size | |
| self.multiple = multiple | |
| self.temporal_size = tuple(map(int, temporal_size.split(':'))) if isinstance(temporal_size, str) else None | |
| self.edit_mode = edit | |
| def center_crop_resize(img, ratio="1:1", target_size: int = 256, multiple: int = 8): | |
| """ | |
| 1. Center crop `img` to the largest window with aspect ratio = ratio. | |
| 2. Resize so HxW ≈ target_size² (each side a multiple of `multiple`). | |
| Args | |
| ---- | |
| img : PIL Image or torch tensor (CHW/HWC) | |
| ratio : "3:2", (3,2), "1:1", etc. | |
| target_size : reference side length (area = target_size²) | |
| multiple : force each output side to be a multiple of this number | |
| """ | |
| # --- parse ratio ---------------------------------------------------------- | |
| if isinstance(ratio, str): | |
| rw, rh = map(int, ratio.split(':')) | |
| else: # already a tuple/list | |
| rw, rh = ratio | |
| R = rw / rh # width / height | |
| # --- crop to that aspect ratio ------------------------------------------- | |
| w, h = img.size if hasattr(img, "size") else (img.shape[-1], img.shape[-2]) | |
| if w / h > R: # image too wide → trim width | |
| crop_h, crop_w = h, int(round(h * R)) | |
| else: # image too tall → trim height | |
| crop_w, crop_h = w, int(round(w / R)) | |
| img = transforms.functional.center_crop(img, (crop_h, crop_w)) | |
| # --- compute output dimensions ------------------------------------------- | |
| area = target_size ** 2 | |
| out_h = _nearest_multiple(math.sqrt(area / R), multiple) | |
| out_w = _nearest_multiple(math.sqrt(area * R), multiple) | |
| # --- resize & return ------------------------------------------------------ | |
| return transforms.functional.resize(img, (out_h, out_w), antialias=True) | |
| self.transforms = {} | |
| self.size_bucket_maps = {} | |
| self.bucket_size_maps = {} | |
| for bucket in self.image_buckets: | |
| trans = [transforms.Lambda(lambda img, r=bucket: center_crop_resize(img, ratio=r, target_size=image_size, multiple=multiple))] | |
| if not no_flip: | |
| trans.append(transforms.RandomHorizontalFlip()) | |
| trans.extend([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) | |
| self.transforms[bucket] = transforms.Compose(trans) | |
| w, h = map(int, bucket.split(':')) | |
| out_h, out_w = aspect_ratio_to_image_size(image_size, w / h, multiple=multiple) | |
| self.size_bucket_maps[(out_h, out_w)] = bucket | |
| self.bucket_size_maps[bucket] = (out_h, out_w) | |
| self.transform = self.transforms['1:1'] # default transform | |
| print(f"Rank0 -- Loading {len(self.all_lines)} lines of data | {skipped} lines are skipped due to size or error") | |
| def __len__(self): | |
| if self.num_samples_per_rank is not None: | |
| return self.num_samples_per_rank | |
| return len(self.all_lines) | |
| def __getitem__(self, idx): | |
| image_item = self.all_lines[idx] | |
| tar_file = image_item['tar'] | |
| img_file = image_item['file'] | |
| img_bucket = image_item['image_bucket'] | |
| try: | |
| with tarfile.open(tar_file, mode='r') as tar: | |
| img = self._read_image(tar, img_file, img_bucket) | |
| H0, W0 = img.size | |
| scale = self.image_size / min(H0, W0) | |
| state = np.array([scale, H0, W0]) | |
| except Exception as e: | |
| print(f'Reading data error {e}') | |
| sample = image_item.copy() | |
| sample.update(image=img, state=state) | |
| return sample | |
| def _read_image(self, tar, img_file, img_bucket): | |
| def _transform(img): | |
| if not self.use_image_bucket: | |
| return self.transform(img) | |
| else: | |
| return self.transforms[img_bucket](img) | |
| x_shape = aspect_ratio_to_image_size(self.image_size, img_bucket, multiple=self.multiple) | |
| if self.temporal_size is not None: # read video | |
| num_frames, out_fps = self.temporal_size[0], self.temporal_size[1:] | |
| if len(out_fps) == 1: | |
| out_fps = out_fps[0] | |
| else: | |
| out_fps = random.choice(out_fps) # randomly choose one fps from the list | |
| assert img_file.endswith('.mp4'), "Only support mp4 video for now" | |
| try: | |
| with tar.extractfile(img_file) as video_data: | |
| with ram_temp_file(video_data.read()) as tmp_path: | |
| frames, frame_inds = sample_clip(tmp_path, num_frames=num_frames, out_fps=out_fps) | |
| frames = frames.asnumpy() | |
| except Exception as e: | |
| print(f'Reading data error {e} {img_file}') | |
| frames = np.zeros((num_frames, x_shape[0], x_shape[1], 3), dtype=np.uint8) | |
| return torch.stack([_transform(Image.fromarray(frame)) for frame in frames]), out_fps, frame_inds | |
| try: | |
| original_img = Image.open(tar.extractfile(img_file)).convert('RGB') | |
| except Exception as e: | |
| print(f'Reading data error {e} {img_file}') | |
| original_img = Image.new('RGB', (x_shape[0], x_shape[1]), (0, 0, 0)) | |
| return _transform(original_img), 0, None | |
| def collate_fn(self, batch): | |
| batch = default_collate(batch) | |
| return batch | |
| def get_batch_modes(self, x): | |
| x_aspect = self.size_bucket_maps.get(x.size()[-2:], "1:1") | |
| video_mode = self.temporal_size is not None | |
| return x_aspect, video_mode | |
| class OnlineImageTarDataset(ImageTarDataset): | |
| max_retry_n = 20 | |
| max_read = 4096 | |
| # tar_keys_lock will be initialized in __init__ to avoid import-time issues | |
| def __init__(self, dataset_tsv, image_size, batch_size=None, **kwargs): | |
| super().__init__(dataset_tsv, image_size, **kwargs) | |
| # Initialize manager lazily (only when this class is instantiated) | |
| manager = get_manager() | |
| # Use threading.Lock as fallback if multiprocessing manager unavailable | |
| if manager is not None: | |
| self.tar_keys_lock = manager.Lock() | |
| else: | |
| # Fallback to threading lock for single-process use | |
| self.tar_keys_lock = threading.Lock() | |
| self.tar_lists = defaultdict(lambda: []) | |
| self.tar_image_buckets = defaultdict(lambda: defaultdict(lambda: 0)) | |
| for i, line in enumerate(self.all_lines): | |
| tar_file = line['tar'] | |
| image_bucket = line['image_bucket'] | |
| self.tar_lists[tar_file] += [i] | |
| self.tar_image_buckets[tar_file][image_bucket] += 1 | |
| self.reset_tar_keys = [] | |
| for key in self.tar_lists.keys(): | |
| repeat = int(self.weights.get(key, 1)) | |
| self.reset_tar_keys.extend([key] * repeat) | |
| self.tar_keys = manager.list(self.reset_tar_keys) if manager is not None else list(self.reset_tar_keys) | |
| # Use more workers for better prefetching, but limit to reasonable number | |
| self.worker_executors = {} | |
| self.worker_caches = {} # each entry: {active:{tar,key,cnt,inner_idx}, prefetch:{future,key}} | |
| self.worker_caches_lock = threading.Lock() # Protect worker_caches access | |
| self.shuffle_everything() | |
| if self.use_image_bucket: | |
| assert batch_size, "batch_size should be set when use_image_bucket is True" | |
| self.batch_size = batch_size | |
| if self.temporal_size is not None: | |
| assert self.temporal_size[0] > 1, "temporal_size should be greater than 1 for video data" | |
| self.max_read = 512 | |
| def cleanup_worker_cache(self, wid): | |
| """Clean up worker cache entry and associated resources""" | |
| with self.worker_caches_lock: | |
| if wid in self.worker_caches: | |
| cache_entry = self.worker_caches[wid] | |
| # Cancel prefetch future if still running | |
| if 'prefetch' in cache_entry and hasattr(cache_entry['prefetch'], 'cancel'): | |
| cache_entry['prefetch'].cancel() | |
| if cache_entry.get('tar') is not None: | |
| tar = cache_entry['tar'] | |
| self._close_tar(tar) | |
| cache_entry['tar'] = None | |
| # Remove the entire cache entry | |
| del self.worker_caches[wid] | |
| gc.collect() | |
| def _s3(self): | |
| raise NotImplementedError("Please implement your own _s3() method to return a boto3 session/client") | |
| def shuffle_everything(self): | |
| for key in tqdm.tqdm(self.tar_keys): | |
| random.shuffle(self.tar_lists[key]) | |
| random.shuffle(self.tar_keys) | |
| print("shuffle everything done!") | |
| def download_tar(self, prefetch=True, wid=None): | |
| i = 0 | |
| file_stream = None | |
| tar_file = None | |
| download = f'prefetch {wid}' if prefetch else 'just download' | |
| while True: | |
| if i % self.max_retry_n == 0: # retry a different tar file | |
| tar_file = self._get_next_key() # get the next tar file key | |
| file_stream = None | |
| try: | |
| file_stream = io.BytesIO() | |
| self._s3().download_fileobj(self.buckets[tar_file], tar_file, file_stream) # hard-coded | |
| file_stream.seek(0) | |
| tar = tarfile.open(fileobj=file_stream, mode='r') | |
| # Store the file_stream reference so it can be closed later | |
| tar._file_stream = file_stream | |
| xprint(f'[INFO] {download} tar file: {tar_file}') | |
| return tar, tar_file | |
| except Exception as e: | |
| xprint(f'[ERROR] {download} tar file {tar_file} failed: {e}') | |
| i += 1 | |
| if file_stream: | |
| file_stream.close() | |
| file_stream = None | |
| time.sleep(min(i * 0.1, 5)) # Exponential backoff with cap | |
| def _get_next_key(self): | |
| with self.tar_keys_lock: | |
| if not self.tar_keys or len(self.tar_keys) == 0: | |
| xprint(f'[WARN] all dataset exhausted... this should not happen usually') | |
| self.tar_keys.extend(list(self.reset_tar_keys)) # reset | |
| random.shuffle(self.tar_keys) | |
| return self.tar_keys.pop(0) # remove and return the first key | |
| def _start_prefetch(self, wid): | |
| """Start prefetching the next tar file for the worker""" | |
| # Create executor per worker process if it doesn't exist | |
| if wid not in self.worker_executors: | |
| self.worker_executors[wid] = ThreadPoolExecutor(max_workers=1) | |
| future = self.worker_executors[wid].submit(self.download_tar, prefetch=True, wid=wid) # download tar file in a separate thread | |
| self.worker_caches[wid]['prefetch'] = future | |
| def _close_tar(self, tar): | |
| # Properly close both tar and underlying file stream | |
| if hasattr(tar, '_file_stream') and tar._file_stream: | |
| tar._file_stream.close() | |
| tar._file_stream = None | |
| tar.close() | |
| del tar | |
| gc.collect() | |
| def __getitem__(self, idx): | |
| try: | |
| wid = get_worker_info().id | |
| except Exception as e: | |
| wid = -1 | |
| # ─── first time this worker is used ─── # | |
| if wid not in self.worker_caches: | |
| tar, key = self.download_tar(prefetch=False) # download tar file | |
| with self.worker_caches_lock: | |
| self.worker_caches[wid] = dict( | |
| active=dict(tar=tar, key=key, cnt=0, inner_idx=0), # active cache | |
| ) | |
| self._start_prefetch(wid) # start prefetching the next tar file | |
| cache = self.worker_caches[wid] | |
| active = cache['active'] | |
| tar = active['tar'] | |
| key = active['key'] | |
| cnt = active['cnt'] | |
| inner_idx = active['inner_idx'] | |
| # handle image bucketting | |
| if self.use_image_bucket: | |
| if inner_idx % self.batch_size == 0: | |
| # sample based on local tar file statistics in case some dataset only has one image bucket | |
| tar_buckets = self.tar_image_buckets[key] | |
| target_image_bucket = random.choices( | |
| list(tar_buckets.keys()), weights=list(tar_buckets.values()), k=1)[0] | |
| self.worker_caches[wid]['target_image_bucket'] = target_image_bucket | |
| # scan the list to find the nearest target image bucket | |
| target_image_bucket, t_cnt = self.worker_caches[wid]['target_image_bucket'], cnt | |
| while self.all_lines[self.tar_lists[key][t_cnt]]['image_bucket'] != target_image_bucket: | |
| t_cnt += 1 | |
| if t_cnt >= len(self.tar_lists[key]): t_cnt = 0 | |
| # sawp the image location | |
| if cnt != t_cnt: | |
| self.tar_lists[key][cnt], self.tar_lists[key][t_cnt] = self.tar_lists[key][t_cnt], self.tar_lists[key][cnt] | |
| img_id = self.tar_lists[key][cnt] | |
| image_item = self.all_lines[img_id] | |
| sample = {key: image_item[key] for key in image_item} | |
| image, fps, frame_inds = self._read_image(tar, image_item['file'], image_item['image_bucket']) | |
| sample.update(image=image, fps=fps, local_idx=img_id, inner_idx=inner_idx) | |
| if self.edit_mode: | |
| image, fps, _ = self._read_image(tar, image_item['edited_file'], image_item['image_bucket']) | |
| sample.update(edited_image=image, fps=fps, edit_instruction=image_item['edit_instruction']) | |
| if "camera_file" in image_item: # dl3dv data | |
| sample["condition"] = get_camera_condition(tar, image_item["camera_file"], width=image.shape[3], height=image.shape[2], factor=self.multiple, frame_inds=frame_inds) | |
| if "force_caption" in image_item: # force dataset | |
| if "wind_speed" in image_item: # wind force | |
| sample["condition"] = get_wind_condition(image_item["wind_speed"], image_item["wind_angle"], min_force=self.min_wind_force, max_force=self.max_wind_force, num_frames=image.shape[1], width=image.shape[3], height=image.shape[2]) | |
| elif "force" in image_item: # point-wise | |
| sample["condition"] = get_point_condition(image_item["force"], image_item["angle"], image_item["coordx"], image_item["coordy"], min_force=self.min_wind_force, max_force=self.max_wind_force, num_frames=image.shape[1], width=image.shape[3], height=image.shape[2]) | |
| # update cnt | |
| cnt, inner_idx = cnt + 1, inner_idx + 1 | |
| if (cnt == len(self.tar_lists[key])) or (cnt == self.max_read): | |
| # -- active tar finished, switch to prefetched tar -- # | |
| self._close_tar(tar) # close the current tar file | |
| try: | |
| # Wait for prefetch with timeout | |
| new_tar, new_key = cache['prefetch'].result() # 5 minute timeout | |
| except Exception as e: | |
| xprint(f'[WARN] Prefetch failed, downloading new tar synchronously: {e}') | |
| new_tar, new_key = self.download_tar(prefetch=False) | |
| cache['active'] = dict(tar=new_tar, key=new_key, cnt=0, inner_idx=inner_idx) # update active cache | |
| # shuffle the image list | |
| random.shuffle(self.tar_lists[key]) # shuffle the list | |
| with self.tar_keys_lock: | |
| self.tar_keys.append(key) # return the key to the list so other workers can use it | |
| self._start_prefetch(wid) # start prefetching the next tar file | |
| else: | |
| cache['active']['cnt'] = cnt | |
| # always update inner_idx (IMPORTANT) | |
| cache['active']['inner_idx'] = inner_idx | |
| return sample | |
| class OnlineImageCaptionDataset(OnlineImageTarDataset): | |
| def __getitem__(self, idx): | |
| sample = super().__getitem__(idx) | |
| captions, caption_op = sample['caption'] | |
| if caption_op == 'none': | |
| sample['caption'] = captions[0] if isinstance(captions, list) else captions | |
| elif ':' in caption_op: | |
| sample['caption'] = random.choices(captions, weights=[float(a) for a in caption_op.split(':')])[0] | |
| else: | |
| raise NotImplementedError(f"Unknown caption operation: {caption_op}") | |
| return sample | |
| def collate_fn(self, batch): | |
| batch = super().collate_fn(batch) | |
| image = batch['image'] | |
| caption = batch['caption'] | |
| if self.edit_mode: | |
| image = torch.cat([image, batch['edited_image']], dim=0) | |
| caption.extend(batch['edit_instruction']) | |
| meta = {key: batch[key] for key in batch if key not in | |
| ['image', 'caption', 'edited_image', 'edit_instruction']} | |
| return image, caption, meta | |
| # ==== Dummy Dataset Implementation for Open Source Release ==== | |
| class DummyImageCaptionDataset(Dataset): | |
| """ | |
| Dummy dataset that generates synthetic image-caption pairs for training/testing. | |
| Supports mixed aspect ratios and batch-wise aspect ratio consistency. | |
| """ | |
| def __init__( | |
| self, | |
| num_samples: int = 10000, | |
| image_size: int = 256, | |
| temporal_size: Optional[str] = None, | |
| use_image_bucket: bool = False, | |
| batch_size: Optional[int] = None, | |
| multiple: int = 8, | |
| no_flip: bool = False, | |
| edit: bool = False | |
| ): | |
| """ | |
| Args: | |
| num_samples: Number of samples in the dataset | |
| image_size: Base image size for generation | |
| temporal_size: Video size specification (e.g., "16:8" for frames:fps) | |
| use_image_bucket: Whether to use aspect ratio bucketing | |
| batch_size: Batch size for bucketing (required if use_image_bucket=True) | |
| multiple: Multiple for dimension rounding | |
| no_flip: Whether to disable horizontal flipping | |
| edit: Whether this is an editing dataset | |
| """ | |
| self.num_samples = num_samples | |
| self.image_size = image_size | |
| self.temporal_size = temporal_size | |
| self.use_image_bucket = use_image_bucket | |
| self.batch_size = batch_size | |
| self.multiple = multiple | |
| self.no_flip = no_flip | |
| self.edit_mode = edit | |
| # Parse video parameters | |
| self.is_video = temporal_size is not None | |
| if self.is_video: | |
| frames, fps = map(int, temporal_size.split(':')) | |
| self.num_frames = frames | |
| self.fps = fps | |
| else: | |
| self.num_frames = 1 | |
| self.fps = None | |
| # Aspect ratios for mixed aspect ratio training | |
| self.aspect_ratios = [ | |
| "1:1", "2:3", "3:2", "16:9", "9:16", | |
| "4:5", "5:4", "21:9", "9:21" | |
| ] if use_image_bucket else ["1:1"] | |
| # Generate image buckets for aspect ratios | |
| self.image_buckets = {} | |
| for i, ar in enumerate(self.aspect_ratios): | |
| h, w = aspect_ratio_to_image_size(image_size, ar, multiple) | |
| self.image_buckets[ar] = (h, w, ar) | |
| # Sample captions for dummy data | |
| self.sample_captions = [ | |
| "A beautiful landscape with mountains and trees", | |
| "A cute cat sitting on a wooden table", | |
| "A modern city skyline at sunset", | |
| "A vintage car parked on a street", | |
| "A delicious meal on a white plate", | |
| "A person walking in a park", | |
| "A colorful flower garden in bloom", | |
| "A cozy living room with furniture", | |
| "A stormy ocean with large waves", | |
| "A peaceful forest path in autumn", | |
| "A group of friends laughing together", | |
| "A majestic eagle flying in the sky", | |
| "A busy marketplace with vendors", | |
| "A snow-covered mountain peak", | |
| "A child playing with toys", | |
| "A romantic candlelit dinner", | |
| "A train traveling through countryside", | |
| "A lighthouse on a rocky coast", | |
| "A field of sunflowers under blue sky", | |
| "A family having a picnic outdoors" | |
| ] | |
| # Create transform pipeline | |
| def center_crop_resize(img, ratio="1:1", target_size: int = 256, multiple: int = 8): | |
| """ | |
| 1. Center crop `img` to the largest window with aspect ratio = ratio. | |
| 2. Resize so HxW ≈ target_size² (each side a multiple of `multiple`). | |
| Args | |
| ---- | |
| img : PIL Image or torch tensor (CHW/HWC) | |
| ratio : "3:2", (3,2), "1:1", etc. | |
| target_size : reference side length (area = target_size²) | |
| multiple : force each output side to be a multiple of this number | |
| """ | |
| # --- parse ratio ---------------------------------------------------------- | |
| if isinstance(ratio, str): | |
| rw, rh = map(int, ratio.split(':')) | |
| else: # already a tuple/list | |
| rw, rh = ratio | |
| R = rw / rh # width / height | |
| # --- crop to that aspect ratio ------------------------------------------- | |
| w, h = img.size if hasattr(img, "size") else (img.shape[-1], img.shape[-2]) | |
| if w / h > R: # image too wide → trim width | |
| crop_h, crop_w = h, int(round(h * R)) | |
| else: # image too tall → trim height | |
| crop_w, crop_h = w, int(round(w / R)) | |
| img = transforms.functional.center_crop(img, (crop_h, crop_w)) | |
| # --- compute output dimensions ------------------------------------------- | |
| area = target_size ** 2 | |
| out_h = _nearest_multiple(math.sqrt(area / R), multiple) | |
| out_w = _nearest_multiple(math.sqrt(area * R), multiple) | |
| # --- resize & return ------------------------------------------------------ | |
| return transforms.functional.resize(img, (out_h, out_w), antialias=True) | |
| self.transforms = {} | |
| self.size_bucket_maps = {} | |
| self.bucket_size_maps = {} | |
| for bucket in self.image_buckets: | |
| trans = [transforms.Lambda(lambda img, r=bucket: center_crop_resize(img, ratio=r, target_size=image_size, multiple=multiple))] | |
| if not no_flip: | |
| trans.append(transforms.RandomHorizontalFlip()) | |
| trans.extend([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) | |
| self.transforms[bucket] = transforms.Compose(trans) | |
| w, h = map(int, bucket.split(':')) | |
| out_h, out_w = aspect_ratio_to_image_size(image_size, w / h, multiple=multiple) | |
| self.size_bucket_maps[(out_h, out_w)] = bucket | |
| self.bucket_size_maps[bucket] = (out_h, out_w) | |
| self.transform = self.transforms['1:1'] # default transform | |
| def __len__(self) -> int: | |
| return self.num_samples | |
| def __getitem__(self, idx: int) -> dict: | |
| """Get a single sample from the dataset.""" | |
| # Choose aspect ratio | |
| if self.use_image_bucket: | |
| bucket_name = random.choice(list(self.image_buckets.keys())) | |
| h, w, aspect_ratio = self.image_buckets[bucket_name] | |
| else: | |
| h, w, aspect_ratio = self.image_size, self.image_size, "1:1" | |
| bucket_name = aspect_ratio | |
| # Generate dummy image | |
| if self.is_video: | |
| # Generate video tensor (T, C, H, W) | |
| image = torch.randn(self.num_frames, 3, h, w) | |
| # Normalize to [-1, 1] range | |
| image = torch.tanh(image) | |
| else: | |
| # Generate RGB image | |
| image = Image.new('RGB', (w, h), color=( | |
| random.randint(50, 200), | |
| random.randint(50, 200), | |
| random.randint(50, 200) | |
| )) | |
| # Add some random patterns for variety | |
| if random.random() > 0.5: | |
| # Add gradient | |
| pixels = [] | |
| for y in range(h): | |
| for x in range(w): | |
| r = int(255 * x / w) | |
| g = int(255 * y / h) | |
| b = int(255 * (x + y) / (w + h)) | |
| pixels.append((r, g, b)) | |
| image.putdata(pixels) | |
| image = self.transform(image) | |
| # Generate caption | |
| caption = random.choice(self.sample_captions) | |
| # Add some variation to captions | |
| if random.random() > 0.7: | |
| adjectives = ["beautiful", "stunning", "amazing", "incredible", "magnificent"] | |
| caption = f"{random.choice(adjectives)} {caption.lower()}" | |
| sample = { | |
| 'image': image, | |
| 'caption': caption, | |
| 'image_bucket': bucket_name, | |
| 'aspect_ratio': aspect_ratio, | |
| 'idx': idx | |
| } | |
| # Add video-specific metadata | |
| if self.is_video: | |
| sample.update({ | |
| 'num_frames': self.num_frames, | |
| 'fps': self.fps, | |
| 'temporal_size': self.temporal_size | |
| }) | |
| # Add editing data if needed | |
| if self.edit_mode: | |
| # Generate slightly modified image for editing tasks | |
| edited_image = image + torch.randn_like(image) * 0.1 | |
| edited_image = torch.clamp(edited_image, -1, 1) | |
| sample.update({ | |
| 'edited_image': edited_image, | |
| 'edit_instruction': f"Edit this image to make it more {random.choice(['colorful', 'bright', 'artistic', 'realistic'])}" | |
| }) | |
| return sample | |
| def collate_fn(self, batch: list) -> tuple: | |
| """Collate function for batching samples.""" | |
| # Group by aspect ratio if using image buckets | |
| if self.use_image_bucket: | |
| # Sort batch by image bucket for consistency | |
| batch = sorted(batch, key=lambda x: x['image_bucket']) | |
| # Standard collation | |
| collated = {} | |
| images = torch.stack([item['image'] for item in batch], dim=0) | |
| captions = [item['caption'] for item in batch] | |
| # Collect metadata | |
| for key in ['image_bucket', 'aspect_ratio', 'idx']: | |
| if key in batch[0]: | |
| collated[key] = [item[key] for item in batch] | |
| # Handle video metadata | |
| if self.is_video: | |
| for key in ['num_frames', 'fps', 'temporal_size']: | |
| if key in batch[0]: | |
| collated[key] = [item[key] for item in batch] | |
| # Handle editing data | |
| if self.edit_mode and 'edited_image' in batch[0]: | |
| edited_images = torch.stack([item['edited_image'] for item in batch], dim=0) | |
| collated['edited_image'] = edited_images | |
| collated['edit_instruction'] = [item['edit_instruction'] for item in batch] | |
| return images, captions, collated | |
| def get_batch_modes(self, x): | |
| x_aspect = self.size_bucket_maps.get(x.size()[-2:], "1:1") | |
| video_mode = self.temporal_size is not None | |
| return x_aspect, video_mode | |
| class DummyDataLoaderWrapper: | |
| """ | |
| Wrapper that mimics the DataLoaderWrapper functionality. | |
| Provides infinite iteration over the dataset. | |
| """ | |
| def __init__(self, dataset, batch_size=1, num_workers=0, **kwargs): | |
| self.dataset = dataset | |
| self.batch_size = batch_size | |
| self.dataloader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| num_workers=num_workers, | |
| collate_fn=dataset.collate_fn, | |
| shuffle=True, | |
| drop_last=True, | |
| **kwargs | |
| ) | |
| self.iterator = None | |
| self.secondary_loader = None | |
| def __iter__(self): | |
| """Infinite iteration over the dataset.""" | |
| while True: | |
| if self.iterator is None: | |
| self.iterator = iter(self.dataloader) | |
| try: | |
| yield next(self.iterator) | |
| except StopIteration: | |
| self.iterator = iter(self.dataloader) | |
| yield next(self.iterator) | |
| def __len__(self): | |
| return len(self.dataloader) | |
| def create_dummy_dataloader( | |
| dataset_name: str, | |
| img_size: int, | |
| vid_size: Optional[str] = None, | |
| batch_size: int = 16, | |
| use_mixed_aspect: bool = False, | |
| multiple: int = 8, | |
| num_samples: int = 10000, | |
| infinite: bool = False | |
| ) -> Union[DataLoader, DummyDataLoaderWrapper]: | |
| """ | |
| Create a dummy dataloader that mimics the original functionality. | |
| Args: | |
| dataset_name: Name of the dataset (used for deterministic seeding) | |
| img_size: Base image size | |
| vid_size: Video specification (e.g., "16:8") | |
| batch_size: Batch size | |
| use_mixed_aspect: Whether to use mixed aspect ratio training | |
| multiple: Multiple for dimension rounding | |
| num_samples: Number of samples in the dataset | |
| infinite: Whether to create infinite dataloader | |
| Returns: | |
| DataLoader or DummyDataLoaderWrapper | |
| """ | |
| # Set seed based on dataset name for reproducibility | |
| seed = hash(dataset_name) % (2**32 - 1) | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| # Create dataset | |
| dataset = DummyImageCaptionDataset( | |
| num_samples=num_samples, | |
| image_size=img_size, | |
| temporal_size=vid_size, | |
| use_image_bucket=use_mixed_aspect, | |
| batch_size=batch_size, | |
| multiple=multiple, | |
| edit='edit' in dataset_name.lower() | |
| ) | |
| # Set dataset attributes expected by training code | |
| dataset.total_num_samples = num_samples | |
| dataset.num_samples_per_rank = num_samples | |
| # Create dataloader | |
| if infinite: | |
| return DummyDataLoaderWrapper( | |
| dataset, | |
| batch_size=batch_size, | |
| num_workers=2, | |
| pin_memory=True, | |
| drop_last=True, | |
| persistent_workers=True | |
| ) | |
| else: | |
| return DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| num_workers=2, | |
| pin_memory=True, | |
| drop_last=True, | |
| shuffle=True, | |
| collate_fn=dataset.collate_fn, | |
| persistent_workers=True | |
| ) |