Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 15 additions & 21 deletions makani/models/networks/pangu.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,7 @@ def forward(self, x: torch.Tensor, mask=None):
x: input features with shape of (B * num_lon, num_pl*num_lat, N, C)
mask: (0/-inf) mask with shape of (num_lon, num_pl*num_lat, Wpl*Wlat*Wlon, Wpl*Wlat*Wlon)
"""

B_, nW_, N, C = x.shape
qkv = (
self.qkv(x)
Expand All @@ -478,18 +478,18 @@ def forward(self, x: torch.Tensor, mask=None):
attn = self.attn_drop_fn(attn)

x = self.apply_attention(attn, v, B_, nW_, N, C)

else:
if mask is not None:
bias = mask.unsqueeze(1).unsqueeze(0) + earth_position_bias.unsqueeze(0).unsqueeze(0)
# squeeze the bias if needed in dim 2
#bias = bias.squeeze(2)
else:
bias = earth_position_bias.unsqueeze(0)

# extract batch size for q,k,v
nLon = self.num_lon
q = q.view(B_ // nLon, nLon, q.shape[1], q.shape[2], q.shape[3], q.shape[4])
q = q.view(B_ // nLon, nLon, q.shape[1], q.shape[2], q.shape[3], q.shape[4])
k = k.view(B_ // nLon, nLon, k.shape[1], k.shape[2], k.shape[3], k.shape[4])
v = v.view(B_ // nLon, nLon, v.shape[1], v.shape[2], v.shape[3], v.shape[4])
####
Expand Down Expand Up @@ -736,7 +736,7 @@ class Pangu(nn.Module):
- https://arxiv.org/abs/2211.02556
"""

