Skip to content

Commit fcf57fd

Browse files
committed
Refactor Kagome mapping to new PESS structure
1 parent 2a3f009 commit fcf57fd

File tree

1 file changed

+275
-74
lines changed

1 file changed

+275
-74
lines changed

peps_ad/mapping/kagome.py

+275-74
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
11
from dataclasses import dataclass
2+
from os import PathLike
23

34
import jax.numpy as jnp
45
from jax import jit
56

7+
import h5py
8+
9+
import peps_ad.config
610
from peps_ad.peps import PEPS_Tensor, PEPS_Unit_Cell
711
from peps_ad.contractions import apply_contraction
812
from peps_ad.expectation.model import Expectation_Model
913
from peps_ad.expectation.one_site import calc_one_site_multi_gates
1014
from peps_ad.expectation.three_sites import _three_site_triangle_workhorse
1115
from peps_ad.typing import Tensor
16+
from peps_ad.utils.random import PEPS_Random_Number_Generator
17+
from peps_ad.mapping import Map_To_PEPS_Model
1218

13-
from typing import Sequence, Union, List, Callable, TypeVar, Optional
19+
from typing import Sequence, Union, List, Callable, TypeVar, Optional, Tuple, Type
1420

1521
T_float_complex = TypeVar("T_float_complex", float, complex)
22+
T_Kagome_Map_PESS3_To_Single_PEPS_Site = TypeVar(
23+
"T_Kagome_Map_PESS3_To_Single_PEPS_Site",
24+
bound="Kagome_Map_PESS3_To_Single_PEPS_Site",
25+
)
1626

1727

