1010import quimb .tensor as qtn
1111
1212import pepsy as py
13+ import pepsy .optimizers .mps .optimizer as mps_optimizer_module
1314
1415
1516def _non_unitary_entangling_gate ():
@@ -386,6 +387,54 @@ def test_mps_optimizer_canonical_span_norm_matches_full_target_norm(where):
386387 assert local_norm == pytest .approx (target .norm ())
387388
388389
390+ def test_mps_optimizer_canonical_span_norm_ignores_stored_exponent ():
391+ """Internal normalization should measure raw data, not represented scale."""
392+ p0 = qtn .MPS_rand_state (4 , bond_dim = 2 , phys_dim = 2 , dtype = "complex128" )
393+ opt = py .MpsOptimizer (p0 .copy (), gates = [], chi = 8 , mode = "svd" )
394+ opt .p .exponent = 3.0
395+
396+ raw = opt .p .copy ()
397+ raw .exponent = 0.0
398+ measured = opt ._canonical_span_norm (opt .p , (0 , 3 )) # pylint: disable=protected-access
399+
400+ assert measured == pytest .approx (raw .norm ())
401+ assert opt .p .exponent == pytest .approx (3.0 )
402+
403+
404+ def test_mps_optimizer_norm_infidelity_uses_tn_norm_strip_exponent (monkeypatch ):
405+ """Norm-infidelity diagnostics should measure raw norms through ``tn_norm``."""
406+ calls = []
407+ original_tn_norm = mps_optimizer_module .tn_norm
408+
409+ def _spy_tn_norm (* args , ** kwargs ):
410+ calls .append (kwargs .copy ())
411+ return original_tn_norm (* args , ** kwargs )
412+
413+ monkeypatch .setattr (mps_optimizer_module , "tn_norm" , _spy_tn_norm )
414+
415+ p0 = qtn .MPS_computational_state ("0000" , dtype = "complex128" )
416+ gates = [
417+ (qu .hadamard (), (0 ,)),
418+ (qu .hadamard (), (1 ,)),
419+ (_non_unitary_entangling_gate (), (0 , 1 )),
420+ ]
421+
422+ opt = py .MpsOptimizer (p0 .copy (), gates = gates , chi = 1 , mode = "mpo" )
423+ opt .run (
424+ progbar = False ,
425+ cutoff = 1e-12 ,
426+ non_unitary = True ,
427+ normalize_final = True ,
428+ track_norm_infidelity = True ,
429+ )
430+
431+ samples = opt .get_norm_infidelity_samples ()
432+ assert len (samples ) == 1
433+ assert len (calls ) >= 2
434+ assert all (call ["strip_exponent" ] is True for call in calls )
435+ assert all (call ["contraction_opt" ] == opt .contraction_opt for call in calls )
436+
437+
389438def test_mps_optimizer_non_unitary_norm_infidelity_matches_svd_target ():
390439 """SVD non-unitary proxy should match quimb's target infidelity."""
391440 p0 = qtn .MPS_computational_state ("0000" , dtype = "complex128" )
@@ -664,9 +713,9 @@ def test_mps_optimizer_dmrg_non_unitary_matches_mpo_accuracy():
664713 )
665714
666715
667- @pytest .mark .parametrize ("mode" , ["dmrg" , "mpo" ])
716+ @pytest .mark .parametrize ("mode" , ["dmrg" , "mpo" , "swap" , "svd" ])
668717def test_mps_optimizer_non_unitary_norm_infidelity_smoke_other_modes (mode ):
669- """Other compressed modes should expose a bounded non-unitary proxy."""
718+ """All compressed modes should expose a bounded non-unitary proxy."""
670719 p0 = qtn .MPS_computational_state ("0000" , dtype = "complex128" )
671720 gates = [
672721 (qu .hadamard (), (0 ,)),
@@ -675,6 +724,8 @@ def test_mps_optimizer_non_unitary_norm_infidelity_smoke_other_modes(mode):
675724 ]
676725
677726 opt = py .MpsOptimizer (p0 .copy (), gates = gates , chi = 1 , mode = mode )
727+ if mode == "swap" and not hasattr (opt .p , "gate_with_auto_swap_" ):
728+ pytest .skip ("swap mode requires gate_with_auto_swap_ in this quimb version." )
678729 opt .run (
679730 progbar = False ,
680731 cutoff = 1e-12 ,
0 commit comments