-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
你好,你好,Jianping Zhou,关于下述问题的咨询:
问题:在计算predict_hidden = forward_hidden.reshape(B, -1, K, L) * (1 - cond_mask.unsqueeze(1)) + reverse_hidden.reshape(B, -1, K, L) * (1 - reverse_cond_mask.unsqueeze(1))时,原始数据为nan的地方的表征加了两次。
请问,如果需要优化这个小问题,我的思路是:1、不能将原始数据为nan的地方的predict_hidden置为0,否则0值也会被作为真实的值参与InfoNCE()损失的计算;2、不能将原始数据为nan的地方的predict_hidden数据直接移除,否则会导致一个batch中每个样本的长度不同;3、将原始数据为nan的地方的predict_hidden置为0后,统一移动到每行数据的最后面,非0元素顺序保持不变,这样不影响余弦相似度的计算,也就不影响InfoNCE()损失的计算。风险点位一个batch中每个样本的有效非零元素数量不同,影响训练过程。
请问,上述思路3的方法,您是怎么看的,是否认同?谢谢。
Metadata
Metadata
Assignees
Labels
No labels