Skip to content

Commit 4811a49

Browse files
committed
Add minor fixes to tests
1 parent d76da47 commit 4811a49

File tree

3 files changed

+54
-74
lines changed

3 files changed

+54
-74
lines changed

.gitignore

+4-4
Original file line numberDiff line numberDiff line change
@@ -352,10 +352,10 @@ MigrationBackup/
352352
# Gitignore template for Visual Studio Code
353353
# https://github.com/github/gitignore/blob/master/Global/VisualStudioCode.gitignore
354354
.vscode/*
355-
!.vscode/settings.json
356-
!.vscode/tasks.json
357-
!.vscode/launch.json
358-
!.vscode/extensions.json
355+
# !.vscode/settings.json
356+
# !.vscode/tasks.json
357+
# !.vscode/launch.json
358+
# !.vscode/extensions.json
359359
*.code-workspace
360360

361361
# Local History for Visual Studio Code

tests/CSRMultiplication.Tests/CSRMatrixTests.fs

+5-7
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,11 @@ module CSRMatrixTests =
1212

1313
type FloatMatrix =
1414
static member FloatSparseMatrix () =
15-
fun size ->
16-
Gen.oneof [
17-
Arb.Default.NormalFloat () |> Arb.toGen |> Gen.map float
18-
Gen.constant 0.
19-
]
20-
|> Gen.array2DOf
21-
|> Gen.sized
15+
Gen.oneof [
16+
Arb.Default.NormalFloat () |> Arb.toGen |> Gen.map float
17+
Gen.constant 0.
18+
]
19+
|> Gen.array2DOf
2220
|> Arb.fromGen
2321

2422
[<Property(Arbitrary=[| typeof<FloatMatrix> |])>]

tests/CSRMultiplication.Tests/SparseMatrixMultiplicationTests.fs

+45-63
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ open OpenCL.Net
99
open Brahma.OpenCL
1010
open SparseMatrixMultiplication
1111

12-
[<Properties(Verbose=true, MaxTest=100, EndSize=900)>]
1312
module SparseMatrixMultiplicationTests =
1413

1514
type MatrixMultiplicationPair =
@@ -64,41 +63,24 @@ module SparseMatrixMultiplicationTests =
6463
|> Gen.sized
6564
|> Arb.fromGen
6665

67-
// type SetupTestFixture() =
68-
// let provider =
69-
// try ComputeProvider.Create("INTEL*", DeviceType.Cpu)
70-
// with
71-
// | ex -> failwith ex.Message
72-
// let mutable commandQueue = new CommandQueue(provider, provider.Devices |> Seq.head)
73-
74-
// member this.CsrVectorMultiply = multiplySpMV provider commandQueue
75-
// member this.CsrDenseMultiply = multiplySpMM provider commandQueue
76-
// member this.CsrCscMultiply = multiplySpMSpM provider commandQueue
77-
// member this.CsrCsrMultiply = multiplySpMSpM2 provider commandQueue
78-
79-
// interface System.IDisposable with
80-
// member this.Dispose () =
81-
// commandQueue.Dispose ()
82-
// provider.CloseAllBuffers ()
83-
// provider.Dispose ()
84-
85-
let func<'a when 'a :> System.Collections.IEnumerable> (result: 'a) (expected: 'a) =
66+
let getFlattenedDiff<'a when 'a :> System.Collections.IEnumerable> (result: 'a) (expected: 'a) =
8667
let pairMap f (x, y) = f x, f y
8768
(result, expected)
8869
|> pairMap Seq.cast<float>
8970
||> Seq.zip
9071
|> Seq.map (fun (x, y) -> x - y)
9172

92-
let check<'a when 'a :> System.Collections.IEnumerable> (result: 'a) (expected: 'a) =
93-
func result expected
73+
let checkEquality<'a when 'a :> System.Collections.IEnumerable> (result: 'a) (expected: 'a) =
74+
getFlattenedDiff result expected
9475
|> Seq.forall (fun diff -> diff < 1e-8)
9576

9677
let getLabel<'a when 'a :> System.Collections.IEnumerable> (result: 'a) (expected: 'a) =
9778
sprintf "\n Total diff:\n %A\n Result:\n %A\n Expected:\n %A\n"
98-
(func result expected |> Seq.sum)
79+
(getFlattenedDiff result expected |> Seq.sum)
9980
result
10081
expected
10182

83+
[<Properties(Verbose=false, MaxTest=100, EndSize=1600)>]
10284
type Tests() =
10385
let matrixVectorMultiply (vector: float[]) (matrix: float[,]) =
10486
let rows = matrix |> Array2D.length1
@@ -139,64 +121,64 @@ module SparseMatrixMultiplicationTests =
139121
member this.``CSR x Vector multiplication should work correctly on nonempty and nonzero objects`` (matrix: float[,], vector: float[]) =
140122
let result = CSRMatrix.makeFromDenseMatrix matrix |> csrVectorMultiply vector
141123
let expected = matrix |> matrixVectorMultiply vector
142-
(check result expected) |@ (getLabel result expected)
124+
(checkEquality result expected) |@ (getLabel result expected)
143125

144126
[<Trait("Category", "dense")>]
145127
[<Property(Arbitrary=[| typeof<MatrixMultiplicationPair> |])>]
146128
member this.``CSR x Dense multiplication should work correctly on nonempty and nonzero matrices`` (left: float[,], right: float[,]) =
147129
let result = CSRMatrix.makeFromDenseMatrix left |> csrDenseMultiply right
148130
let expected = left |> matrixMatrixMultiply right
149-
(check result expected) |@ (getLabel result expected)
131+
(checkEquality result expected) |@ (getLabel result expected)
150132

151133
[<Trait("Category", "csc")>]
152134
[<Property(Arbitrary=[| typeof<MatrixMultiplicationPair> |])>]
153135
member this.``CSR x CSC multiplication should work correctly on nonempty and nonzero matrices`` (left: float[,], right: float[,]) =
154136
let result = CSRMatrix.makeFromDenseMatrix left |> csrCscMultiply (CSCMatrix.makeFromDenseMatrix right)
155137
let expected = left |> matrixMatrixMultiply right
156-
(check result expected) |@ (getLabel result expected)
138+
(checkEquality result expected) |@ (getLabel result expected)
157139

158140
[<Trait("Category", "csr")>]
159141
[<Property(Arbitrary=[| typeof<MatrixMultiplicationPair> |])>]
160142
member this.``CSR x CSR multiplication should work correctly on nonempty and nonzero matrices`` (left: float[,], right: float[,]) =
161143
let result = CSRMatrix.makeFromDenseMatrix left |> csrCsrMultiply (CSRMatrix.makeFromDenseMatrix right)
162144
let expected = left |> matrixMatrixMultiply right
163-
(check result expected) |@ (getLabel result expected)
164-
165-
// [<Trait("Category", "csr-cpu")>]
166-
// [<Property(Arbitrary=[| typeof<MatrixMultiplicationPair> |])>]
167-
// member this.``csr cpu`` (left: float[,], right: float[,]) =
168-
// let csrMultAlgo (csrMatrixRight: CSRMatrix.CSRMatrix) (csrMatrixLeft: CSRMatrix.CSRMatrix) =
169-
// let leftMatrixRowCount = csrMatrixLeft |> CSRMatrix.rowCount
170-
// let leftMatrixColumnCount = csrMatrixLeft |> CSRMatrix.columnCount
171-
// let rightMatrixRowCount = csrMatrixRight |> CSRMatrix.rowCount
172-
// let rightMatrixColumnCount = csrMatrixRight |> CSRMatrix.columnCount
173-
// if leftMatrixColumnCount <> rightMatrixRowCount then failwith "fail"
174-
175-
// let leftCsrValuesBuffer = csrMatrixLeft.GetValues
176-
// let leftCsrColumnsBuffer = csrMatrixLeft.GetColumns
177-
// let leftCsrRowPointersBuffer = csrMatrixLeft.GetRowPointers
178-
// let rightCsrValuesBuffer = csrMatrixRight.GetValues
179-
// let rightCsrColumnsBuffer = csrMatrixRight.GetColumns
180-
// let rightCsrRowPointersBuffer = csrMatrixRight.GetRowPointers
181-
182-
// let resultMatrix = Array2D.zeroCreate<float> leftMatrixRowCount rightMatrixColumnCount
183-
// for i in 0 .. rightMatrixRowCount - 1 do
184-
// for j in 0 .. leftMatrixRowCount - 1 do
185-
// for k in rightCsrRowPointersBuffer.[i] .. rightCsrRowPointersBuffer.[i + 1] - 1 do
186-
// let mutable localResultBuffer = resultMatrix.[j, rightCsrColumnsBuffer.[k]]
187-
// let mutable pointer = leftCsrRowPointersBuffer.[j]
188-
// while (pointer < leftCsrRowPointersBuffer.[j + 1] && leftCsrColumnsBuffer.[pointer] <= i) do
189-
// if leftCsrColumnsBuffer.[pointer] = i then
190-
// localResultBuffer <- localResultBuffer +
191-
// rightCsrValuesBuffer.[k] * leftCsrValuesBuffer.[pointer]
192-
// pointer <- pointer + 1
193-
// resultMatrix.[j, rightCsrColumnsBuffer.[k]] <- localResultBuffer
194-
195-
// resultMatrix
196-
197-
// let result = CSRMatrix.makeFromDenseMatrix left |> csrMultAlgo (CSRMatrix.makeFromDenseMatrix right)
198-
// let expected = left |> matrixMatrixMultiply right
199-
// result = expected |@ (sprintf "\n %A \n %A \n %A" (result |> Array2D.mapi (fun i j elem -> elem - expected.[i, j])) result expected)
145+
(checkEquality result expected) |@ (getLabel result expected)
146+
147+
[<Trait("Category", "csr-cpu")>]
148+
[<Property(Arbitrary=[| typeof<MatrixMultiplicationPair> |])>]
149+
member this.``CSR x CSR multiplication algo shoud work correctly on cpu`` (left: float[,], right: float[,]) =
150+
let csrMultAlgo (csrMatrixRight: CSRMatrix.CSRMatrix) (csrMatrixLeft: CSRMatrix.CSRMatrix) =
151+
let leftMatrixRowCount = csrMatrixLeft |> CSRMatrix.rowCount
152+
let leftMatrixColumnCount = csrMatrixLeft |> CSRMatrix.columnCount
153+
let rightMatrixRowCount = csrMatrixRight |> CSRMatrix.rowCount
154+
let rightMatrixColumnCount = csrMatrixRight |> CSRMatrix.columnCount
155+
if leftMatrixColumnCount <> rightMatrixRowCount then failwith "fail"
156+
157+
let leftCsrValuesBuffer = csrMatrixLeft.GetValues
158+
let leftCsrColumnsBuffer = csrMatrixLeft.GetColumns
159+
let leftCsrRowPointersBuffer = csrMatrixLeft.GetRowPointers
160+
let rightCsrValuesBuffer = csrMatrixRight.GetValues
161+
let rightCsrColumnsBuffer = csrMatrixRight.GetColumns
162+
let rightCsrRowPointersBuffer = csrMatrixRight.GetRowPointers
163+
164+
let resultMatrix = Array2D.zeroCreate<float> leftMatrixRowCount rightMatrixColumnCount
165+
for i in 0 .. rightMatrixRowCount - 1 do
166+
for j in 0 .. leftMatrixRowCount - 1 do
167+
for k in rightCsrRowPointersBuffer.[i] .. rightCsrRowPointersBuffer.[i + 1] - 1 do
168+
let mutable localResultBuffer = resultMatrix.[j, rightCsrColumnsBuffer.[k]]
169+
let mutable pointer = leftCsrRowPointersBuffer.[j]
170+
while (pointer < leftCsrRowPointersBuffer.[j + 1] && leftCsrColumnsBuffer.[pointer] <= i) do
171+
if leftCsrColumnsBuffer.[pointer] = i then
172+
localResultBuffer <- localResultBuffer +
173+
rightCsrValuesBuffer.[k] * leftCsrValuesBuffer.[pointer]
174+
pointer <- pointer + 1
175+
resultMatrix.[j, rightCsrColumnsBuffer.[k]] <- localResultBuffer
176+
177+
resultMatrix
178+
179+
let result = CSRMatrix.makeFromDenseMatrix left |> csrMultAlgo (CSRMatrix.makeFromDenseMatrix right)
180+
let expected = left |> matrixMatrixMultiply right
181+
(checkEquality result expected) |@ (getLabel result expected)
200182

201183
interface System.IDisposable with
202184
member this.Dispose () =

0 commit comments

Comments
 (0)