88from jax import random
99from collections import namedtuple
1010from numpyro .infer .mcmc import MCMCKernel
11- from qgml .data import SpinConfigurationGeneratorBase
1211from tqdm .auto import tqdm
1312
14- def create_isotropic_interaction_matrix (grid_size : int ):
15- """Create an interaction matrix for a 2D isotropic square lattice."""
16- J = jnp .zeros ((grid_size * grid_size , grid_size * grid_size ))
17-
18- for i in range (grid_size ):
19- for j in range (grid_size ):
20- # Spin index in the grid
21- idx = i * grid_size + j
22-
23- # Calculate the indices of the neighbors
24- right_idx = i * grid_size + (j + 1 ) % grid_size
25- left_idx = i * grid_size + (j - 1 ) % grid_size
26- bottom_idx = ((i + 1 ) % grid_size ) * grid_size + j
27- top_idx = ((i - 1 ) % grid_size ) * grid_size + j
28-
29- # Set the interactions, ensuring each pair is only added once
30- J = J .at [idx , right_idx ].set (1 )
31- J = J .at [idx , left_idx ].set (1 )
32- J = J .at [idx , bottom_idx ].set (1 )
33- J = J .at [idx , top_idx ].set (1 )
34- return J
35-
36-
3713@jax .jit
3814def energy (s , J , b , J_sparse = None ):
3915 """Calculate the Ising energy. For sparse Hamiltonians, it is recommneded to supply a list of nonzero indices of
@@ -51,7 +27,6 @@ def energy(s, J, b, J_sparse=None):
5127 else :
5228 return - jnp .einsum ("i,j,ij->" , s , s , J ) / 2.0 - jnp .dot (s , b )
5329
54-
5530def initialize_spins (rng_key , num_spins , num_chains ):
5631 if num_chains == 1 :
5732 spins = random .bernoulli (rng_key , 0.5 , (num_spins ,))
@@ -119,11 +94,19 @@ def mh_step(i, val):
11994 return MHState (spins , rng_key )
12095
12196
122- # Define the Ising model class
123- class IsingSpins (SpinConfigurationGeneratorBase ):
124- """
125- class object used to generate datasets
126- ArgsL
97+ class IsingSpins :
98+ r"""
99+ class object used to generate datasets by sampling an ising distrbution of a specified interaction
100+ matrix. The distribution is sampled via markov chain Monte Carlo via the Metrolopis Hastings
101+ algorithm.
102+
103+ In the case of perfect sampling, a spin configuration s is sampled with probabability
104+ :math:`p(s)=exp(-H(s)/T)`, where the energy :math:`H(s)=\sum_{i\neq j}s_i s_i J_{ij}+\sum_i b_i s_i`
105+ corresponds to an ising Hamiltonian and configurations s are :math:`\pm1` valued.
106+
107+ The final sampled configurations are converted from a :math:`\pm1` representation to to a binary
108+ representation via x = (s+1)//2.
109+
127110 N (int): Number of spins
128111 J (np.array): interaction matrix
129112 b (np.array): bias terms
@@ -134,14 +117,15 @@ class object used to generate datasets
134117 def __init__ (
135118 self , N : int , J : jnp .array , b : jnp .array , T : float , sparse = False , compute_partition_fn = False
136119 ) -> None :
137- super ().__init__ (N )
120+
121+ self .N = N
138122 self .kernel = MetropolisHastings ()
139123 self .J = J
140124 self .T = T
141125 self .b = b
142126 self .J_sparse = jnp .nonzero (J ) if sparse else None
143127
144- if compute_partition_fn :
128+ if compute_partition_fn :
145129 Z = 0
146130 for i in tqdm (range (2 ** self .N ), desc = "Computing partition function" ):
147131 lattice = (- 1 ) ** jnp .array (jnp .unravel_index (i , [2 ] * self .N ))
@@ -181,22 +165,70 @@ def sample(
181165 J_sparse = self .J_sparse ,
182166 )
183167 samples = mcmc .get_samples ()
184- return samples .reshape ((- 1 , self .N ))
168+ samples .reshape ((- 1 , self .N ))
169+ return (samples + 1 )// 2
170+
171+ def probability (self , x : ndarray ) -> float :
172+ """
173+ compute the probability of a binary configuration x
174+ Args:
175+ x: binary configuration array
176+ Returns:
177+ (float): the probability of sampling x according to the ising distribution
178+ """
179+
180+ if not (hasattr (self , 'Z' )):
181+ raise Exception ('probability requires partition fuction to have been computed' )
185182
186- def probability (self , spin_configuration : ndarray ) -> float :
187183 return (
188- jnp .exp (- energy (spin_configuration , self .J , self .b , self .J_sparse ) / self .T )
184+ jnp .exp (- energy (x , self .J , self .b , self .J_sparse ) / self .T )
189185 / self .Z
190186 )
191187
192- def generate_isometric_ising (
193- num_samples : int = 100 , T : float = 2.5 , grid_size : int = 4
194- ) -> (ndarray , None ):
195- num_spins = grid_size * grid_size
196- num_chains = 2
197- num_steps = 1000
198- J = create_isotropic_interaction_matrix (grid_size )
199- model = IsingSpins (num_spins , J , b = 1.0 , T = T )
200- # Plot the magnetization and energy trajectories for a single T
201- samples = model .sample (num_samples * num_steps , num_chains = num_chains , num_warmup = 10000 , key = 0 )
202- return samples [- num_samples :], None
188+ def generate_ising (N : int ,
189+ num_samples : int ,
190+ J : jnp .array ,
191+ b : jnp .array ,
192+ T : float ,
193+ sparse = False ,
194+ num_chains = 1 ,
195+ thinning = 1 ,
196+ num_warmup = 1000 ,
197+ key = 42 ):
198+ r"""
199+ Generating function for ising datasets.
200+
201+ The dataset is generated by sampling an ising distrbution of a specified interaction
202+ matrix. The distribution is sampled via markov chain Monte Carlo via the Metrolopis Hastings
203+ algorithm.
204+
205+ In the case of perfect sampling, a spin configuration s is sampled with probabability
206+ :math:`p(s)=exp(-H(s)/T)`, where the energy :math:`H(s)=\sum_{i\neq j}s_i s_i J_{ij}+\sum_i b_i s_i`
207+ corresponds to an ising Hamiltonian and configurations s are :math:`\pm1` valued.
208+
209+ The final sampled configurations are converted from a :math:`\pm1` representation to to a binary
210+ representation via x = (s+1)//2.
211+
212+ Note that in order to use parallelization, the number of avaliable cores has to be specified explicitly
213+ to numpyro. i.e. the line `numpyro.set_host_device_count(num_cores)` should appear before running the
214+ generator, where num_cores is the number of avaliable CPU cores you want to use.
215+
216+ N (int): Number of spins
217+ num_samples (int): total number of samples to generate per chain
218+ J (np.array): interaction matrix of shape (N,N)
219+ b (np.array): bias array of shape (N,)
220+ T (float): temperature
221+ num_chains (int): number of chains, defaults to 1.
222+ thinning (int): how much to thin the sampling. e.g. if thinning = 10 a sample will be drawn after each
223+ 10 steps of mcmc sampling. Larger numbers result in more unbiased samples.
224+ num_warmup (int): number of mcmc 'burn in' steps to perform before collecting any samples.
225+ key (int): random seed used to initialize sampling.
226+ sparse (bool): If true, J is converted to a sparse representation (faster for sparse Hamiltonians)
227+
228+ Returns:
229+ Array of data samples, and Nonetype object (since there are no labels)
230+ """
231+
232+ sampler = IsingSpins (N , J , b , T , sparse = sparse , compute_partition_fn = False )
233+ samples = sampler .sample (num_samples , num_chains = num_chains , thinning = thinning , num_warmup = num_warmup , key = key )
234+ return samples , None
0 commit comments