Skip to content

Commit e765172

Browse files
committed
optimize duplicate-heavy exact block evaluation
1 parent 7bb6318 commit e765172

4 files changed

Lines changed: 150 additions & 2 deletions

File tree

docs/progress.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ This document tracks the next correctness-first phase as repository-visible work
2323
- [x] Config-matrix differential and invariant test layer added
2424
- [x] Local benchmark harness and first measured optimization pass added
2525
- [x] Exact acceptance-cost optimization using permutation delta energy added
26+
- [x] Duplicate-aware grouped exact evaluator path added
2627

2728
## Active next-phase checklist
2829

@@ -59,6 +60,7 @@ This document tracks the next correctness-first phase as repository-visible work
5960
- keep benchmark datasets aligned with correctness stress patterns
6061
- use benchmark observations to justify small, auditable optimization passes
6162
- continue reducing transport acceptance cost without weakening exactness
63+
- improve duplicate-heavy paths without regressing transport-first semantics
6264

6365
### Docs and positioning
6466

src/core/energy.zig

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,66 @@
11
const Config = @import("config.zig").Config;
22
const key = @import("key.zig");
33

4+
pub fn collectDistinctGroups(
5+
comptime T: type,
6+
keys: []const key.KeyType(T),
7+
distinct_keys_out: []key.KeyType(T),
8+
group_ids_out: []u8,
9+
) usize {
10+
var distinct_count: usize = 0;
11+
12+
for (keys) |value| {
13+
var found = false;
14+
for (distinct_keys_out[0..distinct_count]) |existing| {
15+
if (existing == value) {
16+
found = true;
17+
break;
18+
}
19+
}
20+
if (!found) {
21+
distinct_keys_out[distinct_count] = value;
22+
distinct_count += 1;
23+
}
24+
}
25+
26+
var i: usize = 1;
27+
while (i < distinct_count) : (i += 1) {
28+
const current = distinct_keys_out[i];
29+
var j = i;
30+
while (j > 0 and distinct_keys_out[j - 1] > current) : (j -= 1) {
31+
distinct_keys_out[j] = distinct_keys_out[j - 1];
32+
}
33+
distinct_keys_out[j] = current;
34+
}
35+
36+
for (keys, 0..) |value, idx| {
37+
var group: usize = 0;
38+
while (group < distinct_count and distinct_keys_out[group] != value) : (group += 1) {}
39+
group_ids_out[idx] = @intCast(group);
40+
}
41+
42+
return distinct_count;
43+
}
44+
45+
pub fn buildGroupWeightMatrix(
46+
comptime T: type,
47+
distinct_keys: []const key.KeyType(T),
48+
cfg: Config,
49+
weight_matrix: []u16,
50+
) void {
51+
const stride = distinct_keys.len;
52+
for (distinct_keys, 0..) |left_key, i| {
53+
for (distinct_keys, 0..) |right_key, j| {
54+
const index = i * stride + j;
55+
if (left_key > right_key) {
56+
weight_matrix[index] = @intCast(pairWeightFromKeys(T, left_key, right_key, cfg));
57+
} else {
58+
weight_matrix[index] = 0;
59+
}
60+
}
61+
}
62+
}
63+
464
pub fn pairWeight(comptime T: type, left: T, right: T, cfg: Config) u64 {
565
return 1 + key.closenessBonus(T, left, right, cfg);
666
}
@@ -41,6 +101,24 @@ pub fn blockEnergyFromKeys(comptime T: type, keys: []const key.KeyType(T), cfg:
41101
return total;
42102
}
43103

