diff --git a/BitFaster.Caching.Benchmarks/BitFaster.Caching.Benchmarks.csproj b/BitFaster.Caching.Benchmarks/BitFaster.Caching.Benchmarks.csproj index 789e1174..aa79c2b9 100644 --- a/BitFaster.Caching.Benchmarks/BitFaster.Caching.Benchmarks.csproj +++ b/BitFaster.Caching.Benchmarks/BitFaster.Caching.Benchmarks.csproj @@ -3,12 +3,14 @@ Exe latest - net48;net6.0;net8.0 + net6.0;net8.0 True true true true + true + true @@ -41,5 +43,11 @@ MacOS + + Arm64 + + + X64 + - \ No newline at end of file + diff --git a/BitFaster.Caching.Benchmarks/Lfu/CmSketchNoPin.cs b/BitFaster.Caching.Benchmarks/Lfu/CmSketchNoPin.cs index 809a3f9b..68428e06 100644 --- a/BitFaster.Caching.Benchmarks/Lfu/CmSketchNoPin.cs +++ b/BitFaster.Caching.Benchmarks/Lfu/CmSketchNoPin.cs @@ -1,9 +1,12 @@ using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + #if NET6_0_OR_GREATER using System.Runtime.Intrinsics; +using System.Runtime.Intrinsics.Arm; using System.Runtime.Intrinsics.X86; #endif @@ -61,6 +64,12 @@ public int EstimateFrequency(T value) { return EstimateFrequencyAvx(value); } +#if NET6_0_OR_GREATER + else if (isa.IsArm64Supported) + { + return EstimateFrequencyArm(value); + } +#endif else { return EstimateFrequencyStd(value); @@ -84,6 +93,12 @@ public void Increment(T value) { IncrementAvx(value); } +#if NET6_0_OR_GREATER + else if (isa.IsArm64Supported) + { + IncrementArm(value); + } +#endif else { IncrementStd(value); @@ -314,5 +329,94 @@ private unsafe void IncrementAvx(T value) } } #endif + +#if NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe void IncrementArm(T value) + { + int blockHash = Spread(comparer.GetHashCode(value)); + int counterHash = Rehash(blockHash); + int block = (blockHash & blockMask) << 3; + + Vector128 h = AdvSimd.ShiftArithmetic(Vector128.Create(counterHash), Vector128.Create(0, -8, -16, -24)); + Vector128 index = AdvSimd.And(AdvSimd.ShiftRightLogical(h, 1), Vector128.Create(0xf)); + Vector128 blockOffset = AdvSimd.Add(AdvSimd.Add(Vector128.Create(block), AdvSimd.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6)); + + fixed (long* tablePtr = table) + { + int t0 = AdvSimd.Extract(blockOffset, 0); + int t1 = AdvSimd.Extract(blockOffset, 1); + int t2 = AdvSimd.Extract(blockOffset, 2); + int t3 = AdvSimd.Extract(blockOffset, 3); + + Vector128 tableVectorA = Vector128.Create(AdvSimd.LoadVector64(tablePtr + t0), AdvSimd.LoadVector64(tablePtr + t1)); + Vector128 tableVectorB = Vector128.Create(AdvSimd.LoadVector64(tablePtr + t2), AdvSimd.LoadVector64(tablePtr + t3)); + + index = AdvSimd.ShiftLeftLogicalSaturate(index, 2); + + Vector128 longOffA = AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128.Zero, 0, index, 0), 2, index, 1); + Vector128 longOffB = AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128.Zero, 0, index, 2), 2, index, 3); + + Vector128 fifteen = Vector128.Create(0xfL); + Vector128 maskA = AdvSimd.ShiftArithmetic(fifteen, longOffA.AsInt64()); + Vector128 maskB = AdvSimd.ShiftArithmetic(fifteen, longOffB.AsInt64()); + + Vector128 maskedA = AdvSimd.Not(AdvSimd.Arm64.CompareEqual(AdvSimd.And(tableVectorA, maskA), maskA)); + Vector128 maskedB = AdvSimd.Not(AdvSimd.Arm64.CompareEqual(AdvSimd.And(tableVectorB, maskB), maskB)); + + var one = Vector128.Create(1L); + Vector128 incA = AdvSimd.And(maskedA, AdvSimd.ShiftArithmetic(one, longOffA.AsInt64())); + Vector128 incB = AdvSimd.And(maskedB, AdvSimd.ShiftArithmetic(one, longOffB.AsInt64())); + + tablePtr[t0] += AdvSimd.Extract(incA, 0); + tablePtr[t1] += AdvSimd.Extract(incA, 1); + tablePtr[t2] += AdvSimd.Extract(incB, 0); + tablePtr[t3] += AdvSimd.Extract(incB, 1); + + var max = AdvSimd.Arm64.MaxAcross(AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.MaxAcross(incA.AsInt32()), 1, AdvSimd.Arm64.MaxAcross(incB.AsInt32()), 0).AsInt16()); + + if (max.ToScalar() != 0 && (++size == sampleSize)) + { + Reset(); + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe int EstimateFrequencyArm(T value) + { + int blockHash = Spread(comparer.GetHashCode(value)); + int counterHash = Rehash(blockHash); + int block = (blockHash & blockMask) << 3; + + Vector128 h = AdvSimd.ShiftArithmetic(Vector128.Create(counterHash), Vector128.Create(0, -8, -16, -24)); + Vector128 index = AdvSimd.And(AdvSimd.ShiftRightLogical(h, 1), Vector128.Create(0xf)); + Vector128 blockOffset = AdvSimd.Add(AdvSimd.Add(Vector128.Create(block), AdvSimd.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6)); + + fixed (long* tablePtr = table) + { + Vector128 tableVectorA = Vector128.Create(AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 0)), AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 1))); + Vector128 tableVectorB = Vector128.Create(AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 2)), AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 3))); + + index = AdvSimd.ShiftLeftLogicalSaturate(index, 2); + + Vector128 indexA = AdvSimd.Negate(AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128.Zero, 0, index, 0), 2, index, 1)); + Vector128 indexB = AdvSimd.Negate(AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128.Zero, 0, index, 2), 2, index, 3)); + + var fifteen = Vector128.Create(0xfL); + Vector128 a = AdvSimd.And(AdvSimd.ShiftArithmetic(tableVectorA, indexA.AsInt64()), fifteen); + Vector128 b = AdvSimd.And(AdvSimd.ShiftArithmetic(tableVectorB, indexB.AsInt64()), fifteen); + + // Before: < 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, A, B, C, D, E, F > + // After: < 0, 1, 2, 3, 8, 9, A, B, 4, 5, 6, 7, C, D, E, F > + var min = AdvSimd.Arm64.VectorTableLookup(a.AsByte(), Vector128.Create(0x0B0A090803020100, 0xFFFFFFFFFFFFFFFF).AsByte()); + min = AdvSimd.Arm64.VectorTableLookupExtension(min, b.AsByte(), Vector128.Create(0xFFFFFFFFFFFFFFFF, 0x0B0A090803020100).AsByte()); + + var min32 = AdvSimd.Arm64.MinAcross(min.AsInt32()); + + return min32.ToScalar(); + } + } +#endif } } diff --git a/BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs b/BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs index 137b9dcd..ba206ddf 100644 --- a/BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs +++ b/BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs @@ -51,7 +51,7 @@ public int FrequencyFlat() return count; } - +#if X64 [Benchmark(OperationsPerInvoke = iterations)] public int FrequencyFlatAvx() { @@ -61,7 +61,7 @@ public int FrequencyFlatAvx() return count; } - +#endif [Benchmark(OperationsPerInvoke = iterations)] public int FrequencyBlock() { @@ -73,7 +73,11 @@ public int FrequencyBlock() } [Benchmark(OperationsPerInvoke = iterations)] +#if Arm64 + public int FrequencyBlockNeonNotPinned() +#else public int FrequencyBlockAvxNotPinned() +#endif { int count = 0; for (int i = 0; i < iterations; i++) @@ -83,7 +87,12 @@ public int FrequencyBlockAvxNotPinned() } [Benchmark(OperationsPerInvoke = iterations)] + +#if Arm64 + public int FrequencyBlockNeonPinned() +#else public int FrequencyBlockAvxPinned() +#endif { int count = 0; for (int i = 0; i < iterations; i++) diff --git a/BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs b/BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs index 6f6ab1e7..e2fb1e02 100644 --- a/BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs +++ b/BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs @@ -27,6 +27,7 @@ public class SketchIncrement private CmSketchNoPin blockAvxNoPin; private CmSketchCore blockAvx; + [Params(32_768, 524_288, 8_388_608, 134_217_728)] public int Size { get; set; } @@ -49,7 +50,7 @@ public void IncFlat() flatStd.Increment(i); } } - +#if X64 [Benchmark(OperationsPerInvoke = iterations)] public void IncFlatAvx() { @@ -58,7 +59,7 @@ public void IncFlatAvx() flatAvx.Increment(i); } } - +#endif [Benchmark(OperationsPerInvoke = iterations)] public void IncBlock() { @@ -69,7 +70,11 @@ public void IncBlock() } [Benchmark(OperationsPerInvoke = iterations)] +#if Arm64 + public void IncBlockNeonNotPinned() +#else public void IncBlockAvxNotPinned() +#endif { for (int i = 0; i < iterations; i++) { @@ -78,7 +83,11 @@ public void IncBlockAvxNotPinned() } [Benchmark(OperationsPerInvoke = iterations)] +#if Arm64 + public void IncBlockNeonPinned() +#else public void IncBlockAvxPinned() +#endif { for (int i = 0; i < iterations; i++) { diff --git a/BitFaster.Caching.UnitTests/Intrinsics.cs b/BitFaster.Caching.UnitTests/Intrinsics.cs index 312d78a1..ebbe194a 100644 --- a/BitFaster.Caching.UnitTests/Intrinsics.cs +++ b/BitFaster.Caching.UnitTests/Intrinsics.cs @@ -1,6 +1,10 @@ #if NETCOREAPP3_1_OR_GREATER using System.Runtime.Intrinsics.X86; #endif +#if NET6_0_OR_GREATER +using System.Runtime.Intrinsics.Arm; +#endif + using Xunit; namespace BitFaster.Caching.UnitTests @@ -10,8 +14,14 @@ public static class Intrinsics public static void SkipAvxIfNotSupported() { #if NETCOREAPP3_1_OR_GREATER + #if NET6_0_OR_GREATER + // when we are trying to test Avx2/Arm64, skip the test if it's not supported + Skip.If(typeof(I) == typeof(DetectIsa) && !(Avx2.IsSupported || AdvSimd.Arm64.IsSupported)); + #else // when we are trying to test Avx2, skip the test if it's not supported Skip.If(typeof(I) == typeof(DetectIsa) && !Avx2.IsSupported); + #endif + #else Skip.If(true); #endif diff --git a/BitFaster.Caching.UnitTests/Lfu/CmSketchTests.cs b/BitFaster.Caching.UnitTests/Lfu/CmSketchTests.cs index 85de5040..ba32b42e 100644 --- a/BitFaster.Caching.UnitTests/Lfu/CmSketchTests.cs +++ b/BitFaster.Caching.UnitTests/Lfu/CmSketchTests.cs @@ -6,12 +6,12 @@ namespace BitFaster.Caching.UnitTests.Lfu { - // Test with AVX2 if it is supported - public class CMSketchAvx2Tests : CmSketchTestBase + // Test with AVX2/ARM64 if it is supported + public class CMSketchIntrinsicsTests : CmSketchTestBase { } - // Test with AVX2 disabled + // Test with AVX2/ARM64 disabled public class CmSketchTests : CmSketchTestBase { } @@ -29,14 +29,23 @@ public CmSketchTestBase() public void Repro() { sketch = new CmSketchCore(1_048_576, EqualityComparer.Default); + var baseline = new CmSketchCore(1_048_576, EqualityComparer.Default); for (int i = 0; i < 1_048_576; i++) { if (i % 3 == 0) { sketch.Increment(i); + baseline.Increment(i); } } + + baseline.Size.Should().Be(sketch.Size); + + for (int i = 0; i < 1_048_576; i++) + { + sketch.EstimateFrequency(i).Should().Be(baseline.EstimateFrequency(i)); + } } diff --git a/BitFaster.Caching/Intrinsics.cs b/BitFaster.Caching/Intrinsics.cs index 8a1bd29a..45908a01 100644 --- a/BitFaster.Caching/Intrinsics.cs +++ b/BitFaster.Caching/Intrinsics.cs @@ -2,6 +2,10 @@ using System.Runtime.Intrinsics.X86; #endif +#if NET6_0 +using System.Runtime.Intrinsics.Arm; +#endif + namespace BitFaster.Caching { /// @@ -12,7 +16,14 @@ public interface IsaProbe /// /// Gets a value indicating whether AVX2 is supported. /// - bool IsAvx2Supported { get; } + bool IsAvx2Supported { get; } + +#if NET6_0_OR_GREATER + /// + /// Gets a value indicating whether Arm64 is supported. + /// + bool IsArm64Supported { get => false; } +#endif } /// @@ -25,7 +36,15 @@ public interface IsaProbe public bool IsAvx2Supported => false; #else /// - public bool IsAvx2Supported => Avx2.IsSupported; + public bool IsAvx2Supported => Avx2.IsSupported; +#endif + +#if NET6_0_OR_GREATER + /// + public bool IsArm64Supported => AdvSimd.Arm64.IsSupported; +#else + /// + public bool IsArm64Supported => false; #endif } @@ -35,6 +54,9 @@ public interface IsaProbe public readonly struct DisableHardwareIntrinsics : IsaProbe { /// - public bool IsAvx2Supported => false; + public bool IsAvx2Supported => false; + + /// + public bool IsArm64Supported => false; } } diff --git a/BitFaster.Caching/Lfu/CmSketchCore.cs b/BitFaster.Caching/Lfu/CmSketchCore.cs index 733b1ea0..46f516b7 100644 --- a/BitFaster.Caching/Lfu/CmSketchCore.cs +++ b/BitFaster.Caching/Lfu/CmSketchCore.cs @@ -10,6 +10,10 @@ using System.Runtime.Intrinsics.X86; #endif +#if NET6_0_OR_GREATER +using System.Runtime.Intrinsics.Arm; +#endif + namespace BitFaster.Caching.Lfu { /// @@ -81,6 +85,12 @@ public int EstimateFrequency(T value) { return EstimateFrequencyAvx(value); } +#if NET6_0_OR_GREATER + else if (isa.IsArm64Supported) + { + return EstimateFrequencyArm(value); + } +#endif else { return EstimateFrequencyStd(value); @@ -104,6 +114,12 @@ public void Increment(T value) { IncrementAvx(value); } +#if NET6_0_OR_GREATER + else if (isa.IsArm64Supported) + { + IncrementArm(value); + } +#endif else { IncrementStd(value); @@ -127,7 +143,7 @@ private void EnsureCapacity(long maximumSize) #if NET6_0_OR_GREATER I isa = default; - if (isa.IsAvx2Supported) + if (isa.IsAvx2Supported || isa.IsArm64Supported) { // over alloc by 8 to give 64 bytes padding, tableAddr is then aligned to 64 bytes const int pad = 8; @@ -329,5 +345,94 @@ private unsafe void IncrementAvx(T value) } } #endif + +#if NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe void IncrementArm(T value) + { + int blockHash = Spread(comparer.GetHashCode(value)); + int counterHash = Rehash(blockHash); + int block = (blockHash & blockMask) << 3; + + Vector128 h = AdvSimd.ShiftArithmetic(Vector128.Create(counterHash), Vector128.Create(0, -8, -16, -24)); + Vector128 index = AdvSimd.And(AdvSimd.ShiftRightLogical(h, 1), Vector128.Create(0xf)); + Vector128 blockOffset = AdvSimd.Add(AdvSimd.Add(Vector128.Create(block), AdvSimd.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6)); + + long* tablePtr = tableAddr; + { + int t0 = AdvSimd.Extract(blockOffset, 0); + int t1 = AdvSimd.Extract(blockOffset, 1); + int t2 = AdvSimd.Extract(blockOffset, 2); + int t3 = AdvSimd.Extract(blockOffset, 3); + + Vector128 tableVectorA = Vector128.Create(AdvSimd.LoadVector64(tablePtr + t0), AdvSimd.LoadVector64(tablePtr + t1)); + Vector128 tableVectorB = Vector128.Create(AdvSimd.LoadVector64(tablePtr + t2), AdvSimd.LoadVector64(tablePtr + t3)); + + index = AdvSimd.ShiftLeftLogicalSaturate(index, 2); + + Vector128 longOffA = AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128.Zero, 0, index, 0), 2, index, 1); + Vector128 longOffB = AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128.Zero, 0, index, 2), 2, index, 3); + + Vector128 fifteen = Vector128.Create(0xfL); + Vector128 maskA = AdvSimd.ShiftArithmetic(fifteen, longOffA.AsInt64()); + Vector128 maskB = AdvSimd.ShiftArithmetic(fifteen, longOffB.AsInt64()); + + Vector128 maskedA = AdvSimd.Not(AdvSimd.Arm64.CompareEqual(AdvSimd.And(tableVectorA, maskA), maskA)); + Vector128 maskedB = AdvSimd.Not(AdvSimd.Arm64.CompareEqual(AdvSimd.And(tableVectorB, maskB), maskB)); + + var one = Vector128.Create(1L); + Vector128 incA = AdvSimd.And(maskedA, AdvSimd.ShiftArithmetic(one, longOffA.AsInt64())); + Vector128 incB = AdvSimd.And(maskedB, AdvSimd.ShiftArithmetic(one, longOffB.AsInt64())); + + tablePtr[t0] += AdvSimd.Extract(incA, 0); + tablePtr[t1] += AdvSimd.Extract(incA, 1); + tablePtr[t2] += AdvSimd.Extract(incB, 0); + tablePtr[t3] += AdvSimd.Extract(incB, 1); + + var max = AdvSimd.Arm64.MaxAcross(AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.MaxAcross(incA.AsInt32()), 1, AdvSimd.Arm64.MaxAcross(incB.AsInt32()), 0).AsInt16()); + + if (max.ToScalar() != 0 && (++size == sampleSize)) + { + Reset(); + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private unsafe int EstimateFrequencyArm(T value) + { + int blockHash = Spread(comparer.GetHashCode(value)); + int counterHash = Rehash(blockHash); + int block = (blockHash & blockMask) << 3; + + Vector128 h = AdvSimd.ShiftArithmetic(Vector128.Create(counterHash), Vector128.Create(0, -8, -16, -24)); + Vector128 index = AdvSimd.And(AdvSimd.ShiftRightLogical(h, 1), Vector128.Create(0xf)); + Vector128 blockOffset = AdvSimd.Add(AdvSimd.Add(Vector128.Create(block), AdvSimd.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6)); + + long* tablePtr = tableAddr; + { + Vector128 tableVectorA = Vector128.Create(AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 0)), AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 1))); + Vector128 tableVectorB = Vector128.Create(AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 2)), AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 3))); + + index = AdvSimd.ShiftLeftLogicalSaturate(index, 2); + + Vector128 indexA = AdvSimd.Negate(AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128.Zero, 0, index, 0), 2, index, 1)); + Vector128 indexB = AdvSimd.Negate(AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128.Zero, 0, index, 2), 2, index, 3)); + + var fifteen = Vector128.Create(0xfL); + Vector128 a = AdvSimd.And(AdvSimd.ShiftArithmetic(tableVectorA, indexA.AsInt64()), fifteen); + Vector128 b = AdvSimd.And(AdvSimd.ShiftArithmetic(tableVectorB, indexB.AsInt64()), fifteen); + + // Before: < 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, A, B, C, D, E, F > + // After: < 0, 1, 2, 3, 8, 9, A, B, 4, 5, 6, 7, C, D, E, F > + var min = AdvSimd.Arm64.VectorTableLookup(a.AsByte(), Vector128.Create(0x0B0A090803020100, 0xFFFFFFFFFFFFFFFF).AsByte()); + min = AdvSimd.Arm64.VectorTableLookupExtension(min, b.AsByte(), Vector128.Create(0xFFFFFFFFFFFFFFFF, 0x0B0A090803020100).AsByte()); + + var min32 = AdvSimd.Arm64.MinAcross(min.AsInt32()); + + return min32.ToScalar(); + } + } +#endif } }