Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit b362106

Browse files
committedMar 20, 2025
Update path condition. Refactor.
1 parent 64b5bf5 commit b362106

File tree

1 file changed

+145
-152
lines changed

1 file changed

+145
-152
lines changed
 

‎VSharp.Explorer/AISearcher.fs

+145-152
Original file line numberDiff line numberDiff line change
@@ -17,49 +17,51 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
1717
| None -> 0u<step>
1818
| Some options -> options.stepsToPlay
1919

20-
let mutable lastCollectedStatistics =
21-
Statistics ()
20+
let mutable lastCollectedStatistics = Statistics()
2221
let mutable defaultSearcherSteps = 0u<step>
23-
let mutable (gameState: Option<GameState>) =
24-
None
25-
let mutable useDefaultSearcher =
26-
stepsToSwitchToAI > 0u<step>
22+
let mutable (gameState: Option<GameState>) = None
23+
let mutable useDefaultSearcher = stepsToSwitchToAI > 0u<step>
2724
let mutable afterFirstAIPeek = false
28-
let mutable incorrectPredictedStateId =
29-
false
25+
let mutable incorrectPredictedStateId = false
26+
3027
let defaultSearcher =
3128
match aiAgentTrainingOptions with
32-
| None -> BFSSearcher () :> IForwardSearcher
29+
| None -> BFSSearcher() :> IForwardSearcher
3330
| Some options ->
3431
match options.defaultSearchStrategy with
35-
| BFSMode -> BFSSearcher () :> IForwardSearcher
36-
| DFSMode -> DFSSearcher () :> IForwardSearcher
32+
| BFSMode -> BFSSearcher() :> IForwardSearcher
33+
| DFSMode -> DFSSearcher() :> IForwardSearcher
3734
| x -> failwithf $"Unexpected default searcher {x}. DFS and BFS supported for now."
35+
3836
let mutable stepsPlayed = 0u<step>
37+
3938
let isInAIMode () =
4039
(not useDefaultSearcher) && afterFirstAIPeek
41-
let q = ResizeArray<_> ()
42-
let availableStates = HashSet<_> ()
40+
41+
let q = ResizeArray<_>()
42+
let availableStates = HashSet<_>()
43+
4344
let updateGameState (delta: GameState) =
4445
match gameState with
4546
| None -> gameState <- Some delta
4647
| Some s ->
47-
let updatedBasicBlocks =
48-
delta.GraphVertices |> Array.map (fun b -> b.Id) |> HashSet
49-
let updatedStates =
50-
delta.States |> Array.map (fun s -> s.Id) |> HashSet
48+
let updatedBasicBlocks = delta.GraphVertices |> Array.map (fun b -> b.Id) |> HashSet
49+
let updatedStates = delta.States |> Array.map (fun s -> s.Id) |> HashSet
50+
5151
let vertices =
5252
s.GraphVertices
5353
|> Array.filter (fun v -> updatedBasicBlocks.Contains v.Id |> not)
5454
|> ResizeArray<_>
55+
5556
vertices.AddRange delta.GraphVertices
57+
5658
let edges =
5759
s.Map
5860
|> Array.filter (fun e -> updatedBasicBlocks.Contains e.VertexFrom |> not)
5961
|> ResizeArray<_>
62+
6063
edges.AddRange delta.Map
61-
let activeStates =
62-
vertices |> Seq.collect (fun v -> v.States) |> HashSet
64+
let activeStates = vertices |> Seq.collect (fun v -> v.States) |> HashSet
6365

6466
let states =
6567
let part1 =
@@ -69,82 +71,93 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
6971

7072
part1.AddRange delta.States
7173

