Would like to add an arg to determine FLOPs to infer on _t_ tokens for `calc_transformer_flops.py` Should be as simple as just turning off the bwd pass