diff --git a/BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs b/BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs index b97bc19d..a6da43e2 100644 --- a/BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs +++ b/BitFaster.Caching.Benchmarks/Lfu/SketchFrequency.cs @@ -8,6 +8,7 @@ namespace BitFaster.Caching.Benchmarks.Lfu { [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [MemoryDiagnoser(displayGenColumns: false)] [HideColumns("Job", "Median", "RatioSD", "Alloc Ratio")] [ColumnChart(Title ="Sketch Frequency ({JOB})")] @@ -22,7 +23,7 @@ public class SketchFrequency private CmSketchCore blockStd; private CmSketchCore blockAvx; - [Params(32_768, 524_288, 8_388_608, 134_217_728)] + [Params(1024, 32_768, 524_288, 8_388_608, 134_217_728)] public int Size { get; set; } [GlobalSetup] @@ -45,7 +46,7 @@ public int FrequencyFlat() return count; } - [Benchmark(OperationsPerInvoke = iterations)] + // [Benchmark(OperationsPerInvoke = iterations)] public int FrequencyFlatAvx() { int count = 0; diff --git a/BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs b/BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs index eb005032..ca4bf3ce 100644 --- a/BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs +++ b/BitFaster.Caching.Benchmarks/Lfu/SketchIncrement.cs @@ -1,5 +1,6 @@  using System.Collections.Generic; +using Benchly; using BenchmarkDotNet.Attributes; using BenchmarkDotNet.Jobs; using BitFaster.Caching.Lfu; @@ -7,8 +8,10 @@ namespace BitFaster.Caching.Benchmarks.Lfu { [SimpleJob(RuntimeMoniker.Net60)] + [SimpleJob(RuntimeMoniker.Net80)] [MemoryDiagnoser(displayGenColumns: false)] [HideColumns("Job", "Median", "RatioSD", "Alloc Ratio")] + [ColumnChart(Title = "Sketch Increment ({JOB})")] public class SketchIncrement { const int iterations = 1_048_576; @@ -19,7 +22,7 @@ public class SketchIncrement private CmSketchCore blockStd; private CmSketchCore blockAvx; - [Params(32_768, 524_288, 8_388_608, 134_217_728)] + [Params(1024, 32_768, 524_288, 8_388_608, 134_217_728)] public int Size { get; set; } [GlobalSetup] @@ -41,7 +44,7 @@ public void IncFlat() } } - [Benchmark(OperationsPerInvoke = iterations)] + //[Benchmark(OperationsPerInvoke = iterations)] public void IncFlatAvx() { for (int i = 0; i < iterations; i++) diff --git a/BitFaster.Caching.UnitTests/Intrinsics.cs b/BitFaster.Caching.UnitTests/Intrinsics.cs index 312d78a1..ebbe194a 100644 --- a/BitFaster.Caching.UnitTests/Intrinsics.cs +++ b/BitFaster.Caching.UnitTests/Intrinsics.cs @@ -1,6 +1,10 @@ #if NETCOREAPP3_1_OR_GREATER using System.Runtime.Intrinsics.X86; #endif +#if NET6_0_OR_GREATER +using System.Runtime.Intrinsics.Arm; +#endif + using Xunit; namespace BitFaster.Caching.UnitTests @@ -10,8 +14,14 @@ public static class Intrinsics public static void SkipAvxIfNotSupported() { #if NETCOREAPP3_1_OR_GREATER + #if NET6_0_OR_GREATER + // when we are trying to test Avx2/Arm64, skip the test if it's not supported + Skip.If(typeof(I) == typeof(DetectIsa) && !(Avx2.IsSupported || AdvSimd.Arm64.IsSupported)); + #else // when we are trying to test Avx2, skip the test if it's not supported Skip.If(typeof(I) == typeof(DetectIsa) && !Avx2.IsSupported); + #endif + #else Skip.If(true); #endif diff --git a/BitFaster.Caching.UnitTests/Lfu/CmSketchTests.cs b/BitFaster.Caching.UnitTests/Lfu/CmSketchTests.cs index 85de5040..ba32b42e 100644 --- a/BitFaster.Caching.UnitTests/Lfu/CmSketchTests.cs +++ b/BitFaster.Caching.UnitTests/Lfu/CmSketchTests.cs @@ -6,12 +6,12 @@ namespace BitFaster.Caching.UnitTests.Lfu { - // Test with AVX2 if it is supported - public class CMSketchAvx2Tests : CmSketchTestBase + // Test with AVX2/ARM64 if it is supported + public class CMSketchIntrinsicsTests : CmSketchTestBase { } - // Test with AVX2 disabled + // Test with AVX2/ARM64 disabled public class CmSketchTests : CmSketchTestBase { } @@ -29,14 +29,23 @@ public CmSketchTestBase() public void Repro() { sketch = new CmSketchCore(1_048_576, EqualityComparer.Default); + var baseline = new CmSketchCore(1_048_576, EqualityComparer.Default); for (int i = 0; i < 1_048_576; i++) { if (i % 3 == 0) { sketch.Increment(i); + baseline.Increment(i); } } + + baseline.Size.Should().Be(sketch.Size); + + for (int i = 0; i < 1_048_576; i++) + { + sketch.EstimateFrequency(i).Should().Be(baseline.EstimateFrequency(i)); + } } diff --git a/BitFaster.Caching/Intrinsics.cs b/BitFaster.Caching/Intrinsics.cs index 8a1bd29a..45908a01 100644 --- a/BitFaster.Caching/Intrinsics.cs +++ b/BitFaster.Caching/Intrinsics.cs @@ -2,6 +2,10 @@ using System.Runtime.Intrinsics.X86; #endif +#if NET6_0 +using System.Runtime.Intrinsics.Arm; +#endif + namespace BitFaster.Caching { /// @@ -12,7 +16,14 @@ public interface IsaProbe /// /// Gets a value indicating whether AVX2 is supported. /// - bool IsAvx2Supported { get; } + bool IsAvx2Supported { get; } + +#if NET6_0_OR_GREATER + /// + /// Gets a value indicating whether Arm64 is supported. + /// + bool IsArm64Supported { get => false; } +#endif } /// @@ -25,7 +36,15 @@ public interface IsaProbe public bool IsAvx2Supported => false; #else /// - public bool IsAvx2Supported => Avx2.IsSupported; + public bool IsAvx2Supported => Avx2.IsSupported; +#endif + +#if NET6_0_OR_GREATER + /// + public bool IsArm64Supported => AdvSimd.Arm64.IsSupported; +#else + /// + public bool IsArm64Supported => false; #endif } @@ -35,6 +54,9 @@ public interface IsaProbe public readonly struct DisableHardwareIntrinsics : IsaProbe { /// - public bool IsAvx2Supported => false; + public bool IsAvx2Supported => false; + + /// + public bool IsArm64Supported => false; } } diff --git a/BitFaster.Caching/Lfu/CmSketchCore.cs b/BitFaster.Caching/Lfu/CmSketchCore.cs index de255840..9b13c2fd 100644 --- a/BitFaster.Caching/Lfu/CmSketchCore.cs +++ b/BitFaster.Caching/Lfu/CmSketchCore.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; #if !NETSTANDARD2_0 @@ -8,6 +9,10 @@ using System.Runtime.Intrinsics.X86; #endif +#if NET6_0_OR_GREATER +using System.Runtime.Intrinsics.Arm; +#endif + namespace BitFaster.Caching.Lfu { /// @@ -76,6 +81,12 @@ public int EstimateFrequency(T value) { return EstimateFrequencyAvx(value); } +#if NET6_0_OR_GREATER + else if (isa.IsArm64Supported) + { + return EstimateFrequencyArm(value); + } +#endif else { return EstimateFrequencyStd(value); @@ -99,6 +110,12 @@ public void Increment(T value) { IncrementAvx(value); } +#if NET6_0_OR_GREATER + else if (isa.IsArm64Supported) + { + IncrementArm(value); + } +#endif else { IncrementStd(value); @@ -329,5 +346,94 @@ private unsafe void IncrementAvx(T value) } } #endif + +#if NET6_0_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveOptimization | MethodImplOptions.AggressiveInlining)] + private unsafe void IncrementArm(T value) + { + int blockHash = Spread(comparer.GetHashCode(value)); + int counterHash = Rehash(blockHash); + int block = (blockHash & blockMask) << 3; + + Vector128 h = AdvSimd.ShiftArithmetic(Vector128.Create(counterHash), Vector128.Create(0, -8, -16, -24)); + Vector128 index = AdvSimd.And(AdvSimd.ShiftRightLogical(h, 1), Vector128.Create(0xf)); + Vector128 blockOffset = AdvSimd.Add(AdvSimd.Add(Vector128.Create(block), AdvSimd.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6)); + + fixed (long* tablePtr = table) + { + int t0 = AdvSimd.Extract(blockOffset, 0); + int t1 = AdvSimd.Extract(blockOffset, 1); + int t2 = AdvSimd.Extract(blockOffset, 2); + int t3 = AdvSimd.Extract(blockOffset, 3); + + Vector128 tableVectorA = Vector128.Create(AdvSimd.LoadVector64(tablePtr + t0), AdvSimd.LoadVector64(tablePtr + t1)); + Vector128 tableVectorB = Vector128.Create(AdvSimd.LoadVector64(tablePtr + t2), AdvSimd.LoadVector64(tablePtr + t3)); + + index = AdvSimd.ShiftLeftLogicalSaturate(index, 2); + + Vector128 longOffA = AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128.Zero, 0, index, 0), 2, index, 1); + Vector128 longOffB = AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128.Zero, 0, index, 2), 2, index, 3); + + Vector128 fifteen = Vector128.Create(0xfL); + Vector128 maskA = AdvSimd.ShiftArithmetic(fifteen, longOffA.AsInt64()); + Vector128 maskB = AdvSimd.ShiftArithmetic(fifteen, longOffB.AsInt64()); + + Vector128 maskedA = AdvSimd.Not(AdvSimd.Arm64.CompareEqual(AdvSimd.And(tableVectorA, maskA), maskA)); + Vector128 maskedB = AdvSimd.Not(AdvSimd.Arm64.CompareEqual(AdvSimd.And(tableVectorB, maskB), maskB)); + + var one = Vector128.Create(1L); + Vector128 incA = AdvSimd.And(maskedA, AdvSimd.ShiftArithmetic(one, longOffA.AsInt64())); + Vector128 incB = AdvSimd.And(maskedB, AdvSimd.ShiftArithmetic(one, longOffB.AsInt64())); + + tablePtr[t0] += AdvSimd.Extract(incA, 0); + tablePtr[t1] += AdvSimd.Extract(incA, 1); + tablePtr[t2] += AdvSimd.Extract(incB, 0); + tablePtr[t3] += AdvSimd.Extract(incB, 1); + + var max = AdvSimd.Arm64.MaxAcross(AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.MaxAcross(incA.AsInt32()), 1, AdvSimd.Arm64.MaxAcross(incB.AsInt32()), 0).AsInt16()); + + if (max.ToScalar() != 0 && (++size == sampleSize)) + { + Reset(); + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveOptimization | MethodImplOptions.AggressiveInlining)] + private unsafe int EstimateFrequencyArm(T value) + { + int blockHash = Spread(comparer.GetHashCode(value)); + int counterHash = Rehash(blockHash); + int block = (blockHash & blockMask) << 3; + + Vector128 h = AdvSimd.ShiftArithmetic(Vector128.Create(counterHash), Vector128.Create(0, -8, -16, -24)); + Vector128 index = AdvSimd.And(AdvSimd.ShiftRightLogical(h, 1), Vector128.Create(0xf)); + Vector128 blockOffset = AdvSimd.Add(AdvSimd.Add(Vector128.Create(block), AdvSimd.And(h, Vector128.Create(1))), Vector128.Create(0, 2, 4, 6)); + + fixed (long* tablePtr = table) + { + Vector128 tableVectorA = Vector128.Create(AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 0)), AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 1))); + Vector128 tableVectorB = Vector128.Create(AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 2)), AdvSimd.LoadVector64(tablePtr + AdvSimd.Extract(blockOffset, 3))); + + index = AdvSimd.ShiftLeftLogicalSaturate(index, 2); + + Vector128 indexA = AdvSimd.Negate(AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128.Zero, 0, index, 0), 2, index, 1)); + Vector128 indexB = AdvSimd.Negate(AdvSimd.Arm64.InsertSelectedScalar(AdvSimd.Arm64.InsertSelectedScalar(Vector128.Zero, 0, index, 2), 2, index, 3)); + + var fifteen = Vector128.Create(0xfL); + Vector128 a = AdvSimd.And(AdvSimd.ShiftArithmetic(tableVectorA, indexA.AsInt64()), fifteen); + Vector128 b = AdvSimd.And(AdvSimd.ShiftArithmetic(tableVectorB, indexB.AsInt64()), fifteen); + + // Before: < 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, A, B, C, D, E, F > + // After: < 0, 1, 2, 3, 8, 9, A, B, 4, 5, 6, 7, C, D, E, F > + var min = AdvSimd.Arm64.VectorTableLookup(a.AsByte(), Vector128.Create(0x0B0A090803020100, 0xFFFFFFFFFFFFFFFF).AsByte()); + min = AdvSimd.Arm64.VectorTableLookupExtension(min, b.AsByte(), Vector128.Create(0xFFFFFFFFFFFFFFFF, 0x0B0A090803020100).AsByte()); + + var min32 = AdvSimd.Arm64.MinAcross(min.AsInt32()); + + return min32.ToScalar(); + } + } +#endif } }