1828
@dataclass
@@ -195,98 +205,289 @@ def _kagome_mapping_workhorse(
195205
return result / jnp.linalg.norm(result)
196206

197207

198-
class iPESS3_Single_PEPS_Site:
208+
@dataclass
209+
class Kagome_Map_PESS3_To_Single_PEPS_Site(Map_To_PEPS_Model):
199210
"""
200-
Map a 3-site Kagome PESS unit cell to PEPS structure using a PEPS unitcell
201-
consisting of one site.
211+
Map a 3-site Kagome iPESS unit cell to a iPEPS structure.
212+
213+
Create a PEPS unitcell from a Kagome 3-PESS structure. To this end, the
214+
two simplex tensor and all three sites are mapped into PEPS sites.
215+
216+
The axes of the simplex tensors are expected to be in the order:
217+
- Up: PESS site 1, PESS site 2, PESS site 3
218+
- Down: PESS site 3, PESS site 2, PESS site 1
219+
220+
The axes of the site tensors are expected to be in the order
221+
`connection to down simplex, physical bond, connection to up simplex`.
222+
223+
The PESS structure is contracted in the way that all site tensors are
224+
connected to the up simplex and the down simplex to site 1.
202225
"""
203226

204-
@staticmethod
205-
def unitcell_from_pess_tensors(
206-
up_simplex: Tensor,
207-
down_simplex: Tensor,
208-
site_1: Tensor,
209-
site_2: Tensor,
210-
site_3: Tensor,
227+
unitcell_structure: Sequence[Sequence[int]]
228+
chi: int
229+
230+
def __call__(
231+
self,
232+
input_tensors: Sequence[jnp.ndarray],
233+
*,
234+
generate_unitcell: bool = True,
235+
) -> Union[List[jnp.ndarray], Tuple[List[jnp.ndarray], PEPS_Unit_Cell]]:
236+
num_peps_sites = len(input_tensors) // 5
237+
if num_peps_sites * 5 != len(input_tensors):
238+
raise ValueError(
239+
"Input tensors seems not be a list for a square Kagome simplex system."
240+
)
241+
242+
peps_tensors = [
243+
_kagome_mapping_workhorse(*(input_tensors[(i * 5) : (i * 5 + 5)]))
244+
for i in range(num_peps_sites)
245+
]
246+
247+
if generate_unitcell:
248+
peps_tensor_objs = [
249+
PEPS_Tensor.from_tensor(
250+
i,
251+
i.shape[2],
252+
(i.shape[0], i.shape[1], i.shape[3], i.shape[4]),
253+
self.chi,
254+
)
255+
for i in peps_tensors
256+
]
257+
unitcell = PEPS_Unit_Cell.from_tensor_list(
258+
peps_tensor_objs, self.unitcell_structure
259+
)
260+
261+
return peps_tensors, unitcell
262+
263+
return peps_tensors
264+
265+
@classmethod
266+
def random(
267+
cls: Type[T_Kagome_Map_PESS3_To_Single_PEPS_Site],
268+
structure: Sequence[Sequence[int]],
211269
d: int,
212270
D: int,
213-
chi: int,
214-
) -> PEPS_Unit_Cell:
215-
"""
216-
Create a PEPS unitcell from a Kagome 3-PESS 3-sites structure. To this
217-
end, the two simplex tensor and all three sites are mapped into a single
218-
PEPS site.
271+
chi: Union[int, Sequence[int]],
272+
dtype: Type[jnp.number],
273+
*,
274+
seed: Optional[int] = None,
275+
destroy_random_state: bool = True,
276+
) -> Tuple[List[jnp.ndarray], T_Kagome_Map_PESS3_To_Single_PEPS_Site]:
277+
structure_arr = jnp.asarray(structure)
219278

220-
The axes of the simplex tensors are expected to be in the order:
221-
- Up: PESS site 1, PESS site 2, PESS site 3
222-
- Down: PESS site 3, PESS site 2, PESS site 1
279+
structure_arr, tensors_i = PEPS_Unit_Cell._check_structure(structure_arr)
223280

224-
The axes site tensors are expected to be in the order
225-
connection to down simplex, physical bond, connection to up simplex.
281+
# Check the inputs
282+
if not isinstance(d, int):
283+
raise ValueError("d has to be a single integer.")
226284

227-
The PESS structure is contracted in the way that all site tensors are
228-
connected to the up simplex and the down simplex to site 1.
285+
if not isinstance(D, int):
286+
raise ValueError("D has to be a single integer.")
287+
288+
if not isinstance(chi, int):
289+
raise ValueError("chi has to be a single integer.")
290+
291+
# Generate the PEPS tensors
292+
if destroy_random_state:
293+
PEPS_Random_Number_Generator.destroy_state()
294+
295+
rng = PEPS_Random_Number_Generator.get_generator(seed, backend="jax")
296+
297+
result_tensors = []
298+
299+
for i in tensors_i:
300+
result_tensors.append(rng.block((D, D, D), dtype=dtype)) # simplex_up
301+
result_tensors.append(rng.block((D, D, D), dtype=dtype)) # simplex_down
302+
result_tensors.append(rng.block((D, d, D), dtype=dtype)) # site1
303+
result_tensors.append(rng.block((D, d, D), dtype=dtype)) # site2
304+
result_tensors.append(rng.block((D, d, D), dtype=dtype)) # site3
305+
306+
return result_tensors, cls(unitcell_structure=structure, chi=chi)
307+
308+
@classmethod
309+
def save_to_file(
310+
cls: Type[T_Kagome_Map_PESS3_To_Single_PEPS_Site],
311+
path: PathLike,
312+
tensors: List[jnp.ndarray],
313+
unitcell: PEPS_Unit_Cell,
314+
*,
315+
store_config: bool = True,
316+
) -> None:
317+
"""
318+
Save unit cell to a HDF5 file.
319+
320+
This function creates a single group "kagome_pess" in the file
321+
and pass this group to the method
322+
:obj:`~Kagome_Map_PESS3_To_Single_PEPS_Site.save_to_group` then.
229323
230324
Args:
231-
up_simplex (:obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray`):
232-
The tensor of the up simplex.
233-
down_simplex (:obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray`):
234-
The tensor of the down simplex.
235-
site_1 (:obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray`):
236-
The tensor of the first PESS site.
237-
site_2 (:obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray`):
238-
The tensor of the second PESS site.
239-
site_3 (:obj:`numpy.ndarray` or :obj:`jax.numpy.ndarray`):
240-
The tensor of the third PESS site.
241-
d (:obj:`int`):
242-
Physical dimension
243-
D (:obj:`int`):
244-
Bond dimension.
245-
chi (:obj:`int`):
246-
Environment bond dimension.
247-
Returns:
248-
~peps_ad.peps.PEPS_Unit_Cell:
249-
PEPS unitcell with the mapped PESS structure and initialized
250-
environment tensors.
325+
path (:obj:`os.PathLike`):
326+
Path of the new file. Caution: The file will overwritten if existing.
327+
store_config (:obj:`bool`):
328+
Store the current values of the global config object into the HDF5
329+
file as attrs of an extra group.
251330
"""
252-
if not isinstance(d, int) or not isinstance(D, int) or not isinstance(chi, int):
253-
raise ValueError("Dimensions have to be integers.")
331+
with h5py.File(path, "w", libver=("earliest", "v110")) as f:
332+
grp = f.create_group("kagome_pess")
333+
334+
cls.save_to_group(grp, tensors, unitcell, store_config=store_config)
254335

255-
if not (site_1.shape[1] == site_2.shape[1] == site_3.shape[1] == d):
336+
@staticmethod
337+
def save_to_group(
338+
grp: h5py.Group,
339+
tensors: List[jnp.ndarray],
340+
unitcell: PEPS_Unit_Cell,
341+
*,
342+
store_config: bool = True,
343+
) -> None:
344+
"""
345+
Save unit cell to a HDF5 group which is be passed to the method.
346+
347+
Args:
348+
grp (:obj:`h5py.Group`):
349+
HDF5 group object to store the data into.
350+
store_config (:obj:`bool`):
351+
Store the current values of the global config object into the HDF5
352+
file as attrs of an extra group.
353+
"""
354+
num_peps_sites = len(tensors) // 5
355+
if num_peps_sites * 5 != len(tensors):
256356
raise ValueError(
257-
"Dimension of site tensor mismatches physical dimension argument."
357+
"Input tensors seems not be a list for a Kagome simplex system."
258358
)
259359

260-
if not (
261-
site_1.shape[0]
262-
== site_1.shape[2]
263-
== site_2.shape[0]
264-
== site_2.shape[2]
265-
== site_3.shape[0]
266-
== site_3.shape[2]
267-
== D
268-
) or not (
269-
up_simplex.shape[0]
270-
== up_simplex.shape[1]
271-
== up_simplex.shape[2]
272-
== down_simplex.shape[0]
273-
== down_simplex.shape[1]
274-
== down_simplex.shape[2]
275-
== D
276-
):
277-
raise ValueError("Dimension of tensor mismatches bond dimension argument.")
360+
grp_pess = grp.create_group("pess_tensors", track_order=True)
361+
grp_pess.attrs["num_peps_sites"] = num_peps_sites
362+
363+
for i in range(num_peps_sites):
364+
(
365+
simplex_up,
366+
simplex_down,
367+
t1,
368+
t2,
369+
t3,
370+
) = tensors[(i * 5) : (i * 5 + 5)]
371+
372+
grp_pess.create_dataset(
373+
f"site{i}_simplex_up",
374+
data=simplex_up,
375+
compression="gzip",
376+
compression_opts=6,
377+
)
378+
grp_pess.create_dataset(
379+
f"site{i}_simplex_down",
380+
data=simplex_down,
381+
compression="gzip",
382+
compression_opts=6,
383+
)
384+
grp_pess.create_dataset(
385+
f"site{i}_t1", data=t1, compression="gzip", compression_opts=6
386+
)
387+
grp_pess.create_dataset(
388+
f"site{i}_t2", data=t2, compression="gzip", compression_opts=6
389+
)
390+
grp_pess.create_dataset(
391+
f"site{i}_t3", data=t3, compression="gzip", compression_opts=6
392+
)
278393

279-
peps_tensor = _kagome_mapping_workhorse(
280-
jnp.asarray(up_simplex),
281-
jnp.asarray(down_simplex),
282-
jnp.asarray(site_1),
283-
jnp.asarray(site_2),
284-
jnp.asarray(site_3),
394+
grp_unitcell = grp.create_group("unitcell")
395+
unitcell.save_to_group(grp_unitcell, store_config=store_config)
396+
397+
@classmethod
398+
def load_from_file(
399+
cls: Type[T_Kagome_Map_PESS3_To_Single_PEPS_Site],
400+
path: PathLike,
401+
*,
402+
return_config: bool = False,
403+
) -> Union[
404+
Tuple[List[jnp.ndarray], PEPS_Unit_Cell],
405+
Tuple[List[jnp.ndarray], PEPS_Unit_Cell, peps_ad.config.PEPS_AD_Config],
406+
]:
407+
"""
408+
Load unit cell from a HDF5 file.
409+
410+
This function read the group "kagome_pess" from the file and pass
411+
this group to the method
412+
:obj:`~Kagome_Map_PESS3_To_Single_PEPS_Site.load_from_group` then.
413+
414+
Args:
415+
path (:obj:`os.PathLike`):
416+
Path of the HDF5 file.
417+
return_config (:obj:`bool`):
418+
Return a config object initialized with the values from the HDF5
419+
files. If no config is stored in the file, just the data is returned.
420+
Missing config flags in the file uses the default values from the
421+
config object.
422+
Returns:
423+
:obj:`tuple`\ (:obj:`list`\ (:obj:`jax.numpy.ndarray`), :obj:`~peps_ad.peps.PEPS_Unit_Cell`) or :obj:`tuple`\ (:obj:`list`\ (:obj:`jax.numpy.ndarray`), :obj:`~peps_ad.peps.PEPS_Unit_Cell`, :obj:`~peps_ad.config.PEPS_AD_Config`):
424+
The tuple with the list of the PESS tensors and the PEPS unitcell
425+
is returned. If ``return_config = True``. the config is returned
426+
as well.
427+
"""
428+
with h5py.File(path, "r") as f:
429+
out = cls.load_from_group(f["kagome_pess"], return_config=return_config)
430+
431+
if return_config:
432+
return out[0], out[1], out[2]
433+
434+
return out[0], out[1]
435+
436+
@staticmethod
437+
def load_from_group(
438+
grp: h5py.Group,
439+
*,
440+
return_config: bool = False,
441+
) -> Union[
442+
Tuple[List[jnp.ndarray], PEPS_Unit_Cell],
443+
Tuple[List[jnp.ndarray], PEPS_Unit_Cell, peps_ad.config.PEPS_AD_Config],
444+
]:
445+
"""
446+
Load the unit cell from a HDF5 group which is be passed to the method.
447+
448+
Args:
449+
grp (:obj:`h5py.Group`):
450+
HDF5 group object to load the data from.
451+
return_config (:obj:`bool`):
452+
Return a config object initialized with the values from the HDF5
453+
files. If no config is stored in the file, just the data is returned.
454+
Missing config flags in the file uses the default values from the
455+
config object.
456+
Returns:
457+
:obj:`tuple`\ (:obj:`list`\ (:obj:`jax.numpy.ndarray`), :obj:`~peps_ad.peps.PEPS_Unit_Cell`) or :obj:`tuple`\ (:obj:`list`\ (:obj:`jax.numpy.ndarray`), :obj:`~peps_ad.peps.PEPS_Unit_Cell`, :obj:`~peps_ad.config.PEPS_AD_Config`):
458+
The tuple with the list of the PESS tensors and the PEPS unitcell
459+
is returned. If ``return_config = True``. the config is returned
460+
as well.
461+
"""
462+
grp_pess = grp["pess_tensors"]
463+
num_peps_sites = grp_pess.attrs["num_peps_sites"]
464+
465+
tensors = []
466+
467+
for i in range(num_peps_sites):
468+
tensors.append(jnp.asarray(grp_pess[f"site{i}_simplex_up"]))
469+
tensors.append(jnp.asarray(grp_pess[f"site{i}_simplex_down"]))
470+
tensors.append(jnp.asarray(grp_pess[f"site{i}_t1"]))
471+
tensors.append(jnp.asarray(grp_pess[f"site{i}_t2"]))
472+
tensors.append(jnp.asarray(grp_pess[f"site{i}_t3"]))
473+
474+
out = PEPS_Unit_Cell.load_from_group(
475+
grp["unitcell"], return_config=return_config
285476
)
286477

287-
peps_tensor_obj = PEPS_Tensor.from_tensor(peps_tensor, d**3, D, chi)
478+
if return_config:
479+
return tensors, out[0], out[1]
288480

289-
return PEPS_Unit_Cell.from_tensor_list((peps_tensor_obj,), ((0,),))
481+
return tensors, out
482+
483+
@classmethod
484+
def autosave_wrapper(
485+
cls: Type[T_Kagome_Map_PESS3_To_Single_PEPS_Site],
486+
filename: PathLike,
487+
tensors: jnp.ndarray,
488+
unitcell: PEPS_Unit_Cell,
489+
) -> None:
490+
cls.save_to_file(filename, tensors, unitcell)
290491

291492

292493
class iPESS3_9Sites_Three_PEPS_Site:

0 commit comments

Comments
 (0)