Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tool]ImgDataset2WebDatasetMS #130

Merged
merged 4 commits into from
Jan 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ where each line of [`asset/samples_mini.txt`](asset/samples_mini.txt) contains a

- 32GB VRAM is required for both 0.6B and 1.6B model's training

### 1). Train with image-text pairs in directory

We provide a training example here and you can also select your desired config file from [config files dir](configs/sana_config) based on your data structure.

To launch Sana training, you will first need to prepare data in the following formats. [Here](asset/example_data) is an example for the data structure for reference.
Expand Down Expand Up @@ -310,6 +312,26 @@ bash train_scripts/train.sh \
--train.train_batch_size=8
```

### 2). Train with image-text pairs in directory

We also provide conversion scripts to convert your data to the required format. You can refer to the [data conversion scripts](asset/data_conversion_scripts) for more details.

```bash
python tools/convert_ImgDataset_to_WebDatasetMS_format.py
```

Then Sana's training can be launched via

```bash
# Example of training Sana 0.6B with 512x512 resolution from scratch
bash train_scripts/train.sh \
configs/sana_config/512ms/Sana_600M_img512.yaml \
--data.data_dir="[asset/example_data_tar]" \
--data.type=SanaWebDatasetMS \
--model.multi_scale=true \
--train.train_batch_size=32
```

# 💻 4. Metric toolkit

Refer to [Toolkit Manual](asset/docs/metrics_toolkit.md).
Expand Down
72 changes: 72 additions & 0 deletions tools/convert_ImgDataset_to_WebDatasetMS_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# @Author: Pevernow ([email protected])
# @Date: 2025/1/5
# @License: (Follow the main project)
import json
import os
import tarfile

from PIL import Image, PngImagePlugin

PngImagePlugin.MAX_TEXT_CHUNK = 100 * 1024 * 1024 # Increase maximum size for text chunks


def process_data(input_dir, output_tar_name="output.tar"):
"""
Processes a directory containing PNG files, generates corresponding JSON files,
and packages all files into a TAR file. It also counts the number of processed PNG images,
and saves the height and width of each PNG file to the JSON.

Args:
input_dir (str): The input directory containing PNG files.
output_tar_name (str): The name of the output TAR file (default is "output.tar").
"""
png_count = 0
json_files_created = []

for filename in os.listdir(input_dir):
if filename.lower().endswith(".png"):
png_count += 1
base_name = filename[:-4] # Remove the ".png" extension
txt_filename = os.path.join(input_dir, base_name + ".txt")
json_filename = base_name + ".json"
json_filepath = os.path.join(input_dir, json_filename)
png_filepath = os.path.join(input_dir, filename)

if os.path.exists(txt_filename):
try:
# Get the dimensions of the PNG image
with Image.open(png_filepath) as img:
width, height = img.size

with open(txt_filename, encoding="utf-8") as f:
caption_content = f.read().strip()

data = {"file_name": filename, "prompt": caption_content, "width": width, "height": height}

with open(json_filepath, "w", encoding="utf-8") as outfile:
json.dump(data, outfile, indent=4, ensure_ascii=False)

print(f"Generated: {json_filename}")
json_files_created.append(json_filepath)

except Exception as e:
print(f"Error processing file {filename}: {e}")
else:
print(f"Warning: No corresponding TXT file found for {filename}.")

# Create a TAR file and include all files
with tarfile.open(output_tar_name, "w") as tar:
for item in os.listdir(input_dir):
item_path = os.path.join(input_dir, item)
tar.add(item_path, arcname=item) # arcname maintains the relative path of the file in the tar

print(f"\nAll files have been packaged into: {output_tar_name}")
print(f"Number of PNG images processed: {png_count}")


if __name__ == "__main__":
input_directory = input("Please enter the directory path containing PNG and TXT files: ")
output_tar_filename = (
input("Please enter the name of the output TAR file (default is output.tar): ") or "output.tar"
)
process_data(input_directory, output_tar_filename)
Loading