Skip to content

Conversation

susanbao
Copy link
Collaborator

@susanbao susanbao commented Sep 26, 2025

  • Implement the new eval pipeline which matches that in MLPerf.

    ALGORITHM: Validation Loss Computation
    
    INPUT:
      - validation_samples: set of validation data samples
    
    INITIALIZE:
      - sum[8]: array of zeros for accumulating losses
      - count[8]: array of zeros for counting samples per timestep
    
    FOR each sample, timestep in validation_samples:
        loss = forward_pass(sample, timestep=t/8)
        sum[t] += loss
        count[t] += 1
    
    mean_per_timestep = sum / count
    validation_loss = mean(mean_per_timestep)
    
    RETURN validation_loss
    
  • Add new para enable_ssim which help to debug.

Copy link

@susanbao susanbao requested a review from coolkp October 2, 2025 16:45
@coolkp
Copy link
Collaborator

coolkp commented Oct 3, 2025

A idea for optimizing this PR. iiurc we have nested loop where we split RNG keys first in training loop and then in loop wrapping loss function. it would be good to generate the keys upfront and move the key out of the jit, iiurc eval_step is jitted function, so this will increase compile time too. https://stackoverflow.com/a/75339951 . We should vectorize that and move the key. you can give it a shot or create a TODO for me, i can take a look at it later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants