Skip to content

Commit

Permalink
feat: add ELLA Text Encode node
Browse files Browse the repository at this point in the history
  • Loading branch information
JettHu committed Apr 30, 2024
1 parent 25a4c3d commit 1c62890
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 11 deletions.
12 changes: 12 additions & 0 deletions NODES.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@
#### Outputs
- **conds**: (CONDITIONING), for `KSamplers`.

### ELLA Text Encode
> If clip is provided to this node, it will automatically concat ella condition and clip condition.
#### Inputs
- **ella**: (ELLA), ELLA model loaded by `Load ELLA Model` node. Need to use `Set ELLA Timesteps` first.
- **text_encoder**: (T5_TEXT_ENCODER)
- **text**: (STRING), prompt to encode.
#### Optional
- **clip**: (CLIP), clip model to encode text_clip.
- **text_clip**: (STRING), prompt to encode with clip.
#### Outputs
- **conds**: (CONDITIONING), for `KSamplers`.

### T5 Text Encode #ELLA

#### Inputs
Expand Down
12 changes: 11 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
[ComfyUI](https://github.com/comfyanonymous/ComfyUI) implementation for [ELLA](https://github.com/TencentQQGYLab/ELLA).

## :star2: Changelog

- **[2024.4.30]** Add a new node `ELLA Text Encode` to automatically concat ella and clip condition.
- **[2024.4.24]** Upgraded ELLA Apply method. Better compatibility with the comfyui ecosystem. Refer to the method mentioned in [ComfyUI_ELLA PR #25](https://github.com/ExponentialML/ComfyUI_ELLA/pull/25)
- **DEPRECATED**: `Apply ELLA` without `simgas` is deprecated and it will be removed in a future version.
- **[2024.4.22]** Fix unstable quality of image while multi-batch. Add CLIP concat (support lora trigger words now).
Expand Down Expand Up @@ -52,6 +52,16 @@ With the upgrade(2024.4.24), some interesting workflow can be implemented, such

However, there is no guarantee that positive-only will bring better results.

Workflow with [AYS](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/).

![workflow_ella_ays](./examples/workflow_ella_ays.png)

AYS got with more visual details and better text-alignment, ref to [paper](https://research.nvidia.com/labs/toronto-ai/AlignYourSteps/).

| w/ AYS | w/o AYS |
| :---: | :---: |
| ![](./assets/AYS_output.png) | ![](./assets/wo_AYS_output.png) |

And [EMMA](https://github.com/TencentQQGYLab/ELLA/issues/15) is working in progress.

## :green_book: Install
Expand Down
Binary file added assets/AYS_output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/wo_AYS_output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
83 changes: 73 additions & 10 deletions ella.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,64 @@ def encode(self, ella, embeds: dict, **kwargs):
return (conds,)


class EllaTextEncode:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"ella": (ELLA_TYPE,),
"text_encoder": ("T5_TEXT_ENCODER",),
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
},
"optional": {
"clip": ("CLIP", {"default": None}),
"text_clip": ("STRING", {"default":"", "multiline": True, "dynamicPrompts": True}),
},
}

RETURN_NAMES = ("CONDITIONING", "CLIP CONDITIONING")
RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
FUNCTION = "encode"

CATEGORY = "ella/conditioning"

def encode(self, ella, text_encoder, text, clip=None, text_clip="", **kwargs):
text_encoder_model = text_encoder["model"]
cond = text_encoder_model(text, max_length=None)
embeds = {}
embeds[f"{ELLA_EMBEDS_PREFIX}t5_embeds"] = cond

timesteps = ella.get("timesteps", None)
if timesteps is None:
raise ValueError("timesteps are required but not provided, use the 'Set ELLA Timesteps' node first.")
embeds = {k[ELLA_EMBEDS_PREFIX_LEN:]: v for k, v in embeds.items() if k.startswith(ELLA_EMBEDS_PREFIX)}
ella_conds = ella_encode(ella["model"], timesteps, embeds)

clip_conds = None
if clip is None and text_clip:
raise ValueError("text_clip needs a clip to encode")
if clip is not None:
tokens = clip.tokenize(text_clip)
cond, pooled = clip.encode_from_tokens(tokens, return_pooled=True)
clip_conds = [[cond, {"pooled_output": pooled}]]

if clip_conds is not None:
return (self.concat(ella_conds, clip_conds), clip_conds)

return (ella_conds, None)

def concat(self, conditioning_to, conditioning_from):
out = []
cond_from = conditioning_from[0][0]

for i in range(len(conditioning_to)):
t1 = conditioning_to[i][0]
tw = torch.cat((t1, cond_from),1)
n = [tw, conditioning_to[i][1].copy()]
out.append(n)

return out

"""
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Loaders
Expand Down Expand Up @@ -409,24 +467,27 @@ def INPUT_TYPES(cls):
"scheduler": (samplers.SCHEDULER_NAMES,),
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
}
},
"optional": {
"sigmas": ("SIGMAS", {"default": None}),
},
}

RETURN_TYPES = (ELLA_TYPE,)
CATEGORY = "ella/helper"

FUNCTION = "set_timesteps"

def set_timesteps(self, model, ella, scheduler, steps, denoise):
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
total_steps = int(steps / denoise)

def set_timesteps(self, model, ella, scheduler, steps, denoise, sigmas=None):
model_sampling = model.get_model_object("model_sampling")
sigmas = samplers.calculate_sigmas(model_sampling, scheduler, total_steps).cpu()
timesteps = model_sampling.timestep(sigmas[-(steps + 1) :])
if sigmas is None:
total_steps = steps
if denoise < 1.0:
if denoise <= 0.0:
return (torch.FloatTensor([]),)
total_steps = int(steps / denoise)
sigmas = samplers.calculate_sigmas(model_sampling, scheduler, total_steps).cpu()[-(steps + 1) :]
timesteps = model_sampling.timestep(sigmas)
return ({**ella, "timesteps": timesteps},)


Expand All @@ -440,6 +501,7 @@ def set_timesteps(self, model, ella, scheduler, steps, denoise):
"EllaApply": EllaApply,
"EllaEncode": EllaEncode,
"T5TextEncode #ELLA": T5TextEncode,
"EllaTextEncode": EllaTextEncode,
# Loaders
"ELLALoader": ELLALoader,
"T5TextEncoderLoader #ELLA": T5TextEncoderLoader,
Expand All @@ -456,6 +518,7 @@ def set_timesteps(self, model, ella, scheduler, steps, denoise):
"EllaApply": "Apply ELLA",
"EllaEncode": "ELLA Encode",
"T5TextEncode #ELLA": "T5 Text Encode #ELLA",
"EllaTextEncode": "ELLA Text Encode",
# Loaders
"ELLALoader": "Load ELLA Model",
"T5TextEncoderLoader #ELLA": "Load T5 TextEncoder #ELLA",
Expand Down
Binary file added examples/workflow_ella_ays.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 3 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch import nn

from .activations import get_activation
from .utils import remove_weights


class AdaLayerNorm(nn.Module):
Expand Down Expand Up @@ -133,6 +134,8 @@ def load_model(self):

def __call__(self, caption, text_input_ids=None, attention_mask=None, max_length=None, **kwargs):
self.load_model()
# remove a1111/comfyui prompt weight, t5 embedder currently does not accept weight
caption = remove_weights(caption)
if max_length is None:
max_length = self.max_length

Expand Down
60 changes: 60 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from contextlib import suppress


def escape_important(text):
return text.replace("\\)", "\0\1").replace("\\(", "\0\2")


def unescape_important(text):
return text.replace("\0\1", ")").replace("\0\2", "(")


def parse_parentheses(string):
result = []
current_item = ""
nesting_level = 0
for char in string:
if char == "(":
if nesting_level == 0:
if current_item:
result.append(current_item)
current_item = "("
else:
current_item = "("
else:
current_item += char
nesting_level += 1
elif char == ")":
nesting_level -= 1
if nesting_level == 0:
result.append(current_item + ")")
current_item = ""
else:
current_item += char
else:
current_item += char
if current_item:
result.append(current_item)
return result


def _remove_weights(string):
a = parse_parentheses(string)
out = []
for x in a:
if len(x) >= 2 and x[-1] == ")" and x[0] == "(":
x = x[1:-1]
xx = x.rfind(":")
if xx > 0:
with suppress(Exception):
x = x[:xx]
out += _remove_weights(x)
else:
out += [x]
return out


def remove_weights(text: str):
text = escape_important(text)
parsed_weights = _remove_weights(text)
return "".join([unescape_important(segment) for segment in parsed_weights])

0 comments on commit 1c62890

Please sign in to comment.