72-
part1.ToArray ()
74+
part1.ToArray()
7375
|> Array.map (fun s ->
74-
State (
76+
State(
7577
s.Id,
7678
s.Position,
77-
s.PathConditionSize,
79+
s.PathCondition,
7880
s.VisitedAgainVertices,
7981
s.VisitedNotCoveredVerticesInZone,
8082
s.VisitedNotCoveredVerticesOutOfZone,
8183
s.StepWhenMovedLastTime,
8284
s.InstructionsVisitedInCurrentBlock,
8385
s.History,
8486
s.Children |> Array.filter activeStates.Contains
85-
)
86-
)
87+
))
88+
89+
let pathConditionVertices =
90+
ResizeArray<PathConditionVertex> s.PathConditionVertices
91+
92+
pathConditionVertices.AddRange delta.PathConditionVertices
8793

88-
gameState <- Some <| GameState (vertices.ToArray (), states, edges.ToArray ())
94+
gameState <-
95+
Some
96+
<| GameState(vertices.ToArray(), states, pathConditionVertices.ToArray(), edges.ToArray())
8997

9098

9199
let init states =
92100
q.AddRange states
93101
defaultSearcher.Init q
94102
states |> Seq.iter (availableStates.Add >> ignore)
103+
95104
let reset () =
96-
defaultSearcher.Reset ()
105+
defaultSearcher.Reset()
97106
defaultSearcherSteps <- 0u<step>
98-
lastCollectedStatistics <- Statistics ()
107+
lastCollectedStatistics <- Statistics()
99108
gameState <- None
100109
afterFirstAIPeek <- false
101110
incorrectPredictedStateId <- false
102111
useDefaultSearcher <- stepsToSwitchToAI > 0u<step>
103-
q.Clear ()
104-
availableStates.Clear ()
105-
let update (parent, newSates) =
112+
q.Clear()
113+
availableStates.Clear()
114+
115+
let update (parent, newStates) =
106116
if useDefaultSearcher then
107-
defaultSearcher.Update (parent, newSates)
108-
newSates |> Seq.iter (availableStates.Add >> ignore)
117+
defaultSearcher.Update(parent, newStates)
118+
119+
newStates |> Seq.iter (availableStates.Add >> ignore)
120+
109121
let remove state =
110122
if useDefaultSearcher then
111123
defaultSearcher.Remove state
124+
112125
let removed = availableStates.Remove state
113126
assert removed
127+
114128
for bb in state._history do
115129
bb.Key.AssociatedStates.Remove state |> ignore
116130

117-
let inTrainMode =
118-
aiAgentTrainingOptions.IsSome
131+
let inTrainMode = aiAgentTrainingOptions.IsSome
119132

120133
let pick selector =
121134
if useDefaultSearcher then
122135
defaultSearcherSteps <- defaultSearcherSteps + 1u<step>
136+
123137
if Seq.length availableStates > 0 then
124-
let gameStateDelta =
125-
collectGameStateDelta ()
138+
let gameStateDelta = collectGameStateDelta ()
126139
updateGameState gameStateDelta
127-
let statistics =
128-
computeStatistics gameState.Value
129-
Application.applicationGraphDelta.Clear ()
140+
let statistics = computeStatistics gameState.Value
141+
Application.applicationGraphDelta.Clear()
130142
lastCollectedStatistics <- statistics
131143
useDefaultSearcher <- defaultSearcherSteps < stepsToSwitchToAI
132-
defaultSearcher.Pick ()
144+
145+
defaultSearcher.Pick()
133146
elif Seq.length availableStates = 0 then
134147
None
135148
elif Seq.length availableStates = 1 then
136-
Some (Seq.head availableStates)
149+
Some(Seq.head availableStates)
137150
else
138-
let gameStateDelta =
139-
collectGameStateDelta ()
151+
let gameStateDelta = collectGameStateDelta ()
140152
updateGameState gameStateDelta
141-
let statistics =
142-
computeStatistics gameState.Value
153+
let statistics = computeStatistics gameState.Value
154+
143155
if isInAIMode () then
144-
let reward =
145-
computeReward lastCollectedStatistics statistics
146-
oracle.Feedback (Feedback.MoveReward reward)
147-
Application.applicationGraphDelta.Clear ()
156+
let reward = computeReward lastCollectedStatistics statistics
157+
oracle.Feedback(Feedback.MoveReward reward)
158+
159+
Application.applicationGraphDelta.Clear()
160+
148161
if inTrainMode && stepsToPlay = stepsPlayed then
149162
None
150163
else
@@ -153,17 +166,18 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
153166
gameStateDelta
154167
else
155168
gameState.Value
169+
156170
let stateId = oracle.Predict toPredict
157171
afterFirstAIPeek <- true
158-
let state =
159-
availableStates |> Seq.tryFind (fun s -> s.internalId = stateId)
172+
let state = availableStates |> Seq.tryFind (fun s -> s.internalId = stateId)
160173
lastCollectedStatistics <- statistics
161174
stepsPlayed <- stepsPlayed + 1u<step>
175+
162176
match state with
163177
| Some state -> Some state
164178
| None ->
165179
incorrectPredictedStateId <- true
166-
oracle.Feedback (Feedback.IncorrectPredictedStateId stateId)
180+
oracle.Feedback(Feedback.IncorrectPredictedStateId stateId)
167181
None
168182

