@@ -9,7 +9,6 @@ open OpenCL.Net
99open Brahma.OpenCL
1010open SparseMatrixMultiplication
1111
12- [<Properties( Verbose= true , MaxTest= 100 , EndSize= 900 ) >]
1312module SparseMatrixMultiplicationTests =
1413
1514 type MatrixMultiplicationPair =
@@ -64,41 +63,24 @@ module SparseMatrixMultiplicationTests =
6463 |> Gen.sized
6564 |> Arb.fromGen
6665
67- // type SetupTestFixture() =
68- // let provider =
69- // try ComputeProvider.Create("INTEL*", DeviceType.Cpu)
70- // with
71- // | ex -> failwith ex.Message
72- // let mutable commandQueue = new CommandQueue(provider, provider.Devices |> Seq.head)
73-
74- // member this.CsrVectorMultiply = multiplySpMV provider commandQueue
75- // member this.CsrDenseMultiply = multiplySpMM provider commandQueue
76- // member this.CsrCscMultiply = multiplySpMSpM provider commandQueue
77- // member this.CsrCsrMultiply = multiplySpMSpM2 provider commandQueue
78-
79- // interface System.IDisposable with
80- // member this.Dispose () =
81- // commandQueue.Dispose ()
82- // provider.CloseAllBuffers ()
83- // provider.Dispose ()
84-
85- let func < 'a when 'a :> System.Collections.IEnumerable > ( result : 'a ) ( expected : 'a ) =
66+ let getFlattenedDiff < 'a when 'a :> System.Collections.IEnumerable > ( result : 'a ) ( expected : 'a ) =
8667 let pairMap f ( x , y ) = f x, f y
8768 ( result, expected)
8869 |> pairMap Seq.cast< float>
8970 ||> Seq.zip
9071 |> Seq.map ( fun ( x , y ) -> x - y)
9172
92- let check < 'a when 'a :> System.Collections.IEnumerable > ( result : 'a ) ( expected : 'a ) =
93- func result expected
73+ let checkEquality < 'a when 'a :> System.Collections.IEnumerable > ( result : 'a ) ( expected : 'a ) =
74+ getFlattenedDiff result expected
9475 |> Seq.forall ( fun diff -> diff < 1e-8 )
9576
9677 let getLabel < 'a when 'a :> System.Collections.IEnumerable > ( result : 'a ) ( expected : 'a ) =
9778 sprintf " \n Total diff:\n %A \n Result:\n %A \n Expected:\n %A \n "
98- ( func result expected |> Seq.sum)
79+ ( getFlattenedDiff result expected |> Seq.sum)
9980 result
10081 expected
10182
83+ [<Properties( Verbose= false , MaxTest= 100 , EndSize= 1600 ) >]
10284 type Tests () =
10385 let matrixVectorMultiply ( vector : float []) ( matrix : float [,]) =
10486 let rows = matrix |> Array2D.length1
@@ -139,64 +121,64 @@ module SparseMatrixMultiplicationTests =
139121 member this. ``CSR x Vector multiplication should work correctly on nonempty and nonzero objects`` ( matrix : float [,], vector : float []) =
140122 let result = CSRMatrix.makeFromDenseMatrix matrix |> csrVectorMultiply vector
141123 let expected = matrix |> matrixVectorMultiply vector
142- ( check result expected) |@ ( getLabel result expected)
124+ ( checkEquality result expected) |@ ( getLabel result expected)
143125
144126 [<Trait( " Category" , " dense" ) >]
145127 [<Property( Arbitrary=[| typeof< MatrixMultiplicationPair> |]) >]
146128 member this. ``CSR x Dense multiplication should work correctly on nonempty and nonzero matrices`` ( left : float [,], right : float [,]) =
147129 let result = CSRMatrix.makeFromDenseMatrix left |> csrDenseMultiply right
148130 let expected = left |> matrixMatrixMultiply right
149- ( check result expected) |@ ( getLabel result expected)
131+ ( checkEquality result expected) |@ ( getLabel result expected)
150132
151133 [<Trait( " Category" , " csc" ) >]
152134 [<Property( Arbitrary=[| typeof< MatrixMultiplicationPair> |]) >]
153135 member this. ``CSR x CSC multiplication should work correctly on nonempty and nonzero matrices`` ( left : float [,], right : float [,]) =
154136 let result = CSRMatrix.makeFromDenseMatrix left |> csrCscMultiply ( CSCMatrix.makeFromDenseMatrix right)
155137 let expected = left |> matrixMatrixMultiply right
156- ( check result expected) |@ ( getLabel result expected)
138+ ( checkEquality result expected) |@ ( getLabel result expected)
157139
158140 [<Trait( " Category" , " csr" ) >]
159141 [<Property( Arbitrary=[| typeof< MatrixMultiplicationPair> |]) >]
160142 member this. ``CSR x CSR multiplication should work correctly on nonempty and nonzero matrices`` ( left : float [,], right : float [,]) =
161143 let result = CSRMatrix.makeFromDenseMatrix left |> csrCsrMultiply ( CSRMatrix.makeFromDenseMatrix right)
162144 let expected = left |> matrixMatrixMultiply right
163- ( check result expected) |@ ( getLabel result expected)
164-
165- // [<Trait("Category", "csr-cpu")>]
166- // [<Property(Arbitrary=[| typeof<MatrixMultiplicationPair> |])>]
167- // member this.``csr cpu`` (left: float[,], right: float[,]) =
168- // let csrMultAlgo (csrMatrixRight: CSRMatrix.CSRMatrix) (csrMatrixLeft: CSRMatrix.CSRMatrix) =
169- // let leftMatrixRowCount = csrMatrixLeft |> CSRMatrix.rowCount
170- // let leftMatrixColumnCount = csrMatrixLeft |> CSRMatrix.columnCount
171- // let rightMatrixRowCount = csrMatrixRight |> CSRMatrix.rowCount
172- // let rightMatrixColumnCount = csrMatrixRight |> CSRMatrix.columnCount
173- // if leftMatrixColumnCount <> rightMatrixRowCount then failwith "fail"
174-
175- // let leftCsrValuesBuffer = csrMatrixLeft.GetValues
176- // let leftCsrColumnsBuffer = csrMatrixLeft.GetColumns
177- // let leftCsrRowPointersBuffer = csrMatrixLeft.GetRowPointers
178- // let rightCsrValuesBuffer = csrMatrixRight.GetValues
179- // let rightCsrColumnsBuffer = csrMatrixRight.GetColumns
180- // let rightCsrRowPointersBuffer = csrMatrixRight.GetRowPointers
181-
182- // let resultMatrix = Array2D.zeroCreate<float> leftMatrixRowCount rightMatrixColumnCount
183- // for i in 0 .. rightMatrixRowCount - 1 do
184- // for j in 0 .. leftMatrixRowCount - 1 do
185- // for k in rightCsrRowPointersBuffer.[i] .. rightCsrRowPointersBuffer.[i + 1] - 1 do
186- // let mutable localResultBuffer = resultMatrix.[j, rightCsrColumnsBuffer.[k]]
187- // let mutable pointer = leftCsrRowPointersBuffer.[j]
188- // while (pointer < leftCsrRowPointersBuffer.[j + 1] && leftCsrColumnsBuffer.[pointer] <= i) do
189- // if leftCsrColumnsBuffer.[pointer] = i then
190- // localResultBuffer <- localResultBuffer +
191- // rightCsrValuesBuffer.[k] * leftCsrValuesBuffer.[pointer]
192- // pointer <- pointer + 1
193- // resultMatrix.[j, rightCsrColumnsBuffer.[k]] <- localResultBuffer
194-
195- // resultMatrix
196-
197- // let result = CSRMatrix.makeFromDenseMatrix left |> csrMultAlgo (CSRMatrix.makeFromDenseMatrix right)
198- // let expected = left |> matrixMatrixMultiply right
199- // result = expected |@ (sprintf "\n %A \n %A \n %A" (result |> Array2D.mapi (fun i j elem -> elem - expected.[i, j])) result expected)
145+ ( checkEquality result expected) |@ ( getLabel result expected)
146+
147+ [<Trait( " Category" , " csr-cpu" ) >]
148+ [<Property( Arbitrary=[| typeof< MatrixMultiplicationPair> |]) >]
149+ member this. ``CSR x CSR multiplication algo shoud work correctly on cpu`` ( left : float [,], right : float [,]) =
150+ let csrMultAlgo ( csrMatrixRight : CSRMatrix.CSRMatrix ) ( csrMatrixLeft : CSRMatrix.CSRMatrix ) =
151+ let leftMatrixRowCount = csrMatrixLeft |> CSRMatrix.rowCount
152+ let leftMatrixColumnCount = csrMatrixLeft |> CSRMatrix.columnCount
153+ let rightMatrixRowCount = csrMatrixRight |> CSRMatrix.rowCount
154+ let rightMatrixColumnCount = csrMatrixRight |> CSRMatrix.columnCount
155+ if leftMatrixColumnCount <> rightMatrixRowCount then failwith " fail"
156+
157+ let leftCsrValuesBuffer = csrMatrixLeft.GetValues
158+ let leftCsrColumnsBuffer = csrMatrixLeft.GetColumns
159+ let leftCsrRowPointersBuffer = csrMatrixLeft.GetRowPointers
160+ let rightCsrValuesBuffer = csrMatrixRight.GetValues
161+ let rightCsrColumnsBuffer = csrMatrixRight.GetColumns
162+ let rightCsrRowPointersBuffer = csrMatrixRight.GetRowPointers
163+
164+ let resultMatrix = Array2D.zeroCreate< float> leftMatrixRowCount rightMatrixColumnCount
165+ for i in 0 .. rightMatrixRowCount - 1 do
166+ for j in 0 .. leftMatrixRowCount - 1 do
167+ for k in rightCsrRowPointersBuffer.[ i] .. rightCsrRowPointersBuffer.[ i + 1 ] - 1 do
168+ let mutable localResultBuffer = resultMatrix.[ j, rightCsrColumnsBuffer.[ k]]
169+ let mutable pointer = leftCsrRowPointersBuffer.[ j]
170+ while ( pointer < leftCsrRowPointersBuffer.[ j + 1 ] && leftCsrColumnsBuffer.[ pointer] <= i) do
171+ if leftCsrColumnsBuffer.[ pointer] = i then
172+ localResultBuffer <- localResultBuffer +
173+ rightCsrValuesBuffer.[ k] * leftCsrValuesBuffer.[ pointer]
174+ pointer <- pointer + 1
175+ resultMatrix.[ j, rightCsrColumnsBuffer.[ k]] <- localResultBuffer
176+
177+ resultMatrix
178+
179+ let result = CSRMatrix.makeFromDenseMatrix left |> csrMultAlgo ( CSRMatrix.makeFromDenseMatrix right)
180+ let expected = left |> matrixMatrixMultiply right
181+ ( checkEquality result expected) |@ ( getLabel result expected)
200182
201183 interface System.IDisposable with
202184 member this.Dispose () =
0 commit comments