def __init__(self,
def __init__(self,
inp_shape=(721,1440),
out_shape=(721,1440),
grid_in="equiangular",
Expand Down Expand Up @@ -773,14 +773,14 @@ def __init__(self,
self.checkpointing_level = checkpointing_level

drop_path = np.linspace(0, drop_path_rate, 8).tolist()

# Add static channels to surface
self.num_aux = len(self.aux_channel_names)
N_total_surface = self.num_aux + self.num_surface

# compute static permutations to extract
self._precompute_channel_groups(self.channel_names, self.aux_channel_names)

# Patch embeddings are 2D or 3D convolutions, mapping the data to the required patches
self.patchembed2d = PatchEmbed2D(
img_size=self.inp_shape,
Expand All @@ -791,7 +791,7 @@ def __init__(self,
flatten=False,
norm_layer=None,
)

self.patchembed3d = PatchEmbed3D(
img_size=(num_levels, self.inp_shape[0], self.inp_shape[1]),
patch_size=patch_size,
Expand Down Expand Up @@ -870,7 +870,7 @@ def __init__(self,
self.patchrecovery3d = PatchRecovery3D(
(num_levels, self.inp_shape[0], self.inp_shape[1]), patch_size, 2 * embed_dim, num_atmospheric
)

def _precompute_channel_groups(
self,
channel_names=[],
Expand Down Expand Up @@ -901,7 +901,7 @@ def _precompute_channel_groups(

def prepare_input(self, input):
"""
Prepares the input tensor for the Pangu model by splitting it into surface * static variables and atmospheric,
Prepares the input tensor for the Pangu model by splitting it into surface * static variables and atmospheric,
and reshaping the atmospheric variables into the required format.
"""

Expand Down Expand Up @@ -932,13 +932,13 @@ def prepare_output(self, output_surface, output_atmospheric):
level_dict = {level: [idx for idx, value in enumerate(self.channel_names) if value[1:] == level] for level in levels}
reordered_ids = [idx for level in levels for idx in level_dict[level]]
check_reorder = [f'{level}_{idx}' for level in levels for idx in level_dict[level]]

# Flatten & reorder the output atmospheric to original order (doublechecked that this is working correctly!)
flattened_atmospheric = output_atmospheric.reshape(output_atmospheric.shape[0], -1, output_atmospheric.shape[3], output_atmospheric.shape[4])
reordered_atmospheric = torch.cat([torch.zeros_like(output_surface), torch.zeros_like(flattened_atmospheric)], dim=1)
for i in range(len(reordered_ids)):
reordered_atmospheric[:, reordered_ids[i], :, :] = flattened_atmospheric[:, i, :, :]

# Append the surface output, this has not been reordered.
if output_surface is not None:
_, surf_chans, _, _ = features.get_channel_groups(self.channel_names, self.aux_channel_names)
Expand All @@ -948,7 +948,7 @@ def prepare_output(self, output_surface, output_atmospheric):
output = reordered_atmospheric

return output

def forward(self, input):

# Prep the input by splitting into surface and atmospheric variables
Expand All @@ -959,7 +959,7 @@ def forward(self, input):
surface = checkpoint(self.patchembed2d, surface_aux, use_reentrant=False)
atmospheric = checkpoint(self.patchembed3d, atmospheric, use_reentrant=False)
else:
surface = self.patchembed2d(surface_aux)
surface = self.patchembed2d(surface_aux)
atmospheric = self.patchembed3d(atmospheric)

if surface.shape[1] == 0:
Expand Down Expand Up @@ -1011,11 +1011,5 @@ def forward(self, input):
output_atmospheric = self.patchrecovery3d(output_atmospheric)

output = self.prepare_output(output_surface, output_atmospheric)

return output






return output
12 changes: 6 additions & 6 deletions makani/models/networks/pangu_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class PanguOnnx(OnnxWrapper):
channel_order_PL: List containing the names of the pressure levels with the ordering that the ONNX model expects
onnx_file: Path to the ONNX file containing the model
'''
def __init__(self,
def __init__(self,
channel_names=[],
aux_channel_names=[],
onnx_file=None,
Expand Down Expand Up @@ -78,12 +78,12 @@ def prepare_input(self, input):
B,V,Lat,Long=input.shape

if B>1:
raise NotImplementedError("Not implemented yet for batch size greater than 1")
raise NotImplementedError("Not implemented yet for batch size greater than 1")

input=input.squeeze(0)
surface_aux_inp=input[self.surf_channels]
atmospheric_inp=input[self.atmo_channels].reshape(self.n_atmo_groups,self.n_atmo_chans,Lat,Long).transpose(1,0)

return surface_aux_inp, atmospheric_inp

def prepare_output(self, output_surface, output_atmospheric):
Expand All @@ -99,15 +99,15 @@ def prepare_output(self, output_surface, output_atmospheric):

return output.unsqueeze(0)


def forward(self, input):

surface, atmospheric = self.prepare_input(input)


output,output_surface=self.onnx_session_run({'input':atmospheric,'input_surface':surface})

output = self.prepare_output(output_surface, output)


return output
158 changes: 121 additions & 37 deletions makani/models/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,54 @@
from physicsnemo.distributed.utils import split_tensor_along_dim, compute_split_shapes


class BaseNoiseS2(nn.Module):
class BaseNoise(nn.Module):
def __init__(self, seed=333, **kwargs):
super().__init__()
self.set_rng(seed=seed)

def set_rng(self, seed=333):
self.rng_cpu = torch.Generator(device=torch.device("cpu"))
self.rng_cpu.manual_seed(seed)
if torch.cuda.is_available():
self.rng_gpu = torch.Generator(device=torch.device(f"cuda:{comm.get_local_rank()}"))
self.rng_gpu.manual_seed(seed)

def reset(self, batch_size=None):
if hasattr(self, "state") and self.state is not None:
if batch_size is not None:
# This assumes self.state is defined in the derived class with correct shape logic
# For BaseNoiseS2 and others, specific reset logic might still be needed or this needs to be generic
# We'll leave the generic implementation to the derived classes or implement a helper if shape is known
pass

with torch.no_grad():
self.state.fill_(0.0)

def set_rng_state(self, cpu_state, gpu_state):
if cpu_state is not None:
self.rng_cpu.set_state(cpu_state)
if torch.cuda.is_available() and (gpu_state is not None):
self.rng_gpu.set_state(gpu_state)

def get_rng_state(self):
cpu_state = self.rng_cpu.get_state()
gpu_state = None
if torch.cuda.is_available():
gpu_state = self.rng_gpu.get_state()
return cpu_state, gpu_state

def get_tensor_state(self):
if hasattr(self, "state"):
return self.state.detach().clone()
return None

def set_tensor_state(self, newstate):
if hasattr(self, "state"):
with torch.no_grad():
self.state.copy_(newstate)


class BaseNoiseS2(BaseNoise):
def __init__(
self,
img_shape,
Expand All @@ -43,7 +90,7 @@ def __init__(
Abstract base class for noise on the sphere. Initializes the inverse SHT needed by many of the
noise classes. Derived noise classes can be stateful or stateless.
"""
super().__init__()
super().__init__(seed=seed)

# Number of latitudinal modes.
self.nlat, self.nlon = img_shape
Expand Down Expand Up @@ -72,22 +119,12 @@ def __init__(
self.lmax = self.isht.lmax
self.mmax = self.isht.mmax

# generator objects:
self.set_rng(seed=seed)

# store the noise state: initialize to None
self.register_buffer("state", torch.zeros((batch_size, self.num_time_steps, self.num_channels, self.lmax_local, self.mmax_local, 2), dtype=torch.float32), persistent=False)

def is_stateful(self):
raise NotImplementedError("is_stateful method not implemented for this noise class")

def set_rng(self, seed=333):
self.rng_cpu = torch.Generator(device=torch.device("cpu"))
self.rng_cpu.manual_seed(seed)
if torch.cuda.is_available():
self.rng_gpu = torch.Generator(device=torch.device(f"cuda:{comm.get_local_rank()}"))
self.rng_gpu.manual_seed(seed)

# Resets the internal state. Can be used to change the batch size if required.
def reset(self, batch_size=None):
if self.state is not None:
Expand All @@ -100,7 +137,7 @@ def reset(self, batch_size=None):

# this routine generates a noise sample for a single time step and updates the state accordingly, by appending the last time step
def update(self, replace_state=False, batch_size=None):
# Update should always create a new state, so
# Update should always create a new state, so
# we don't need to check for replace_state
# create single occurence
with torch.no_grad():
Expand All @@ -122,30 +159,6 @@ def update(self, replace_state=False, batch_size=None):

return

def set_rng_state(self, cpu_state, gpu_state):
if cpu_state is not None:
self.rng_cpu.set_state(cpu_state)
if torch.cuda.is_available() and (gpu_state is not None):
self.rng_gpu.set_state(gpu_state)

return

def get_rng_state(self):
cpu_state = self.rng_cpu.get_state()
gpu_state = None
if torch.cuda.is_available():
gpu_state = self.rng_gpu.get_state()

return cpu_state, gpu_state

def get_tensor_state(self):
return self.state.detach().clone()

def set_tensor_state(self, newstate):
with torch.no_grad():
self.state.copy_(newstate)
return


class IsotropicGaussianRandomFieldS2(BaseNoiseS2):
def __init__(
Expand Down Expand Up @@ -518,3 +531,74 @@ def forward(self, update_internal_state=False):
self.update()

return state

class GaussianVectorNoise(BaseNoise):
def __init__(
self,
img_shape,
batch_size,
num_channels,
num_time_steps=1,
sigma=1.0,
seed=333,
**kwargs,
):
r"""
Gaussian noise vector in R^d.

Parameters
============
img_shape : (int, int)
Ignored, kept for compatibility.
batch_size: int
Batch size for the noise
num_channels: int
Number of channels (dimension of the vector)
num_time_steps: int
Number of time steps
sigma : float, default is 1.0
Standard deviation
"""
super().__init__(seed=seed)

self.num_channels = num_channels
self.num_time_steps = num_time_steps
self.sigma = sigma

# State: (B, T, C, 1, 1)
self.register_buffer("state", torch.zeros((batch_size, self.num_time_steps, self.num_channels, 1, 1), dtype=torch.float32), persistent=False)

def is_stateful(self):
return False

def reset(self, batch_size=None):
if self.state is not None:
if batch_size is not None:
self.state = torch.zeros(batch_size, self.num_time_steps, self.num_channels, 1, 1, dtype=self.state.dtype, device=self.state.device)
with torch.no_grad():
self.state.fill_(0.0)

def update(self, replace_state=False, batch_size=None):
with torch.no_grad():
if batch_size is None:
batch_size = self.state.shape[0]

# Generate new noise
newstate = torch.empty((batch_size, self.num_time_steps, self.num_channels, 1, 1), dtype=self.state.dtype, device=self.state.device)
if self.state.is_cuda:
newstate.normal_(mean=0.0, std=self.sigma, generator=self.rng_gpu)
else:
newstate.normal_(mean=0.0, std=self.sigma, generator=self.rng_cpu)

if newstate.shape == self.state.shape:
self.state.copy_(newstate)
else:
self.state = newstate
return

def forward(self, update_internal_state=False):

if update_internal_state:
self.update()

return self.state.clone()
Loading