169183
new(pathToONNX: string, useGPU: bool, optimize: bool) =
@@ -174,40 +188,36 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
174188
let createOracle (pathToONNX: string) =
175189
let sessionOptions =
176190
if useGPU then
177-
SessionOptions.MakeSessionOptionWithCudaProvider (0)
191+
SessionOptions.MakeSessionOptionWithCudaProvider(0)
178192
else
179-
new SessionOptions ()
193+
new SessionOptions()
180194

181195
if optimize then
182196
sessionOptions.ExecutionMode <- ExecutionMode.ORT_PARALLEL
183197
sessionOptions.GraphOptimizationLevel <- GraphOptimizationLevel.ORT_ENABLE_ALL
184198
else
185199
sessionOptions.GraphOptimizationLevel <- GraphOptimizationLevel.ORT_ENABLE_BASIC
186200

187-
let session =
188-
new InferenceSession (pathToONNX, sessionOptions)
189-
let runOptions = new RunOptions ()
201+
let session = new InferenceSession(pathToONNX, sessionOptions)
202+
let runOptions = new RunOptions()
190203
let feedback (x: Feedback) = ()
191204

192205
let predict (gameState: GameState) =
193-
let stateIds =
194-
Dictionary<uint<stateId>, int> ()
195-
let verticesIds =
196-
Dictionary<uint<basicBlockGlobalId>, int> ()
206+
let stateIds = Dictionary<uint<stateId>, int>()
207+
let verticesIds = Dictionary<uint<basicBlockGlobalId>, int>()
197208

198209
let networkInput =
199-
let res = Dictionary<_, _> ()
210+
let res = Dictionary<_, _>()
211+
200212
let gameVertices =
201-
let shape =
202-
[|
203-
int64 gameState.GraphVertices.Length
204-
numOfVertexAttributes
205-
|]
213+
let shape = [| int64 gameState.GraphVertices.Length; numOfVertexAttributes |]
214+
206215
let attributes =
207216
Array.zeroCreate (gameState.GraphVertices.Length * numOfVertexAttributes)
217+
208218
for i in 0 .. gameState.GraphVertices.Length - 1 do
209219
let v = gameState.GraphVertices.[i]
210-
verticesIds.Add (v.Id, i)
220+
verticesIds.Add(v.Id, i)
211221
let j = i * numOfVertexAttributes
212222
attributes.[j] <- float32 <| if v.InCoverageZone then 1u else 0u
213223
attributes.[j + 1] <- float32 <| v.BasicBlockSize
@@ -216,111 +226,97 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
216226
attributes.[j + 4] <- float32 <| if v.TouchedByState then 1u else 0u
217227
attributes.[j + 5] <- float32 <| if v.ContainsCall then 1u else 0u
218228
attributes.[j + 6] <- float32 <| if v.ContainsThrow then 1u else 0u
219-
OrtValue.CreateTensorValueFromMemory (attributes, shape)
229+
230+
OrtValue.CreateTensorValueFromMemory(attributes, shape)
220231

