Skip to content

Commit a43c901

Browse files
committed
clean up
1 parent c81b0b1 commit a43c901

File tree

3 files changed

+139
-61
lines changed

3 files changed

+139
-61
lines changed

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ pyyaml~=6.0
1010
pennyLane~=0.34
1111
scipy~=1.11
1212
pandas~=2.2
13+
numpyro~=0.14.0

src/qml_benchmarks/data/ising.py

+78-46
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,8 @@
88
from jax import random
99
from collections import namedtuple
1010
from numpyro.infer.mcmc import MCMCKernel
11-
from qgml.data import SpinConfigurationGeneratorBase
1211
from tqdm.auto import tqdm
1312

14-
def create_isotropic_interaction_matrix(grid_size: int):
15-
"""Create an interaction matrix for a 2D isotropic square lattice."""
16-
J = jnp.zeros((grid_size * grid_size, grid_size * grid_size))
17-
18-
for i in range(grid_size):
19-
for j in range(grid_size):
20-
# Spin index in the grid
21-
idx = i * grid_size + j
22-
23-
# Calculate the indices of the neighbors
24-
right_idx = i * grid_size + (j + 1) % grid_size
25-
left_idx = i * grid_size + (j - 1) % grid_size
26-
bottom_idx = ((i + 1) % grid_size) * grid_size + j
27-
top_idx = ((i - 1) % grid_size) * grid_size + j
28-
29-
# Set the interactions, ensuring each pair is only added once
30-
J = J.at[idx, right_idx].set(1)
31-
J = J.at[idx, left_idx].set(1)
32-
J = J.at[idx, bottom_idx].set(1)
33-
J = J.at[idx, top_idx].set(1)
34-
return J
35-
36-
3713
@jax.jit
3814
def energy(s, J, b, J_sparse=None):
3915
"""Calculate the Ising energy. For sparse Hamiltonians, it is recommneded to supply a list of nonzero indices of
@@ -51,7 +27,6 @@ def energy(s, J, b, J_sparse=None):
5127
else:
5228
return -jnp.einsum("i,j,ij->", s, s, J) / 2.0 - jnp.dot(s, b)
5329

54-
5530
def initialize_spins(rng_key, num_spins, num_chains):
5631
if num_chains == 1:
5732
spins = random.bernoulli(rng_key, 0.5, (num_spins,))
@@ -119,11 +94,19 @@ def mh_step(i, val):
11994
return MHState(spins, rng_key)
12095

12196

122-
# Define the Ising model class
123-
class IsingSpins(SpinConfigurationGeneratorBase):
124-
"""
125-
class object used to generate datasets
126-
ArgsL
97+
class IsingSpins:
98+
r"""
99+
class object used to generate datasets by sampling an ising distrbution of a specified interaction
100+
matrix. The distribution is sampled via markov chain Monte Carlo via the Metrolopis Hastings
101+
algorithm.
102+
103+
In the case of perfect sampling, a spin configuration s is sampled with probabability
104+
:math:`p(s)=exp(-H(s)/T)`, where the energy :math:`H(s)=\sum_{i\neq j}s_i s_i J_{ij}+\sum_i b_i s_i`
105+
corresponds to an ising Hamiltonian and configurations s are :math:`\pm1` valued.
106+
107+
The final sampled configurations are converted from a :math:`\pm1` representation to to a binary
108+
representation via x = (s+1)//2.
109+
127110
N (int): Number of spins
128111
J (np.array): interaction matrix
129112
b (np.array): bias terms
@@ -134,14 +117,15 @@ class object used to generate datasets
134117
def __init__(
135118
self, N: int, J: jnp.array, b: jnp.array, T: float, sparse=False, compute_partition_fn=False
136119
) -> None:
137-
super().__init__(N)
120+
121+
self.N = N
138122
self.kernel = MetropolisHastings()
139123
self.J = J
140124
self.T = T
141125
self.b = b
142126
self.J_sparse = jnp.nonzero(J) if sparse else None
143127

144-
if compute_partition_fn:
128+
if compute_partition_fn:
145129
Z = 0
146130
for i in tqdm(range(2**self.N), desc="Computing partition function"):
147131
lattice = (-1) ** jnp.array(jnp.unravel_index(i, [2] * self.N))
@@ -181,22 +165,70 @@ def sample(
181165
J_sparse=self.J_sparse,
182166
)
183167
samples = mcmc.get_samples()
184-
return samples.reshape((-1, self.N))
168+
samples.reshape((-1, self.N))
169+
return (samples+1)//2
170+
171+
def probability(self, x: ndarray) -> float:
172+
"""
173+
compute the probability of a binary configuration x
174+
Args:
175+
x: binary configuration array
176+
Returns:
177+
(float): the probability of sampling x according to the ising distribution
178+
"""
179+
180+
if not(hasattr(self, 'Z')):
181+
raise Exception('probability requires partition fuction to have been computed')
185182

186-
def probability(self, spin_configuration: ndarray) -> float:
187183
return (
188-
jnp.exp(-energy(spin_configuration, self.J, self.b, self.J_sparse) / self.T)
184+
jnp.exp(-energy(x, self.J, self.b, self.J_sparse) / self.T)
189185
/ self.Z
190186
)
191187

192-
def generate_isometric_ising(
193-
num_samples: int = 100, T: float = 2.5, grid_size: int = 4
194-
) -> (ndarray, None):
195-
num_spins = grid_size * grid_size
196-
num_chains = 2
197-
num_steps = 1000
198-
J = create_isotropic_interaction_matrix(grid_size)
199-
model = IsingSpins(num_spins, J, b=1.0, T=T)
200-
# Plot the magnetization and energy trajectories for a single T
201-
samples = model.sample(num_samples*num_steps, num_chains=num_chains, num_warmup=10000, key=0)
202-
return samples[-num_samples:], None
188+
def generate_ising(N: int,
189+
num_samples: int,
190+
J: jnp.array,
191+
b: jnp.array,
192+
T: float,
193+
sparse=False,
194+
num_chains=1,
195+
thinning=1,
196+
num_warmup=1000,
197+
key=42):
198+
r"""
199+
Generating function for ising datasets.
200+
201+
The dataset is generated by sampling an ising distrbution of a specified interaction
202+
matrix. The distribution is sampled via markov chain Monte Carlo via the Metrolopis Hastings
203+
algorithm.
204+
205+
In the case of perfect sampling, a spin configuration s is sampled with probabability
206+
:math:`p(s)=exp(-H(s)/T)`, where the energy :math:`H(s)=\sum_{i\neq j}s_i s_i J_{ij}+\sum_i b_i s_i`
207+
corresponds to an ising Hamiltonian and configurations s are :math:`\pm1` valued.
208+
209+
The final sampled configurations are converted from a :math:`\pm1` representation to to a binary
210+
representation via x = (s+1)//2.
211+
212+
Note that in order to use parallelization, the number of avaliable cores has to be specified explicitly
213+
to numpyro. i.e. the line `numpyro.set_host_device_count(num_cores)` should appear before running the
214+
generator, where num_cores is the number of avaliable CPU cores you want to use.
215+
216+
N (int): Number of spins
217+
num_samples (int): total number of samples to generate per chain
218+
J (np.array): interaction matrix of shape (N,N)
219+
b (np.array): bias array of shape (N,)
220+
T (float): temperature
221+
num_chains (int): number of chains, defaults to 1.
222+
thinning (int): how much to thin the sampling. e.g. if thinning = 10 a sample will be drawn after each
223+
10 steps of mcmc sampling. Larger numbers result in more unbiased samples.
224+
num_warmup (int): number of mcmc 'burn in' steps to perform before collecting any samples.
225+
key (int): random seed used to initialize sampling.
226+
sparse (bool): If true, J is converted to a sparse representation (faster for sparse Hamiltonians)
227+
228+
Returns:
229+
Array of data samples, and Nonetype object (since there are no labels)
230+
"""
231+
232+
sampler = IsingSpins(N, J, b, T, sparse=sparse, compute_partition_fn=False)
233+
samples = sampler.sample(num_samples, num_chains=num_chains, thinning=thinning, num_warmup=num_warmup, key=key)
234+
return samples, None

src/qml_benchmarks/data/spin_blobs.py

+60-15
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,23 @@
1616

1717
import numpy as np
1818

19-
2019
class RandomSpinBlobs:
21-
"""Generate spin configurations with high probabilites for certain spins.
22-
23-
The dataset is generated by creating random spin samples close to a few
24-
chosen `peak_spin` configurations of dimension `N` with each spin having
25-
the possible values 0 or 1. We can vary the `peak_probabilities` parameter
26-
to create data with different modes, where some samples will have higher
27-
probabilities allowing us to study the effects of imbalance in the data.
20+
"""
21+
Class object used to generate spin blob datasets: a binary analog of the
22+
'gaussian blobs' dataset, in which bitstrings are sampled close in Hamming
23+
distance to a set of specified configurations.
2824
29-
Samples are generated by selecting one of the peak spin configurations
30-
distributed according `peak_probabilities`, and then by flipping some of the
31-
spins. The number of spins that are flipped each time, is drawn from a
32-
Binomial distribution bin(`N`, `p`) where `p=1` will flip all the spins
33-
and `p=0` will not flip any spins therefore creating very narrow distributions
34-
around the peak spins.
25+
The dataset is generated by specifying a list of configurations (peak spins)
26+
that mark the centre of the 'blobs'. Data points are sampled by chosing one of
27+
the peak spins (with probabilities specified by peak probabilities), and then
28+
flipping some of the bits. Each bit is flipped with probability specified by
29+
p, so that (for small p) datapoints are close in Hamming distance to one of
30+
the peak probabilities.
3531
3632
Args:
3733
N (int): The number of spins.
3834
num_blobs (int):
39-
The number of blobs or peak probabilities.
35+
The number of blobs.
4036
peak_probabilities (list[float], optional):
4137
The probability of each spin to be selected. If not specified,
4238
the probabilities are distributed uniformly.
@@ -56,6 +52,7 @@ def __init__(
5652
peak_spins: list[np.array] = None,
5753
p: float = 0.01,
5854
) -> None:
55+
5956
self.N = N
6057
self.num_blobs = num_blobs
6158

@@ -122,6 +119,54 @@ def sample(self, num_samples: int, return_labels=False) -> np.array:
122119
else:
123120
return samples
124121

122+
def generate_spin_blobs(N: int, num_blobs: int, num_samples:int, peak_probabilities: list[float] = None, peak_spins: list[np.array] = None,
123+
p: float = 0.01):
124+
125+
"""
126+
Generator function for spin blob datasets: a binary analog of the
127+
'gaussian blobs' dataset, in which bitstrings are sampled close in Hamming
128+
distance to a set of specified configurations.
129+
130+
The dataset is generated by specifying a list of configurations (peak spins)
131+
that mark the centre of the 'blobs'. Data points are sampled by chosing one of
132+
the peak spins (with probabilities specified by peak probabilities), and then
133+
flipping some of the bits. Each bit is flipped with probability specified by
134+
p, so that (for small p) datapoints are close in Hamming distance to one of
135+
the peak probabilities.
136+
137+
Args:
138+
N (int): The number of spins.
139+
num_blobs (int):
140+
The number of blobs.
141+
num_samples (int): The number of samples to generate.
142+
peak_probabilities (list[float], optional):
143+
The probability of each spin to be selected. If not specified,
144+
the probabilities are distributed uniformly.
145+
peak_spins (list[np.array], optional):
146+
The peak spin configurations. Selected randomly by default.
147+
p (float, optional):
148+
The value of the parameter `p` in a Binomial distribution specifying
149+
the number of spins that are flipped each time during sampling.
150+
Defaults to 0.01.
151+
152+
Returns:
153+
tuple(np.ndarray): Dataset array and label array specifying the peak spin
154+
that was used to sample each datapoint.
155+
"""
156+
157+
sampler = RandomSpinBlobs(
158+
N=N,
159+
num_blobs=num_blobs,
160+
peak_probabilities=peak_probabilities,
161+
peak_spins=peak_spins,
162+
p=p,
163+
)
164+
165+
X, y = sampler.sample(num_samples=num_samples, return_labels=True)
166+
X = X.reshape(-1, N)
167+
168+
return X, y
169+
125170

126171
def generate_8blobs(
127172
num_samples: int,

0 commit comments

Comments
 (0)