8
8
from jax import random
9
9
from collections import namedtuple
10
10
from numpyro .infer .mcmc import MCMCKernel
11
- from qgml .data import SpinConfigurationGeneratorBase
12
11
from tqdm .auto import tqdm
13
12
14
- def create_isotropic_interaction_matrix (grid_size : int ):
15
- """Create an interaction matrix for a 2D isotropic square lattice."""
16
- J = jnp .zeros ((grid_size * grid_size , grid_size * grid_size ))
17
-
18
- for i in range (grid_size ):
19
- for j in range (grid_size ):
20
- # Spin index in the grid
21
- idx = i * grid_size + j
22
-
23
- # Calculate the indices of the neighbors
24
- right_idx = i * grid_size + (j + 1 ) % grid_size
25
- left_idx = i * grid_size + (j - 1 ) % grid_size
26
- bottom_idx = ((i + 1 ) % grid_size ) * grid_size + j
27
- top_idx = ((i - 1 ) % grid_size ) * grid_size + j
28
-
29
- # Set the interactions, ensuring each pair is only added once
30
- J = J .at [idx , right_idx ].set (1 )
31
- J = J .at [idx , left_idx ].set (1 )
32
- J = J .at [idx , bottom_idx ].set (1 )
33
- J = J .at [idx , top_idx ].set (1 )
34
- return J
35
-
36
-
37
13
@jax .jit
38
14
def energy (s , J , b , J_sparse = None ):
39
15
"""Calculate the Ising energy. For sparse Hamiltonians, it is recommneded to supply a list of nonzero indices of
@@ -51,7 +27,6 @@ def energy(s, J, b, J_sparse=None):
51
27
else :
52
28
return - jnp .einsum ("i,j,ij->" , s , s , J ) / 2.0 - jnp .dot (s , b )
53
29
54
-
55
30
def initialize_spins (rng_key , num_spins , num_chains ):
56
31
if num_chains == 1 :
57
32
spins = random .bernoulli (rng_key , 0.5 , (num_spins ,))
@@ -119,11 +94,19 @@ def mh_step(i, val):
119
94
return MHState (spins , rng_key )
120
95
121
96
122
- # Define the Ising model class
123
- class IsingSpins (SpinConfigurationGeneratorBase ):
124
- """
125
- class object used to generate datasets
126
- ArgsL
97
+ class IsingSpins :
98
+ r"""
99
+ class object used to generate datasets by sampling an ising distrbution of a specified interaction
100
+ matrix. The distribution is sampled via markov chain Monte Carlo via the Metrolopis Hastings
101
+ algorithm.
102
+
103
+ In the case of perfect sampling, a spin configuration s is sampled with probabability
104
+ :math:`p(s)=exp(-H(s)/T)`, where the energy :math:`H(s)=\sum_{i\neq j}s_i s_i J_{ij}+\sum_i b_i s_i`
105
+ corresponds to an ising Hamiltonian and configurations s are :math:`\pm1` valued.
106
+
107
+ The final sampled configurations are converted from a :math:`\pm1` representation to to a binary
108
+ representation via x = (s+1)//2.
109
+
127
110
N (int): Number of spins
128
111
J (np.array): interaction matrix
129
112
b (np.array): bias terms
@@ -134,14 +117,15 @@ class object used to generate datasets
134
117
def __init__ (
135
118
self , N : int , J : jnp .array , b : jnp .array , T : float , sparse = False , compute_partition_fn = False
136
119
) -> None :
137
- super ().__init__ (N )
120
+
121
+ self .N = N
138
122
self .kernel = MetropolisHastings ()
139
123
self .J = J
140
124
self .T = T
141
125
self .b = b
142
126
self .J_sparse = jnp .nonzero (J ) if sparse else None
143
127
144
- if compute_partition_fn :
128
+ if compute_partition_fn :
145
129
Z = 0
146
130
for i in tqdm (range (2 ** self .N ), desc = "Computing partition function" ):
147
131
lattice = (- 1 ) ** jnp .array (jnp .unravel_index (i , [2 ] * self .N ))
@@ -181,22 +165,70 @@ def sample(
181
165
J_sparse = self .J_sparse ,
182
166
)
183
167
samples = mcmc .get_samples ()
184
- return samples .reshape ((- 1 , self .N ))
168
+ samples .reshape ((- 1 , self .N ))
169
+ return (samples + 1 )// 2
170
+
171
+ def probability (self , x : ndarray ) -> float :
172
+ """
173
+ compute the probability of a binary configuration x
174
+ Args:
175
+ x: binary configuration array
176
+ Returns:
177
+ (float): the probability of sampling x according to the ising distribution
178
+ """
179
+
180
+ if not (hasattr (self , 'Z' )):
181
+ raise Exception ('probability requires partition fuction to have been computed' )
185
182
186
- def probability (self , spin_configuration : ndarray ) -> float :
187
183
return (
188
- jnp .exp (- energy (spin_configuration , self .J , self .b , self .J_sparse ) / self .T )
184
+ jnp .exp (- energy (x , self .J , self .b , self .J_sparse ) / self .T )
189
185
/ self .Z
190
186
)
191
187
192
- def generate_isometric_ising (
193
- num_samples : int = 100 , T : float = 2.5 , grid_size : int = 4
194
- ) -> (ndarray , None ):
195
- num_spins = grid_size * grid_size
196
- num_chains = 2
197
- num_steps = 1000
198
- J = create_isotropic_interaction_matrix (grid_size )
199
- model = IsingSpins (num_spins , J , b = 1.0 , T = T )
200
- # Plot the magnetization and energy trajectories for a single T
201
- samples = model .sample (num_samples * num_steps , num_chains = num_chains , num_warmup = 10000 , key = 0 )
202
- return samples [- num_samples :], None
188
+ def generate_ising (N : int ,
189
+ num_samples : int ,
190
+ J : jnp .array ,
191
+ b : jnp .array ,
192
+ T : float ,
193
+ sparse = False ,
194
+ num_chains = 1 ,
195
+ thinning = 1 ,
196
+ num_warmup = 1000 ,
197
+ key = 42 ):
198
+ r"""
199
+ Generating function for ising datasets.
200
+
201
+ The dataset is generated by sampling an ising distrbution of a specified interaction
202
+ matrix. The distribution is sampled via markov chain Monte Carlo via the Metrolopis Hastings
203
+ algorithm.
204
+
205
+ In the case of perfect sampling, a spin configuration s is sampled with probabability
206
+ :math:`p(s)=exp(-H(s)/T)`, where the energy :math:`H(s)=\sum_{i\neq j}s_i s_i J_{ij}+\sum_i b_i s_i`
207
+ corresponds to an ising Hamiltonian and configurations s are :math:`\pm1` valued.
208
+
209
+ The final sampled configurations are converted from a :math:`\pm1` representation to to a binary
210
+ representation via x = (s+1)//2.
211
+
212
+ Note that in order to use parallelization, the number of avaliable cores has to be specified explicitly
213
+ to numpyro. i.e. the line `numpyro.set_host_device_count(num_cores)` should appear before running the
214
+ generator, where num_cores is the number of avaliable CPU cores you want to use.
215
+
216
+ N (int): Number of spins
217
+ num_samples (int): total number of samples to generate per chain
218
+ J (np.array): interaction matrix of shape (N,N)
219
+ b (np.array): bias array of shape (N,)
220
+ T (float): temperature
221
+ num_chains (int): number of chains, defaults to 1.
222
+ thinning (int): how much to thin the sampling. e.g. if thinning = 10 a sample will be drawn after each
223
+ 10 steps of mcmc sampling. Larger numbers result in more unbiased samples.
224
+ num_warmup (int): number of mcmc 'burn in' steps to perform before collecting any samples.
225
+ key (int): random seed used to initialize sampling.
226
+ sparse (bool): If true, J is converted to a sparse representation (faster for sparse Hamiltonians)
227
+
228
+ Returns:
229
+ Array of data samples, and Nonetype object (since there are no labels)
230
+ """
231
+
232
+ sampler = IsingSpins (N , J , b , T , sparse = sparse , compute_partition_fn = False )
233
+ samples = sampler .sample (num_samples , num_chains = num_chains , thinning = thinning , num_warmup = num_warmup , key = key )
234
+ return samples , None
0 commit comments