Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

output image resolution details issues #126

Open
bjohn22 opened this issue Feb 2, 2025 · 1 comment
Open

output image resolution details issues #126

bjohn22 opened this issue Feb 2, 2025 · 1 comment

Comments

@bjohn22
Copy link

bjohn22 commented Feb 2, 2025

Any thoughts on why I am not able to reproduce the same resolution reported out there:
Weights: I downloaded weights from Huggingface to local and loaded it from local directory
Model: Janus-Pro-7B

Prompt python "/home/mytemp/Documents/ml_projects/Janus/janus_run_script.py" generate --prompt "Elephant on Dallas Texas street" --seed 142 --guidance 5.0

Here is my text to image snippet:

def generate(input_ids,
width,
height,
temperature: float = 1,
parallel_size: int = 16,
cfg_weight: float = 5,
image_token_num_per_image: int = 576,
patch_size: int = 16):
"""
Internal method that generates image patches from the text input.
"""
torch.cuda.empty_cache()
tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device)

# Prepare token embeddings
for i in range(parallel_size * 2):
    tokens[i, :] = input_ids
    # Odd rows get "unconditional" tokens
    if i % 2 != 0:
        tokens[i, 1:-1] = vl_chat_processor.pad_id

inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device)

pkv = None
# Loop to generate tokens for image
for i in range(image_token_num_per_image):
    outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv)
    pkv = outputs.past_key_values
    hidden_states = outputs.last_hidden_state
    logits = vl_gpt.gen_head(hidden_states[:, -1, :])

    # Extract conditional/unconditional logits
    logit_cond = logits[0::2, :]
    logit_uncond = logits[1::2, :]
    logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)

    probs = torch.softmax(logits / temperature, dim=-1)
    next_token = torch.multinomial(probs, num_samples=1)
    generated_tokens[:, i] = next_token.squeeze(dim=-1)

    # Expand the next_token for conditional/unconditional
    next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1)
    img_embeds = vl_gpt.prepare_gen_img_embeds(next_token)
    inputs_embeds = img_embeds.unsqueeze(dim=1)

# Decode tokens into patches
patches = vl_gpt.gen_vision_model.decode_code(
    generated_tokens.to(dtype=torch.int),
    shape=[parallel_size, 8, width // patch_size, height // patch_size]
)
return generated_tokens.to(dtype=torch.int), patches

def unpack(dec, width, height, parallel_size=16):
"""
Convert raw patches into final numpy images.
"""
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
dec = np.clip((dec + 1) / 2 * 255, 0, 255)

visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
visual_img[:, :, :] = dec
return visual_img

@torch.inference_mode()
def generate_image(prompt, seed, guidance):
"""
Generate multiple images for a given prompt.
"""
torch.cuda.empty_cache()
seed = seed if seed is not None else 12345
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)

width = 384
height = 384
parallel_size = 16

with torch.no_grad():
    # Prepare text
    messages = [
        {'role': 'User', 'content': prompt},
        {'role': 'Assistant', 'content': ''}
    ]
    text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
        conversations=messages,
        sft_format=vl_chat_processor.sft_format,
        system_prompt=''
    )
    text = text + vl_chat_processor.image_start_tag

    input_ids = torch.LongTensor(tokenizer.encode(text)).to(cuda_device)

    _, patches = generate(
        input_ids,
        width // 16 * 16,
        height // 16 * 16,
        cfg_weight=guidance,
        parallel_size=parallel_size
    )
    images = unpack(patches, width // 16 * 16, height // 16 * 16)

    # Convert to PIL and upscale
    return [
        Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS)
        for i in range(parallel_size)
    ]

Image

@bjohn22 bjohn22 changed the title resolution issues output image resolution details issues Feb 2, 2025
@bjohn22
Copy link
Author

bjohn22 commented Feb 2, 2025

Image

Here are my files list

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant