diff --git a/.golangci.yaml b/.golangci.yaml index b42ba769..25c0fd79 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -35,8 +35,6 @@ linters-settings: errcheck: exclude-functions: - io.Copy(os.Stdout) - - (*github.com/peterstace/simplefeatures/rtree.RTree).RangeSearch - - (*github.com/peterstace/simplefeatures/rtree.RTree).PrioritySearch # NOTE: every linter supported by golangci-lint is either explicitly included # or excluded. @@ -79,7 +77,6 @@ linters: - importas - ineffassign - intrange - - ireturn - loggercheck - makezero - mirror @@ -143,6 +140,7 @@ linters: - gomnd - inamedparam - interfacebloat + - ireturn - lll - maintidx - nestif diff --git a/CHANGELOG.md b/CHANGELOG.md index 2455acda..c0e5ca0e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,19 @@ # Changelog +## Unreleased + +- **Breaking change:** The `rtree` package types and functions are now generic + over the record type. The `RTree` type is now `RTree[T]`, `BulkItem` is now + `BulkItem[T]`, and `BulkLoad` is now `BulkLoad[T]`. The `RecordID int` field + in `BulkItem` has been renamed to `Record T`. This allows users to store + their records directly in the tree rather than maintaining separate mappings + between integer IDs and records. Users can upgrade by adding type parameters + to their rtree usage (e.g., `RTree[int]` to maintain existing behavior with + integer IDs, or use a custom type like `RTree[MyRecord]` to store records + directly). The `RecordID` field in `BulkItem` should be renamed to `Record`, + and callback function signatures should change from `func(recordID int)` to + `func(record T)` where `T` is the type parameter. + ## v0.56.0 2025-11-21 @@ -15,7 +29,7 @@ assertions. - **Breaking change:** The minimum required Go version is now 1.18 (previously - 1.17). This is required to support the `any` keyword. + 1.17). This is required to support the `any` keyword and generics. ## v0.55.0 diff --git a/geom/alg_distance.go b/geom/alg_distance.go index 95148352..6add7e93 100644 --- a/geom/alg_distance.go +++ b/geom/alg_distance.go @@ -37,60 +37,42 @@ func Distance(g1, g2 Geometry) (float64, bool) { lns1, lns2 = lns2, lns1 } - tr := loadTree(xys2, lns2) + xyTree := loadXYTree(xys2) + lnTree := loadLineTree(lns2) minDist := math.Inf(+1) - searchBody := func( - env Envelope, - recordID int, - xyDist func(int) float64, - lnDist func(int) float64, - ) error { - // Convert recordID back to array indexes. - xyIdx := recordID - 1 - lnIdx := -recordID - 1 - - // Abort the search if we're gone further away compared to our best - // distance so far. - var recordEnv Envelope - if recordID > 0 { - recordEnv = xys2[xyIdx].uncheckedEnvelope() - } else { - recordEnv = lns2[lnIdx].uncheckedEnvelope() - } - if d, ok := recordEnv.Distance(env); ok && d > minDist { - return rtree.Stop - } - - // See if the current item in the tree is better than our current best - // distance. - if recordID > 0 { - minDist = fastMin(minDist, xyDist(xyIdx)) - } else { - minDist = fastMin(minDist, lnDist(lnIdx)) - } - return nil - } - for _, xy := range xys1 { - xyEnv := xy.uncheckedEnvelope() - tr.PrioritySearch(xy.box(), func(recordID int) error { - return searchBody( - xyEnv, - recordID, - func(i int) float64 { return distBetweenXYs(xy, xys2[i]) }, - func(i int) float64 { return distBetweenXYAndLine(xy, lns2[i]) }, - ) + for _, xy1 := range xys1 { + xy1Env := xy1.uncheckedEnvelope() + _ = xyTree.PrioritySearch(xy1.box(), func(xy2 XY) error { + if d, ok := xy2.uncheckedEnvelope().Distance(xy1Env); ok && d > minDist { + return rtree.Stop + } + minDist = fastMin(minDist, distBetweenXYs(xy1, xy2)) + return nil + }) + _ = lnTree.PrioritySearch(xy1.box(), func(ln2 line) error { + if d, ok := ln2.uncheckedEnvelope().Distance(xy1Env); ok && d > minDist { + return rtree.Stop + } + minDist = fastMin(minDist, distBetweenXYAndLine(xy1, ln2)) + return nil }) } - for _, ln := range lns1 { - lnEnv := ln.uncheckedEnvelope() - tr.PrioritySearch(ln.box(), func(recordID int) error { - return searchBody( - lnEnv, - recordID, - func(i int) float64 { return distBetweenXYAndLine(xys2[i], ln) }, - func(i int) float64 { return distBetweenLineAndLine(lns2[i], ln) }, - ) + for _, ln1 := range lns1 { + ln1Env := ln1.uncheckedEnvelope() + _ = xyTree.PrioritySearch(ln1.box(), func(xy2 XY) error { + if d, ok := xy2.uncheckedEnvelope().Distance(ln1Env); ok && d > minDist { + return rtree.Stop + } + minDist = fastMin(minDist, distBetweenXYAndLine(xy2, ln1)) + return nil + }) + _ = lnTree.PrioritySearch(ln1.box(), func(ln2 line) error { + if d, ok := ln2.uncheckedEnvelope().Distance(ln1Env); ok && d > minDist { + return rtree.Stop + } + minDist = fastMin(minDist, distBetweenLineAndLine(ln1, ln2)) + return nil }) } @@ -128,22 +110,23 @@ func extractXYsAndLines(g Geometry) ([]XY, []line) { } } -// loadTree creates a new RTree that indexes both the XYs and the lines. It -// uses positive record IDs to refer to the XYs, and negative recordIDs to -// refer to the lines. Because +0 and -0 are the same, indexing is 1-based and -// recordID 0 is not used. -func loadTree(xys []XY, lns []line) *rtree.RTree { - items := make([]rtree.BulkItem, len(xys)+len(lns)) +func loadXYTree(xys []XY) *rtree.RTree[XY] { + items := make([]rtree.BulkItem[XY], len(xys)) for i, xy := range xys { - items[i] = rtree.BulkItem{ - Box: xy.box(), - RecordID: i + 1, + items[i] = rtree.BulkItem[XY]{ + Box: xy.box(), + Record: xy, } } + return rtree.BulkLoad(items) +} + +func loadLineTree(lns []line) *rtree.RTree[line] { + items := make([]rtree.BulkItem[line], len(lns)) for i, ln := range lns { - items[i+len(xys)] = rtree.BulkItem{ - Box: ln.box(), - RecordID: -(i + 1), + items[i] = rtree.BulkItem[line]{ + Box: ln.box(), + Record: ln, } } return rtree.BulkLoad(items) diff --git a/geom/alg_intersection.go b/geom/alg_intersection.go index 23762c1a..89e56e5a 100644 --- a/geom/alg_intersection.go +++ b/geom/alg_intersection.go @@ -10,7 +10,7 @@ func intersectionOfIndexedLines( var pts []Point seen := make(map[XY]bool) for i := range lines1.lines { - lines2.tree.RangeSearch(lines1.lines[i].box(), func(j int) error { + _ = lines2.tree.RangeSearch(lines1.lines[i].box(), func(j int) error { inter := lines1.lines[i].intersectLine(lines2.lines[j]) if inter.empty { return nil diff --git a/geom/alg_intersects.go b/geom/alg_intersects.go index b541d781..94d80f4d 100644 --- a/geom/alg_intersects.go +++ b/geom/alg_intersects.go @@ -195,11 +195,11 @@ func hasIntersectionBetweenLines( lines1, lines2 = lines2, lines1 } - bulk := make([]rtree.BulkItem, len(lines1)) + bulk := make([]rtree.BulkItem[int], len(lines1)) for i, ln := range lines1 { - bulk[i] = rtree.BulkItem{ - Box: ln.box(), - RecordID: i, + bulk[i] = rtree.BulkItem[int]{ + Box: ln.box(), + Record: i, } } tree := rtree.BulkLoad(bulk) @@ -209,7 +209,7 @@ func hasIntersectionBetweenLines( var env Envelope for _, lnA := range lines2 { - tree.RangeSearch(lnA.box(), func(i int) error { + _ = tree.RangeSearch(lnA.box(), func(i int) error { lnB := lines1[i] inter := lnA.intersectLine(lnB) if inter.empty { diff --git a/geom/alg_point_in_ring.go b/geom/alg_point_in_ring.go index e8b68dc4..3f03ac68 100644 --- a/geom/alg_point_in_ring.go +++ b/geom/alg_point_in_ring.go @@ -61,7 +61,7 @@ func relatePointToPolygon(pt XY, polyBoundary indexedLines) side { } var onBound bool var count int - polyBoundary.tree.RangeSearch(box, func(i int) error { + _ = polyBoundary.tree.RangeSearch(box, func(i int) error { ln := polyBoundary.lines[i] crossing, onLine := hasCrossing(pt, ln) if onLine { diff --git a/geom/dcel_ghosts.go b/geom/dcel_ghosts.go index d428d9d7..e6ae07f9 100644 --- a/geom/dcel_ghosts.go +++ b/geom/dcel_ghosts.go @@ -30,9 +30,9 @@ func spanningTree(xys []XY) MultiLineString { // Load points into r-tree. xys = sortAndUniquifyXYs(xys) - items := make([]rtree.BulkItem, len(xys)) + items := make([]rtree.BulkItem[int], len(xys)) for i, xy := range xys { - items[i] = rtree.BulkItem{Box: xy.box(), RecordID: i} + items[i] = rtree.BulkItem[int]{Box: xy.box(), Record: i} } tree := rtree.BulkLoad(items) @@ -49,7 +49,7 @@ func spanningTree(xys []XY) MultiLineString { // of being the closest to another point. continue } - tree.PrioritySearch(xyi.box(), func(j int) error { + _ = tree.PrioritySearch(xyi.box(), func(j int) error { // We don't want to include a new edge in the spanning tree if it // would cause a cycle (i.e. the two endpoints are already in the // same tree). This is checked via dset. diff --git a/geom/dcel_re_noding.go b/geom/dcel_re_noding.go index 71a103a1..943ac8bf 100644 --- a/geom/dcel_re_noding.go +++ b/geom/dcel_re_noding.go @@ -48,7 +48,7 @@ func reNodeGeometries(g1, g2 Geometry, mls MultiLineString) (Geometry, Geometry, // Create new nodes for point/line intersections. ptIndex := newIndexedPoints(nodes.list()) appendCutsForPointXLine := func(ln line, cuts []XY) []XY { - ptIndex.tree.RangeSearch(ln.box(), func(i int) error { + _ = ptIndex.tree.RangeSearch(ln.box(), func(i int) error { xy := ptIndex.points[i] if !ln.hasEndpoint(xy) && distBetweenXYAndLine(xy, ln) < ulp*0x200 { cuts = append(cuts, xy) @@ -64,7 +64,7 @@ func reNodeGeometries(g1, g2 Geometry, mls MultiLineString) (Geometry, Geometry, // Create new nodes for line/line intersections. lnIndex := newIndexedLines(appendLines(nil, all())) appendCutsLineXLine := func(ln line, cuts []XY) []XY { - lnIndex.tree.RangeSearch(ln.box(), func(i int) error { + _ = lnIndex.tree.RangeSearch(ln.box(), func(i int) error { other := lnIndex.lines[i] // TODO: This is a hacky approach (re-orders inputs, rather than diff --git a/geom/rtree.go b/geom/rtree.go index eb4f9d60..5080afd9 100644 --- a/geom/rtree.go +++ b/geom/rtree.go @@ -2,20 +2,44 @@ package geom import "github.com/peterstace/simplefeatures/rtree" +// TODO: Use this instead of indexedLines/Points where possible. +func newLineRTree(lines []line) *rtree.RTree[line] { //nolint:unused + items := make([]rtree.BulkItem[line], len(lines)) + for i, ln := range lines { + items[i] = rtree.BulkItem[line]{ + Box: ln.box(), + Record: ln, + } + } + return rtree.BulkLoad(items) +} + +// TODO: Use this instead of indexedLines/Points where possible. +func newPointRTree(points []XY) *rtree.RTree[XY] { //nolint:unused + items := make([]rtree.BulkItem[XY], len(points)) + for i, pt := range points { + items[i] = rtree.BulkItem[XY]{ + Box: pt.box(), + Record: pt, + } + } + return rtree.BulkLoad(items) +} + // indexedLines is a simple container to hold a list of lines, and a r-tree // structure indexing those lines. The record IDs in the rtree correspond to // the indices of the lines slice. type indexedLines struct { lines []line - tree *rtree.RTree + tree *rtree.RTree[int] } func newIndexedLines(lines []line) indexedLines { - bulk := make([]rtree.BulkItem, len(lines)) + bulk := make([]rtree.BulkItem[int], len(lines)) for i, ln := range lines { - bulk[i] = rtree.BulkItem{ - Box: ln.box(), - RecordID: i, + bulk[i] = rtree.BulkItem[int]{ + Box: ln.box(), + Record: i, } } return indexedLines{lines, rtree.BulkLoad(bulk)} @@ -26,15 +50,15 @@ func newIndexedLines(lines []line) indexedLines { // the indices of the points slice. type indexedPoints struct { points []XY - tree *rtree.RTree + tree *rtree.RTree[int] } func newIndexedPoints(points []XY) indexedPoints { - bulk := make([]rtree.BulkItem, len(points)) + bulk := make([]rtree.BulkItem[int], len(points)) for i, pt := range points { - bulk[i] = rtree.BulkItem{ - Box: rtree.Box{MinX: pt.X, MaxX: pt.X, MinY: pt.Y, MaxY: pt.Y}, - RecordID: i, + bulk[i] = rtree.BulkItem[int]{ + Box: rtree.Box{MinX: pt.X, MaxX: pt.X, MinY: pt.Y, MaxY: pt.Y}, + Record: i, } } return indexedPoints{points, rtree.BulkLoad(bulk)} diff --git a/geom/type_line_string.go b/geom/type_line_string.go index 72814448..f4ac322e 100644 --- a/geom/type_line_string.go +++ b/geom/type_line_string.go @@ -116,13 +116,13 @@ func (s LineString) IsSimple() bool { } n := s.seq.Length() - items := make([]rtree.BulkItem, 0, n-1) + items := make([]rtree.BulkItem[int], 0, n-1) for i := 0; i < n; i++ { ln, ok := getLine(s.seq, i) if !ok { continue } - items = append(items, rtree.BulkItem{Box: ln.box(), RecordID: i}) + items = append(items, rtree.BulkItem[int]{Box: ln.box(), Record: i}) } tree := rtree.BulkLoad(items) @@ -142,7 +142,7 @@ func (s LineString) IsSimple() bool { } simple := true // assume simple until proven otherwise - tree.RangeSearch(ln.box(), func(j int) error { + _ = tree.RangeSearch(ln.box(), func(j int) error { // Skip finding the original line (i == j) or cases where we have // already checked that pair (i > j). if i >= j { diff --git a/geom/type_multi_line_string.go b/geom/type_multi_line_string.go index 89caa230..3b3ded7b 100644 --- a/geom/type_multi_line_string.go +++ b/geom/type_multi_line_string.go @@ -110,22 +110,16 @@ func (m MultiLineString) IsSimple() bool { } } - // Map between record ID in the rtree and a particular line segment: - toRecordID := func(lineStringIdx, seqIdx int) int { - return int(uint64(lineStringIdx)<<32 | uint64(seqIdx)) - } - fromRecordID := func(recordID int) (lineStringIdx, seqIdx int) { - lineStringIdx = int(uint64(recordID) >> 32) - seqIdx = int((uint64(recordID) << 32) >> 32) - return - } - // Create an RTree containing all line segments. + type record struct { + lineStringIdx int + seqIdx int + } var numItems int for _, ls := range m.lines { numItems += maxInt(0, ls.Coordinates().Length()-1) } - items := make([]rtree.BulkItem, 0, numItems) + items := make([]rtree.BulkItem[record], 0, numItems) for i, ls := range m.lines { seq := ls.Coordinates() seqLen := seq.Length() @@ -134,9 +128,9 @@ func (m MultiLineString) IsSimple() bool { if !ok { continue } - items = append(items, rtree.BulkItem{ - Box: ln.box(), - RecordID: toRecordID(i, j), + items = append(items, rtree.BulkItem[record]{ + Box: ln.box(), + Record: record{lineStringIdx: i, seqIdx: j}, }) } } @@ -151,15 +145,14 @@ func (m MultiLineString) IsSimple() bool { continue } isSimple := true // assume simple until proven otherwise - tree.RangeSearch(ln.box(), func(recordID int) error { + _ = tree.RangeSearch(ln.box(), func(rec record) error { // Ignore the intersection if it's for the same LineString that we're currently looking up. - lineStringIdx, seqIdx := fromRecordID(recordID) - if lineStringIdx == i { + if rec.lineStringIdx == i { return nil } - otherLS := m.lines[lineStringIdx] - other, ok := getLine(otherLS.Coordinates(), seqIdx) + otherLS := m.lines[rec.lineStringIdx] + other, ok := getLine(otherLS.Coordinates(), rec.seqIdx) if !ok { // Shouldn't even happen, since we were able to insert this // entry into the RTree. diff --git a/geom/type_multi_polygon.go b/geom/type_multi_polygon.go index 235770c4..9b821584 100644 --- a/geom/type_multi_polygon.go +++ b/geom/type_multi_polygon.go @@ -62,11 +62,11 @@ func (m MultiPolygon) checkMultiPolygonConstraints() error { // Construct RTree of Polygons. boxes := make([]rtree.Box, len(m.polys)) - items := make([]rtree.BulkItem, 0, len(m.polys)) + items := make([]rtree.BulkItem[int], 0, len(m.polys)) for i, p := range m.polys { if box, ok := p.Envelope().AsBox(); ok { boxes[i] = box - item := rtree.BulkItem{Box: boxes[i], RecordID: i} + item := rtree.BulkItem[int]{Box: boxes[i], Record: i} items = append(items, item) } } @@ -142,7 +142,7 @@ func validatePolyNotInsidePoly(p1, p2 indexedLines) error { for j := range p2.lines { // Find intersection points. var pts []XY - p1.tree.RangeSearch(p2.lines[j].box(), func(i int) error { + _ = p1.tree.RangeSearch(p2.lines[j].box(), func(i int) error { inter := p1.lines[i].intersectLine(p2.lines[j]) if inter.empty { return nil diff --git a/geom/type_null_geometry_test.go b/geom/type_null_geometry_test.go index a6d444f7..ca91fe6d 100644 --- a/geom/type_null_geometry_test.go +++ b/geom/type_null_geometry_test.go @@ -13,7 +13,7 @@ func TestNullGeometryScan(t *testing.T) { for _, tc := range []struct { description string - value interface{} + value any wantValid bool wantWKT string }{ diff --git a/geom/type_polygon.go b/geom/type_polygon.go index 8352d5a9..75ebc36e 100644 --- a/geom/type_polygon.go +++ b/geom/type_polygon.go @@ -65,7 +65,7 @@ func (p Polygon) Validate() error { // Construct RTree of rings. boxes := make([]rtree.Box, len(p.rings)) - items := make([]rtree.BulkItem, len(p.rings)) + items := make([]rtree.BulkItem[int], len(p.rings)) for i, r := range p.rings { box, ok := r.Envelope().AsBox() if !ok { @@ -74,7 +74,7 @@ func (p Polygon) Validate() error { panic("unexpected empty ring") } boxes[i] = box - items[i] = rtree.BulkItem{Box: boxes[i], RecordID: i} + items[i] = rtree.BulkItem[int]{Box: boxes[i], Record: i} } tree := rtree.BulkLoad(items) diff --git a/rtree/box.go b/rtree/box.go index 704be2d3..e6f3461f 100644 --- a/rtree/box.go +++ b/rtree/box.go @@ -6,7 +6,7 @@ type Box struct { } // calculateBound calculates the smallest bounding box that fits a node. -func calculateBound(n *node) Box { +func calculateBound[T any](n *node[T]) Box { box := n.entries[0].box for i := 1; i < n.numEntries; i++ { box = combine(box, n.entries[i].box) diff --git a/rtree/bulk.go b/rtree/bulk.go index 18482716..425d4b3c 100644 --- a/rtree/bulk.go +++ b/rtree/bulk.go @@ -1,23 +1,23 @@ package rtree // BulkItem is an item that can be inserted for bulk loading. -type BulkItem struct { - Box Box - RecordID int +type BulkItem[T any] struct { + Box Box + Record T } // BulkLoad bulk loads multiple items into a new R-Tree. The bulk load // operation is optimised for creating R-Trees with minimal node overlap. This // allows for fast searching. -func BulkLoad(items []BulkItem) *RTree { +func BulkLoad[T any](items []BulkItem[T]) *RTree[T] { if len(items) == 0 { - return &RTree{} + return &RTree[T]{} } root := bulkInsert(items) - return &RTree{root, len(items)} + return &RTree[T]{root, len(items)} } -func bulkInsert(items []BulkItem) *node { +func bulkInsert[T any](items []BulkItem[T]) *node[T] { if len(items) == 0 { panic("should not have recursed into bulkInsert without any items") } @@ -27,11 +27,11 @@ func bulkInsert(items []BulkItem) *node { // 4 or fewer items can fit into a single node. if len(items) <= 4 { - n := &node{numEntries: len(items)} + n := &node[T]{numEntries: len(items)} for i, item := range items { - n.entries[i] = entry{ - box: item.Box, - recordID: item.RecordID, + n.entries[i] = entry[T]{ + box: item.Box, + record: item.Record, } } return n @@ -52,8 +52,8 @@ func bulkInsert(items []BulkItem) *node { return bulkNode(firstQuarter, secondQuarter, thirdQuarter, fourthQuarter) } -func bulkNode(parts ...[]BulkItem) *node { - root := &node{numEntries: len(parts)} +func bulkNode[T any](parts ...[]BulkItem[T]) *node[T] { + root := &node[T]{numEntries: len(parts)} for i, part := range parts { child := bulkInsert(part) root.entries[i].child = child @@ -62,7 +62,7 @@ func bulkNode(parts ...[]BulkItem) *node { return root } -func splitBulkItems2Ways(items []BulkItem) ([]BulkItem, []BulkItem) { +func splitBulkItems2Ways[T any](items []BulkItem[T]) ([]BulkItem[T], []BulkItem[T]) { horizontal := itemsAreHorizontal(items) split := len(items) / 2 quickPartition(items, split, horizontal) @@ -72,7 +72,7 @@ func splitBulkItems2Ways(items []BulkItem) ([]BulkItem, []BulkItem) { // quickPartition performs a partial in-place sort on the items slice. The // partial sort is such that items 0 through k-1 are less than or equal to item // k, and items k+1 through n-1 are greater than or equal to item k. -func quickPartition(items []BulkItem, k int, horizontal bool) { +func quickPartition[T any](items []BulkItem[T], k int, horizontal bool) { // Use a custom linear congruential random number generator. This is used // because we don't need high quality random numbers. Using a regular // rand.Rand generator causes a significant bottleneck due to the reliance @@ -150,7 +150,7 @@ func quickPartition(items []BulkItem, k int, horizontal bool) { } } -func itemsAreHorizontal(items []BulkItem) bool { +func itemsAreHorizontal[T any](items []BulkItem[T]) bool { box := items[0].Box for _, item := range items[1:] { box = combine(box, item.Box) diff --git a/rtree/golden_internal_test.go b/rtree/golden_internal_test.go index 74bec92e..9ffcb90c 100644 --- a/rtree/golden_internal_test.go +++ b/rtree/golden_internal_test.go @@ -135,12 +135,12 @@ func TestBulkLoadGolden(t *testing.T) { } } -func checksum(n *node) uint64 { +func checksum(n *node[int]) uint64 { var entries []string for i := 0; i < n.numEntries; i++ { var entry string if n.entries[i].child == nil { - entry = strconv.Itoa(n.entries[i].recordID) + entry = strconv.Itoa(n.entries[i].record) } else { entry = strconv.FormatUint(checksum(n.entries[i].child), 10) } diff --git a/rtree/nearest.go b/rtree/nearest.go index 39b27038..e4dc8e9b 100644 --- a/rtree/nearest.go +++ b/rtree/nearest.go @@ -9,13 +9,13 @@ import ( // as measured by the Euclidean metric. Note that there may be multiple records // that are equidistant from the input box, in which case one is chosen // arbitrarily. If the RTree is empty, then false is returned. -func (t *RTree) Nearest(box Box) (recordID int, found bool) { - t.PrioritySearch(box, func(rid int) error { - recordID = rid +func (t *RTree[T]) Nearest(box Box) (record T, found bool) { + _ = t.PrioritySearch(box, func(rec T) error { + record = rec found = true return Stop }) - return recordID, found + return record, found } // PrioritySearch iterates over the records in the RTree in priority order of @@ -25,13 +25,13 @@ func (t *RTree) Nearest(box Box) (recordID int, found bool) { // error returned from the callback is returned by PrioritySearch, except for // the case where the special Stop sentinel error is returned (in which case // nil will be returned from PrioritySearch). Stop may be wrapped. -func (t *RTree) PrioritySearch(box Box, callback func(recordID int) error) error { +func (t *RTree[T]) PrioritySearch(box Box, callback func(record T) error) error { if t.root == nil { return nil } - queue := entriesQueue{origin: box} - equeueNode := func(n *node) { + queue := entriesQueue[T]{origin: box} + equeueNode := func(n *node[T]) { for i := 0; i < n.numEntries; i++ { heap.Push(&queue, &n.entries[i]) } @@ -39,9 +39,9 @@ func (t *RTree) PrioritySearch(box Box, callback func(recordID int) error) error equeueNode(t.root) for len(queue.entries) > 0 { - nearest := heap.Pop(&queue).(*entry) + nearest := heap.Pop(&queue).(*entry[T]) if nearest.child == nil { - if err := callback(nearest.recordID); err != nil { + if err := callback(nearest.record); err != nil { if errors.Is(err, Stop) { return nil } @@ -54,30 +54,30 @@ func (t *RTree) PrioritySearch(box Box, callback func(recordID int) error) error return nil } -type entriesQueue struct { - entries []*entry +type entriesQueue[T any] struct { + entries []*entry[T] origin Box } -func (q *entriesQueue) Len() int { +func (q *entriesQueue[T]) Len() int { return len(q.entries) } -func (q *entriesQueue) Less(i int, j int) bool { +func (q *entriesQueue[T]) Less(i int, j int) bool { d1 := squaredEuclideanDistance(q.entries[i].box, q.origin) d2 := squaredEuclideanDistance(q.entries[j].box, q.origin) return d1 < d2 } -func (q *entriesQueue) Swap(i int, j int) { +func (q *entriesQueue[T]) Swap(i int, j int) { q.entries[i], q.entries[j] = q.entries[j], q.entries[i] } -func (q *entriesQueue) Push(x any) { - q.entries = append(q.entries, x.(*entry)) +func (q *entriesQueue[T]) Push(x any) { + q.entries = append(q.entries, x.(*entry[T])) } -func (q *entriesQueue) Pop() any { +func (q *entriesQueue[T]) Pop() any { e := q.entries[len(q.entries)-1] q.entries = q.entries[:len(q.entries)-1] return e diff --git a/rtree/nearest_internal_test.go b/rtree/nearest_internal_test.go index e518d143..e95aeac0 100644 --- a/rtree/nearest_internal_test.go +++ b/rtree/nearest_internal_test.go @@ -21,7 +21,7 @@ func TestNearest(t *testing.T) { } } -func checkNearest(t *testing.T, rt *RTree, boxes []Box, rnd *rand.Rand) { +func checkNearest(t *testing.T, rt *RTree[int], boxes []Box, rnd *rand.Rand) { t.Helper() for i := 0; i < 10; i++ { originBB := randomBox(rnd, 0.9, 0.1) @@ -47,13 +47,13 @@ func checkNearest(t *testing.T, rt *RTree, boxes []Box, rnd *rand.Rand) { } } -func checkPrioritySearch(t *testing.T, rt *RTree, boxes []Box, rnd *rand.Rand) { +func checkPrioritySearch(t *testing.T, rt *RTree[int], boxes []Box, rnd *rand.Rand) { t.Helper() for i := 0; i < 10; i++ { var got []int originBB := randomBox(rnd, 0.9, 0.1) t.Logf("origin: %v", originBB) - rt.PrioritySearch(originBB, func(recordID int) error { + _ = rt.PrioritySearch(originBB, func(recordID int) error { got = append(got, recordID) return nil }) @@ -79,10 +79,10 @@ func TestPrioritySearchEarlyStop(t *testing.T) { boxes[i] = randomBox(rnd, 0.9, 0.1) } - inserts := make([]BulkItem, len(boxes)) + inserts := make([]BulkItem[int], len(boxes)) for i := range inserts { inserts[i].Box = boxes[i] - inserts[i].RecordID = i + inserts[i].Record = i } rt := BulkLoad(inserts) origin := randomBox(rnd, 0.9, 0.1) diff --git a/rtree/perf_internal_test.go b/rtree/perf_internal_test.go index e9f6bf36..664c5b0e 100644 --- a/rtree/perf_internal_test.go +++ b/rtree/perf_internal_test.go @@ -13,10 +13,10 @@ func BenchmarkBulk(b *testing.B) { for i := range boxes { boxes[i] = randomBox(rnd, 0.9, 0.1) } - inserts := make([]BulkItem, len(boxes)) + inserts := make([]BulkItem[int], len(boxes)) for i := range inserts { inserts[i].Box = boxes[i] - inserts[i].RecordID = i + inserts[i].Record = i } b.Run(fmt.Sprintf("n=%d", pop), func(b *testing.B) { for i := 0; i < b.N; i++ { @@ -33,7 +33,7 @@ func BenchmarkRangeSearch(b *testing.B) { tree, _ := testBulkLoad(rnd, pop) b.ResetTimer() for i := 0; i < b.N; i++ { - tree.RangeSearch(Box{0.5, 0.5, 0.5, 0.5}, func(int) error { return nil }) + _ = tree.RangeSearch(Box{0.5, 0.5, 0.5, 0.5}, func(int) error { return nil }) } }) } diff --git a/rtree/quick_partition_internal_test.go b/rtree/quick_partition_internal_test.go index 625065a8..50116ea7 100644 --- a/rtree/quick_partition_internal_test.go +++ b/rtree/quick_partition_internal_test.go @@ -45,10 +45,10 @@ func TestQuickPartition(t *testing.T) { t.Run(strconv.Itoa(i), func(t *testing.T) { for k := range tc { t.Run(fmt.Sprintf("k=%d", k), func(t *testing.T) { - items := make([]BulkItem, 0, len(tc)) + items := make([]BulkItem[int], 0, len(tc)) for _, num := range tc { f := float64(num) - items = append(items, BulkItem{ + items = append(items, BulkItem[int]{ Box{f, f, f, f}, len(items), }) diff --git a/rtree/rtree.go b/rtree/rtree.go index d1fc92a9..48af0247 100644 --- a/rtree/rtree.go +++ b/rtree/rtree.go @@ -9,29 +9,26 @@ const ( maxEntries = 4 ) -// node is a node in an R-Tree, holding user record IDs and/or links to deeper +// node is a node in an R-Tree, holding user records and/or links to deeper // nodes in the tree. -type node struct { - entries [maxEntries]entry +type node[T any] struct { + entries [maxEntries]entry[T] numEntries int } // entry is an entry contained inside a node. An entry can either hold a user -// record ID, or point to a deeper node in the tree (but not both). Because 0 -// is a valid record ID, the child pointer should be used to distinguish -// between the two types of entries. -type entry struct { - box Box - child *node - recordID int +// record, or point to a deeper node in the tree (but not both). The child +// pointer should be used to distinguish between the two types of entries. +type entry[T any] struct { + box Box + child *node[T] + record T } -// RTree is an in-memory R-Tree data structure. It holds record ID and bounding -// box pairs (the actual records aren't stored in the tree; the user is -// responsible for storing their own records). Its zero value is an empty -// R-Tree. -type RTree struct { - root *node +// RTree is an in-memory R-Tree data structure. It holds records of type T +// along with their bounding boxes. Its zero value is an empty R-Tree. +type RTree[T any] struct { + root *node[T] count int } @@ -40,24 +37,24 @@ type RTree struct { var Stop = errors.New("stop") //nolint:stylecheck,revive // RangeSearch looks for any items in the tree that overlap with the given -// bounding box. The callback is called with the record ID for each found item. -// If an error is returned from the callback then the search is terminated -// early. Any error returned from the callback is returned by RangeSearch, -// except for the case where the special Stop sentinel error is returned (in -// which case nil will be returned from RangeSearch). Stop may be wrapped. -func (t *RTree) RangeSearch(box Box, callback func(recordID int) error) error { +// bounding box. The callback is called with each found item's record. If an +// error is returned from the callback then the search is terminated early. +// Any error returned from the callback is returned by RangeSearch, except for +// the case where the special Stop sentinel error is returned (in which case +// nil will be returned from RangeSearch). Stop may be wrapped. +func (t *RTree[T]) RangeSearch(box Box, callback func(record T) error) error { if t.root == nil { return nil } - var recurse func(*node) error - recurse = func(n *node) error { + var recurse func(*node[T]) error + recurse = func(n *node[T]) error { for i := 0; i < n.numEntries; i++ { entry := n.entries[i] if !overlap(entry.box, box) { continue } if entry.child == nil { - if err := callback(entry.recordID); errors.Is(err, Stop) { + if err := callback(entry.record); errors.Is(err, Stop) { return nil } else if err != nil { return err @@ -75,7 +72,7 @@ func (t *RTree) RangeSearch(box Box, callback func(recordID int) error) error { // Extent gives the Box that most closely bounds the RTree. If the RTree is // empty, then false is returned. -func (t *RTree) Extent() (Box, bool) { +func (t *RTree[T]) Extent() (Box, bool) { if t.root == nil || t.root.numEntries == 0 { return Box{}, false } @@ -83,6 +80,6 @@ func (t *RTree) Extent() (Box, bool) { } // Count gives the number of entries in the RTree. -func (t *RTree) Count() int { +func (t *RTree[T]) Count() int { return t.count } diff --git a/rtree/rtree_internal_test.go b/rtree/rtree_internal_test.go index 5b39b1bb..1afb2f92 100644 --- a/rtree/rtree_internal_test.go +++ b/rtree/rtree_internal_test.go @@ -9,7 +9,7 @@ import ( "testing" ) -func testBulkLoad(rnd *rand.Rand, pop int) (*RTree, []Box) { +func testBulkLoad(rnd *rand.Rand, pop int) (*RTree[int], []Box) { boxes := make([]Box, pop) seenX := make(map[float64]bool) seenY := make(map[float64]bool) @@ -27,10 +27,10 @@ func testBulkLoad(rnd *rand.Rand, pop int) (*RTree, []Box) { } boxes[i] = box } - inserts := make([]BulkItem, len(boxes)) + inserts := make([]BulkItem[int], len(boxes)) for i := range inserts { inserts[i].Box = boxes[i] - inserts[i].RecordID = i + inserts[i].Record = i } return BulkLoad(inserts), boxes } @@ -57,12 +57,12 @@ func TestRandom(t *testing.T) { } } -func checkSearch(t *testing.T, rt *RTree, boxes []Box, rnd *rand.Rand) { +func checkSearch(t *testing.T, rt *RTree[int], boxes []Box, rnd *rand.Rand) { t.Helper() for i := 0; i < 10; i++ { searchBB := randomBox(rnd, 0.5, 0.5) var got []int - rt.RangeSearch(searchBB, func(idx int) error { + _ = rt.RangeSearch(searchBB, func(idx int) error { got = append(got, idx) return nil }) @@ -99,16 +99,16 @@ func randomBox(rnd *rand.Rand, maxStart, maxWidth float64) Box { return box } -func checkInvariants(t *testing.T, rt *RTree, boxes []Box) { +func checkInvariants(t *testing.T, rt *RTree[int], boxes []Box) { t.Helper() - var recurse func(*node, string) - recurse = func(current *node, indent string) { + var recurse func(*node[int], string) + recurse = func(current *node[int], indent string) { t.Logf("%sNode addr=%p numEntries=%d", indent, current, current.numEntries) indent += "\t" for i := 0; i < current.numEntries; i++ { e := current.entries[i] if e.child == nil { - t.Logf("%sEntry[%d] recordID=%d box=%v", indent, i, e.recordID, e.box) + t.Logf("%sEntry[%d] recordID=%d box=%v", indent, i, e.record, e.box) } else { t.Logf("%sEntry[%d] box=%v", indent, i, e.box) recurse(e.child, indent+"\t") @@ -134,20 +134,20 @@ func checkInvariants(t *testing.T, rt *RTree, boxes []Box) { minLeafLevel := math.MaxInt maxLeafLevel := math.MinInt - var check func(n *node, level int) - check = func(current *node, level int) { + var check func(n *node[int], level int) + check = func(current *node[int], level int) { for i := 0; i < current.numEntries; i++ { e := current.entries[i] if e.child == nil { minLeafLevel = minInt(minLeafLevel, level) maxLeafLevel = maxInt(maxLeafLevel, level) - if _, ok := unfound[e.recordID]; !ok { + if _, ok := unfound[e.record]; !ok { t.Fatal("record ID found in tree but wasn't in unfound map") } - delete(unfound, e.recordID) + delete(unfound, e.record) } else { - if e.recordID != 0 { - t.Fatal("non-leaf has recordID") + if e.record != 0 { + t.Fatal("non-leaf has record") } box := e.child.entries[0].box for j := 1; j < e.child.numEntries; j++ { @@ -161,7 +161,7 @@ func checkInvariants(t *testing.T, rt *RTree, boxes []Box) { } for i := current.numEntries; i < len(current.entries); i++ { e := current.entries[i] - if e != (entry{}) { + if e != (entry[int]{}) { t.Fatal("entry past numEntries is not the zero value") } }