104+
pub fn blockEnergyFromGroupIds(group_ids: []const u8, distinct_count: usize, weight_matrix: []const u16) u64 {
105+
var counts: [Config.max_block_size]usize = [_]usize{0} ** Config.max_block_size;
106+
var total: u64 = 0;
107+
108+
var idx = group_ids.len;
109+
while (idx > 0) {
110+
idx -= 1;
111+
const group = group_ids[idx];
112+
var smaller: usize = 0;
113+
while (smaller < group) : (smaller += 1) {
114+
total += @as(u64, counts[smaller]) * weight_matrix[@as(usize, group) * distinct_count + smaller];
115+
}
116+
counts[group] += 1;
117+
}
118+
119+
return total;
120+
}
121+
44122
pub fn energyAfterPermutationFromKeys(
45123
comptime T: type,
46124
keys: []const key.KeyType(T),

src/core/transport.zig

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,32 @@ fn useMovedDeltaPath(block_len: usize, moved_count: usize) bool {
1919
return affected_pairs * 2 < total_pairs;
2020
}
2121

22+
fn useGroupedExactPath(block_len: usize, distinct_count: usize) bool {
23+
return distinct_count <= 8 and distinct_count * 4 <= block_len;
24+
}
25+
26+
fn cheapDistinctCount(comptime T: type, keys: []const key.KeyType(T)) usize {
27+
const table_len = Config.max_block_size * 2;
28+
var used: [table_len]bool = [_]bool{false} ** table_len;
29+
var table: [table_len]key.KeyType(T) = undefined;
30+
var count: usize = 0;
31+
32+
for (keys) |value| {
33+
var hash: usize = @intCast(value ^ (value >> @min(@bitSizeOf(key.KeyType(T)) / 2, 16)));
34+
hash %= table_len;
35+
while (used[hash]) : (hash = (hash + 1) % table_len) {
36+
if (table[hash] == value) break;
37+
}
38+
if (!used[hash]) {
39+
used[hash] = true;
40+
table[hash] = value;
41+
count += 1;
42+
}
43+
}
44+
45+
return count;
46+
}
47+
2248
fn isPermutation(mapping: []const usize) bool {
2349
var seen: [Config.max_block_size]bool = [_]bool{false} ** Config.max_block_size;
2450
for (mapping) |value| {
@@ -76,8 +102,26 @@ pub fn tryTransportBlock(comptime T: type, block: []T, cfg: Config, stats: ?*Sta
76102
var desired: [Config.max_block_size]usize = undefined;
77103
var source_to_final: [Config.max_block_size]usize = undefined;
78104
var moved_indices: [Config.max_block_size]usize = undefined;
105+
var group_ids: [Config.max_block_size]u8 = undefined;
106+
var final_group_ids: [Config.max_block_size]u8 = undefined;
107+
var distinct_keys: [Config.max_block_size]key.KeyType(T) = undefined;
108+
var weight_matrix: [Config.max_block_size * Config.max_block_size]u16 = undefined;
109+
110+
const estimated_distinct = cheapDistinctCount(T, keys[0..block.len]);
111+
const grouped_path = useGroupedExactPath(block.len, estimated_distinct);
112+
const distinct_count = if (grouped_path)
113+
energy.collectDistinctGroups(T, keys[0..block.len], distinct_keys[0..block.len], group_ids[0..block.len])
114+
else
115+
0;
79116

80-
const before_energy = pressure.computeFromKeysWithEnergy(T, keys[0..block.len], cfg, pressures[0..block.len]);
117+
const before_energy = if (grouped_path)
118+
blk: {
119+
energy.buildGroupWeightMatrix(T, distinct_keys[0..distinct_count], cfg, weight_matrix[0 .. distinct_count * distinct_count]);
120+
pressure.computeFromKeys(T, keys[0..block.len], cfg, pressures[0..block.len]);
121+
break :blk energy.blockEnergyFromGroupIds(group_ids[0..block.len], distinct_count, weight_matrix[0 .. distinct_count * distinct_count]);
122+
}
123+
else
124+
pressure.computeFromKeysWithEnergy(T, keys[0..block.len], cfg, pressures[0..block.len]);
81125
if (before_energy == 0) {
82126
if (stats) |s| s.transport_blocks_rejected += 1;
83127
return .{ .accepted = false, .before_energy = before_energy, .after_energy = before_energy };
@@ -116,7 +160,14 @@ pub fn tryTransportBlock(comptime T: type, block: []T, cfg: Config, stats: ?*Sta
116160
return .{ .accepted = false, .before_energy = before_energy, .after_energy = before_energy };
117161
}
118162

119-
const after_energy = if (useMovedDeltaPath(block.len, moved_count))
163+
const after_energy = if (grouped_path)
164+
blk: {
165+
for (group_ids[0..block.len], 0..) |group, source_index| {
166+
final_group_ids[source_to_final[source_index]] = group;
167+
}
168+
break :blk energy.blockEnergyFromGroupIds(final_group_ids[0..block.len], distinct_count, weight_matrix[0 .. distinct_count * distinct_count]);
169+
}
170+
else if (useMovedDeltaPath(block.len, moved_count))
120171
energy.energyAfterPermutationFromMovedKeys(
121172
T,
122173
keys[0..block.len],

test/unit/energy_test.zig

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,20 @@ test "moved-only permutation delta energy matches exact recomputation" {
5959

6060
try std.testing.expectEqual(after_exact, after_delta);
6161
}
62+
63+
test "grouped energy matches exact recomputation on duplicate-heavy block" {
64+
const cfg = Config{ .valuation_cap = 8 };
65+
const xs = [_]i32{ 2, 2, 1, 1, 0, 0, 2, 1 };
66+
var keys = [_]key.KeyType(i32){ 0, 0, 0, 0, 0, 0, 0, 0 };
67+
var distinct_keys = [_]key.KeyType(i32){0} ** 8;
68+
var group_ids = [_]u8{0} ** 8;
69+
var weights = [_]u16{0} ** (8 * 8);
70+
71+
for (xs, 0..) |value, i| keys[i] = key.biasedKey(i32, value);
72+
const distinct_count = energy.collectDistinctGroups(i32, keys[0..], distinct_keys[0..], group_ids[0..]);
73+
energy.buildGroupWeightMatrix(i32, distinct_keys[0..distinct_count], cfg, weights[0 .. distinct_count * distinct_count]);
74+
75+
const exact = energy.blockEnergy(i32, xs[0..], cfg);
76+
const grouped = energy.blockEnergyFromGroupIds(group_ids[0..], distinct_count, weights[0 .. distinct_count * distinct_count]);
77+
try std.testing.expectEqual(exact, grouped);
78+
}

0 commit comments

Comments
 (0)