Skip to content

Commit

Permalink
feat: add display rate back for saving bandwidth (#132)
Browse files Browse the repository at this point in the history
* feat: add display rate back for saving bandwidth

* style: fix overload and cli autocomplete

Co-authored-by: Jina Dev Bot <dev-bot@jina.ai>
hanxiao and jina-bot authored Aug 3, 2022
1 parent 5baa0da commit 471c6dc
Showing 6 changed files with 22 additions and 13 deletions.
19 changes: 9 additions & 10 deletions discoart/config.py
Original file line number Diff line number Diff line change
@@ -27,7 +27,7 @@
) as ymlfile:
cut_schedules = yaml.load(ymlfile, Loader=Loader)

_legacy_args = {'clip_sequential_evaluation', 'fuzzy_prompt', 'display_rate'}
_legacy_args = {'clip_sequential_evaluation', 'fuzzy_prompt'}


def load_config(
@@ -52,16 +52,15 @@ def load_config(

cfg.update(**user_config)

for k, v in cfg.items():
if k in (
'batch_size',
'display_rate',
int_keys = {k for k, v in default_args.items() if isinstance(v, int)}
int_keys.union(
{
'seed',
'skip_steps',
'steps',
'n_batches',
'cutn_batches',
) and isinstance(v, float):
}
)

for k, v in cfg.items():
if k in int_keys and isinstance(v, float):
cfg[k] = int(v)
if k == 'width_height':
cfg[k] = [int(vv) for vv in v]
2 changes: 2 additions & 0 deletions discoart/create.py
Original file line number Diff line number Diff line change
@@ -39,6 +39,7 @@ def create(
diffusion_model: Optional[str] = '512x512_diffusion_uncond_finetune_008100',
diffusion_model_config: Optional[Dict[str, Any]] = None,
diffusion_sampling_mode: Optional[str] = 'ddim',
display_rate: Optional[int] = 1,
eta: Optional[float] = 0.8,
gif_fps: Optional[int] = 20,
gif_size_ratio: Optional[float] = 0.5,
@@ -114,6 +115,7 @@ def create(**kwargs) -> Optional['DocumentArray']:
:param diffusion_model: Diffusion_model of choice. Note that you don't have to write the full name of the diffusion model, e.g. any prefix is enough.To use a listed all diffusion models, you can do:```pythonfrom discoart import createcreate(diffusion_model='portrait_generator', ...)```
:param diffusion_model_config: [DiscoArt] The customized diffusion model config as a dictionary, if specified will override the values with the same name in the default model config.
:param diffusion_sampling_mode: Two alternate diffusion denoising algorithms. ddim has been around longer, and is more established and tested. plms is a newly added alternate method that promises good diffusion results in fewer steps, but has not been as fully tested and may have side effects. This new plms mode is actively being researched in the #settings-and-techniques channel in the DD Discord.
:param display_rate: [DiscoArt] The refresh rate of displaying the generated images in Notebook environment. The value has nothing to do with the rate of saving images and the speed of generation or sampling. It is purely about your browser refreshing. Smaller value (1 is the smallest, 0 will disable the refresh) will consume more network bandwidth, as your browser will actively fetch refreshed images to local. Change it to a bigger value if you have limited network bandwidth.
:param eta: eta (greek letter η) is a diffusion model variable that mixes in a random amount of scaled noise into each timestep. 0 is no noise, 1.0 is more noise. As with most DD parameters, you can go below zero for eta, but it may give you unpredictable results. The steps parameter has a close relationship with the eta parameter. If you set eta to 0, then you can get decent output with only 50-75 steps. Setting eta to 1.0 favors higher step counts, ideally around 250 and up. eta has a subtle, unpredictable effect on image, so you’ll need to experiment to see how this affects your projects.
:param gif_fps: [DiscoArt] The frame rate of the generated GIF. Set it to -1 for not saving GIF.
:param gif_size_ratio: [DiscoArt] The relative size vs. the original image, small size ratio gives smaller file size.
4 changes: 3 additions & 1 deletion discoart/persist.py
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@ def _sample(
is_save_step,
is_save_gif,
is_image_output,
is_display_step,
):
with threading.Lock():
is_sampling_done.clear()
@@ -78,7 +79,8 @@ def _sample(

_display_html.append(f'<img src="{c.uri}" alt="step {j} minibatch {k}">')

_handlers.preview.value = '<br>\n'.join(_display_html)
if is_display_step:
_handlers.preview.value = '<br>\n'.join(_display_html)
logger.debug('sample and plot is done')
is_sampling_done.set()

3 changes: 2 additions & 1 deletion discoart/resources/default.yml
Original file line number Diff line number Diff line change
@@ -62,4 +62,5 @@ stop_event:
text_clip_on_cpu: False
truncate_overlength_prompt: False
image_output: True
visualize_cuts: False
visualize_cuts: False
display_rate: 1
5 changes: 4 additions & 1 deletion discoart/resources/docstrings.yml
Original file line number Diff line number Diff line change
@@ -227,4 +227,7 @@ image_output: |
[DiscoArt] If set, then output will be saved as images. This includes intermediate, final results in the form of PNG and GIF. If set to False, then no images will be saved, everything will be saved in a Protobuf LZ4 format. https://docarray.jina.ai/fundamentals/documentarray/serialization/#from-to-bytes
visualize_cuts: |
[DiscoArt] If set, then `cuts-{step}.png` will be saved for each step, visualizing all cuts in a sprite image at each step.
[DiscoArt] If set, then `cuts-{step}.png` will be saved for each step, visualizing all cuts in a sprite image at each step.
display_rate: |
[DiscoArt] The refresh rate of displaying the generated images in Notebook environment. The value has nothing to do with the rate of saving images and the speed of generation or sampling. It is purely about your browser refreshing. Smaller value (1 is the smallest, 0 will disable the refresh) will consume more network bandwidth, as your browser will actively fetch refreshed images to local. Change it to a bigger value if you have limited network bandwidth.
2 changes: 2 additions & 0 deletions discoart/runner.py
Original file line number Diff line number Diff line change
@@ -459,6 +459,7 @@ def cond_fn(x, t, **kwargs):

is_save_step = args.save_rate > 0 and j % args.save_rate == 0
is_complete = cur_t == -1
is_display_step = args.display_rate > 0 and j % args.display_rate == 0

threads.append(
_sample_thread(
@@ -475,6 +476,7 @@ def cond_fn(x, t, **kwargs):
is_save_step or is_complete,
args.gif_fps > 0,
args.image_output,
is_display_step,
)
)

0 comments on commit 471c6dc

Please sign in to comment.