221232
let states, numOfParentOfEdges, numOfHistoryEdges =
222233
let mutable numOfParentOfEdges = 0
223234
let mutable numOfHistoryEdges = 0
224-
let shape =
225-
[|
226-
int64 gameState.States.Length
227-
numOfStateAttributes
228-
|]
229-
let attributes =
230-
Array.zeroCreate (gameState.States.Length * numOfStateAttributes)
235+
let shape = [| int64 gameState.States.Length; numOfStateAttributes |]
236+
let attributes = Array.zeroCreate (gameState.States.Length * numOfStateAttributes)
237+
231238
for i in 0 .. gameState.States.Length - 1 do
232239
let v = gameState.States.[i]
233240
numOfHistoryEdges <- numOfHistoryEdges + v.History.Length
234241
numOfParentOfEdges <- numOfParentOfEdges + v.Children.Length
235-
stateIds.Add (v.Id, i)
242+
stateIds.Add(v.Id, i)
236243
let j = i * numOfStateAttributes
237244
attributes.[j] <- float32 v.Position
238-
attributes.[j + 1] <- float32 v.PathConditionSize
245+
// TODO: Support path condition
246+
// attributes.[j + 1] <- float32 v.PathConditionSize
239247
attributes.[j + 2] <- float32 v.VisitedAgainVertices
240248
attributes.[j + 3] <- float32 v.VisitedNotCoveredVerticesInZone
241249
attributes.[j + 4] <- float32 v.VisitedNotCoveredVerticesOutOfZone
242250
attributes.[j + 5] <- float32 v.StepWhenMovedLastTime
243251
attributes.[j + 6] <- float32 v.InstructionsVisitedInCurrentBlock
244-
OrtValue.CreateTensorValueFromMemory (attributes, shape), numOfParentOfEdges, numOfHistoryEdges
252+
253+
OrtValue.CreateTensorValueFromMemory(attributes, shape), numOfParentOfEdges, numOfHistoryEdges
245254

246255
let vertexToVertexEdgesIndex, vertexToVertexEdgesAttributes =
247-
let shapeOfIndex =
248-
[| 2L ; gameState.Map.Length |]
249-
let shapeOfAttributes =
250-
[| int64 gameState.Map.Length |]
251-
let index =
252-
Array.zeroCreate (2 * gameState.Map.Length)
253-
let attributes =
254-
Array.zeroCreate gameState.Map.Length
256+
let shapeOfIndex = [| 2L; gameState.Map.Length |]
257+
let shapeOfAttributes = [| int64 gameState.Map.Length |]
258+
let index = Array.zeroCreate (2 * gameState.Map.Length)
259+
let attributes = Array.zeroCreate gameState.Map.Length
260+
255261
gameState.Map
256262
|> Array.iteri (fun i e ->
257263
index[i] <- int64 verticesIds[e.VertexFrom]
258264
index[gameState.Map.Length + i] <- int64 verticesIds[e.VertexTo]
259-
attributes[i] <- int64 e.Label.Token
260-
)
265+
attributes[i] <- int64 e.Label.Token)
261266

262-
OrtValue.CreateTensorValueFromMemory (index, shapeOfIndex),
263-
OrtValue.CreateTensorValueFromMemory (attributes, shapeOfAttributes)
267+
OrtValue.CreateTensorValueFromMemory(index, shapeOfIndex),
268+
OrtValue.CreateTensorValueFromMemory(attributes, shapeOfAttributes)
264269

