diff --git a/src/weather_model_graphs/spherical_utils.py b/src/weather_model_graphs/spherical_utils.py new file mode 100644 index 0000000..ca55905 --- /dev/null +++ b/src/weather_model_graphs/spherical_utils.py @@ -0,0 +1,18 @@ +import torch + +def lat_lon_to_cartesian(lat: torch.Tensor, lon: torch.Tensor, radius: float = 1.0) -> torch.Tensor: + """ + Vectorized conversion from Latitude/Longitude (degrees) to 3D Cartesian coordinates. + Assumes lat is in [-90, 90] and lon is in [-180, 180]. + """ + # Convert degrees to radians + lat_rad = torch.deg2rad(lat) + lon_rad = torch.deg2rad(lon) + + # Calculate x, y, z components + x = radius * torch.cos(lat_rad) * torch.cos(lon_rad) + y = radius * torch.cos(lat_rad) * torch.sin(lon_rad) + z = radius * torch.sin(lat_rad) + + # Stack into a single tensor of shape (..., 3) + return torch.stack([x, y, z], dim=-1) \ No newline at end of file diff --git a/tests/test_spherical_utils.py b/tests/test_spherical_utils.py new file mode 100644 index 0000000..a621372 --- /dev/null +++ b/tests/test_spherical_utils.py @@ -0,0 +1,26 @@ +import torch +# We import from the new file you just created +from weather_model_graphs.spherical_utils import lat_lon_to_cartesian + +def test_north_pole_singularity(): + """Test that different longitudes at the North Pole converge to the same 3D point.""" + lat = torch.tensor([90.0, 90.0]) + lon = torch.tensor([0.0, 180.0]) + + coords = lat_lon_to_cartesian(lat, lon) + distance = torch.norm(coords[0] - coords[1]) + + # The physical distance between them in 3D space should be exactly 0 + assert torch.isclose(distance, torch.tensor(0.0), atol=1e-6), "Distance at the pole must be zero." + +def test_anti_meridian_crossing(): + """Test that points across the Date Line calculate physical distance correctly.""" + lat = torch.tensor([0.0, 0.0]) + lon = torch.tensor([179.0, -179.0]) + + coords = lat_lon_to_cartesian(lat, lon) + distance = torch.norm(coords[0] - coords[1]) + + # 2 degrees apart on a unit sphere (chord length) + expected_distance = 2 * torch.sin(torch.tensor(torch.pi / 180.0)) + assert torch.isclose(distance, expected_distance, atol=1e-5), "Anti-meridian distance calculation failed." \ No newline at end of file