Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Feb 12, 2024
1 parent 532782a commit dd23a63
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 24 deletions.
53 changes: 33 additions & 20 deletions phi/geom/_proximity_graph.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
from typing import Dict, Tuple, Any, Union

from phiml.math import is_sparse, tensor_like
from ._geom import Geometry
from ._sphere import sphere_radius_with_same_volume
from .. import math
from ..math import Tensor, pairwise_distances, vec_length, Shape, non_channel, dual, where, PI


class ProximityGraph(Geometry):

def __init__(self, nodes: Geometry, boundary: Dict[str, Dict[str, slice]], kernel: str, format='dense'):
def __init__(self, nodes: Geometry, boundary: Dict[str, Dict[str, slice]], kernel: str, target_num_neighbors: Tensor = None, format='dense'):
assert isinstance(nodes, Geometry), f"nodes must be a Geometry instance but got {type(nodes)}"
self._nodes = nodes
self._boundary = boundary
self._kernel = kernel
self._target_num_neighbors = default_target_num_neighbors(kernel) if target_num_neighbors is None else target_num_neighbors
self._format = format
self._boundary = boundary
self._deltas = None
self._distances = None
max_distance = get_kernel_cutoff(self._kernel, self.element_size)
self._deltas = pairwise_distances(self.center, max_distance, format=self._format)
self._distances = vec_length(self._deltas)
if is_sparse(self._deltas):
self._connectivity = tensor_like(self._deltas, True)
else:
self._connectivity = self._distances > 0

@property
def nodes(self):
Expand All @@ -31,6 +40,14 @@ def format(self):
def center(self) -> Tensor:
return self._nodes.center

@property
def volume(self) -> Tensor:
return self._nodes.volume

@property
def radius(self):
return sphere_radius_with_same_volume(self._nodes)

@property
def boundary_elements(self) -> Dict[str, Dict[str, slice]]:
return self._boundary
Expand All @@ -40,24 +57,19 @@ def boundary_faces(self) -> Dict[str, Tuple[Dict[str, slice], Dict[str, slice]]]
return {key: {'~' + dim: s for dim, s in slices.items()} for key, slices in self._boundary_elements.items()}

@property
def distances(self) -> Tensor:
if self._distances is None:
self._distances = vec_length(self.deltas)
return self._distances
def connectivity(self) -> Tensor:
return self._connectivity

@property
def element_size(self):
return 2 * math.max(self._nodes.bounding_half_extent(), 'vector')
def distances(self) -> Tensor:
return self._distances

@property
def deltas(self) -> Tensor:
"""
Returns the pairwise position deltas between all elements as `Tensor`, possibly sparse depending on `format´.
The result has shape (elements, ~elements, vector).
"""
if self._deltas is None:
max_distance = get_kernel_cutoff(self._kernel, self.element_size)
self._deltas = pairwise_distances(self.center, max_distance, format=self._format)
return self._deltas

def __with_attrs__(self, **attrs): # Make sure cached distances are invalidated
Expand All @@ -71,10 +83,6 @@ def __variable_attrs__(self) -> Tuple[str, ...]:
def at(self, center: Tensor) -> 'Geometry':
return ProximityGraph(self._nodes.at(center), self._boundary, self._kernel, self._format)

@property
def volume(self) -> Tensor:
return self._nodes.volume

@property
def shape(self) -> Shape:
return self._nodes.shape
Expand Down Expand Up @@ -124,7 +132,12 @@ def __getitem__(self, item):
return ProximityGraph(self._nodes[item], self._boundary, self._kernel, self._format)


def get_kernel_cutoff(kernel: str, element_size):
DEFAULT_TARGET_NUM_NEIGHBORS = {
'quintic-spline':
}


def get_kernel_cutoff(kernel: str, element_radius):
"""
Returns the cut-off distance for a kernel given the element size.
Expand All @@ -137,9 +150,9 @@ def get_kernel_cutoff(kernel: str, element_size):
Cut-off distance as float or float `Tensor`
"""
if kernel == 'quintic-spline':
return 3. * element_size
return 3. * element_radius
elif kernel == 'wendland-c2':
return 2. * element_size
return 2. * element_radius
else:
raise ValueError(kernel)

Expand Down
18 changes: 14 additions & 4 deletions phi/geom/_sphere.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Union, Dict, Tuple

from phi import math
from phiml.math import Shape, dual
from phiml.math import Shape, dual, PI
from ._geom import Geometry, _keep_vector, NO_GEOMETRY
from ..math import wrap, Tensor, expand
from ..math.magic import slicing_dict
Expand Down Expand Up @@ -54,11 +54,11 @@ def volume(self) -> math.Tensor:
if self.spatial_rank == 1:
return 2 * self._radius
elif self.spatial_rank == 2:
return math.PI * self._radius ** 2
return PI * self._radius ** 2
elif self.spatial_rank == 3:
return 4 / 3 * math.PI * self._radius ** 3
return 4 / 3 * PI * self._radius ** 3
else:
raise NotImplementedError()
raise NotImplementedError(f"Only spatial ranks up to 3 supported but Sphere has rank {self.spatial_rank}")
# n = self.spatial_rank
# return math.pi ** (n // 2) / math.faculty(math.ceil(n / 2)) * self._radius ** n

Expand Down Expand Up @@ -138,3 +138,13 @@ def boundary_faces(self) -> Dict[str, Tuple[Dict[str, slice], Dict[str, slice]]]
@property
def face_shape(self) -> Shape:
return self.shape.without('vector') & dual(shell=0)


def sphere_radius_with_same_volume(g: Geometry):
if g.spatial_rank == 1:
return g.volume / 2
elif g.spatial_rank == 2:
return math.sqrt(g.volume / PI)
elif g.spatial_rank == 3:
return (g.volume / (4 / 3 * PI)) ** (1/3)
raise NotImplementedError(f"Only spatial ranks up to 3 supported but got {g}")

0 comments on commit dd23a63

Please sign in to comment.