Update managers/ltx_manager.py
Browse files- managers/ltx_manager.py +2 -2
managers/ltx_manager.py
CHANGED
|
@@ -139,7 +139,7 @@ class LtxPoolManager:
|
|
| 139 |
def _prepare_pipeline_params(self, worker: 'LtxWorker', **kwargs) -> dict:
|
| 140 |
pipeline_params = {
|
| 141 |
"height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
|
| 142 |
-
"
|
| 143 |
"frame_rate": kwargs.get('video_fps', 24),
|
| 144 |
"generator": torch.Generator(device=worker.device).manual_seed(int(time.time()) + kwargs.get('current_fragment_index', 0)),
|
| 145 |
"is_video": True, "vae_per_channel_normalize": True,
|
|
@@ -172,7 +172,7 @@ class LtxPoolManager:
|
|
| 172 |
height, width = kwargs['height'], kwargs['width']
|
| 173 |
padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
|
| 174 |
padding_vals = calculate_padding(height, width, padded_h, padded_w)
|
| 175 |
-
kwargs['
|
| 176 |
kwargs['height'], kwargs['width'] = padded_h, padded_w
|
| 177 |
pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
|
| 178 |
logger.info(f"Initiating GENERATION on {worker_to_use.device} with shape {padded_w}x{padded_h}")
|
|
|
|
| 139 |
def _prepare_pipeline_params(self, worker: 'LtxWorker', **kwargs) -> dict:
|
| 140 |
pipeline_params = {
|
| 141 |
"height": kwargs['height'], "width": kwargs['width'], "num_frames": kwargs['video_total_frames'],
|
| 142 |
+
"callback_on_step_end" : kwargs['callback_on_step_end']
|
| 143 |
"frame_rate": kwargs.get('video_fps', 24),
|
| 144 |
"generator": torch.Generator(device=worker.device).manual_seed(int(time.time()) + kwargs.get('current_fragment_index', 0)),
|
| 145 |
"is_video": True, "vae_per_channel_normalize": True,
|
|
|
|
| 172 |
height, width = kwargs['height'], kwargs['width']
|
| 173 |
padded_h, padded_w = ((height - 1) // 32 + 1) * 32, ((width - 1) // 32 + 1) * 32
|
| 174 |
padding_vals = calculate_padding(height, width, padded_h, padded_w)
|
| 175 |
+
kwargs['callback_on_step_end'],
|
| 176 |
kwargs['height'], kwargs['width'] = padded_h, padded_w
|
| 177 |
pipeline_params = self._prepare_pipeline_params(worker_to_use, **kwargs)
|
| 178 |
logger.info(f"Initiating GENERATION on {worker_to_use.device} with shape {padded_w}x{padded_h}")
|