265270
let historyEdgesIndex_vertexToState, historyEdgesAttributes, parentOfEdges =
266-
let shapeOfParentOf =
267-
[| 2L ; numOfParentOfEdges |]
268-
let parentOf =
269-
Array.zeroCreate (2 * numOfParentOfEdges)
270-
let shapeOfHistory =
271-
[| 2L ; numOfHistoryEdges |]
272-
let historyIndex_vertexToState =
273-
Array.zeroCreate (2 * numOfHistoryEdges)
271+
let shapeOfParentOf = [| 2L; numOfParentOfEdges |]
272+
let parentOf = Array.zeroCreate (2 * numOfParentOfEdges)
273+
let shapeOfHistory = [| 2L; numOfHistoryEdges |]
274+
let historyIndex_vertexToState = Array.zeroCreate (2 * numOfHistoryEdges)
275+
274276
let shapeOfHistoryAttributes =
275-
[|
276-
int64 numOfHistoryEdges
277-
int64 numOfHistoryEdgeAttributes
278-
|]
279-
let historyAttributes =
280-
Array.zeroCreate (2 * numOfHistoryEdges)
277+
[| int64 numOfHistoryEdges; int64 numOfHistoryEdgeAttributes |]
278+
279+
let historyAttributes = Array.zeroCreate (2 * numOfHistoryEdges)
281280
let mutable firstFreePositionInParentsOf = 0
282-
let mutable firstFreePositionInHistoryIndex =
283-
0
284-
let mutable firstFreePositionInHistoryAttributes =
285-
0
281+
let mutable firstFreePositionInHistoryIndex = 0
282+
let mutable firstFreePositionInHistoryAttributes = 0
283+
286284
gameState.States
287285
|> Array.iter (fun state ->
288286
state.Children
289287
|> Array.iteri (fun i children ->
290288
let j = firstFreePositionInParentsOf + i
291289
parentOf[j] <- int64 stateIds[state.Id]
292-
parentOf[numOfParentOfEdges + j] <- int64 stateIds[children]
293-
)
290+
parentOf[numOfParentOfEdges + j] <- int64 stateIds[children])
291+
294292
firstFreePositionInParentsOf <- firstFreePositionInParentsOf + state.Children.Length
293+
295294
state.History
296295
|> Array.iteri (fun i historyElem ->
297296
let j = firstFreePositionInHistoryIndex + i
298297
historyIndex_vertexToState[j] <- int64 verticesIds[historyElem.GraphVertexId]
299298
historyIndex_vertexToState[numOfHistoryEdges + j] <- int64 stateIds[state.Id]
300299

301-
let j =
302-
firstFreePositionInHistoryAttributes + numOfHistoryEdgeAttributes * i
300+
let j = firstFreePositionInHistoryAttributes + numOfHistoryEdgeAttributes * i
303301
historyAttributes[j] <- int64 historyElem.NumOfVisits
304-
historyAttributes[j + 1] <- int64 historyElem.StepWhenVisitedLastTime
305-
)
302+
historyAttributes[j + 1] <- int64 historyElem.StepWhenVisitedLastTime)
303+
306304
firstFreePositionInHistoryIndex <- firstFreePositionInHistoryIndex + state.History.Length
305+
307306
firstFreePositionInHistoryAttributes <-
308307
firstFreePositionInHistoryAttributes
309-
+ numOfHistoryEdgeAttributes * state.History.Length
310-
)
308+
+ numOfHistoryEdgeAttributes * state.History.Length)
311309

312-
OrtValue.CreateTensorValueFromMemory (historyIndex_vertexToState, shapeOfHistory),
313-
OrtValue.CreateTensorValueFromMemory (historyAttributes, shapeOfHistoryAttributes),
314-
OrtValue.CreateTensorValueFromMemory (parentOf, shapeOfParentOf)
310+
OrtValue.CreateTensorValueFromMemory(historyIndex_vertexToState, shapeOfHistory),
311+
OrtValue.CreateTensorValueFromMemory(historyAttributes, shapeOfHistoryAttributes),
312+
OrtValue.CreateTensorValueFromMemory(parentOf, shapeOfParentOf)
315313

