Skip to content

Commit

Permalink
made model :fastmath/ols frezzable by nippy
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Jan 9, 2025
1 parent 39dce97 commit 82b613f
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 36 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ unreleased
* fixes #30 - dummy classifier does not predict by majority #30
* added dummy regression model
* improved design-matrix feature . Breaking ! columns need to be refered know by "precise name" (string, symbol, keyword)
* made model :fastmath/ols frezzable by nippy


0.11.1
Expand Down
14 changes: 10 additions & 4 deletions src/scicloj/metamorph/ml/regression.clj
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

[tech.v3.dataset.column-filters :as cf]
[tablecloth.column.api :as tcc]
[tech.v3.dataset.modelling :as ds-mod])
[tech.v3.dataset.modelling :as ds-mod]
[fastmath.ml.regression :as regression])
(:import [org.apache.commons.math3.stat.regression OLSMultipleLinearRegression]
[fastmath.java Array]))

Expand Down Expand Up @@ -75,9 +76,14 @@
(-> target-ds
cf/target
first
second)]

(fm-reg/lm ys xss clean-options)))
second)
model (fm-reg/lm ys xss clean-options)]

(assoc model
:analysis
(-> model :analysis deref))

))

(defn- predict-fm-ols [feature-ds thawed-model model]
(let [prediction (map
Expand Down
114 changes: 82 additions & 32 deletions test/scicloj/metamorph/linear_regression_test.clj
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
(ns scicloj.metamorph.linear-regression-test
(:require
[clojure.pprint :as pp]
[clojure.set :as set]
[clojure.test :refer [deftest is]]
[scicloj.metamorph.ml :as ml]
[scicloj.metamorph.ml.regression]
[scicloj.metamorph.ml.toydata :as data]
[taoensso.nippy :as nippy]
[tech.v3.dataset :as ds]
[clojure.test :refer [deftest is]]
[scicloj.metamorph.ml.regression]
[tech.v3.dataset.modelling :as ds-mod]
[clojure.math :as math]))
[tech.v3.dataset.modelling :as ds-mod]))


(defn approx? [x0 x1]
Expand All @@ -20,21 +22,25 @@
(->> (map approx? v0 v1)
(every? true?)))

(deftest linear-regression-mtcars-fm-ols
(let [ds
(->
(data/mtcars-ds)
(ds/drop-columns [:model])
(ds-mod/set-inference-target :mpg))

model (ml/train ds {:model-type :fastmath/ols})
(defn validate-model--linear-regression-mtcars-fm-ols [ ds model]
;;(def ds ds)
;;(def model model)

(let [
glance (-> (ml/glance model) (ds/rows :as-map) first)
tidy (ml/tidy model)
augment (ml/augment model ds)
prediction (:mpg (ml/predict ds model))]
prediction (:mpg (ml/predict ds model))


]
;; (def glance glance)
;; (def tidy tidy)
;; (def augment augment)
;; (def prediction prediction)




(is (approx? 0.8066423189909859 (-> glance :adj.r.squared)))
(is (approx? 0.8690157644777647 (:r.squared glance)))
Expand All @@ -49,7 +55,7 @@
(is (= 32 (:n glance)))
(is (= 21 (:df.residual glance)))
(is (approx? 13.932463690208833 (:statistic glance)))

(is (all-approx? [22.599505761262364
22.11188607935665
26.25064408479878
Expand Down Expand Up @@ -83,9 +89,9 @@
13.941118382059862
24.368267683243772]
prediction))



(is (= [:mpg :cyl :disp :hp :drat :wt :qsec :vs :am :gear :carb] (-> tidy :term)))
(is (all-approx? [12.303374155996154
-0.11144047788686227
Expand Down Expand Up @@ -203,15 +209,47 @@
19.693828154474765
13.941118382059862
24.368267683243772]
(augment :.fitted)))))
;; => #'scicloj.metamorph.linear-regression-test/linear-regression-mtcars
(augment :.fitted)))
)

)

(defn thaw-fm-ols [frozen-model]
(nippy/thaw-from-string frozen-model
{:serializable-allowlist
(set/union nippy/*thaw-serializable-allowlist*
#{"org.apache.commons.math3.linear.Array2DRowRealMatrix"})})

)

(defn pretty-spit
[file-name collection]
(spit (java.io.File. file-name)
(with-out-str (pp/write collection :dispatch pp/code-dispatch))))


(deftest linear-regression-mtcars-fm-ols
(let [ds
(->
(data/mtcars-ds)
(ds/drop-columns [:model])
(ds-mod/set-inference-target :mpg))

model (ml/train ds {:model-type :fastmath/ols})


frozen-model (nippy/freeze-to-string model)
unfrozen-model (thaw-fm-ols frozen-model)
]
(validate-model--linear-regression-mtcars-fm-ols ds model)
(validate-model--linear-regression-mtcars-fm-ols ds unfrozen-model)



))
;; => #'scicloj.metamorph.linear-regression-test/linear-regression-mtcars


;; (->
;; (ml/glance model)
;; (ds/rows)
Expand Down Expand Up @@ -335,18 +373,8 @@
;; 10 Merc 280 19.2 6 168. 123 3.92 3.44 18.3 1 0 4 4 18.7 0.501 0.429 2.71 0.00428 0.250


(deftest metamorph.ml-ols


(let [ds
(->
(data/mtcars-ds)
(ds/drop-columns [:model])
(ds-mod/set-inference-target :mpg))

model (ml/train ds {:model-type :metamorph.ml/ols})
prediction (:mpg (ml/predict ds model))]

(defn- validate-model--ols [model ds]
(let [prediction (:mpg (ml/predict ds model))]
(is (=
[22.599505761262364
22.11188607935665
Expand Down Expand Up @@ -406,10 +434,11 @@
;; | :am | 2.52022689 | 2.05665055 |
;; | :gear | 0.65541302 | 1.49325996 |
;; | :carb | -0.19941925 | 0.82875250 |



(is (ds/dataset?
(ml/augment model ds)))))
(ml/augment model ds))))
;; => _unnamed [32 12]:
;; | :mpg | :cyl | :disp | :hp | :drat | :wt | :qsec | :vs | :am | :gear | :carb | :.residuals |
;; |-----:|-----:|------:|----:|------:|------:|------:|----:|----:|------:|------:|------------:|
Expand Down Expand Up @@ -501,3 +530,24 @@
;; [9] 24.41909 18.69903 19.19165 14.17216 15.59957 15.74222 12.03401 10.93644
;; [17] 10.49363 27.77291 29.89674 29.51237 23.64310 16.94305 17.73218 13.30602
;; [25] 16.69168 28.29347 26.15295 27.63627 18.87004 19.69383 13.94112 24.36827

)

(deftest metamorph.ml-ols


(let [ds
(->
(data/mtcars-ds)
(ds/drop-columns [:model])
(ds-mod/set-inference-target :mpg))

model (ml/train ds {:model-type :metamorph.ml/ols})
;; cannot be persisted, as not serializable
;model-frozen (nippy/freeze model)
;model-unfrozen (nippy/thaw model-frozen)

]
(validate-model--ols model ds)
;(validate-model--ols model-unfrozen ds)
))

0 comments on commit 82b613f

Please sign in to comment.