diff --git a/one_file_ref.py b/one_file_ref.py index 61364afb..a32785b4 100644 --- a/one_file_ref.py +++ b/one_file_ref.py @@ -258,7 +258,7 @@ def forward( ) mask = torch.tril(tensor, diagonal=0).to(h.dtype) # make the mask banded to account for sliding window - mask = torch.triu(mask, diagonal=-self.args.sliding_window) + mask = torch.triu(mask, diagonal=-self.args.sliding_window+1) mask = torch.log(mask) for layer in self.layers: @@ -362,4 +362,4 @@ def demo(model_path: str, max_tokens: int = 35): print("=====================") if __name__ == "__main__": - fire.Fire(demo) \ No newline at end of file + fire.Fire(demo)