Skip to content

Commit 94b1a57

Browse files
committed
Add from_tensor method for split transfer
1 parent 832e290 commit 94b1a57

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed

varipeps/peps/tensor.py

+138
Original file line numberDiff line numberDiff line change
@@ -1987,6 +1987,144 @@ def __post_init__(self) -> None:
19871987
# "At least one transfer tensors mismatch bond dimensions of PEPS tensor."
19881988
# )
19891989

1990+
@classmethod
1991+
def from_tensor(
1992+
cls: Type[T_PEPS_Tensor],
1993+
tensor: Tensor,
1994+
d: int,
1995+
D: Union[int, Sequence[int]],
1996+
chi: int,
1997+
interlayer_chi: Optional[int] = None,
1998+
max_chi: Optional[int] = None,
1999+
*,
2000+
ctm_tensors_are_identities: bool = True,
2001+
normalize: bool = True,
2002+
seed: Optional[int] = None,
2003+
backend: str = "jax",
2004+
) -> T_PEPS_Tensor:
2005+
"""
2006+
Initialize a PEPS tensor object with a given tensor and new CTM tensors.
2007+
2008+
Args:
2009+
tensor (:obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray`):
2010+
PEPS tensor to initialize the object with
2011+
d (:obj:`int`):
2012+
Physical dimension
2013+
D (:obj:`int` or :term:`sequence` of :obj:`int`):
2014+
Bond dimensions for the PEPS tensor
2015+
chi (:obj:`int`):
2016+
Bond dimension for the environment tensors
2017+
interlayer_chi (:obj:`int`):
2018+
Bond dimension for the interlayer bonds of the environment tensors
2019+
max_chi (:obj:`int`):
2020+
Maximal allowed bond dimension for the environment tensors
2021+
Keyword args:
2022+
ctm_tensors_are_identities (:obj:`bool`, optional):
2023+
Flag if the CTM tensors are initialized as identities. Otherwise,
2024+
they are initialized randomly. Defaults to True.
2025+
normalize (:obj:`bool`, optional):
2026+
Flag if the generated tensors are normalized. Defaults to True.
2027+
seed (:obj:`int`, optional):
2028+
Seed for the random number generator.
2029+
backend (:obj:`str`, optional):
2030+
Backend for the generated tensors (may be ``jax`` or ``numpy``).
2031+
Defaults to ``jax``.
2032+
Returns:
2033+
PEPS_Tensor:
2034+
Instance of PEPS_Tensor with the randomly initialized tensors.
2035+
"""
2036+
if not is_tensor(tensor):
2037+
raise ValueError("Invalid argument for tensor.")
2038+
2039+
if isinstance(D, int):
2040+
D = (D,) * 4
2041+
elif isinstance(D, collections.abc.Sequence) and not isinstance(D, tuple):
2042+
D = tuple(D)
2043+
2044+
if not all(isinstance(i, int) for i in D) or not len(D) == 4:
2045+
raise ValueError("Invalid argument for D.")
2046+
2047+
if (
2048+
tensor.shape[0] != D[0]
2049+
or tensor.shape[1] != D[1]
2050+
or tensor.shape[3] != D[2]
2051+
or tensor.shape[4] != D[3]
2052+
or tensor.shape[2] != d
2053+
):
2054+
raise ValueError("Tensor dimensions mismatch the dimension arguments.")
2055+
2056+
if interlayer_chi is None:
2057+
interlayer_chi = chi
2058+
if max_chi is None:
2059+
max_chi = chi
2060+
2061+
dtype = tensor.dtype
2062+
2063+
if ctm_tensors_are_identities:
2064+
C1 = jnp.ones((1, 1), dtype=dtype)
2065+
C2 = jnp.ones((1, 1), dtype=dtype)
2066+
C3 = jnp.ones((1, 1), dtype=dtype)
2067+
C4 = jnp.ones((1, 1), dtype=dtype)
2068+
2069+
T1 = jnp.eye(D[3], dtype=dtype).reshape(1, D[3], D[3], 1)
2070+
T2 = jnp.eye(D[2], dtype=dtype).reshape(D[2], D[2], 1, 1)
2071+
T3 = jnp.eye(D[1], dtype=dtype).reshape(1, 1, D[1], D[1])
2072+
T4 = jnp.eye(D[0], dtype=dtype).reshape(1, D[0], D[0], 1)
2073+
2074+
return cls(
2075+
tensor=tensor,
2076+
C1=C1,
2077+
C2=C2,
2078+
C3=C3,
2079+
C4=C4,
2080+
T1=T1,
2081+
T2=T2,
2082+
T3=T3,
2083+
T4=T4,
2084+
d=d,
2085+
D=D, # type: ignore
2086+
chi=chi,
2087+
interlayer_chi=interlayer_chi,
2088+
max_chi=max_chi,
2089+
)
2090+
else:
2091+
rng = PEPS_Random_Number_Generator.get_generator(seed, backend=backend)
2092+
2093+
C1 = rng.block((chi, chi), dtype, normalize=normalize)
2094+
C2 = rng.block((chi, chi), dtype, normalize=normalize)
2095+
C3 = rng.block((chi, chi), dtype, normalize=normalize)
2096+
C4 = rng.block((chi, chi), dtype, normalize=normalize)
2097+
2098+
T1_ket = rng.block((chi, D[3], interlayer_chi), dtype, normalize=normalize)
2099+
T1_bra = rng.block((interlayer_chi, D[3], chi), dtype, normalize=normalize)
2100+
T2_ket = rng.block((interlayer_chi, D[2], chi), dtype, normalize=normalize)
2101+
T2_bra = rng.block((chi, D[2], interlayer_chi), dtype, normalize=normalize)
2102+
T3_ket = rng.block((chi, D[1], interlayer_chi), dtype, normalize=normalize)
2103+
T3_bra = rng.block((interlayer_chi, D[1], chi), dtype, normalize=normalize)
2104+
T4_ket = rng.block((interlayer_chi, D[0], chi), dtype, normalize=normalize)
2105+
T4_bra = rng.block((chi, D[0], interlayer_chi), dtype, normalize=normalize)
2106+
2107+
return cls(
2108+
tensor=tensor,
2109+
C1=C1,
2110+
C2=C2,
2111+
C3=C3,
2112+
C4=C4,
2113+
T1_ket=T1_ket,
2114+
T1_bra=T1_bra,
2115+
T2_ket=T2_ket,
2116+
T2_bra=T2_bra,
2117+
T3_ket=T3_ket,
2118+
T3_bra=T3_bra,
2119+
T4_ket=T4_ket,
2120+
T4_bra=T4_bra,
2121+
d=d,
2122+
D=D, # type: ignore
2123+
chi=chi,
2124+
interlayer_chi=interlayer_chi,
2125+
max_chi=max_chi,
2126+
)
2127+
19902128
@property
19912129
def left_upper_transfer_shape(self) -> Tensor:
19922130
return self.T4_ket.shape[2]

0 commit comments

Comments
 (0)