1
+ println (" *" ^ 20 ," \n Hello from Julia!\n " ," *" ^ 20 )
2
+
3
+ const MSet = Dict{Int,Int}
4
+
5
+ function build_set_and_tids_mset (token_ids)
6
+ tids_mset = MSet ()
7
+
8
+ for tid in token_ids
9
+ # this skips already matched token ids that are -1
10
+ tid == - 1 && continue
11
+ tids_mset[tid] = get (()-> 0 , tids_mset, tid) + 1
12
+ end
13
+
14
+ return BitSet (keys (tids_mset)), tids_mset
15
+ end
16
+
17
+ set_counter (set:: BitSet ) = length (set)
18
+ set_counter (set:: MSet ) = sum (values (set))
19
+
20
+ set_intersector (set1:: BitSet , set2:: BitSet ) = intersect (set1, set2)
21
+ function set_intersector (set1:: MSet , set2:: MSet )
22
+ length (set1) > length (set2) && ((set1,set2) = (set2,set1))
23
+ ret = MSet ()
24
+ for (k, v) in set1
25
+ ! haskey (set2, k) && continue
26
+ ret[k] = min (v, set2[k])
27
+ end
28
+ return ret
29
+ end
30
+
31
+ set_high_intersection_filter (set:: BitSet , cutoff) = filter (<= (cutoff), set)
32
+ set_high_intersection_filter (set:: MSet , cutoff) = filter (pair -> pair. first<= cutoff, set)
33
+
34
+ function compare_token_sets (qset, iset, len_legalese, min_matched_length_high, min_matched_length; minimum_containment= 0 , high_resemblance_threshold= 0.8 )
35
+ intersection = set_intersector (qset, iset)
36
+ length (intersection) == 0 && return nothing ,nothing
37
+ high_intersection = set_high_intersection_filter (intersection, len_legalese)
38
+ length (high_intersection) == 0 && return nothing ,nothing
39
+ length (set_counter (high_intersection)) < min_matched_length_high && return nothing ,nothing
40
+
41
+ rule_length = set_counter (iset)
42
+ matched_length = set_counter (intersection)
43
+ matched_length < min_matched_length && return nothing , nothing
44
+
45
+ union_len = set_counter (qset) + rule_length - matched_length
46
+ resemblance = matched_length / union_len
47
+ containment = matched_length / rule_length
48
+ containment < minimum_containment && return nothing , nothing
49
+
50
+ amplified_resemblance = resemblance^ 2
51
+ score_vec1 = (;
52
+ is_highly_resemblant= round (resemblance; digits= 1 ) >= high_resemblance_threshold,
53
+ containment= round (containment; digits= 1 ),
54
+ resemblance= round (amplified_resemblance; digits= 1 ),
55
+ matched_length= round (Int, matched_length / 20 ))
56
+
57
+ score_vec2 = (;
58
+ is_highly_resemblant= resemblance >= high_resemblance_threshold,
59
+ containment= containment,
60
+ resemblance= amplified_resemblance,
61
+ matched_length= matched_length
62
+ )
63
+
64
+ return (score_vec1,score_vec2), high_intersection
65
+
66
+ end
67
+
68
+ const ScoreVector = @NamedTuple {is_highly_resemblant:: Bool , containment:: Float64 , resemblance:: Float64 , matched_length:: Int64 }
69
+
70
+ struct RuleInfo
71
+ min_matched_length_unique:: Int
72
+ min_matched_length:: Int
73
+ min_high_matched_length_unique:: Int
74
+ min_high_matched_length:: Int
75
+ minimum_containment:: Float64
76
+ end
77
+
78
+ function convert_rule_list (rules_by_rid)
79
+ return [RuleInfo (
80
+ pyconvert (Any, r. get_min_matched_length (true )),
81
+ pyconvert (Any, r. get_min_matched_length (false )),
82
+ pyconvert (Any, r. get_min_high_matched_length (true )),
83
+ pyconvert (Any, r. get_min_high_matched_length (false )),
84
+ pyconvert (Any, r. _minimum_containment)) for r in rules_by_rid]
85
+ end
86
+
87
+ function convert_set_list (sets)
88
+ return [isnothing (set) ? nothing : BitSet (set) for set in sets]
89
+ end
90
+
91
+ function convert_mset_list (msets)
92
+ return [isnothing (mset) ? nothing : MSet (mset) for mset in msets]
93
+ end
94
+
95
+ function compute_candidates (token_ids, len_legalese, rules_by_rid, sets_by_rid, msets_by_rid,
96
+ matchable_rids, top= 50 , high_resemblance= false , high_resemblance_threshold= 0.8 )
97
+ # collect query-side sets used for matching
98
+ qset, qmset = build_set_and_tids_mset (token_ids)
99
+
100
+ # @info "compute_candidates" typeof(token_ids) typeof(len_legalese) typeof(rules_by_rid) typeof(sets_by_rid) typeof(msets_by_rid) typeof(matchable_rids) typeof(top) typeof(high_resemblance) typeof(high_resemblance_threshold)
101
+ # typeof(token_ids) = Vector{Int64} (alias for Array{Int64, 1})
102
+ # typeof(len_legalese) = Int64
103
+ # typeof(rules_by_rid) = Vector{RuleInfo} (alias for Array{RuleInfo, 1})
104
+ # typeof(sets_by_rid) = Vector{Union{Nothing, BitSet}} (alias for Array{Union{Nothing, BitSet}, 1})
105
+ # typeof(msets_by_rid) = Vector{Union{Nothing, Dict{Int64, Int64}}} (alias for Array{Union{Nothing, Dict{Int64, Int64}}, 1})
106
+ # typeof(matchable_rids) = BitSet
107
+ # typeof(top) = Int64
108
+ # typeof(high_resemblance) = Bool
109
+ # typeof(high_resemblance_threshold) = Float64
110
+
111
+
112
+ # perform two steps of ranking:
113
+ # step one with tid sets and step two with tid multisets for refinement
114
+
115
+ # ###########################################################################
116
+ # step 1 is on token id sets:
117
+ # ###########################################################################
118
+
119
+ sortable_candidates = Tuple{Tuple{ScoreVector,ScoreVector}, Int, RuleInfo, BitSet}[]
120
+
121
+ for (rid, rule) in enumerate (rules_by_rid)
122
+ rid -= 1 # julia python compat
123
+ rid in matchable_rids || continue
124
+
125
+ scores_vectors, high_set_intersection = compare_token_sets (
126
+ qset,
127
+ sets_by_rid[rid+ 1 ],
128
+ len_legalese,
129
+ rule. min_high_matched_length_unique,
130
+ rule. min_matched_length_unique;
131
+ minimum_containment= rule. minimum_containment,
132
+ high_resemblance_threshold)
133
+
134
+ if ! isnothing (scores_vectors)
135
+ svr, svf = scores_vectors
136
+ if (! high_resemblance || (high_resemblance && svr. is_highly_resemblant && svf. is_highly_resemblant))
137
+ # @info "" scores_vectors rid rule high_set_intersection
138
+ push! (sortable_candidates, (scores_vectors, rid, rule, high_set_intersection))
139
+ end
140
+ end
141
+ end
142
+
143
+ length (sortable_candidates) == 0 && return sortable_candidates
144
+
145
+ sort! (sortable_candidates; rev= true )
146
+
147
+ # ###################################################################
148
+ # step 2 is on tids multisets
149
+ # ###################################################################
150
+ # keep only the 10 x top candidates
151
+ sortable_candidates_new = eltype (sortable_candidates)[]
152
+ for (k , (_score_vectors, rid, rule, high_set_intersection)) in enumerate (sortable_candidates)
153
+ k >= 10 * top && break
154
+ scores_vectors, _intersection = compare_token_sets (
155
+ qmset,
156
+ msets_by_rid[rid+ 1 ],
157
+ len_legalese,
158
+ rule. min_high_matched_length,
159
+ rule. min_matched_length;
160
+ minimum_containment= rule. minimum_containment,
161
+ high_resemblance_threshold)
162
+
163
+ if ! isnothing (scores_vectors)
164
+ svr, svf = scores_vectors
165
+ if (! high_resemblance || (high_resemblance && svr. is_highly_resemblant && svf. is_highly_resemblant))
166
+ push! (sortable_candidates_new, (scores_vectors, rid, rule, high_set_intersection))
167
+ end
168
+ end
169
+ end
170
+
171
+ length (sortable_candidates_new) == 0 && return sortable_candidates_new
172
+
173
+ # rank candidates
174
+ return sort! (sortable_candidates_new; rev= true )[1 : min (top, length (sortable_candidates_new))]
175
+ end
0 commit comments