Skip to content

Commit 75d16a4

Browse files
authored
Fix multiprocessing segfault (#252)
1 parent f41036d commit 75d16a4

File tree

4 files changed

+15
-13
lines changed

4 files changed

+15
-13
lines changed

.github/workflows/UnitTests.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,13 @@ jobs:
5252
- name: Analysing the code with ruff
5353
run: |
5454
ruff check .
55+
- name: version check
56+
run: |
57+
python --version
58+
pip show jax jaxlib flax transformers datasets tensorflow tensorflow_datasets
5559
- name: PyTest
56-
run: |
57-
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py --deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py -x
60+
run: | #--deselect=src/maxdiffusion/tests/input_pipeline_interface_test.py
61+
HF_HUB_CACHE=/mnt/disks/github-runner-disk/ HF_HOME=/mnt/disks/github-runner-disk/ TOKENIZERS_PARALLELISM=false python3 -m pytest --deselect=src/maxdiffusion/tests/ltx_transformer_step_test.py -x
5862
# add_pull_ready:
5963
# if: github.ref != 'refs/heads/main'
6064
# permissions:

src/maxdiffusion/input_pipeline/_tfds_data_processing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from maxdiffusion import multihost_dataloading, max_logging
2323

2424
AUTOTUNE = tf.data.AUTOTUNE
25-
25+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2626

2727
def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_count):
2828
dataset = dataset.with_format("tensorflow")[:]
@@ -50,7 +50,7 @@ def make_tf_iterator(
5050
function=tokenize_fn,
5151
batched=True,
5252
remove_columns=[config.caption_column],
53-
num_proc=1 if config.cache_latents_text_encoder_outputs else config.tokenize_captions_num_proc,
53+
num_proc=None,
5454
desc="Running tokenizer on train dataset",
5555
)
5656
# need to do it before load_as_tf_dataset
@@ -60,7 +60,7 @@ def make_tf_iterator(
6060
function=image_transforms_fn,
6161
batched=True,
6262
remove_columns=[config.image_column],
63-
num_proc=1 if config.cache_latents_text_encoder_outputs else config.transform_images_num_proc,
63+
num_proc=None,
6464
desc="Transforming images",
6565
)
6666
if config.cache_latents_text_encoder_outputs:

src/maxdiffusion/input_pipeline/input_pipeline_interface.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from PIL import Image
4141

4242
AUTOTUNE = tf.data.experimental.AUTOTUNE
43-
43+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
4444

4545
def make_data_iterator(
4646
config,
@@ -159,7 +159,7 @@ def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, v
159159
function=tokenize_fn,
160160
batched=True,
161161
remove_columns=[INSTANCE_PROMPT_IDS],
162-
num_proc=1,
162+
num_proc=None,
163163
desc="Running tokenizer on instance dataset",
164164
)
165165
rng = jax.random.key(config.seed)
@@ -177,7 +177,7 @@ def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, v
177177
function=transform_images_fn,
178178
batched=True,
179179
remove_columns=[INSTANCE_IMAGES],
180-
num_proc=1,
180+
num_proc=None,
181181
desc="Running vae on instance dataset",
182182
)
183183

@@ -188,7 +188,7 @@ def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, v
188188
function=tokenize_fn,
189189
batched=True,
190190
remove_columns=[CLASS_PROMPT_IDS],
191-
num_proc=1,
191+
num_proc=None,
192192
desc="Running tokenizer on class dataset",
193193
)
194194
transform_images_fn = partial(
@@ -204,7 +204,7 @@ def make_dreambooth_train_iterator(config, mesh, global_batch_size, tokenizer, v
204204
function=transform_images_fn,
205205
batched=True,
206206
remove_columns=[CLASS_IMAGES],
207-
num_proc=1,
207+
num_proc=None,
208208
desc="Running vae on instance dataset",
209209
)
210210

src/maxdiffusion/tests/input_pipeline_interface_test.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
import subprocess
2222
import unittest
2323
from absl.testing import absltest
24-
2524
import numpy as np
26-
import pytest
2725
import tensorflow as tf
2826
import tensorflow.experimental.numpy as tnp
2927
import jax
@@ -70,7 +68,6 @@ class InputPipelineInterface(unittest.TestCase):
7068
def setUp(self):
7169
InputPipelineInterface.dummy_data = {}
7270

73-
@pytest.mark.skip(reason="Debug segfault")
7471
def test_make_dreambooth_train_iterator(self):
7572

7673
instance_class_gcs_dir = "gs://maxdiffusion-github-runner-test-assets/datasets/dreambooth/instance_class"
@@ -85,6 +82,7 @@ def test_make_dreambooth_train_iterator(self):
8582
os.path.join(THIS_DIR, "..", "configs", "base14.yml"),
8683
"cache_latents_text_encoder_outputs=True",
8784
"dataset_name=my_dreambooth_dataset",
85+
"transform_images_num_proc=1",
8886
f"instance_data_dir={instance_class_local_dir}",
8987
f"class_data_dir={class_class_local_dir}",
9088
"instance_prompt=photo of ohwx dog",

0 commit comments

Comments
 (0)