Skip to content

Commit 35a9afc

Browse files
committed
Basic structure for QUBOConstraints
1 parent 489228e commit 35a9afc

File tree

6 files changed

+171
-9
lines changed

6 files changed

+171
-9
lines changed

docs/make.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,11 @@ makedocs(;
8585
"Aggregation Layer" => "learning/aggregation.md",
8686
"Comparison Layer" => "learning/comparison.md",
8787
],
88-
"QUBOConstraints.jl" => "learning/qubo_constraints.md",
88+
"QUBOConstraints.jl" => [
89+
"Model as QUBO" => "learning/qubo_constraints.md",
90+
"Encoding" => "learning/qubo_encoding.md",
91+
"Learning" => "learning/qubo_learning.md",
92+
],
8993
"ConstraintLearning.jl" => "learning/constraint_learning.md",
9094
],
9195
"Solvers" => [

docs/src/full_api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@
33
```@autodocs
44
Modules=[
55
ConstraintCommons, ConstraintDomains, Constraints,
6-
CompositionalNetworks,
6+
CompositionalNetworks, QUBOConstraints,
77
]
8-
```
8+
```

docs/src/learning/qubo_constraints.md

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
1-
# QUBOConstraints.jl
1+
# Introduction to QUBOConstraints.jl
22

3-
Documentation for `QUBOConstraints.jl`.
3+
Introduction to `QUBOConstraints.jl`.
44

5-
```@autodocs
6-
Modules=[QUBOConstraints]
5+
```@meta
6+
CurrentModule = QUBOConstraints
7+
```
8+
9+
## Basic features
10+
11+
```@docs; canonical=false
12+
QUBO_base
13+
QUBO_linear_sum
714
```

docs/src/learning/qubo_encoding.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Encoding for QUBO programs
2+
3+
```@meta
4+
CurrentModule = QUBOConstraints
5+
```
6+
7+
```@docs; canonical=false
8+
is_valid
9+
binarize
10+
debinarize
11+
```

docs/src/learning/qubo_learning.md

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# Learning QUBO matrices
2+
3+
```@meta
4+
CurrentModule = QUBOConstraints
5+
```
6+
7+
## Interface
8+
9+
```@docs; canonical=false
10+
AbstractOptimizer
11+
train
12+
```
13+
14+
## Examples with various optimizers
15+
16+
### Gradient Descent
17+
18+
```julia
19+
struct GradientDescentOptimizer <: QUBOConstraints.AbstractOptimizer
20+
binarization::Symbol
21+
η::Float64
22+
precision::Int
23+
oversampling::Bool
24+
end
25+
26+
function GradientDescentOptimizer(;
27+
binarization = :one_hot,
28+
η = .001,
29+
precision = 5,
30+
oversampling = false,
31+
)
32+
return GradientDescentOptimizer(binarization, η, precision, oversampling)
33+
end
34+
35+
36+
predict(x, Q) = transpose(x) * Q * x
37+
38+
loss(x, y, Q) = (predict(x, Q) .-y).^2
39+
40+
function make_df(X, Q, penalty, binarization, domains)
41+
df = DataFrame()
42+
for (i,x) in enumerate(X)
43+
if i == 1
44+
df = DataFrame(transpose(x), :auto)
45+
else
46+
push!(df, transpose(x))
47+
end
48+
end
49+
50+
dim = length(df[1,:])
51+
52+
if binarization == :none
53+
df[!,:penalty] = map(r -> penalty(Vector(r)), eachrow(df))
54+
df[!,:predict] = map(r -> predict(Vector(r), Q), eachrow(df[:, 1:dim]))
55+
else
56+
df[!,:penalty] = map(
57+
r -> penalty(binarize(Vector(r), domains; binarization)),
58+
eachrow(df)
59+
)
60+
df[!,:predict] = map(
61+
r -> predict(binarize(Vector(r), domains; binarization), Q),
62+
eachrow(df[:, 1:dim])
63+
)
64+
end
65+
66+
min_false = minimum(
67+
filter(:penalty => >(minimum(df[:,:penalty])), df)[:,:predict];
68+
init = typemax(Int)
69+
)
70+
df[!,:shifted] = df[:,:predict] .- min_false
71+
df[!,:accurate] = df[:, :penalty] .* df[:,:shifted] .≥ 0.
72+
73+
return df
74+
end
75+
76+
function preliminaries(X, domains, binarization)
77+
if binarization==:none
78+
n = length(first(X))
79+
return X, zeros(n,n)
80+
else
81+
Y = map(x -> collect(binarize(x, domains; binarization)), X)
82+
n = length(first(Y))
83+
return Y, zeros(n,n)
84+
end
85+
end
86+
87+
function preliminaries(X, _)
88+
n = length(first(X))
89+
return X, zeros(n,n)
90+
end
91+
92+
function train!(Q, X, penalty, η, precision, X_test, oversampling, binarization, domains)
93+
θ = params(Q)
94+
try
95+
penalty(first(X))
96+
catch e
97+
if isa(e, UndefKeywordError)
98+
penalty = (x; dom_size = δ_extrema(Iterators.flatten(X)))-> penalty(x; dom_size)
99+
else
100+
throw(e)
101+
end
102+
end
103+
for x in (oversampling ? oversample(X, penalty) : X)
104+
grads = gradient(() -> loss(x, penalty(x), Q), θ)
105+
Q .-= η * grads[Q]
106+
end
107+
108+
Q[:,:] = round.(precision*Q)
109+
110+
df = make_df(X_test, Q, penalty, binarization, domains)
111+
return pretty_table(describe(df[!, [:penalty, :predict, :shifted, :accurate]]))
112+
end
113+
114+
function train(
115+
X,
116+
penalty,
117+
domains::Vector{D};
118+
optimizer = GradientDescentOptimizer(),
119+
X_test = X,
120+
) where {D <: DiscreteDomain}
121+
Y, Q = preliminaries(X, domains, optimizer.binarization)
122+
train!(
123+
Q, Y, penalty, optimizer.η, optimizer.precision, X_test,
124+
optimizer.oversampling, optimizer.binarization, domains
125+
)
126+
return Q
127+
end
128+
129+
function train(
130+
X,
131+
penalty,
132+
dom_stuff = nothing;
133+
optimizer = GradientDescentOptimizer(),
134+
X_test = X,
135+
)
136+
return train(X, penalty, to_domains(X, dom_stuff); optimizer, X_test)
137+
end
138+
```
139+
140+
### Constraint-based Local Search

docs/src/public_api.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
```@autodocs; canonical=false
44
Modules=[
55
ConstraintCommons, ConstraintDomains, Constraints,
6-
CompositionalNetworks,
6+
CompositionalNetworks, QUBOConstraints,
77
]
88
Private = false
9-
```
9+
```

0 commit comments

Comments
 (0)