316314
let statePosition_stateToVertex, statePosition_vertexToState =
317-
let data_stateToVertex =
318-
Array.zeroCreate (2 * gameState.States.Length)
319-
let data_vertexToState =
320-
Array.zeroCreate (2 * gameState.States.Length)
321-
let shape =
322-
[| 2L ; gameState.States.Length |]
315+
let data_stateToVertex = Array.zeroCreate (2 * gameState.States.Length)
316+
let data_vertexToState = Array.zeroCreate (2 * gameState.States.Length)
317+
let shape = [| 2L; gameState.States.Length |]
323318
let mutable firstFreePosition = 0
319+
324320
gameState.GraphVertices
325321
|> Array.iter (fun v ->
326322
v.States
@@ -332,46 +328,43 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
332328
data_stateToVertex[stateIds.Count + j] <- vertexIndex
333329

334330
data_vertexToState[j] <- vertexIndex
335-
data_vertexToState[stateIds.Count + j] <- stateIndex
336-
)
337-
firstFreePosition <- firstFreePosition + v.States.Length
338-
)
339-
OrtValue.CreateTensorValueFromMemory (data_stateToVertex, shape),
340-
OrtValue.CreateTensorValueFromMemory (data_vertexToState, shape)
331+
data_vertexToState[stateIds.Count + j] <- stateIndex)
332+
333+
firstFreePosition <- firstFreePosition + v.States.Length)
334+
335+
OrtValue.CreateTensorValueFromMemory(data_stateToVertex, shape),
336+
OrtValue.CreateTensorValueFromMemory(data_vertexToState, shape)
341337

342-
res.Add ("game_vertex", gameVertices)
343-
res.Add ("state_vertex", states)
338+
res.Add("game_vertex", gameVertices)
339+
res.Add("state_vertex", states)
344340

345-
res.Add ("gamevertex_to_gamevertex_index", vertexToVertexEdgesIndex)
346-
res.Add ("gamevertex_to_gamevertex_type", vertexToVertexEdgesAttributes)
341+
res.Add("gamevertex_to_gamevertex_index", vertexToVertexEdgesIndex)
342+
res.Add("gamevertex_to_gamevertex_type", vertexToVertexEdgesAttributes)
347343

348-
res.Add ("gamevertex_history_statevertex_index", historyEdgesIndex_vertexToState)
349-
res.Add ("gamevertex_history_statevertex_attrs", historyEdgesAttributes)
344+
res.Add("gamevertex_history_statevertex_index", historyEdgesIndex_vertexToState)
345+
res.Add("gamevertex_history_statevertex_attrs", historyEdgesAttributes)
350346

351-
res.Add ("gamevertex_in_statevertex", statePosition_vertexToState)
352-
res.Add ("statevertex_parentof_statevertex", parentOfEdges)
347+
res.Add("gamevertex_in_statevertex", statePosition_vertexToState)
348+
res.Add("statevertex_parentof_statevertex", parentOfEdges)
353349

354350
res
355351

356-
let output =
357-
session.Run (runOptions, networkInput, session.OutputNames)
358-
let weighedStates =
359-
output[0].GetTensorDataAsSpan<float32>().ToArray ()
352+
let output = session.Run(runOptions, networkInput, session.OutputNames)
353+
let weighedStates = output[0].GetTensorDataAsSpan<float32>().ToArray()
360354

361-
let id =
362-
weighedStates |> Array.mapi (fun i v -> i, v) |> Array.maxBy snd |> fst
355+
let id = weighedStates |> Array.mapi (fun i v -> i, v) |> Array.maxBy snd |> fst
363356
stateIds |> Seq.find (fun kvp -> kvp.Value = id) |> (fun x -> x.Key)
364357

365-
Oracle (predict, feedback)
358+
Oracle(predict, feedback)
366359

367-
AISearcher (createOracle pathToONNX, None)
360+
AISearcher(createOracle pathToONNX, None)
368361

369362
interface IForwardSearcher with
370363
override x.Init states = init states
371-
override x.Pick () = pick (always true)
364+
override x.Pick() = pick (always true)
372365
override x.Pick selector = pick selector
373-
override x.Update (parent, newStates) = update (parent, newStates)
374-
override x.States () = availableStates
375-
override x.Reset () = reset ()
366+
override x.Update(parent, newStates) = update (parent, newStates)
367+
override x.States() = availableStates
368+
override x.Reset() = reset ()
376369
override x.Remove cilState = remove cilState
377370
override x.StatesCount = availableStates.Count

0 commit comments

Comments
 (0)
Please sign in to comment.