From b3e565bf2a1ef5a26bbd98d1e8607103322a64d7 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 9 Aug 2023 18:17:20 -0300 Subject: [PATCH 001/185] mix new candlex --- candlex/.formatter.exs | 4 ++++ candlex/.gitignore | 26 ++++++++++++++++++++++++++ candlex/README.md | 21 +++++++++++++++++++++ candlex/lib/candlex.ex | 18 ++++++++++++++++++ candlex/mix.exs | 28 ++++++++++++++++++++++++++++ candlex/test/candlex_test.exs | 8 ++++++++ candlex/test/test_helper.exs | 1 + 7 files changed, 106 insertions(+) create mode 100644 candlex/.formatter.exs create mode 100644 candlex/.gitignore create mode 100644 candlex/README.md create mode 100644 candlex/lib/candlex.ex create mode 100644 candlex/mix.exs create mode 100644 candlex/test/candlex_test.exs create mode 100644 candlex/test/test_helper.exs diff --git a/candlex/.formatter.exs b/candlex/.formatter.exs new file mode 100644 index 0000000000..d2cda26edd --- /dev/null +++ b/candlex/.formatter.exs @@ -0,0 +1,4 @@ +# Used by "mix format" +[ + inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] +] diff --git a/candlex/.gitignore b/candlex/.gitignore new file mode 100644 index 0000000000..82bb4695f2 --- /dev/null +++ b/candlex/.gitignore @@ -0,0 +1,26 @@ +# The directory Mix will write compiled artifacts to. +/_build/ + +# If you run "mix test --cover", coverage assets end up here. +/cover/ + +# The directory Mix downloads your dependencies sources to. +/deps/ + +# Where third-party dependencies like ExDoc output generated docs. +/doc/ + +# Ignore .fetch files in case you like to edit your project deps locally. +/.fetch + +# If the VM crashes, it generates a dump, let's ignore it too. +erl_crash.dump + +# Also ignore archive artifacts (built via "mix archive.build"). +*.ez + +# Ignore package tarball (built via "mix hex.build"). +candlex-*.tar + +# Temporary files, for example, from tests. +/tmp/ diff --git a/candlex/README.md b/candlex/README.md new file mode 100644 index 0000000000..83509bb3cb --- /dev/null +++ b/candlex/README.md @@ -0,0 +1,21 @@ +# Candlex + +**TODO: Add description** + +## Installation + +If [available in Hex](https://hex.pm/docs/publish), the package can be installed +by adding `candlex` to your list of dependencies in `mix.exs`: + +```elixir +def deps do + [ + {:candlex, "~> 0.1.0"} + ] +end +``` + +Documentation can be generated with [ExDoc](https://github.com/elixir-lang/ex_doc) +and published on [HexDocs](https://hexdocs.pm). Once published, the docs can +be found at . + diff --git a/candlex/lib/candlex.ex b/candlex/lib/candlex.ex new file mode 100644 index 0000000000..a569e059bb --- /dev/null +++ b/candlex/lib/candlex.ex @@ -0,0 +1,18 @@ +defmodule Candlex do + @moduledoc """ + Documentation for `Candlex`. + """ + + @doc """ + Hello world. + + ## Examples + + iex> Candlex.hello() + :world + + """ + def hello do + :world + end +end diff --git a/candlex/mix.exs b/candlex/mix.exs new file mode 100644 index 0000000000..59d4d2e00e --- /dev/null +++ b/candlex/mix.exs @@ -0,0 +1,28 @@ +defmodule Candlex.MixProject do + use Mix.Project + + def project do + [ + app: :candlex, + version: "0.1.0", + elixir: "~> 1.15", + start_permanent: Mix.env() == :prod, + deps: deps() + ] + end + + # Run "mix help compile.app" to learn about applications. + def application do + [ + extra_applications: [:logger] + ] + end + + # Run "mix help deps" to learn about dependencies. + defp deps do + [ + # {:dep_from_hexpm, "~> 0.3.0"}, + # {:dep_from_git, git: "https://github.com/elixir-lang/my_dep.git", tag: "0.1.0"} + ] + end +end diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs new file mode 100644 index 0000000000..9a9f9e2799 --- /dev/null +++ b/candlex/test/candlex_test.exs @@ -0,0 +1,8 @@ +defmodule CandlexTest do + use ExUnit.Case + doctest Candlex + + test "greets the world" do + assert Candlex.hello() == :world + end +end diff --git a/candlex/test/test_helper.exs b/candlex/test/test_helper.exs new file mode 100644 index 0000000000..869559e709 --- /dev/null +++ b/candlex/test/test_helper.exs @@ -0,0 +1 @@ +ExUnit.start() From b1d89ec91201fe127254fceb2ff1d2f217eac622 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 11 Aug 2023 13:30:38 -0300 Subject: [PATCH 002/185] mix rustler.new --- candlex/mix.exs | 3 +- candlex/mix.lock | 5 + candlex/native/candlex/.cargo/config.toml | 5 + candlex/native/candlex/.gitignore | 1 + candlex/native/candlex/Cargo.lock | 149 ++++++++++++++++++++++ candlex/native/candlex/Cargo.toml | 13 ++ candlex/native/candlex/README.md | 20 +++ candlex/native/candlex/src/lib.rs | 6 + 8 files changed, 200 insertions(+), 2 deletions(-) create mode 100644 candlex/mix.lock create mode 100644 candlex/native/candlex/.cargo/config.toml create mode 100644 candlex/native/candlex/.gitignore create mode 100644 candlex/native/candlex/Cargo.lock create mode 100644 candlex/native/candlex/Cargo.toml create mode 100644 candlex/native/candlex/README.md create mode 100644 candlex/native/candlex/src/lib.rs diff --git a/candlex/mix.exs b/candlex/mix.exs index 59d4d2e00e..fbe64ea0e8 100644 --- a/candlex/mix.exs +++ b/candlex/mix.exs @@ -21,8 +21,7 @@ defmodule Candlex.MixProject do # Run "mix help deps" to learn about dependencies. defp deps do [ - # {:dep_from_hexpm, "~> 0.3.0"}, - # {:dep_from_git, git: "https://github.com/elixir-lang/my_dep.git", tag: "0.1.0"} + {:rustler, "~> 0.29.1"} ] end end diff --git a/candlex/mix.lock b/candlex/mix.lock new file mode 100644 index 0000000000..79e9e68d25 --- /dev/null +++ b/candlex/mix.lock @@ -0,0 +1,5 @@ +%{ + "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, + "rustler": {:hex, :rustler, "0.29.1", "880f20ae3027bd7945def6cea767f5257bc926f33ff50c0d5d5a5315883c084d", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "109497d701861bfcd26eb8f5801fe327a8eef304f56a5b63ef61151ff44ac9b6"}, + "toml": {:hex, :toml, "0.7.0", "fbcd773caa937d0c7a02c301a1feea25612720ac3fa1ccb8bfd9d30d822911de", [:mix], [], "hexpm", "0690246a2478c1defd100b0c9b89b4ea280a22be9a7b313a8a058a2408a2fa70"}, +} diff --git a/candlex/native/candlex/.cargo/config.toml b/candlex/native/candlex/.cargo/config.toml new file mode 100644 index 0000000000..20f03f3d80 --- /dev/null +++ b/candlex/native/candlex/.cargo/config.toml @@ -0,0 +1,5 @@ +[target.'cfg(target_os = "macos")'] +rustflags = [ + "-C", "link-arg=-undefined", + "-C", "link-arg=dynamic_lookup", +] diff --git a/candlex/native/candlex/.gitignore b/candlex/native/candlex/.gitignore new file mode 100644 index 0000000000..ea8c4bf7f3 --- /dev/null +++ b/candlex/native/candlex/.gitignore @@ -0,0 +1 @@ +/target diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock new file mode 100644 index 0000000000..eadc972e3f --- /dev/null +++ b/candlex/native/candlex/Cargo.lock @@ -0,0 +1,149 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "aho-corasick" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8f9420f797f2d9e935edf629310eb938a0d839f984e25327f3c7eed22300c" +dependencies = [ + "memchr", +] + +[[package]] +name = "candlex" +version = "0.1.0" +dependencies = [ + "rustler", +] + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "memchr" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" + +[[package]] +name = "proc-macro2" +version = "1.0.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50f3b39ccfb720540debaa0164757101c08ecb8d326b15358ce76a62c7e85965" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "regex" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81bc1d4caf89fac26a70747fe603c130093b53c773888797a6329091246d651a" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fed1ceff11a1dddaee50c9dc8e4938bd106e9d89ae372f192311e7da498e3b69" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" + +[[package]] +name = "rustler" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0884cb623b9f43d3e2c51f9071c5e96a5acf3e6e6007866812884ff0cb983f1e" +dependencies = [ + "lazy_static", + "rustler_codegen", + "rustler_sys", +] + +[[package]] +name = "rustler_codegen" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50e277af754f2560cf4c4ebedb68c1a735292fb354505c6133e47ec406e699cf" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "rustler_sys" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a7c0740e5322b64e2b952d8f0edce5f90fcf6f6fe74cca3f6e78eb3de5ea858" +dependencies = [ + "regex", + "unreachable", +] + +[[package]] +name = "syn" +version = "2.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" + +[[package]] +name = "unreachable" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "382810877fe448991dfc7f0dd6e3ae5d58088fd0ea5e35189655f84e6814fa56" +dependencies = [ + "void", +] + +[[package]] +name = "void" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" diff --git a/candlex/native/candlex/Cargo.toml b/candlex/native/candlex/Cargo.toml new file mode 100644 index 0000000000..7cce12f0ec --- /dev/null +++ b/candlex/native/candlex/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "candlex" +version = "0.1.0" +authors = [] +edition = "2021" + +[lib] +name = "candlex" +path = "src/lib.rs" +crate-type = ["cdylib"] + +[dependencies] +rustler = "0.29.1" diff --git a/candlex/native/candlex/README.md b/candlex/native/candlex/README.md new file mode 100644 index 0000000000..8c3b89c5fb --- /dev/null +++ b/candlex/native/candlex/README.md @@ -0,0 +1,20 @@ +# NIF for Elixir.Candlex + +## To build the NIF module: + +- Your NIF will now build along with your project. + +## To load the NIF: + +```elixir +defmodule Candlex do + use Rustler, otp_app: :candlex, crate: "candlex" + + # When your NIF is loaded, it will override this function. + def add(_a, _b), do: :erlang.nif_error(:nif_not_loaded) +end +``` + +## Examples + +[This](https://github.com/rusterlium/NifIo) is a complete example of a NIF written in Rust. diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs new file mode 100644 index 0000000000..1fe0bd6445 --- /dev/null +++ b/candlex/native/candlex/src/lib.rs @@ -0,0 +1,6 @@ +#[rustler::nif] +fn add(a: i64, b: i64) -> i64 { + a + b +} + +rustler::init!("Elixir.Candlex", [add]); From 1f70213fae6a0502e5fab1baf29c7b7ef11e496a Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 17 Aug 2023 14:58:35 -0300 Subject: [PATCH 003/185] [native/candlex] cargo add candle-core --- candlex/native/candlex/Cargo.lock | 542 +++++++++++++++++++++++++++++- candlex/native/candlex/Cargo.toml | 1 + candlex/native/candlex/README.md | 20 -- 3 files changed, 537 insertions(+), 26 deletions(-) delete mode 100644 candlex/native/candlex/README.md diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index eadc972e3f..1092a0ce56 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -4,38 +4,395 @@ version = 3 [[package]] name = "aho-corasick" -version = "1.0.3" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b8f9420f797f2d9e935edf629310eb938a0d839f984e25327f3c7eed22300c" +checksum = "6748e8def348ed4d14996fa801f4122cd763fff530258cdc03f64b25f89d3a5a" dependencies = [ "memchr", ] +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bytemuck" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" + +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + +[[package]] +name = "candle-core" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e56d08f7794036648d7ba5448c82ab7c3f38c25e90cdc4032afd246d1292a42c" +dependencies = [ + "byteorder", + "candle-gemm", + "half", + "memmap2", + "num-traits", + "num_cpus", + "rand", + "rand_distr", + "rayon", + "safetensors", + "thiserror", + "zip", +] + +[[package]] +name = "candle-gemm" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b726a1f6cdd7ff080e95e3d91694701b1e04a58acd198e4a78c39428b2274e" +dependencies = [ + "candle-gemm-c32", + "candle-gemm-c64", + "candle-gemm-common", + "candle-gemm-f16", + "candle-gemm-f32", + "candle-gemm-f64", + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-c32" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "661470663389f0c99fd8449e620bfae630a662739f830a323eda4dcf80888843" +dependencies = [ + "candle-gemm-common", + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-c64" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a111ddf61db562854a6d2ff4dfe1e8a84066431b7bc68d3afae4bf60874fda0" +dependencies = [ + "candle-gemm-common", + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-common" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6dd93783ead7eeef14361667ea32014dc6f716a2fc956b075fe78729e10dd5" +dependencies = [ + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-f16" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b76499bf4b858cacc526c5c8f948bc7152774247dce8568f174b743ab1363fa4" +dependencies = [ + "candle-gemm-common", + "candle-gemm-f32", + "dyn-stack", + "half", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-f32" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7bec152e7d36339d3785e0d746d75ee94a4e92968fbb12ddcc91b536b938d016" +dependencies = [ + "candle-gemm-common", + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "candle-gemm-f64" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00f59ac68a5521e2ff71431bb7f1b22126ff0b60c5e66599b1f4676433da6e69" +dependencies = [ + "candle-gemm-common", + "dyn-stack", + "lazy_static", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + [[package]] name = "candlex" version = "0.1.0" dependencies = [ + "candle-core", "rustler", ] +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "crc32fast" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "dyn-stack" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24269739c7c175bc12130622ef1a60b9ab2d5b30c0b9ce5110cd406d7fd497bc" +dependencies = [ + "bytemuck", + "reborrow", +] + +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + +[[package]] +name = "getrandom" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "half" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +dependencies = [ + "cfg-if", + "crunchy", + "num-traits", + "rand", + "rand_distr", +] + [[package]] name = "heck" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "hermit-abi" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" + +[[package]] +name = "itoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" + [[package]] name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +[[package]] +name = "libc" +version = "0.2.147" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" + +[[package]] +name = "libm" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" + [[package]] name = "memchr" version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +[[package]] +name = "memmap2" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" +dependencies = [ + "libc", +] + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num-complex" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "paste" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro2" version = "1.0.66" @@ -47,13 +404,90 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.32" +version = "1.0.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50f3b39ccfb720540debaa0164757101c08ecb8d326b15358ce76a62c7e85965" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "raw-cpuid" +version = "10.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags", +] + +[[package]] +name = "rayon" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "num_cpus", +] + +[[package]] +name = "reborrow" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2962bf2e1f971c53ef59b2d7ca51d6a5e5c4a9d2be47eb1f661a321a4da85888" + [[package]] name = "regex" version = "1.9.3" @@ -116,17 +550,96 @@ dependencies = [ "unreachable", ] +[[package]] +name = "ryu" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" + +[[package]] +name = "safetensors" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad8cbd90c388a0b028565d8ad22e090101599d951c6b5f105b4f7772721a9d5f" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" + +[[package]] +name = "serde" +version = "1.0.183" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32ac8da02677876d532745a130fc9d8e6edfa81a269b107c5b00829b91d8eb3c" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.183" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aafe972d60b0b9bee71a91b92fee2d4fb3c9d7e8f6b179aa99f27203d99a4816" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +dependencies = [ + "itoa", + "ryu", + "serde", +] + [[package]] name = "syn" -version = "2.0.28" +version = "2.0.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04361975b3f5e348b2189d8dc55bc942f278b2d482a6a0365de5bdd62d351567" +checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] +[[package]] +name = "thiserror" +version = "1.0.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97a802ec30afc17eee47b2855fc72e0c4cd62be9b4efe6591edde0ec5bd68d8f" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.47" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6bb623b56e39ab7dcd4b1b98bb6c8f8d907ed255b18de254088016b27a8ee19b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicode-ident" version = "1.0.11" @@ -147,3 +660,20 @@ name = "void" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "byteorder", + "crc32fast", + "crossbeam-utils", +] diff --git a/candlex/native/candlex/Cargo.toml b/candlex/native/candlex/Cargo.toml index 7cce12f0ec..35bd43a188 100644 --- a/candlex/native/candlex/Cargo.toml +++ b/candlex/native/candlex/Cargo.toml @@ -10,4 +10,5 @@ path = "src/lib.rs" crate-type = ["cdylib"] [dependencies] +candle-core = "0.1.0" rustler = "0.29.1" diff --git a/candlex/native/candlex/README.md b/candlex/native/candlex/README.md deleted file mode 100644 index 8c3b89c5fb..0000000000 --- a/candlex/native/candlex/README.md +++ /dev/null @@ -1,20 +0,0 @@ -# NIF for Elixir.Candlex - -## To build the NIF module: - -- Your NIF will now build along with your project. - -## To load the NIF: - -```elixir -defmodule Candlex do - use Rustler, otp_app: :candlex, crate: "candlex" - - # When your NIF is loaded, it will override this function. - def add(_a, _b), do: :erlang.nif_error(:nif_not_loaded) -end -``` - -## Examples - -[This](https://github.com/rusterlium/NifIo) is a complete example of a NIF written in Rust. From 9fd2713f1107d22534ca0363e22ce1721d9752f4 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 17 Aug 2023 14:59:39 -0300 Subject: [PATCH 004/185] gitignore candlex nif .so --- candlex/.gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/candlex/.gitignore b/candlex/.gitignore index 82bb4695f2..1fb4d171fd 100644 --- a/candlex/.gitignore +++ b/candlex/.gitignore @@ -24,3 +24,6 @@ candlex-*.tar # Temporary files, for example, from tests. /tmp/ + +# Shared objects build by Rust. +*.so From 9e9e4453d86424d918344f693886144739093495 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 17 Aug 2023 15:01:56 -0300 Subject: [PATCH 005/185] WIP feat: can create candle tensor --- candlex/.formatter.exs | 1 + candlex/lib/candlex.ex | 13 ----- candlex/lib/candlex/backend.ex | 82 +++++++++++++++++++++++++++ candlex/lib/candlex/native.ex | 12 ++++ candlex/mix.exs | 1 + candlex/mix.lock | 2 + candlex/native/candlex/src/candlex.rs | 75 ++++++++++++++++++++++++ candlex/native/candlex/src/lib.rs | 15 +++-- candlex/test/candlex_test.exs | 23 +++++++- candlex/test/test_helper.exs | 1 + 10 files changed, 204 insertions(+), 21 deletions(-) create mode 100644 candlex/lib/candlex/backend.ex create mode 100644 candlex/lib/candlex/native.ex create mode 100644 candlex/native/candlex/src/candlex.rs diff --git a/candlex/.formatter.exs b/candlex/.formatter.exs index d2cda26edd..d9f22bc671 100644 --- a/candlex/.formatter.exs +++ b/candlex/.formatter.exs @@ -1,4 +1,5 @@ # Used by "mix format" [ + import_deps: [:nx], inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] ] diff --git a/candlex/lib/candlex.ex b/candlex/lib/candlex.ex index a569e059bb..ee5cfce363 100644 --- a/candlex/lib/candlex.ex +++ b/candlex/lib/candlex.ex @@ -2,17 +2,4 @@ defmodule Candlex do @moduledoc """ Documentation for `Candlex`. """ - - @doc """ - Hello world. - - ## Examples - - iex> Candlex.hello() - :world - - """ - def hello do - :world - end end diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex new file mode 100644 index 0000000000..9746688af0 --- /dev/null +++ b/candlex/lib/candlex/backend.ex @@ -0,0 +1,82 @@ +defmodule Candlex.Backend do + @moduledoc """ + An opaque Nx backend with bindings to candle. + """ + + defstruct [:resource] + + # TODO: enable behaviour + # @behaviour Nx.Backend + + alias Nx.Tensor, as: T + alias Candlex.Backend, as: CB + alias Candlex.Native + + # @impl true + def init(opts) do + if opts != [] do + raise ArgumentError, "Candlex.Backend accepts no options" + end + + opts + end + + # @impl true + def constant(%T{shape: {}, type: _type} = tensor, scalar, _backend_options) do + # TODO: Don't ignore backend options + + scalar + |> Native.scalar_tensor() + |> to_nx(tensor) + end + + # @impl true + def from_binary(%T{shape: {2}, type: {:u, 32}} = tensor, binary, _backend_options) do + # TODO: Don't ignore backend options + + binary + |> Native.from_binary() + |> to_nx(tensor) + end + + # @impl true + def inspect(%T{} = tensor, inspect_opts) do + limit = if inspect_opts.limit == :infinity, do: :infinity, else: inspect_opts.limit + 1 + + tensor + |> to_binary(min(limit, Nx.size(tensor))) + |> then(&Nx.Backend.inspect(tensor, &1, inspect_opts)) + |> maybe_add_signature(tensor) + end + + def to_binary(tensor, _limit) do + # TODO: don't ignore limit + + from_nx(tensor) + |> Native.to_binary() + end + + defp maybe_add_signature(result, %T{data: %CB{resource: ref}}) when is_reference(ref) do + Inspect.Algebra.concat([ + "Candlex.Backend(#{:erlang.ref_to_list(ref)})", + Inspect.Algebra.line(), + result + ]) + end + + # defp to_candle_type({:u, 32}), do: :u32 + + # defp device_option(_backend_options) do + # # TODO: Support CUDA + # :cpu + # end + + ## Conversions + + @doc false + defp from_nx(%T{data: data}), do: data + + defp to_nx(%{resource: ref} = backend_tensor, %T{} = t) when is_reference(ref) do + %{t | data: backend_tensor} + end +end diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex new file mode 100644 index 0000000000..04a93157e4 --- /dev/null +++ b/candlex/lib/candlex/native.ex @@ -0,0 +1,12 @@ +defmodule Candlex.Native do + @moduledoc false + + use Rustler, otp_app: :candlex, crate: "candlex" + + # Rustler will override all the below stub functions with real NIFs + def scalar_tensor(_scalar), do: error() + def to_binary(_tensor), do: error() + def from_binary(_binary), do: error() + + defp error(), do: :erlang.nif_error(:nif_not_loaded) +end diff --git a/candlex/mix.exs b/candlex/mix.exs index fbe64ea0e8..d593fcff6f 100644 --- a/candlex/mix.exs +++ b/candlex/mix.exs @@ -21,6 +21,7 @@ defmodule Candlex.MixProject do # Run "mix help deps" to learn about dependencies. defp deps do [ + {:nx, path: "../nx"}, {:rustler, "~> 0.29.1"} ] end diff --git a/candlex/mix.lock b/candlex/mix.lock index 79e9e68d25..dd8959ea38 100644 --- a/candlex/mix.lock +++ b/candlex/mix.lock @@ -1,5 +1,7 @@ %{ + "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, "rustler": {:hex, :rustler, "0.29.1", "880f20ae3027bd7945def6cea767f5257bc926f33ff50c0d5d5a5315883c084d", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "109497d701861bfcd26eb8f5801fe327a8eef304f56a5b63ef61151ff44ac9b6"}, + "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "toml": {:hex, :toml, "0.7.0", "fbcd773caa937d0c7a02c301a1feea25612720ac3fa1ccb8bfd9d30d822911de", [:mix], [], "hexpm", "0690246a2478c1defd100b0c9b89b4ea280a22be9a7b313a8a058a2408a2fa70"}, } diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs new file mode 100644 index 0000000000..44d4b184e9 --- /dev/null +++ b/candlex/native/candlex/src/candlex.rs @@ -0,0 +1,75 @@ +use candle_core::{Tensor, Device}; +use rustler::{Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; +use std::ops::Deref; + +pub struct TensorRef(pub Tensor); + +#[derive(NifStruct)] +#[module = "Candlex.Backend"] +pub struct ExTensor { + pub resource: ResourceArc, +} + +impl TensorRef { + pub fn new(tensor: Tensor) -> Self { + Self(tensor) + } +} + +impl ExTensor { + pub fn new(tensor: Tensor) -> Self { + Self { + resource: ResourceArc::new(TensorRef::new(tensor)), + } + } +} + +// Implement Deref so we can call `Tensor` functions directly from a `ExTensor` struct. +impl Deref for ExTensor { + type Target = Tensor; + + fn deref(&self) -> &Self::Target { + &self.resource.0 + } +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn scalar_tensor(scalar: u32) -> ExTensor { + ExTensor::new(Tensor::new(scalar, &Device::Cpu).unwrap()) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Binary { + let tensor = ex_tensor.flatten_all().unwrap(); + let vec = tensor.to_vec1::().unwrap(); + + let bytes: Vec = vec.iter().flat_map(|val| val.to_ne_bytes().to_vec()).collect(); + + let mut binary = NewBinary::new(env, bytes.len()); + binary.as_mut_slice().copy_from_slice(bytes.as_slice()); + + binary.into() +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn from_binary(binary: Binary) -> ExTensor { + let slice = binary.as_slice(); + + // let slice: &[u32; 2] = unsafe { + // std::mem::transmute::<&[u8], &[u32; 2]>(binary.as_slice()) + // }; + let (_prefx, slice, _suffix) = unsafe { slice.align_to::() }; + + println!("{:?}", &slice); + // let slice = u32::from_ne_bytes(slice); + ExTensor::new(Tensor::from_slice(slice, 2, &Device::Cpu).unwrap()) + + // let mut vec = vec![0u32; 2]; + // reader.read_u32_into::(&mut vec)?; + // ExTensor::new(Tensor::from_vec(vec, 2, &Device::Cpu).unwrap()) +} + +pub fn load(env: Env, _info: Term) -> bool { + rustler::resource!(TensorRef, env); + true +} diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 1fe0bd6445..500cee350e 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -1,6 +1,11 @@ -#[rustler::nif] -fn add(a: i64, b: i64) -> i64 { - a + b -} +mod candlex; -rustler::init!("Elixir.Candlex", [add]); +rustler::init! { + "Elixir.Candlex.Native", + [ + candlex::scalar_tensor, + candlex::to_binary, + candlex::from_binary + ], + load = candlex::load +} diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 9a9f9e2799..abbcc6de33 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1,8 +1,25 @@ defmodule CandlexTest do - use ExUnit.Case + use ExUnit.Case, async: true doctest Candlex - test "greets the world" do - assert Candlex.hello() == :world + describe "creation" do + test "tensor" do + Nx.tensor(100_002, type: :u32, backend: Candlex.Backend) + |> IO.inspect() + |> Nx.to_binary() + |> IO.inspect() + + Nx.tensor([1, 2], type: :u32, backend: Candlex.Backend) + |> IO.inspect() + |> Nx.to_binary() + |> IO.inspect() + + # Nx.tensor([[1, 2], [3, 4]], type: :u32, backend: Candlex.Backend) + # |> IO.inspect() + # |> Nx.to_binary() + # |> IO.inspect() + + # assert Nx.backend_transfer(tensor) == Nx.tensor(1, type: :u32, backend: Nx.BinaryBackend) + end end end diff --git a/candlex/test/test_helper.exs b/candlex/test/test_helper.exs index 869559e709..13695bf434 100644 --- a/candlex/test/test_helper.exs +++ b/candlex/test/test_helper.exs @@ -1 +1,2 @@ +Application.put_env(:nx, :default_backend, Candlex.Backend) ExUnit.start() From 2e600895c2c559ab517fb6f609a370d7e01a6838 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 13:12:17 -0300 Subject: [PATCH 006/185] support multi dimension tensor --- candlex/lib/candlex/backend.ex | 4 ++-- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/candlex.rs | 14 +++++++++++--- candlex/test/candlex_test.exs | 18 ++++++++++++++---- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 9746688af0..e82d810830 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -31,11 +31,11 @@ defmodule Candlex.Backend do end # @impl true - def from_binary(%T{shape: {2}, type: {:u, 32}} = tensor, binary, _backend_options) do + def from_binary(%T{shape: shape, type: {:u, 32}} = tensor, binary, _backend_options) do # TODO: Don't ignore backend options binary - |> Native.from_binary() + |> Native.from_binary(shape) |> to_nx(tensor) end diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 04a93157e4..48176d418f 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -6,7 +6,7 @@ defmodule Candlex.Native do # Rustler will override all the below stub functions with real NIFs def scalar_tensor(_scalar), do: error() def to_binary(_tensor), do: error() - def from_binary(_binary), do: error() + def from_binary(_binary, _shape), do: error() defp error(), do: :erlang.nif_error(:nif_not_loaded) end diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index 44d4b184e9..65524746de 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -52,7 +52,7 @@ pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Binary { } #[rustler::nif(schedule = "DirtyCpu")] -pub fn from_binary(binary: Binary) -> ExTensor { +pub fn from_binary(binary: Binary, shape: Term) -> ExTensor { let slice = binary.as_slice(); // let slice: &[u32; 2] = unsafe { @@ -60,9 +60,9 @@ pub fn from_binary(binary: Binary) -> ExTensor { // }; let (_prefx, slice, _suffix) = unsafe { slice.align_to::() }; - println!("{:?}", &slice); + // println!("{:?}", &slice); // let slice = u32::from_ne_bytes(slice); - ExTensor::new(Tensor::from_slice(slice, 2, &Device::Cpu).unwrap()) + ExTensor::new(Tensor::from_slice(slice, tuple_to_vec(shape), &Device::Cpu).unwrap()) // let mut vec = vec![0u32; 2]; // reader.read_u32_into::(&mut vec)?; @@ -73,3 +73,11 @@ pub fn load(env: Env, _info: Term) -> bool { rustler::resource!(TensorRef, env); true } + +fn tuple_to_vec(term: Term) -> Vec { + rustler::types::tuple::get_tuple(term) + .unwrap() + .iter() + .map(|elem| elem.decode().unwrap()) + .collect() +} diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index abbcc6de33..565876f69f 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -14,11 +14,21 @@ defmodule CandlexTest do |> Nx.to_binary() |> IO.inspect() - # Nx.tensor([[1, 2], [3, 4]], type: :u32, backend: Candlex.Backend) - # |> IO.inspect() - # |> Nx.to_binary() - # |> IO.inspect() + Nx.tensor([[1, 2], [3, 4]], type: :u32, backend: Candlex.Backend) + |> IO.inspect() + |> Nx.to_binary() + |> IO.inspect() + Nx.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], type: :u32, backend: Candlex.Backend) + |> IO.inspect() + |> Nx.to_binary() + |> IO.inspect() + + + Nx.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], type: :u32, backend: Candlex.Backend) + |> IO.inspect() + |> Nx.to_binary() + |> IO.inspect() # assert Nx.backend_transfer(tensor) == Nx.tensor(1, type: :u32, backend: Nx.BinaryBackend) end end From cd43dfe62560e4ee041cadbc16c23232fa7e9b34 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 13:19:01 -0300 Subject: [PATCH 007/185] clean --- candlex/native/candlex/src/candlex.rs | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index 65524746de..09fad77320 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -53,20 +53,9 @@ pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Binary { #[rustler::nif(schedule = "DirtyCpu")] pub fn from_binary(binary: Binary, shape: Term) -> ExTensor { - let slice = binary.as_slice(); + let (_prefx, slice, _suffix) = unsafe { binary.as_slice().align_to::() }; - // let slice: &[u32; 2] = unsafe { - // std::mem::transmute::<&[u8], &[u32; 2]>(binary.as_slice()) - // }; - let (_prefx, slice, _suffix) = unsafe { slice.align_to::() }; - - // println!("{:?}", &slice); - // let slice = u32::from_ne_bytes(slice); ExTensor::new(Tensor::from_slice(slice, tuple_to_vec(shape), &Device::Cpu).unwrap()) - - // let mut vec = vec![0u32; 2]; - // reader.read_u32_into::(&mut vec)?; - // ExTensor::new(Tensor::from_vec(vec, 2, &Device::Cpu).unwrap()) } pub fn load(env: Env, _info: Term) -> bool { From 56206af415cd060513bfb8234c98dacdff66038b Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 13:20:37 -0300 Subject: [PATCH 008/185] clean --- candlex/native/candlex/src/candlex.rs | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index 09fad77320..e68d419ace 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -40,10 +40,15 @@ pub fn scalar_tensor(scalar: u32) -> ExTensor { #[rustler::nif(schedule = "DirtyCpu")] pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Binary { - let tensor = ex_tensor.flatten_all().unwrap(); - let vec = tensor.to_vec1::().unwrap(); - - let bytes: Vec = vec.iter().flat_map(|val| val.to_ne_bytes().to_vec()).collect(); + let bytes: Vec = + ex_tensor + .flatten_all() + .unwrap() + .to_vec1::() + .unwrap() + .iter() + .flat_map(|val| val.to_ne_bytes().to_vec()) + .collect(); let mut binary = NewBinary::new(env, bytes.len()); binary.as_mut_slice().copy_from_slice(bytes.as_slice()); From 875f6699faee4d2c41597e6ea03fff6b2eb27aaa Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 13:21:55 -0300 Subject: [PATCH 009/185] clean --- candlex/native/candlex/src/candlex.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index e68d419ace..49e32a4941 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -24,7 +24,7 @@ impl ExTensor { } } -// Implement Deref so we can call `Tensor` functions directly from a `ExTensor` struct. +// Implement Deref so we can call `Tensor` functions directly from an `ExTensor` struct. impl Deref for ExTensor { type Target = Tensor; From 868bf8813f51227eee2be4ee0db4d49d4f5f0296 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 13:23:11 -0300 Subject: [PATCH 010/185] clean --- candlex/native/candlex/src/candlex.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index 49e32a4941..4dd4896083 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -47,7 +47,7 @@ pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Binary { .to_vec1::() .unwrap() .iter() - .flat_map(|val| val.to_ne_bytes().to_vec()) + .flat_map(|val| val.to_ne_bytes()) .collect(); let mut binary = NewBinary::new(env, bytes.len()); From fcec30cbebd9dfebd5919f987a6bbbcfda2c5300 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 13:24:17 -0300 Subject: [PATCH 011/185] fn order --- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/candlex.rs | 14 +++++++------- candlex/native/candlex/src/lib.rs | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 48176d418f..430467c855 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -5,8 +5,8 @@ defmodule Candlex.Native do # Rustler will override all the below stub functions with real NIFs def scalar_tensor(_scalar), do: error() - def to_binary(_tensor), do: error() def from_binary(_binary, _shape), do: error() + def to_binary(_tensor), do: error() defp error(), do: :erlang.nif_error(:nif_not_loaded) end diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index 4dd4896083..dd37c2adfb 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -38,6 +38,13 @@ pub fn scalar_tensor(scalar: u32) -> ExTensor { ExTensor::new(Tensor::new(scalar, &Device::Cpu).unwrap()) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn from_binary(binary: Binary, shape: Term) -> ExTensor { + let (_prefx, slice, _suffix) = unsafe { binary.as_slice().align_to::() }; + + ExTensor::new(Tensor::from_slice(slice, tuple_to_vec(shape), &Device::Cpu).unwrap()) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Binary { let bytes: Vec = @@ -56,13 +63,6 @@ pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Binary { binary.into() } -#[rustler::nif(schedule = "DirtyCpu")] -pub fn from_binary(binary: Binary, shape: Term) -> ExTensor { - let (_prefx, slice, _suffix) = unsafe { binary.as_slice().align_to::() }; - - ExTensor::new(Tensor::from_slice(slice, tuple_to_vec(shape), &Device::Cpu).unwrap()) -} - pub fn load(env: Env, _info: Term) -> bool { rustler::resource!(TensorRef, env); true diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 500cee350e..444bdd0e96 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -4,8 +4,8 @@ rustler::init! { "Elixir.Candlex.Native", [ candlex::scalar_tensor, - candlex::to_binary, - candlex::from_binary + candlex::from_binary, + candlex::to_binary ], load = candlex::load } From 815a7337af25a0453e5111a5f4dfe628aafa27d0 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 13:26:22 -0300 Subject: [PATCH 012/185] fmt --- candlex/native/candlex/src/candlex.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index dd37c2adfb..692c13b393 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -1,4 +1,4 @@ -use candle_core::{Tensor, Device}; +use candle_core::{Device, Tensor}; use rustler::{Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; use std::ops::Deref; @@ -47,8 +47,7 @@ pub fn from_binary(binary: Binary, shape: Term) -> ExTensor { #[rustler::nif(schedule = "DirtyCpu")] pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Binary { - let bytes: Vec = - ex_tensor + let bytes: Vec = ex_tensor .flatten_all() .unwrap() .to_vec1::() From 6ea27bcf3f99bcdd6bd1ccb60ba22e10f0608673 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 13:56:16 -0300 Subject: [PATCH 013/185] backend_copy --- candlex/lib/candlex/backend.ex | 7 ++++++- candlex/test/candlex_test.exs | 8 +++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index e82d810830..955e611beb 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -39,6 +39,11 @@ defmodule Candlex.Backend do |> to_nx(tensor) end + # @impl true + def backend_copy(%T{} = tensor, backend, backend_options) do + backend.from_binary(tensor, to_binary(tensor), backend_options) + end + # @impl true def inspect(%T{} = tensor, inspect_opts) do limit = if inspect_opts.limit == :infinity, do: :infinity, else: inspect_opts.limit + 1 @@ -49,7 +54,7 @@ defmodule Candlex.Backend do |> maybe_add_signature(tensor) end - def to_binary(tensor, _limit) do + def to_binary(tensor, _limit \\ nil) do # TODO: don't ignore limit from_nx(tensor) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 565876f69f..5be0156ad7 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -4,7 +4,9 @@ defmodule CandlexTest do describe "creation" do test "tensor" do - Nx.tensor(100_002, type: :u32, backend: Candlex.Backend) + tensor = Nx.tensor(100_002, type: :u32, backend: Candlex.Backend) + + tensor |> IO.inspect() |> Nx.to_binary() |> IO.inspect() @@ -24,12 +26,12 @@ defmodule CandlexTest do |> Nx.to_binary() |> IO.inspect() - Nx.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], type: :u32, backend: Candlex.Backend) |> IO.inspect() |> Nx.to_binary() |> IO.inspect() - # assert Nx.backend_transfer(tensor) == Nx.tensor(1, type: :u32, backend: Nx.BinaryBackend) + + assert Nx.backend_copy(tensor) == Nx.tensor(100_002, type: :u32, backend: Nx.BinaryBackend) end end end From e44868d3a1aaffcde348ffa2ca7fd4edc43fe3b0 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 14:05:27 -0300 Subject: [PATCH 014/185] backend_transfer --- candlex/lib/candlex/backend.ex | 12 ++++++++++++ candlex/test/candlex_test.exs | 1 + 2 files changed, 13 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 955e611beb..44922e0acc 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -44,6 +44,18 @@ defmodule Candlex.Backend do backend.from_binary(tensor, to_binary(tensor), backend_options) end + # @impl true + def backend_transfer(tensor, backend, backend_options) do + backend_copy(tensor, backend, backend_options) + after + backend_deallocate(tensor) + end + + # @impl true + def backend_deallocate(%T{} = _tensor) do + true + end + # @impl true def inspect(%T{} = tensor, inspect_opts) do limit = if inspect_opts.limit == :infinity, do: :infinity, else: inspect_opts.limit + 1 diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 5be0156ad7..d628d56af5 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -32,6 +32,7 @@ defmodule CandlexTest do |> IO.inspect() assert Nx.backend_copy(tensor) == Nx.tensor(100_002, type: :u32, backend: Nx.BinaryBackend) + assert Nx.backend_transfer(tensor) == Nx.tensor(100_002, type: :u32, backend: Nx.BinaryBackend) end end end From 6988885a0c6ed06f87333af9a5cc8ee4ece1891a Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 14:14:07 -0300 Subject: [PATCH 015/185] enable behaviour --- candlex/lib/candlex/backend.ex | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 44922e0acc..770f3ec093 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -5,14 +5,13 @@ defmodule Candlex.Backend do defstruct [:resource] - # TODO: enable behaviour - # @behaviour Nx.Backend + @behaviour Nx.Backend alias Nx.Tensor, as: T alias Candlex.Backend, as: CB alias Candlex.Native - # @impl true + @impl true def init(opts) do if opts != [] do raise ArgumentError, "Candlex.Backend accepts no options" @@ -21,7 +20,7 @@ defmodule Candlex.Backend do opts end - # @impl true + @impl true def constant(%T{shape: {}, type: _type} = tensor, scalar, _backend_options) do # TODO: Don't ignore backend options @@ -30,7 +29,7 @@ defmodule Candlex.Backend do |> to_nx(tensor) end - # @impl true + @impl true def from_binary(%T{shape: shape, type: {:u, 32}} = tensor, binary, _backend_options) do # TODO: Don't ignore backend options @@ -39,24 +38,24 @@ defmodule Candlex.Backend do |> to_nx(tensor) end - # @impl true + @impl true def backend_copy(%T{} = tensor, backend, backend_options) do backend.from_binary(tensor, to_binary(tensor), backend_options) end - # @impl true + @impl true def backend_transfer(tensor, backend, backend_options) do backend_copy(tensor, backend, backend_options) after backend_deallocate(tensor) end - # @impl true + @impl true def backend_deallocate(%T{} = _tensor) do true end - # @impl true + @impl true def inspect(%T{} = tensor, inspect_opts) do limit = if inspect_opts.limit == :infinity, do: :infinity, else: inspect_opts.limit + 1 @@ -66,6 +65,7 @@ defmodule Candlex.Backend do |> maybe_add_signature(tensor) end + @impl true def to_binary(tensor, _limit \\ nil) do # TODO: don't ignore limit From d36aca6ef9131571f2247be4f317c335fe729301 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 15:09:43 -0300 Subject: [PATCH 016/185] f64 --- candlex/lib/candlex/backend.ex | 8 ++++++-- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/candlex.rs | 27 +++++++++++++++------------ candlex/test/candlex_test.exs | 5 +++++ 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 770f3ec093..b93c7f8752 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -30,11 +30,11 @@ defmodule Candlex.Backend do end @impl true - def from_binary(%T{shape: shape, type: {:u, 32}} = tensor, binary, _backend_options) do + def from_binary(%T{shape: shape, type: type} = tensor, binary, _backend_options) do # TODO: Don't ignore backend options binary - |> Native.from_binary(shape) + |> Native.from_binary(to_candle_dtype(type), shape) |> to_nx(tensor) end @@ -96,4 +96,8 @@ defmodule Candlex.Backend do defp to_nx(%{resource: ref} = backend_tensor, %T{} = t) when is_reference(ref) do %{t | data: backend_tensor} end + + defp to_candle_dtype({:u, 32}), do: "u32" + defp to_candle_dtype({:f, 32}), do: "f32" + defp to_candle_dtype({:f, 64}), do: "f64" end diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 430467c855..f5d92d3c5e 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -5,7 +5,7 @@ defmodule Candlex.Native do # Rustler will override all the below stub functions with real NIFs def scalar_tensor(_scalar), do: error() - def from_binary(_binary, _shape), do: error() + def from_binary(_binary, _dtype, _shape), do: error() def to_binary(_tensor), do: error() defp error(), do: :erlang.nif_error(:nif_not_loaded) diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index 692c13b393..8461e7abe2 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -1,6 +1,7 @@ -use candle_core::{Device, Tensor}; +use candle_core::{Device, DType, Tensor}; use rustler::{Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; use std::ops::Deref; +use std::str::FromStr; pub struct TensorRef(pub Tensor); @@ -39,22 +40,24 @@ pub fn scalar_tensor(scalar: u32) -> ExTensor { } #[rustler::nif(schedule = "DirtyCpu")] -pub fn from_binary(binary: Binary, shape: Term) -> ExTensor { - let (_prefx, slice, _suffix) = unsafe { binary.as_slice().align_to::() }; +pub fn from_binary(binary: Binary, dtype: &str, shape: Term) -> ExTensor { + let dtype = DType::from_str(dtype).unwrap(); + // let (_prefx, slice, _suffix) = unsafe { binary.as_slice().align_to::() }; - ExTensor::new(Tensor::from_slice(slice, tuple_to_vec(shape), &Device::Cpu).unwrap()) + ExTensor::new(Tensor::from_raw_buffer(binary.as_slice(), dtype, &tuple_to_vec(shape), &Device::Cpu).unwrap()) + // ExTensor::new(Tensor::from_slice(slice, tuple_to_vec(shape), &Device::Cpu).unwrap()) } #[rustler::nif(schedule = "DirtyCpu")] pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Binary { - let bytes: Vec = ex_tensor - .flatten_all() - .unwrap() - .to_vec1::() - .unwrap() - .iter() - .flat_map(|val| val.to_ne_bytes()) - .collect(); + let tensor = ex_tensor.flatten_all().unwrap(); + + let bytes: Vec = match tensor.dtype() { + DType::U32 => tensor.to_vec1::().unwrap().iter().flat_map(|val| val.to_ne_bytes()).collect(), + DType::F64 => tensor.to_vec1::().unwrap().iter().flat_map(|val| val.to_ne_bytes()).collect(), + // TODO: Support all dtypes + _ => tensor.to_vec1::().unwrap().iter().flat_map(|val| val.to_ne_bytes()).collect() + }; let mut binary = NewBinary::new(env, bytes.len()); binary.as_mut_slice().copy_from_slice(bytes.as_slice()); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index d628d56af5..9083101fef 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -33,6 +33,11 @@ defmodule CandlexTest do assert Nx.backend_copy(tensor) == Nx.tensor(100_002, type: :u32, backend: Nx.BinaryBackend) assert Nx.backend_transfer(tensor) == Nx.tensor(100_002, type: :u32, backend: Nx.BinaryBackend) + + Nx.tensor([-0.5, 0.88], type: :f64, backend: Candlex.Backend) + |> IO.inspect() + |> Nx.to_binary() + |> IO.inspect() end end end From be7b29d38534413bdbeb714adf6f64eb9f2af659 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 15:49:19 -0300 Subject: [PATCH 017/185] f32 --- candlex/native/candlex/src/candlex.rs | 36 +++++++++++++++++++++------ candlex/test/candlex_test.exs | 5 ++++ 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index 8461e7abe2..a856c507ba 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -1,4 +1,4 @@ -use candle_core::{Device, DType, Tensor}; +use candle_core::{DType, Device, Tensor}; use rustler::{Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; use std::ops::Deref; use std::str::FromStr; @@ -42,10 +42,11 @@ pub fn scalar_tensor(scalar: u32) -> ExTensor { #[rustler::nif(schedule = "DirtyCpu")] pub fn from_binary(binary: Binary, dtype: &str, shape: Term) -> ExTensor { let dtype = DType::from_str(dtype).unwrap(); - // let (_prefx, slice, _suffix) = unsafe { binary.as_slice().align_to::() }; - ExTensor::new(Tensor::from_raw_buffer(binary.as_slice(), dtype, &tuple_to_vec(shape), &Device::Cpu).unwrap()) - // ExTensor::new(Tensor::from_slice(slice, tuple_to_vec(shape), &Device::Cpu).unwrap()) + ExTensor::new( + Tensor::from_raw_buffer(binary.as_slice(), dtype, &tuple_to_vec(shape), &Device::Cpu) + .unwrap(), + ) } #[rustler::nif(schedule = "DirtyCpu")] @@ -53,10 +54,31 @@ pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Binary { let tensor = ex_tensor.flatten_all().unwrap(); let bytes: Vec = match tensor.dtype() { - DType::U32 => tensor.to_vec1::().unwrap().iter().flat_map(|val| val.to_ne_bytes()).collect(), - DType::F64 => tensor.to_vec1::().unwrap().iter().flat_map(|val| val.to_ne_bytes()).collect(), + DType::U32 => tensor + .to_vec1::() + .unwrap() + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::F32 => tensor + .to_vec1::() + .unwrap() + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::F64 => tensor + .to_vec1::() + .unwrap() + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), // TODO: Support all dtypes - _ => tensor.to_vec1::().unwrap().iter().flat_map(|val| val.to_ne_bytes()).collect() + _ => tensor + .to_vec1::() + .unwrap() + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), }; let mut binary = NewBinary::new(env, bytes.len()); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 9083101fef..af07ce30db 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -34,6 +34,11 @@ defmodule CandlexTest do assert Nx.backend_copy(tensor) == Nx.tensor(100_002, type: :u32, backend: Nx.BinaryBackend) assert Nx.backend_transfer(tensor) == Nx.tensor(100_002, type: :u32, backend: Nx.BinaryBackend) + Nx.tensor([-0.5, 0.88], type: :f32, backend: Candlex.Backend) + |> IO.inspect() + |> Nx.to_binary() + |> IO.inspect() + Nx.tensor([-0.5, 0.88], type: :f64, backend: Candlex.Backend) |> IO.inspect() |> Nx.to_binary() From 26008e9f749dce3fd6c97e487f76381328541c2c Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 15:51:34 -0300 Subject: [PATCH 018/185] u8 --- candlex/lib/candlex/backend.ex | 1 + candlex/native/candlex/src/candlex.rs | 6 ++++++ candlex/test/candlex_test.exs | 5 +++++ 3 files changed, 12 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index b93c7f8752..4615e71c4c 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -97,6 +97,7 @@ defmodule Candlex.Backend do %{t | data: backend_tensor} end + defp to_candle_dtype({:u, 8}), do: "u8" defp to_candle_dtype({:u, 32}), do: "u32" defp to_candle_dtype({:f, 32}), do: "f32" defp to_candle_dtype({:f, 64}), do: "f64" diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index a856c507ba..66d6279cab 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -54,6 +54,12 @@ pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Binary { let tensor = ex_tensor.flatten_all().unwrap(); let bytes: Vec = match tensor.dtype() { + DType::U8 => tensor + .to_vec1::() + .unwrap() + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), DType::U32 => tensor .to_vec1::() .unwrap() diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index af07ce30db..0782657b37 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -34,6 +34,11 @@ defmodule CandlexTest do assert Nx.backend_copy(tensor) == Nx.tensor(100_002, type: :u32, backend: Nx.BinaryBackend) assert Nx.backend_transfer(tensor) == Nx.tensor(100_002, type: :u32, backend: Nx.BinaryBackend) + Nx.tensor([0, 255], type: :u8, backend: Candlex.Backend) + |> IO.inspect() + |> Nx.to_binary() + |> IO.inspect() + Nx.tensor([-0.5, 0.88], type: :f32, backend: Candlex.Backend) |> IO.inspect() |> Nx.to_binary() From a04345c9774e02a3b3a7d85d0f7532fd7fb2bd0e Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 16:17:31 -0300 Subject: [PATCH 019/185] more scalar tensor types --- candlex/lib/candlex/backend.ex | 10 ++--- candlex/lib/candlex/native.ex | 1 - candlex/native/candlex/src/candlex.rs | 5 --- candlex/native/candlex/src/lib.rs | 1 - candlex/test/candlex_test.exs | 63 +++++++++------------------ 5 files changed, 25 insertions(+), 55 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 4615e71c4c..19725c58e0 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -21,12 +21,10 @@ defmodule Candlex.Backend do end @impl true - def constant(%T{shape: {}, type: _type} = tensor, scalar, _backend_options) do - # TODO: Don't ignore backend options - - scalar - |> Native.scalar_tensor() - |> to_nx(tensor) + def constant(%T{} = tensor, scalar, backend_options) do + tensor + |> Nx.BinaryBackend.constant(scalar, []) + |> Nx.BinaryBackend.backend_transfer(__MODULE__, backend_options) end @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index f5d92d3c5e..664626e9a9 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -4,7 +4,6 @@ defmodule Candlex.Native do use Rustler, otp_app: :candlex, crate: "candlex" # Rustler will override all the below stub functions with real NIFs - def scalar_tensor(_scalar), do: error() def from_binary(_binary, _dtype, _shape), do: error() def to_binary(_tensor), do: error() diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index 66d6279cab..0bf65e6db8 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -34,11 +34,6 @@ impl Deref for ExTensor { } } -#[rustler::nif(schedule = "DirtyCpu")] -pub fn scalar_tensor(scalar: u32) -> ExTensor { - ExTensor::new(Tensor::new(scalar, &Device::Cpu).unwrap()) -} - #[rustler::nif(schedule = "DirtyCpu")] pub fn from_binary(binary: Binary, dtype: &str, shape: Term) -> ExTensor { let dtype = DType::from_str(dtype).unwrap(); diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 444bdd0e96..2ba26e864a 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -3,7 +3,6 @@ mod candlex; rustler::init! { "Elixir.Candlex.Native", [ - candlex::scalar_tensor, candlex::from_binary, candlex::to_binary ], diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 0782657b37..34f06042ae 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -4,50 +4,29 @@ defmodule CandlexTest do describe "creation" do test "tensor" do - tensor = Nx.tensor(100_002, type: :u32, backend: Candlex.Backend) - - tensor - |> IO.inspect() - |> Nx.to_binary() - |> IO.inspect() - - Nx.tensor([1, 2], type: :u32, backend: Candlex.Backend) - |> IO.inspect() - |> Nx.to_binary() - |> IO.inspect() - - Nx.tensor([[1, 2], [3, 4]], type: :u32, backend: Candlex.Backend) - |> IO.inspect() - |> Nx.to_binary() - |> IO.inspect() - - Nx.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], type: :u32, backend: Candlex.Backend) - |> IO.inspect() - |> Nx.to_binary() - |> IO.inspect() - - Nx.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], type: :u32, backend: Candlex.Backend) - |> IO.inspect() - |> Nx.to_binary() - |> IO.inspect() - - assert Nx.backend_copy(tensor) == Nx.tensor(100_002, type: :u32, backend: Nx.BinaryBackend) - assert Nx.backend_transfer(tensor) == Nx.tensor(100_002, type: :u32, backend: Nx.BinaryBackend) + check(255, :u8) + check(100_002, :u32) + check(1.11, :f32) + check(-0.002, :f64) + check([1, 2], :u32) + check([[1, 2], [3, 4]], :u32) + check([[1, 2, 3, 4], [5, 6, 7, 8]], :u32) + check([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], :u32) + check([0, 255], :u8) + check([-0.5, 0.88], :f32) + check([-0.5, 0.88], :f64) + end + end - Nx.tensor([0, 255], type: :u8, backend: Candlex.Backend) - |> IO.inspect() - |> Nx.to_binary() - |> IO.inspect() + defp check(value, type) do + tensor = Nx.tensor(value, type: type, backend: Candlex.Backend) - Nx.tensor([-0.5, 0.88], type: :f32, backend: Candlex.Backend) - |> IO.inspect() - |> Nx.to_binary() - |> IO.inspect() + tensor + |> IO.inspect() + |> Nx.to_binary() + |> IO.inspect() - Nx.tensor([-0.5, 0.88], type: :f64, backend: Candlex.Backend) - |> IO.inspect() - |> Nx.to_binary() - |> IO.inspect() - end + assert Nx.backend_copy(tensor) == Nx.tensor(value, type: type, backend: Nx.BinaryBackend) + assert Nx.backend_transfer(tensor) == Nx.tensor(value, type: type, backend: Nx.BinaryBackend) end end From 89b84c402e62d8c92d5307c79ca20e7da496df90 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 16:18:14 -0300 Subject: [PATCH 020/185] clean --- candlex/lib/candlex/backend.ex | 2 -- 1 file changed, 2 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 19725c58e0..f80161203b 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -79,8 +79,6 @@ defmodule Candlex.Backend do ]) end - # defp to_candle_type({:u, 32}), do: :u32 - # defp device_option(_backend_options) do # # TODO: Support CUDA # :cpu From fe6f41ec83129f96658a6e775a28dad50bf5189c Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 21 Aug 2023 16:34:32 -0300 Subject: [PATCH 021/185] clean --- candlex/native/candlex/src/candlex.rs | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index 0bf65e6db8..c647ad0b0b 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -35,12 +35,15 @@ impl Deref for ExTensor { } #[rustler::nif(schedule = "DirtyCpu")] -pub fn from_binary(binary: Binary, dtype: &str, shape: Term) -> ExTensor { - let dtype = DType::from_str(dtype).unwrap(); - +pub fn from_binary(binary: Binary, dtype_str: &str, shape: Term) -> ExTensor { ExTensor::new( - Tensor::from_raw_buffer(binary.as_slice(), dtype, &tuple_to_vec(shape), &Device::Cpu) - .unwrap(), + Tensor::from_raw_buffer( + binary.as_slice(), + DType::from_str(dtype_str).unwrap(), + &tuple_to_vec(shape), + &Device::Cpu, + ) + .unwrap(), ) } From 55a0693e5cca514f23ba7745582d63a0f341c8ac Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 22 Aug 2023 14:45:27 -0300 Subject: [PATCH 022/185] handle errors from_binary native --- candlex/lib/candlex/backend.ex | 4 ++++ candlex/native/candlex/Cargo.lock | 1 + candlex/native/candlex/Cargo.toml | 1 + candlex/native/candlex/src/candlex.rs | 20 ++++++++++++-------- candlex/native/candlex/src/error.rs | 19 +++++++++++++++++++ candlex/native/candlex/src/lib.rs | 1 + 6 files changed, 38 insertions(+), 8 deletions(-) create mode 100644 candlex/native/candlex/src/error.rs diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index f80161203b..61aa8db77a 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -33,6 +33,7 @@ defmodule Candlex.Backend do binary |> Native.from_binary(to_candle_dtype(type), shape) + |> unwrap!() |> to_nx(tensor) end @@ -97,4 +98,7 @@ defmodule Candlex.Backend do defp to_candle_dtype({:u, 32}), do: "u32" defp to_candle_dtype({:f, 32}), do: "f32" defp to_candle_dtype({:f, 64}), do: "f64" + + defp unwrap!({:ok, result}), do: result + defp unwrap!({:error, error}), do: raise("Candlex: " <> List.to_string(error)) end diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 1092a0ce56..2ae42f332e 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -186,6 +186,7 @@ version = "0.1.0" dependencies = [ "candle-core", "rustler", + "thiserror", ] [[package]] diff --git a/candlex/native/candlex/Cargo.toml b/candlex/native/candlex/Cargo.toml index 35bd43a188..069ca5099b 100644 --- a/candlex/native/candlex/Cargo.toml +++ b/candlex/native/candlex/Cargo.toml @@ -12,3 +12,4 @@ crate-type = ["cdylib"] [dependencies] candle-core = "0.1.0" rustler = "0.29.1" +thiserror = "1.0.47" diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index c647ad0b0b..66b41cdd10 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -1,6 +1,8 @@ +use crate::error::CandlexError; use candle_core::{DType, Device, Tensor}; use rustler::{Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; use std::ops::Deref; +use std::result::Result; use std::str::FromStr; pub struct TensorRef(pub Tensor); @@ -35,15 +37,17 @@ impl Deref for ExTensor { } #[rustler::nif(schedule = "DirtyCpu")] -pub fn from_binary(binary: Binary, dtype_str: &str, shape: Term) -> ExTensor { - ExTensor::new( - Tensor::from_raw_buffer( - binary.as_slice(), - DType::from_str(dtype_str).unwrap(), - &tuple_to_vec(shape), - &Device::Cpu, +pub fn from_binary(binary: Binary, dtype_str: &str, shape: Term) -> Result { + Ok( + ExTensor::new( + Tensor::from_raw_buffer( + binary.as_slice(), + // TODO: Handle DTypeParseError + DType::from_str(dtype_str).unwrap(), + &tuple_to_vec(shape), + &Device::Cpu, + )? ) - .unwrap(), ) } diff --git a/candlex/native/candlex/src/error.rs b/candlex/native/candlex/src/error.rs new file mode 100644 index 0000000000..a4d84cfb8a --- /dev/null +++ b/candlex/native/candlex/src/error.rs @@ -0,0 +1,19 @@ +use rustler::{Encoder, Env, Term}; +use thiserror::Error; + +// Defines the atoms for each value of ExplorerError. +rustler::atoms! { + candle, +} + +#[derive(Error, Debug)] +pub enum CandlexError { + #[error("Candle Error: {0}")] + Candle(#[from] candle_core::Error) +} + +impl Encoder for CandlexError { + fn encode<'b>(&self, env: Env<'b>) -> Term<'b> { + format!("{self}").encode(env) + } +} diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 2ba26e864a..4d7a4b810f 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -1,4 +1,5 @@ mod candlex; +mod error; rustler::init! { "Elixir.Candlex.Native", From c8ec618a5f90fb884a79139ac379889730c6e237 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 22 Aug 2023 15:15:41 -0300 Subject: [PATCH 023/185] handle errors to_binary native --- candlex/lib/candlex/backend.ex | 1 + candlex/native/candlex/src/candlex.rs | 74 +++++++++++++-------------- 2 files changed, 37 insertions(+), 38 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 61aa8db77a..09a913f030 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -70,6 +70,7 @@ defmodule Candlex.Backend do from_nx(tensor) |> Native.to_binary() + |> unwrap!() end defp maybe_add_signature(result, %T{data: %CB{resource: ref}}) when is_reference(ref) do diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index 66b41cdd10..cb4c0d231c 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -52,47 +52,12 @@ pub fn from_binary(binary: Binary, dtype_str: &str, shape: Term) -> Result Binary { - let tensor = ex_tensor.flatten_all().unwrap(); - - let bytes: Vec = match tensor.dtype() { - DType::U8 => tensor - .to_vec1::() - .unwrap() - .iter() - .flat_map(|val| val.to_ne_bytes()) - .collect(), - DType::U32 => tensor - .to_vec1::() - .unwrap() - .iter() - .flat_map(|val| val.to_ne_bytes()) - .collect(), - DType::F32 => tensor - .to_vec1::() - .unwrap() - .iter() - .flat_map(|val| val.to_ne_bytes()) - .collect(), - DType::F64 => tensor - .to_vec1::() - .unwrap() - .iter() - .flat_map(|val| val.to_ne_bytes()) - .collect(), - // TODO: Support all dtypes - _ => tensor - .to_vec1::() - .unwrap() - .iter() - .flat_map(|val| val.to_ne_bytes()) - .collect(), - }; - +pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Result { + let bytes = tensor_bytes(ex_tensor.flatten_all()?)?; let mut binary = NewBinary::new(env, bytes.len()); binary.as_mut_slice().copy_from_slice(bytes.as_slice()); - binary.into() + Ok(binary.into()) } pub fn load(env: Env, _info: Term) -> bool { @@ -107,3 +72,36 @@ fn tuple_to_vec(term: Term) -> Vec { .map(|elem| elem.decode().unwrap()) .collect() } + +fn tensor_bytes(tensor: Tensor) -> Result, CandlexError> { + Ok( + match tensor.dtype() { + DType::U8 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::U32 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::F32 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::F64 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + // TODO: Support all dtypes + _ => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + } + ) +} From d8682840e03607c27ed659bcafa90508013ef918 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 22 Aug 2023 15:54:29 -0300 Subject: [PATCH 024/185] no panic in tuple_to_vec --- candlex/native/candlex/src/candlex.rs | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index cb4c0d231c..a1a6f04f49 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -44,7 +44,8 @@ pub fn from_binary(binary: Binary, dtype_str: &str, shape: Term) -> Result bool { true } -fn tuple_to_vec(term: Term) -> Vec { - rustler::types::tuple::get_tuple(term) - .unwrap() +fn tuple_to_vec(term: Term) -> Result, rustler::Error> { + Ok( + rustler::types::tuple::get_tuple(term)? .iter() - .map(|elem| elem.decode().unwrap()) - .collect() + .map(|elem| elem.decode()) + .collect::>()? + ) } fn tensor_bytes(tensor: Tensor) -> Result, CandlexError> { From 753f55b1601eb002f500b81fa7f6aafd1fe9c1c5 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 22 Aug 2023 15:55:42 -0300 Subject: [PATCH 025/185] better comment --- candlex/native/candlex/src/candlex.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index a1a6f04f49..4b9b26cb85 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -98,7 +98,7 @@ fn tensor_bytes(tensor: Tensor) -> Result, CandlexError> { .iter() .flat_map(|val| val.to_ne_bytes()) .collect(), - // TODO: Support all dtypes + // TODO: Support f16 and bf16 _ => tensor .to_vec1::()? .iter() From 1a6ee73ea6dc6a92a1cecf7483532a91f189f9a9 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 22 Aug 2023 16:02:44 -0300 Subject: [PATCH 026/185] move ex tensor to mod --- candlex/native/candlex/src/candlex.rs | 35 ++------------------------- candlex/native/candlex/src/lib.rs | 3 ++- candlex/native/candlex/src/tensor.rs | 34 ++++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 34 deletions(-) create mode 100644 candlex/native/candlex/src/tensor.rs diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/candlex.rs index 4b9b26cb85..51a2798dbc 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/candlex.rs @@ -1,41 +1,10 @@ use crate::error::CandlexError; +use crate::tensor::{ExTensor, TensorRef}; use candle_core::{DType, Device, Tensor}; -use rustler::{Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; -use std::ops::Deref; +use rustler::{Binary, Env, NewBinary, Term}; use std::result::Result; use std::str::FromStr; -pub struct TensorRef(pub Tensor); - -#[derive(NifStruct)] -#[module = "Candlex.Backend"] -pub struct ExTensor { - pub resource: ResourceArc, -} - -impl TensorRef { - pub fn new(tensor: Tensor) -> Self { - Self(tensor) - } -} - -impl ExTensor { - pub fn new(tensor: Tensor) -> Self { - Self { - resource: ResourceArc::new(TensorRef::new(tensor)), - } - } -} - -// Implement Deref so we can call `Tensor` functions directly from an `ExTensor` struct. -impl Deref for ExTensor { - type Target = Tensor; - - fn deref(&self) -> &Self::Target { - &self.resource.0 - } -} - #[rustler::nif(schedule = "DirtyCpu")] pub fn from_binary(binary: Binary, dtype_str: &str, shape: Term) -> Result { Ok( diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 4d7a4b810f..d0923fce11 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -1,5 +1,6 @@ -mod candlex; mod error; +mod tensor; +mod candlex; rustler::init! { "Elixir.Candlex.Native", diff --git a/candlex/native/candlex/src/tensor.rs b/candlex/native/candlex/src/tensor.rs new file mode 100644 index 0000000000..55dc6300da --- /dev/null +++ b/candlex/native/candlex/src/tensor.rs @@ -0,0 +1,34 @@ +use candle_core::Tensor; +use rustler::{NifStruct, ResourceArc}; +use std::ops::Deref; + +pub struct TensorRef(pub Tensor); + +#[derive(NifStruct)] +#[module = "Candlex.Backend"] +pub struct ExTensor { + pub resource: ResourceArc, +} + +impl TensorRef { + pub fn new(tensor: Tensor) -> Self { + Self(tensor) + } +} + +impl ExTensor { + pub fn new(tensor: Tensor) -> Self { + Self { + resource: ResourceArc::new(TensorRef::new(tensor)), + } + } +} + +// Implement Deref so we can call `Tensor` functions directly from an `ExTensor` struct. +impl Deref for ExTensor { + type Target = Tensor; + + fn deref(&self) -> &Self::Target { + &self.resource.0 + } +} From 8a54be4d2520a2790de7d70ee2c806318073caee Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 22 Aug 2023 16:05:26 -0300 Subject: [PATCH 027/185] removes unnecessary pub resource --- candlex/native/candlex/src/tensor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candlex/native/candlex/src/tensor.rs b/candlex/native/candlex/src/tensor.rs index 55dc6300da..f8c92147b1 100644 --- a/candlex/native/candlex/src/tensor.rs +++ b/candlex/native/candlex/src/tensor.rs @@ -7,7 +7,7 @@ pub struct TensorRef(pub Tensor); #[derive(NifStruct)] #[module = "Candlex.Backend"] pub struct ExTensor { - pub resource: ResourceArc, + resource: ResourceArc, } impl TensorRef { From b5510401e8a32739d9c3d006e405970fe0d51314 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 22 Aug 2023 16:05:49 -0300 Subject: [PATCH 028/185] reorder structs --- candlex/native/candlex/src/tensor.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/candlex/native/candlex/src/tensor.rs b/candlex/native/candlex/src/tensor.rs index f8c92147b1..27651365e0 100644 --- a/candlex/native/candlex/src/tensor.rs +++ b/candlex/native/candlex/src/tensor.rs @@ -4,18 +4,18 @@ use std::ops::Deref; pub struct TensorRef(pub Tensor); -#[derive(NifStruct)] -#[module = "Candlex.Backend"] -pub struct ExTensor { - resource: ResourceArc, -} - impl TensorRef { pub fn new(tensor: Tensor) -> Self { Self(tensor) } } +#[derive(NifStruct)] +#[module = "Candlex.Backend"] +pub struct ExTensor { + resource: ResourceArc, +} + impl ExTensor { pub fn new(tensor: Tensor) -> Self { Self { From 00515b40a3663efd749464eb9a0ba896f02169db Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 22 Aug 2023 16:06:20 -0300 Subject: [PATCH 029/185] removes unnecessary pub tensor field --- candlex/native/candlex/src/tensor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candlex/native/candlex/src/tensor.rs b/candlex/native/candlex/src/tensor.rs index 27651365e0..4a5080ae7e 100644 --- a/candlex/native/candlex/src/tensor.rs +++ b/candlex/native/candlex/src/tensor.rs @@ -2,7 +2,7 @@ use candle_core::Tensor; use rustler::{NifStruct, ResourceArc}; use std::ops::Deref; -pub struct TensorRef(pub Tensor); +pub struct TensorRef(Tensor); impl TensorRef { pub fn new(tensor: Tensor) -> Self { From 5e08ecd13ee45090d8df94921ff8ddec324ca883 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 22 Aug 2023 16:09:19 -0300 Subject: [PATCH 030/185] remove unnecessary pub TensorRef --- candlex/native/candlex/src/tensor.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candlex/native/candlex/src/tensor.rs b/candlex/native/candlex/src/tensor.rs index 4a5080ae7e..1f195749f0 100644 --- a/candlex/native/candlex/src/tensor.rs +++ b/candlex/native/candlex/src/tensor.rs @@ -2,7 +2,7 @@ use candle_core::Tensor; use rustler::{NifStruct, ResourceArc}; use std::ops::Deref; -pub struct TensorRef(Tensor); +pub(crate) struct TensorRef(Tensor); impl TensorRef { pub fn new(tensor: Tensor) -> Self { From 1423ca5ffe0782d65f2ce2bdedfe536ca4d5a1c8 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 22 Aug 2023 16:13:04 -0300 Subject: [PATCH 031/185] removes unnecessary TensorRef::new --- candlex/native/candlex/src/tensor.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/candlex/native/candlex/src/tensor.rs b/candlex/native/candlex/src/tensor.rs index 1f195749f0..4052f47c8a 100644 --- a/candlex/native/candlex/src/tensor.rs +++ b/candlex/native/candlex/src/tensor.rs @@ -4,12 +4,6 @@ use std::ops::Deref; pub(crate) struct TensorRef(Tensor); -impl TensorRef { - pub fn new(tensor: Tensor) -> Self { - Self(tensor) - } -} - #[derive(NifStruct)] #[module = "Candlex.Backend"] pub struct ExTensor { @@ -19,7 +13,7 @@ pub struct ExTensor { impl ExTensor { pub fn new(tensor: Tensor) -> Self { Self { - resource: ResourceArc::new(TensorRef::new(tensor)), + resource: ResourceArc::new(TensorRef(tensor)), } } } From b3d2a43437efce5213638ba045d9c2a4ae3c2eb0 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 22 Aug 2023 16:25:08 -0300 Subject: [PATCH 032/185] refactor mods --- candlex/native/candlex/src/lib.rs | 17 ++++++--- candlex/native/candlex/src/tensor.rs | 28 --------------- .../candlex/src/{candlex.rs => tensors.rs} | 35 +++++++++++++++---- 3 files changed, 40 insertions(+), 40 deletions(-) delete mode 100644 candlex/native/candlex/src/tensor.rs rename candlex/native/candlex/src/{candlex.rs => tensors.rs} (78%) diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index d0923fce11..cf9b002dd0 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -1,12 +1,19 @@ mod error; -mod tensor; -mod candlex; +mod tensors; + +use rustler::{Env, Term}; +use tensors::TensorRef; + +fn load(env: Env, _info: Term) -> bool { + rustler::resource!(TensorRef, env); + true +} rustler::init! { "Elixir.Candlex.Native", [ - candlex::from_binary, - candlex::to_binary + tensors::from_binary, + tensors::to_binary ], - load = candlex::load + load = load } diff --git a/candlex/native/candlex/src/tensor.rs b/candlex/native/candlex/src/tensor.rs deleted file mode 100644 index 4052f47c8a..0000000000 --- a/candlex/native/candlex/src/tensor.rs +++ /dev/null @@ -1,28 +0,0 @@ -use candle_core::Tensor; -use rustler::{NifStruct, ResourceArc}; -use std::ops::Deref; - -pub(crate) struct TensorRef(Tensor); - -#[derive(NifStruct)] -#[module = "Candlex.Backend"] -pub struct ExTensor { - resource: ResourceArc, -} - -impl ExTensor { - pub fn new(tensor: Tensor) -> Self { - Self { - resource: ResourceArc::new(TensorRef(tensor)), - } - } -} - -// Implement Deref so we can call `Tensor` functions directly from an `ExTensor` struct. -impl Deref for ExTensor { - type Target = Tensor; - - fn deref(&self) -> &Self::Target { - &self.resource.0 - } -} diff --git a/candlex/native/candlex/src/candlex.rs b/candlex/native/candlex/src/tensors.rs similarity index 78% rename from candlex/native/candlex/src/candlex.rs rename to candlex/native/candlex/src/tensors.rs index 51a2798dbc..a89411667f 100644 --- a/candlex/native/candlex/src/candlex.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,10 +1,35 @@ use crate::error::CandlexError; -use crate::tensor::{ExTensor, TensorRef}; use candle_core::{DType, Device, Tensor}; -use rustler::{Binary, Env, NewBinary, Term}; +use rustler::{Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; +use std::ops::Deref; use std::result::Result; use std::str::FromStr; +pub(crate) struct TensorRef(Tensor); + +#[derive(NifStruct)] +#[module = "Candlex.Backend"] +pub struct ExTensor { + resource: ResourceArc, +} + +impl ExTensor { + pub fn new(tensor: Tensor) -> Self { + Self { + resource: ResourceArc::new(TensorRef(tensor)), + } + } +} + +// Implement Deref so we can call `Tensor` functions directly from an `ExTensor` struct. +impl Deref for ExTensor { + type Target = Tensor; + + fn deref(&self) -> &Self::Target { + &self.resource.0 + } +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn from_binary(binary: Binary, dtype_str: &str, shape: Term) -> Result { Ok( @@ -21,6 +46,7 @@ pub fn from_binary(binary: Binary, dtype_str: &str, shape: Term) -> Result Result { let bytes = tensor_bytes(ex_tensor.flatten_all()?)?; @@ -30,11 +56,6 @@ pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Result Ok(binary.into()) } -pub fn load(env: Env, _info: Term) -> bool { - rustler::resource!(TensorRef, env); - true -} - fn tuple_to_vec(term: Term) -> Result, rustler::Error> { Ok( rustler::types::tuple::get_tuple(term)? From 801ceb73304bf91ea6aebfd24e666eedb6eeada8 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 22 Aug 2023 16:46:01 -0300 Subject: [PATCH 033/185] update candle-core to v1.0.2 --- candlex/native/candlex/Cargo.lock | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 2ae42f332e..6898d66bd0 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -37,9 +37,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" -version = "0.1.1" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e56d08f7794036648d7ba5448c82ab7c3f38c25e90cdc4032afd246d1292a42c" +checksum = "1a14e585d4632a3c278c03d69db4e3595801cceeb72f329da099bb687025b05c" dependencies = [ "byteorder", "candle-gemm", @@ -581,18 +581,18 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.183" +version = "1.0.185" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32ac8da02677876d532745a130fc9d8e6edfa81a269b107c5b00829b91d8eb3c" +checksum = "be9b6f69f1dfd54c3b568ffa45c310d6973a5e5148fd40cf515acaf38cf5bc31" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.183" +version = "1.0.185" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aafe972d60b0b9bee71a91b92fee2d4fb3c9d7e8f6b179aa99f27203d99a4816" +checksum = "dc59dfdcbad1437773485e0367fea4b090a2e0a16d9ffc46af47764536a298ec" dependencies = [ "proc-macro2", "quote", From 384d36c2c91f505b38eb54a84e5463b744f8f65d Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 23 Aug 2023 09:41:14 -0300 Subject: [PATCH 034/185] support s64 after support added in candle --- candlex/lib/candlex/backend.ex | 3 ++- candlex/native/candlex/Cargo.lock | 5 ++--- candlex/native/candlex/Cargo.toml | 2 +- candlex/native/candlex/src/tensors.rs | 6 +++++- candlex/test/candlex_test.exs | 1 + 5 files changed, 11 insertions(+), 6 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 09a913f030..50eb5153c2 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -95,11 +95,12 @@ defmodule Candlex.Backend do %{t | data: backend_tensor} end + defp to_candle_dtype({:s, 64}), do: "i64" defp to_candle_dtype({:u, 8}), do: "u8" defp to_candle_dtype({:u, 32}), do: "u32" defp to_candle_dtype({:f, 32}), do: "f32" defp to_candle_dtype({:f, 64}), do: "f64" defp unwrap!({:ok, result}), do: result - defp unwrap!({:error, error}), do: raise("Candlex: " <> List.to_string(error)) + defp unwrap!({:error, error}), do: raise("Candlex: #{error}") end diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 6898d66bd0..6a633beb85 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -37,9 +37,8 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a14e585d4632a3c278c03d69db4e3595801cceeb72f329da099bb687025b05c" +version = "0.1.3" +source = "git+https://github.com/huggingface/candle#aba1e90797e430f28eec13b14b76dd5355876f9c" dependencies = [ "byteorder", "candle-gemm", diff --git a/candlex/native/candlex/Cargo.toml b/candlex/native/candlex/Cargo.toml index 069ca5099b..0ac3fed84f 100644 --- a/candlex/native/candlex/Cargo.toml +++ b/candlex/native/candlex/Cargo.toml @@ -10,6 +10,6 @@ path = "src/lib.rs" crate-type = ["cdylib"] [dependencies] -candle-core = "0.1.0" +candle-core = { git = "https://github.com/huggingface/candle" } rustler = "0.29.1" thiserror = "1.0.47" diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index a89411667f..554c25494f 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -46,7 +46,6 @@ pub fn from_binary(binary: Binary, dtype_str: &str, shape: Term) -> Result Result { let bytes = tensor_bytes(ex_tensor.flatten_all()?)?; @@ -68,6 +67,11 @@ fn tuple_to_vec(term: Term) -> Result, rustler::Error> { fn tensor_bytes(tensor: Tensor) -> Result, CandlexError> { Ok( match tensor.dtype() { + DType::I64 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), DType::U8 => tensor .to_vec1::()? .iter() diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 34f06042ae..68de14409f 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -6,6 +6,7 @@ defmodule CandlexTest do test "tensor" do check(255, :u8) check(100_002, :u32) + check(-101, :s64) check(1.11, :f32) check(-0.002, :f64) check([1, 2], :u32) From 0e268eaf3d9da79ba760b08131a483eda74ec44d Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 23 Aug 2023 09:41:43 -0300 Subject: [PATCH 035/185] test --- candlex/test/candlex_test.exs | 1 + 1 file changed, 1 insertion(+) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 68de14409f..1a4a1bf993 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -8,6 +8,7 @@ defmodule CandlexTest do check(100_002, :u32) check(-101, :s64) check(1.11, :f32) + check([1, 2, 3], :f32) check(-0.002, :f64) check([1, 2], :u32) check([[1, 2], [3, 4]], :u32) From 6800a05a35e42d9936d554dae89e5a33305aefef Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 23 Aug 2023 10:28:52 -0300 Subject: [PATCH 036/185] support half precision float - f16 and bf16 --- candlex/lib/candlex/backend.ex | 2 ++ candlex/native/candlex/Cargo.lock | 1 + candlex/native/candlex/Cargo.toml | 1 + candlex/native/candlex/src/tensors.rs | 11 ++++++++--- candlex/test/candlex_test.exs | 4 +++- 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 50eb5153c2..72190dd898 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -98,8 +98,10 @@ defmodule Candlex.Backend do defp to_candle_dtype({:s, 64}), do: "i64" defp to_candle_dtype({:u, 8}), do: "u8" defp to_candle_dtype({:u, 32}), do: "u32" + defp to_candle_dtype({:f, 16}), do: "f16" defp to_candle_dtype({:f, 32}), do: "f32" defp to_candle_dtype({:f, 64}), do: "f64" + defp to_candle_dtype({:bf, 16}), do: "bf16" defp unwrap!({:ok, result}), do: result defp unwrap!({:error, error}), do: raise("Candlex: #{error}") diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 6a633beb85..7d8031fc20 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -184,6 +184,7 @@ name = "candlex" version = "0.1.0" dependencies = [ "candle-core", + "half", "rustler", "thiserror", ] diff --git a/candlex/native/candlex/Cargo.toml b/candlex/native/candlex/Cargo.toml index 0ac3fed84f..4a0bd9dc07 100644 --- a/candlex/native/candlex/Cargo.toml +++ b/candlex/native/candlex/Cargo.toml @@ -11,5 +11,6 @@ crate-type = ["cdylib"] [dependencies] candle-core = { git = "https://github.com/huggingface/candle" } +half = "2.3.1" rustler = "0.29.1" thiserror = "1.0.47" diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 554c25494f..0830636bae 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,5 +1,6 @@ use crate::error::CandlexError; use candle_core::{DType, Device, Tensor}; +use half::{bf16, f16}; use rustler::{Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; use std::ops::Deref; use std::result::Result; @@ -82,6 +83,11 @@ fn tensor_bytes(tensor: Tensor) -> Result, CandlexError> { .iter() .flat_map(|val| val.to_ne_bytes()) .collect(), + DType::F16 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), DType::F32 => tensor .to_vec1::()? .iter() @@ -92,9 +98,8 @@ fn tensor_bytes(tensor: Tensor) -> Result, CandlexError> { .iter() .flat_map(|val| val.to_ne_bytes()) .collect(), - // TODO: Support f16 and bf16 - _ => tensor - .to_vec1::()? + DType::BF16 => tensor + .to_vec1::()? .iter() .flat_map(|val| val.to_ne_bytes()) .collect(), diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 1a4a1bf993..2c0e9eabac 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -7,7 +7,8 @@ defmodule CandlexTest do check(255, :u8) check(100_002, :u32) check(-101, :s64) - check(1.11, :f32) + check(1.16, :f16) + check(1.32, :f32) check([1, 2, 3], :f32) check(-0.002, :f64) check([1, 2], :u32) @@ -17,6 +18,7 @@ defmodule CandlexTest do check([0, 255], :u8) check([-0.5, 0.88], :f32) check([-0.5, 0.88], :f64) + check(2.16, :bf16) end end From 2b68418711bb75877d16fd6fed76a909be3d5a2f Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 23 Aug 2023 10:41:49 -0300 Subject: [PATCH 037/185] make clear which types no yet supported --- candlex/lib/candlex/backend.ex | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 72190dd898..c0a6261a06 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -95,13 +95,24 @@ defmodule Candlex.Backend do %{t | data: backend_tensor} end + defp to_candle_dtype({:s, 8}), do: unsupported_dtype() + defp to_candle_dtype({:s, 16}), do: unsupported_dtype() + defp to_candle_dtype({:s, 32}), do: unsupported_dtype() defp to_candle_dtype({:s, 64}), do: "i64" defp to_candle_dtype({:u, 8}), do: "u8" + defp to_candle_dtype({:u, 16}), do: unsupported_dtype() defp to_candle_dtype({:u, 32}), do: "u32" + defp to_candle_dtype({:u, 64}), do: unsupported_dtype() defp to_candle_dtype({:f, 16}), do: "f16" defp to_candle_dtype({:f, 32}), do: "f32" defp to_candle_dtype({:f, 64}), do: "f64" defp to_candle_dtype({:bf, 16}), do: "bf16" + defp to_candle_dtype({:c, 64}), do: unsupported_dtype() + defp to_candle_dtype({:c, 128}), do: unsupported_dtype() + + defp unsupported_dtype do + raise("Unsupported dtype") + end defp unwrap!({:ok, result}), do: result defp unwrap!({:error, error}), do: raise("Candlex: #{error}") From 2c4721d22702fb0c71db01c3dc42b68337cea808 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 23 Aug 2023 09:23:01 -0300 Subject: [PATCH 038/185] candlex tensor add --- candlex/lib/candlex/backend.ex | 62 ++++++++++++++++++++++----- candlex/lib/candlex/native.ex | 3 ++ candlex/mix.exs | 4 ++ candlex/native/candlex/src/lib.rs | 5 ++- candlex/native/candlex/src/tensors.rs | 27 ++++++++++++ candlex/test/candlex_test.exs | 12 +++++- candlex/test/support/candlex_case.ex | 30 +++++++++++++ 7 files changed, 131 insertions(+), 12 deletions(-) create mode 100644 candlex/test/support/candlex_case.ex diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index c0a6261a06..4354bb8fc5 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -20,6 +20,8 @@ defmodule Candlex.Backend do opts end + # Creation + @impl true def constant(%T{} = tensor, scalar, backend_options) do tensor @@ -37,6 +39,8 @@ defmodule Candlex.Backend do |> to_nx(tensor) end + # Backend + @impl true def backend_copy(%T{} = tensor, backend, backend_options) do backend.from_binary(tensor, to_binary(tensor), backend_options) @@ -54,15 +58,7 @@ defmodule Candlex.Backend do true end - @impl true - def inspect(%T{} = tensor, inspect_opts) do - limit = if inspect_opts.limit == :infinity, do: :infinity, else: inspect_opts.limit + 1 - - tensor - |> to_binary(min(limit, Nx.size(tensor))) - |> then(&Nx.Backend.inspect(tensor, &1, inspect_opts)) - |> maybe_add_signature(tensor) - end + # Conversion @impl true def to_binary(tensor, _limit \\ nil) do @@ -73,6 +69,48 @@ defmodule Candlex.Backend do |> unwrap!() end + # Aggregates + + @impl true + def all(%T{} = out, %T{} = tensor, _opts) do + from_nx(tensor) + |> Native.all() + |> unwrap!() + |> to_nx(out) + end + + # Binary ops + + @impl true + def add(%T{} = out, %T{} = left, %T{} = right) do + from_nx(left) + |> Native.add(from_nx(right)) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def equal(%T{} = out, %T{} = left, %T{} = right) do + from_nx(left) + |> Native.eq(from_nx(right)) + |> unwrap!() + |> to_nx(out) + end + + # Unary ops + + # Inspect + + @impl true + def inspect(%T{} = tensor, inspect_opts) do + limit = if inspect_opts.limit == :infinity, do: :infinity, else: inspect_opts.limit + 1 + + tensor + |> to_binary(min(limit, Nx.size(tensor))) + |> then(&Nx.Backend.inspect(tensor, &1, inspect_opts)) + |> maybe_add_signature(tensor) + end + defp maybe_add_signature(result, %T{data: %CB{resource: ref}}) when is_reference(ref) do Inspect.Algebra.concat([ "Candlex.Backend(#{:erlang.ref_to_list(ref)})", @@ -111,7 +149,11 @@ defmodule Candlex.Backend do defp to_candle_dtype({:c, 128}), do: unsupported_dtype() defp unsupported_dtype do - raise("Unsupported dtype") + raise("Unsupported candle dtype") + end + + defp unsupported_op do + raise("Unsupported candle op") end defp unwrap!({:ok, result}), do: result diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 664626e9a9..dbe9ef5e04 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -6,6 +6,9 @@ defmodule Candlex.Native do # Rustler will override all the below stub functions with real NIFs def from_binary(_binary, _dtype, _shape), do: error() def to_binary(_tensor), do: error() + def add(_left, _right), do: error() + def eq(_left, _right), do: error() + def all(_tensor), do: error() defp error(), do: :erlang.nif_error(:nif_not_loaded) end diff --git a/candlex/mix.exs b/candlex/mix.exs index d593fcff6f..ffd08acef8 100644 --- a/candlex/mix.exs +++ b/candlex/mix.exs @@ -6,6 +6,7 @@ defmodule Candlex.MixProject do app: :candlex, version: "0.1.0", elixir: "~> 1.15", + elixirc_paths: elixirc_paths(Mix.env()), start_permanent: Mix.env() == :prod, deps: deps() ] @@ -18,6 +19,9 @@ defmodule Candlex.MixProject do ] end + defp elixirc_paths(:test), do: ["lib", "test/support"] + defp elixirc_paths(_), do: ["lib"] + # Run "mix help deps" to learn about dependencies. defp deps do [ diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index cf9b002dd0..76cb4d29e9 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -13,7 +13,10 @@ rustler::init! { "Elixir.Candlex.Native", [ tensors::from_binary, - tensors::to_binary + tensors::to_binary, + tensors::add, + tensors::eq, + tensors::all ], load = load } diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 0830636bae..505d95d791 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -56,6 +56,33 @@ pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Result Ok(binary.into()) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn add(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new(left.add(&right)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn eq(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new(left.eq(&right)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn all(ex_tensor: ExTensor) -> Result { + let device = &Device::Cpu; + let t = ex_tensor.flatten_all()?; + let dims = t.shape().dims(); + let on_true = Tensor::ones(dims, DType::U8, device)?; + let on_false = Tensor::zeros(dims, DType::U8, device)?; + + let bool_scalar = + match t.where_cond(&on_true, &on_false)?.min(0)?.to_scalar::()? { + 0 => 0u8, + _ => 1u8 + }; + + Ok(ExTensor::new(Tensor::new(bool_scalar, device)?)) +} + fn tuple_to_vec(term: Term) -> Result, rustler::Error> { Ok( rustler::types::tuple::get_tuple(term)? diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 2c0e9eabac..980775056b 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1,5 +1,5 @@ defmodule CandlexTest do - use ExUnit.Case, async: true + use Candlex.Case, async: true doctest Candlex describe "creation" do @@ -20,6 +20,16 @@ defmodule CandlexTest do check([-0.5, 0.88], :f64) check(2.16, :bf16) end + + test "addition" do + t([1, 2, 3]) + |> Nx.add(t([10, 20, 30])) + |> assert_equal(t([11, 22, 33])) + end + end + + defp t(values, backend \\ Candlex.Backend) do + Nx.tensor(values, type: :u32, backend: backend) end defp check(value, type) do diff --git a/candlex/test/support/candlex_case.ex b/candlex/test/support/candlex_case.ex new file mode 100644 index 0000000000..d90b66eeef --- /dev/null +++ b/candlex/test/support/candlex_case.ex @@ -0,0 +1,30 @@ +defmodule Candlex.Case do + @moduledoc """ + Test case for tensor assertions + """ + + use ExUnit.CaseTemplate + + using do + quote do + import Candlex.Case + end + end + + def assert_equal(left, right) do + equals = + left + |> Nx.equal(right) + # |> Nx.logical_or(Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right))) + |> Nx.all() + |> Nx.to_number() + + if equals != 1 do + flunk(""" + Tensor assertion failed. + left: #{inspect(left)} + right: #{inspect(right)} + """) + end + end +end From 010017e998a0f4d6d18e10f27ad982eb8cc2af12 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 23 Aug 2023 19:46:11 -0300 Subject: [PATCH 039/185] candlex Nx.multiply --- candlex/lib/candlex/backend.ex | 22 ++++++++-------------- candlex/lib/candlex/native.ex | 6 ++++-- candlex/native/candlex/src/lib.rs | 3 ++- candlex/native/candlex/src/tensors.rs | 7 ++++++- candlex/test/candlex_test.exs | 14 ++++++++++++++ 5 files changed, 34 insertions(+), 18 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 4354bb8fc5..0a23a1316e 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -81,20 +81,14 @@ defmodule Candlex.Backend do # Binary ops - @impl true - def add(%T{} = out, %T{} = left, %T{} = right) do - from_nx(left) - |> Native.add(from_nx(right)) - |> unwrap!() - |> to_nx(out) - end - - @impl true - def equal(%T{} = out, %T{} = left, %T{} = right) do - from_nx(left) - |> Native.eq(from_nx(right)) - |> unwrap!() - |> to_nx(out) + for op <- [:add, :equal, :multiply] do + @impl true + def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do + from_nx(left) + |> Native.unquote(op)(from_nx(right)) + |> unwrap!() + |> to_nx(out) + end end # Unary ops diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index dbe9ef5e04..d1aad8a0df 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -6,9 +6,11 @@ defmodule Candlex.Native do # Rustler will override all the below stub functions with real NIFs def from_binary(_binary, _dtype, _shape), do: error() def to_binary(_tensor), do: error() - def add(_left, _right), do: error() - def eq(_left, _right), do: error() def all(_tensor), do: error() + for op <- [:add, :equal, :multiply] do + def unquote(op)(_left, _right), do: error() + end + defp error(), do: :erlang.nif_error(:nif_not_loaded) end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 76cb4d29e9..abb8afaaa5 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -15,7 +15,8 @@ rustler::init! { tensors::from_binary, tensors::to_binary, tensors::add, - tensors::eq, + tensors::multiply, + tensors::equal, tensors::all ], load = load diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 505d95d791..ad81b147ec 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -62,7 +62,12 @@ pub fn add(left: ExTensor, right: ExTensor) -> Result { } #[rustler::nif(schedule = "DirtyCpu")] -pub fn eq(left: ExTensor, right: ExTensor) -> Result { +pub fn multiply(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new(left.broadcast_mul(&right)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn equal(left: ExTensor, right: ExTensor) -> Result { Ok(ExTensor::new(left.eq(&right)?)) } diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 980775056b..ccb9f9cccb 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -26,6 +26,20 @@ defmodule CandlexTest do |> Nx.add(t([10, 20, 30])) |> assert_equal(t([11, 22, 33])) end + + test "multiply" do + t([1, 2]) + |> Nx.multiply(t([3, 4])) + |> assert_equal(t([3, 8])) + + t([[1], [2]]) + |> Nx.multiply(t([3, 4])) + |> assert_equal(t([[3, 4], [6, 8]])) + + t([1, 2]) + |> Nx.multiply(t([[3], [4]])) + |> assert_equal(t([[3, 6], [4, 8]])) + end end defp t(values, backend \\ Candlex.Backend) do From 1ef12449826f56be3d2113b79cf9a952a127ded6 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 24 Aug 2023 11:03:57 -0300 Subject: [PATCH 040/185] candlex tensors access --- candlex/lib/candlex/backend.ex | 49 +++++++++++++++++++++++++++ candlex/lib/candlex/native.ex | 2 ++ candlex/native/candlex/src/lib.rs | 4 ++- candlex/native/candlex/src/tensors.rs | 10 ++++++ candlex/test/candlex_test.exs | 7 ++++ 5 files changed, 71 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 0a23a1316e..425a3d181b 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -93,6 +93,38 @@ defmodule Candlex.Backend do # Unary ops + # Indexed + + @impl true + def slice( + %T{shape: _output_shape} = out, + %T{shape: input_shape} = t, + starts, + lengths, + _strides + ) do + t + |> from_nx() + |> narrow(starts, lengths, 0, input_shape) + # TODO: Support strides + # |> stride(output_shape, lengths, strides) + |> to_nx(out) + end + + # Shape + + @impl true + def squeeze(%T{} = out, %T{} = t, axes) do + # sort the axes desc so we don't have to decrease the axis numbers after each squeeze + for axis <- Enum.sort(axes, :desc), reduce: from_nx(t) do + ref -> + ref + |> Native.squeeze(axis) + |> unwrap!() + end + |> to_nx(out) + end + # Inspect @impl true @@ -118,6 +150,23 @@ defmodule Candlex.Backend do # :cpu # end + defp narrow(t, [start | starts], [length | lengths], axis, shape) do + dim = elem(shape, axis) + start = min(start, dim - length) + + if start == 0 and length == dim do + # Nothing to narrow at this step + t + else + t + |> Native.narrow(axis, start, length) + |> unwrap!() + end + |> narrow(starts, lengths, axis + 1, shape) + end + + defp narrow(t, [], [], _axis, _shape), do: t + ## Conversions @doc false diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index d1aad8a0df..836a83d56c 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -7,6 +7,8 @@ defmodule Candlex.Native do def from_binary(_binary, _dtype, _shape), do: error() def to_binary(_tensor), do: error() def all(_tensor), do: error() + def narrow(_tensor, _dim, _start, _length), do: error() + def squeeze(_tensor, _dim), do: error() for op <- [:add, :equal, :multiply] do def unquote(op)(_left, _right), do: error() diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index abb8afaaa5..e8e5a34662 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -17,7 +17,9 @@ rustler::init! { tensors::add, tensors::multiply, tensors::equal, - tensors::all + tensors::all, + tensors::narrow, + tensors::squeeze ], load = load } diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index ad81b147ec..9c6ed68cc6 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -71,6 +71,16 @@ pub fn equal(left: ExTensor, right: ExTensor) -> Result Ok(ExTensor::new(left.eq(&right)?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn narrow(t: ExTensor, dim: usize, start: usize, length: usize) -> Result { + Ok(ExTensor::new(t.narrow(dim, start, length)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn squeeze(t: ExTensor, dim: usize) -> Result { + Ok(ExTensor::new(t.squeeze(dim)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn all(ex_tensor: ExTensor) -> Result { let device = &Device::Cpu; diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index ccb9f9cccb..0adadbcbd0 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -40,6 +40,13 @@ defmodule CandlexTest do |> Nx.multiply(t([[3], [4]])) |> assert_equal(t([[3, 6], [4, 8]])) end + + test "access" do + tensor = t([[1, 2], [3, 4]]) + + assert_equal(tensor[0], t([1, 2])) + assert_equal(tensor[1], t([3, 4])) + end end defp t(values, backend \\ Candlex.Backend) do From 2bdd6722f32960905dd18399c111f49d15cca3b6 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 24 Aug 2023 14:23:34 -0300 Subject: [PATCH 041/185] candlex iota --- candlex/lib/candlex/backend.ex | 12 ++++++++++++ candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 3 ++- candlex/native/candlex/src/tensors.rs | 5 +++++ candlex/test/candlex_test.exs | 20 +++++++++++++++++++- 5 files changed, 39 insertions(+), 2 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 425a3d181b..01503e10ce 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -39,6 +39,18 @@ defmodule Candlex.Backend do |> to_nx(tensor) end + @impl true + def iota(%T{shape: {}} = out, nil, backend_options) do + constant(out, 0, backend_options) + end + + def iota(%T{shape: shape} = out, nil, _backend_options) do + # TODO: Support different types + Native.arange(0, Nx.size(shape), shape) + |> unwrap!() + |> to_nx(out) + end + # Backend @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 836a83d56c..d4eebb761d 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -9,6 +9,7 @@ defmodule Candlex.Native do def all(_tensor), do: error() def narrow(_tensor, _dim, _start, _length), do: error() def squeeze(_tensor, _dim), do: error() + def arange(_start, _end, _shape), do: error() for op <- [:add, :equal, :multiply] do def unquote(op)(_left, _right), do: error() diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index e8e5a34662..13bd3693fa 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -19,7 +19,8 @@ rustler::init! { tensors::equal, tensors::all, tensors::narrow, - tensors::squeeze + tensors::squeeze, + tensors::arange ], load = load } diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 9c6ed68cc6..df05ca42d7 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -81,6 +81,11 @@ pub fn squeeze(t: ExTensor, dim: usize) -> Result { Ok(ExTensor::new(t.squeeze(dim)?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn arange(start: i64, end: i64, shape: Term) -> Result { + Ok(ExTensor::new(Tensor::arange(start, end, &Device::Cpu)?.reshape(tuple_to_vec(shape).unwrap())?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn all(ex_tensor: ExTensor) -> Result { let device = &Device::Cpu; diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 0adadbcbd0..e3bf7e8ed0 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -27,6 +27,24 @@ defmodule CandlexTest do |> assert_equal(t([11, 22, 33])) end + test "iota" do + Nx.iota({}) + |> assert_equal(t(0)) + + Nx.iota({}, type: :f32) + |> assert_equal(t(0.0)) + + Nx.iota({5}) + |> assert_equal(t([0, 1, 2, 3, 4])) + + # TODO: Support iota with float + # Nx.iota({5}, type: :f32) + # |> assert_equal(t([0.0, 1.0, 2.0, 3.0, 4.0])) + + Nx.iota({2, 3}) + |> assert_equal(t([[0, 1, 2], [3, 4, 5]])) + end + test "multiply" do t([1, 2]) |> Nx.multiply(t([3, 4])) @@ -50,7 +68,7 @@ defmodule CandlexTest do end defp t(values, backend \\ Candlex.Backend) do - Nx.tensor(values, type: :u32, backend: backend) + Nx.tensor(values, backend: backend) end defp check(value, type) do From 4fd332825499c61f789fb5851f8d57bf9316c540 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 24 Aug 2023 14:50:28 -0300 Subject: [PATCH 042/185] named dimensions --- candlex/test/candlex_test.exs | 57 ++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index e3bf7e8ed0..13df788226 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -4,21 +4,28 @@ defmodule CandlexTest do describe "creation" do test "tensor" do - check(255, :u8) - check(100_002, :u32) - check(-101, :s64) - check(1.16, :f16) - check(1.32, :f32) - check([1, 2, 3], :f32) - check(-0.002, :f64) - check([1, 2], :u32) - check([[1, 2], [3, 4]], :u32) - check([[1, 2, 3, 4], [5, 6, 7, 8]], :u32) - check([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], :u32) - check([0, 255], :u8) - check([-0.5, 0.88], :f32) - check([-0.5, 0.88], :f64) - check(2.16, :bf16) + check(255, type: :u8) + check(100_002, type: :u32) + check(-101, type: :s64) + check(1.16, type: :f16) + check(1.32, type: :f32) + check([1, 2, 3], type: :f32) + check(-0.002, type: :f64) + check([1, 2], type: :u32) + check([[1, 2], [3, 4]], type: :u32) + check([[1, 2, 3, 4], [5, 6, 7, 8]], type: :u32) + check([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], type: :u32) + check([0, 255], type: :u8) + check([-0.5, 0.88], type: :f32) + check([-0.5, 0.88], type: :f64) + check(2.16, type: :bf16) + end + + test "named dimensions" do + check([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) + + t([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) + |> assert_equal(t([[1, 2, 3], [4, 5, 6]])) end test "addition" do @@ -67,19 +74,27 @@ defmodule CandlexTest do end end - defp t(values, backend \\ Candlex.Backend) do - Nx.tensor(values, backend: backend) + defp t(values, opts \\ []) do + opts = + [backend: Candlex.Backend] + |> Keyword.merge(opts) + + Nx.tensor(values, opts) end - defp check(value, type) do - tensor = Nx.tensor(value, type: type, backend: Candlex.Backend) + defp check(value, opts \\ []) do + tensor = t(value, opts) tensor |> IO.inspect() |> Nx.to_binary() |> IO.inspect() - assert Nx.backend_copy(tensor) == Nx.tensor(value, type: type, backend: Nx.BinaryBackend) - assert Nx.backend_transfer(tensor) == Nx.tensor(value, type: type, backend: Nx.BinaryBackend) + opts = + [backend: Nx.BinaryBackend] + |> Keyword.merge(opts) + + assert Nx.backend_copy(tensor) == t(value, opts) + assert Nx.backend_transfer(tensor) == t(value, opts) end end From 42508353c0ad86147481df4fadb5fc5124caf4d7 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 28 Aug 2023 13:50:50 -0300 Subject: [PATCH 043/185] tril/u --- candlex/lib/candlex/backend.ex | 60 +++++++++++++++++++++++++-- candlex/lib/candlex/native.ex | 5 ++- candlex/native/candlex/src/lib.rs | 7 +++- candlex/native/candlex/src/tensors.rs | 25 +++++++++++ candlex/test/candlex_test.exs | 17 ++++++++ 5 files changed, 109 insertions(+), 5 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 01503e10ce..fd1a4a72da 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -91,13 +91,40 @@ defmodule Candlex.Backend do |> to_nx(out) end + # Element-wise + + @impl true + def select(%T{shape: shape} = out, pred, on_true, on_false) do + on_true = + on_true + |> Nx.as_type(Nx.type(out)) + |> from_nx() + |> Native.broadcast_to(shape) + |> unwrap!() + + on_false = + on_false + |> Nx.as_type(Nx.type(out)) + |> from_nx() + |> Native.broadcast_to(shape) + |> unwrap!() + + pred + |> from_nx() + |> Native.where_cond(on_true, on_false) + |> unwrap!() + |> to_nx(out) + end + # Binary ops - for op <- [:add, :equal, :multiply] do + for op <- [:add, :equal, :greater_equal, :multiply, :subtract] do @impl true def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do - from_nx(left) - |> Native.unquote(op)(from_nx(right)) + {left, right} = maybe_broadcast_bin_args(out.shape, left, right) + + left + |> Native.unquote(op)(right) |> unwrap!() |> to_nx(out) end @@ -125,6 +152,14 @@ defmodule Candlex.Backend do # Shape + @impl true + def as_type(%T{type: type} = out, %T{} = t) do + from_nx(t) + |> Native.to_type(to_candle_dtype(type)) + |> unwrap!() + |> to_nx(out) + end + @impl true def squeeze(%T{} = out, %T{} = t, axes) do # sort the axes desc so we don't have to decrease the axis numbers after each squeeze @@ -179,6 +214,25 @@ defmodule Candlex.Backend do defp narrow(t, [], [], _axis, _shape), do: t + defp maybe_broadcast_bin_args(_out_shape, %{shape: {}} = l, r), do: {from_nx(l), from_nx(r)} + defp maybe_broadcast_bin_args(_out_shape, l, %{shape: {}} = r), do: {from_nx(l), from_nx(r)} + + defp maybe_broadcast_bin_args(out_shape, l, r) do + { + case l.shape do + ^out_shape -> + from_nx(l) + + _ -> + l |> from_nx() |> Native.broadcast_to(out_shape) |> unwrap!() + end, + case r.shape do + ^out_shape -> from_nx(r) + _ -> r |> from_nx() |> Native.broadcast_to(out_shape) |> unwrap!() + end + } + end + ## Conversions @doc false diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index d4eebb761d..f6e31c2cac 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -7,11 +7,14 @@ defmodule Candlex.Native do def from_binary(_binary, _dtype, _shape), do: error() def to_binary(_tensor), do: error() def all(_tensor), do: error() + def where_cond(_tensor, _on_true, _on_false), do: error() def narrow(_tensor, _dim, _start, _length), do: error() def squeeze(_tensor, _dim), do: error() def arange(_start, _end, _shape), do: error() + def broadcast_to(_tensor, _shape), do: error() + def to_type(_tensor, _dtype), do: error() - for op <- [:add, :equal, :multiply] do + for op <- [:add, :equal, :greater_equal, :multiply, :subtract] do def unquote(op)(_left, _right), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 13bd3693fa..b29a9ed46d 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -17,10 +17,15 @@ rustler::init! { tensors::add, tensors::multiply, tensors::equal, + tensors::greater_equal, + tensors::subtract, tensors::all, + tensors::where_cond, tensors::narrow, tensors::squeeze, - tensors::arange + tensors::arange, + tensors::to_type, + tensors::broadcast_to ], load = load } diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index df05ca42d7..1033dce410 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -71,6 +71,16 @@ pub fn equal(left: ExTensor, right: ExTensor) -> Result Ok(ExTensor::new(left.eq(&right)?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn greater_equal(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new(left.ge(&right)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn subtract(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new(left.broadcast_sub(&right)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn narrow(t: ExTensor, dim: usize, start: usize, length: usize) -> Result { Ok(ExTensor::new(t.narrow(dim, start, length)?)) @@ -103,6 +113,21 @@ pub fn all(ex_tensor: ExTensor) -> Result { Ok(ExTensor::new(Tensor::new(bool_scalar, device)?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn broadcast_to(t: ExTensor, shape: Term) -> Result { + Ok(ExTensor::new(t.broadcast_as(tuple_to_vec(shape).unwrap())?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn where_cond(t: ExTensor, on_true: ExTensor, on_false: ExTensor) -> Result { + Ok(ExTensor::new(t.where_cond(&on_true, &on_false)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn to_type(t: ExTensor, dtype_str: &str) -> Result { + Ok(ExTensor::new(t.to_dtype(DType::from_str(dtype_str).unwrap())?)) +} + fn tuple_to_vec(term: Term) -> Result, rustler::Error> { Ok( rustler::types::tuple::get_tuple(term)? diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 13df788226..b27ef4543b 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -28,6 +28,23 @@ defmodule CandlexTest do |> assert_equal(t([[1, 2, 3], [4, 5, 6]])) end + test "tensor tensor" do + t(t([1, 2, 3])) + |> assert_equal(t([1, 2, 3])) + end + + test "tril" do + t([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + |> Nx.tril() + |> assert_equal(t([[1, 0, 0], [4, 5, 0], [7, 8, 9]])) + end + + test "triu" do + t([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + |> Nx.triu() + |> assert_equal(t([[1, 2, 3], [0, 5, 6], [0, 0, 9]])) + end + test "addition" do t([1, 2, 3]) |> Nx.add(t([10, 20, 30])) From 2c4f4120f9a69f539e6b8fddee4c2065c87b0111 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 28 Aug 2023 15:08:48 -0300 Subject: [PATCH 044/185] candlex max/min --- candlex/lib/candlex/backend.ex | 31 ++++++++++++++++++++++++++- candlex/lib/candlex/native.ex | 4 +++- candlex/native/candlex/src/lib.rs | 6 +++++- candlex/native/candlex/src/tensors.rs | 21 ++++++++++++++++++ candlex/test/candlex_test.exs | 24 +++++++++++++++++++++ 5 files changed, 83 insertions(+), 3 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index fd1a4a72da..66c9f078c7 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -118,9 +118,10 @@ defmodule Candlex.Backend do # Binary ops - for op <- [:add, :equal, :greater_equal, :multiply, :subtract] do + for op <- [:add, :equal, :greater_equal, :max, :min, :multiply, :subtract] do @impl true def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do + {left, right} = maybe_upcast(left, right) {left, right} = maybe_broadcast_bin_args(out.shape, left, right) left @@ -150,8 +151,27 @@ defmodule Candlex.Backend do |> to_nx(out) end + # N-dim + + @impl true + def concatenate(%T{} = out, tensors, axis) do + tensors + |> Enum.map(&from_nx/1) + |> Native.concatenate(axis) + |> unwrap!() + |> to_nx(out) + end + # Shape + @impl true + def reshape(%T{shape: shape} = out, %T{} = t) do + from_nx(t) + |> Native.reshape(shape) + |> unwrap!() + |> to_nx(out) + end + @impl true def as_type(%T{type: type} = out, %T{} = t) do from_nx(t) @@ -214,6 +234,15 @@ defmodule Candlex.Backend do defp narrow(t, [], [], _axis, _shape), do: t + defp maybe_upcast(%T{type: t} = left, %T{type: t} = right) do + {left, right} + end + defp maybe_upcast(left, right) do + type = Nx.Type.merge(left.type, right.type) + + {Nx.as_type(left, type), Nx.as_type(right, type)} + end + defp maybe_broadcast_bin_args(_out_shape, %{shape: {}} = l, r), do: {from_nx(l), from_nx(r)} defp maybe_broadcast_bin_args(_out_shape, l, %{shape: {}} = r), do: {from_nx(l), from_nx(r)} diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index f6e31c2cac..ad90f3cc4e 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -12,9 +12,11 @@ defmodule Candlex.Native do def squeeze(_tensor, _dim), do: error() def arange(_start, _end, _shape), do: error() def broadcast_to(_tensor, _shape), do: error() + def reshape(_tensor, _shape), do: error() def to_type(_tensor, _dtype), do: error() + def concatenate(_tensors, _axis), do: error() - for op <- [:add, :equal, :greater_equal, :multiply, :subtract] do + for op <- [:add, :equal, :greater_equal, :max, :min, :multiply, :subtract] do def unquote(op)(_left, _right), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index b29a9ed46d..59ac3b94d3 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -15,6 +15,8 @@ rustler::init! { tensors::from_binary, tensors::to_binary, tensors::add, + tensors::max, + tensors::min, tensors::multiply, tensors::equal, tensors::greater_equal, @@ -25,7 +27,9 @@ rustler::init! { tensors::squeeze, tensors::arange, tensors::to_type, - tensors::broadcast_to + tensors::broadcast_to, + tensors::reshape, + tensors::concatenate ], load = load } diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 1033dce410..6dbf9fef78 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -61,6 +61,16 @@ pub fn add(left: ExTensor, right: ExTensor) -> Result { Ok(ExTensor::new(left.add(&right)?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn max(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new(left.broadcast_maximum(&right)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn min(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new(left.broadcast_minimum(&right)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn multiply(left: ExTensor, right: ExTensor) -> Result { Ok(ExTensor::new(left.broadcast_mul(&right)?)) @@ -118,6 +128,11 @@ pub fn broadcast_to(t: ExTensor, shape: Term) -> Result Ok(ExTensor::new(t.broadcast_as(tuple_to_vec(shape).unwrap())?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn reshape(t: ExTensor, shape: Term) -> Result { + Ok(ExTensor::new(t.reshape(tuple_to_vec(shape).unwrap())?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn where_cond(t: ExTensor, on_true: ExTensor, on_false: ExTensor) -> Result { Ok(ExTensor::new(t.where_cond(&on_true, &on_false)?)) @@ -128,6 +143,12 @@ pub fn to_type(t: ExTensor, dtype_str: &str) -> Result { Ok(ExTensor::new(t.to_dtype(DType::from_str(dtype_str).unwrap())?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn concatenate(ex_tensors: Vec, dim: usize) -> Result { + let tensors = ex_tensors.iter().map(|t| t.deref()).collect::>(); + Ok(ExTensor::new(Tensor::cat(&tensors[..], dim)?)) +} + fn tuple_to_vec(term: Term) -> Result, rustler::Error> { Ok( rustler::types::tuple::get_tuple(term)? diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index b27ef4543b..8bf2502b50 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -69,6 +69,30 @@ defmodule CandlexTest do |> assert_equal(t([[0, 1, 2], [3, 4, 5]])) end + test "max" do + Nx.max(1, 2) + |> assert_equal(t(2)) + + Nx.max(1, t([1.0, 2.0, 3.0], names: [:data])) + |> assert_equal(t([1.0, 2.0, 3.0])) + + t([[1], [2]], type: :f32, names: [:x, nil]) + |> Nx.max(t([[10, 20]], type: :f32, names: [nil, :y])) + |> assert_equal(t([[10.0, 20.0], [10.0, 20.0]])) + end + + test "min" do + Nx.min(1, 2) + |> assert_equal(t(1)) + + Nx.min(1, t([1.0, 2.0, 3.0], names: [:data])) + |> assert_equal(t([1.0, 1.0, 1.0])) + + t([[1], [2]], type: :f32, names: [:x, nil]) + |> Nx.min(t([[10, 20]], type: :f32, names: [nil, :y])) + |> assert_equal(t([[1.0, 1.0], [2.0, 2.0]])) + end + test "multiply" do t([1, 2]) |> Nx.multiply(t([3, 4])) From c11408a3eb4bc824792a74df471c2a2ff15fdf72 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 28 Aug 2023 17:15:19 -0300 Subject: [PATCH 045/185] candlex broadcast --- candlex/lib/candlex/backend.ex | 37 ++++++++++++++++++++++++++++++++++ candlex/test/candlex_test.exs | 9 +++++++++ 2 files changed, 46 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 66c9f078c7..dfcdb83888 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -164,6 +164,16 @@ defmodule Candlex.Backend do # Shape + @impl true + def broadcast(out, %T{} = t, shape, axes) do + t + |> maybe_reshape(shape, axes) + |> from_nx() + |> Native.broadcast_to(shape) + |> unwrap!() + |> to_nx(out) + end + @impl true def reshape(%T{shape: shape} = out, %T{} = t) do from_nx(t) @@ -234,6 +244,33 @@ defmodule Candlex.Backend do defp narrow(t, [], [], _axis, _shape), do: t + defp maybe_reshape(%T{shape: {}} = t, target_shape, _axes) do + shape = + 1 + |> List.duplicate(tuple_size(target_shape)) + |> List.to_tuple() + + t + |> Nx.reshape(shape) + end + + defp maybe_reshape(%T{shape: shape} = t, target_shape, axes) do + base_broadcast_shape = 1 |> List.duplicate(tuple_size(target_shape)) |> List.to_tuple() + + new_shape = + shape + |> Tuple.to_list() + |> Enum.zip(axes) + |> Enum.reduce(base_broadcast_shape, fn {dim_size, target_axis}, shape_acc -> + shape_acc + |> Tuple.delete_at(target_axis) + |> Tuple.insert_at(target_axis, dim_size) + end) + + t + |> Nx.reshape(new_shape) + end + defp maybe_upcast(%T{type: t} = left, %T{type: t} = right) do {left, right} end diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 8bf2502b50..8faccf1dbb 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -107,6 +107,15 @@ defmodule CandlexTest do |> assert_equal(t([[3, 6], [4, 8]])) end + test "broadcast" do + Nx.broadcast(1, {1, 2, 3}) + |> assert_equal(t([[[1, 1, 1], [1, 1, 1]]])) + + t([1, 2, 3]) + |> Nx.broadcast({3, 2}, axes: [0]) + |> assert_equal(t([[1, 1], [2, 2], [3, 3]])) + end + test "access" do tensor = t([[1, 2], [3, 4]]) From 2f8f9bdccc8a42ded4c2730318a151619d23a8fb Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 29 Aug 2023 11:32:00 -0300 Subject: [PATCH 046/185] improve from_nx --- candlex/lib/candlex/backend.ex | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index dfcdb83888..f26188732e 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -302,7 +302,12 @@ defmodule Candlex.Backend do ## Conversions @doc false - defp from_nx(%T{data: data}), do: data + defp from_nx(%T{data: %CB{} = data}), do: data + defp from_nx(%T{} = tensor) do + tensor + |> Nx.backend_transfer(CB) + |> from_nx() + end defp to_nx(%{resource: ref} = backend_tensor, %T{} = t) when is_reference(ref) do %{t | data: backend_tensor} From 4b5065b2556eea23298808c8873a79d1e51eaf08 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 29 Aug 2023 11:32:48 -0300 Subject: [PATCH 047/185] tests for concatenate --- candlex/test/candlex_test.exs | 37 +++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 8faccf1dbb..2329381a66 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -122,6 +122,43 @@ defmodule CandlexTest do assert_equal(tensor[0], t([1, 2])) assert_equal(tensor[1], t([3, 4])) end + + test "concatenate" do + [t([1, 2, 3])] + |> Nx.concatenate() + |> assert_equal(t([1, 2, 3])) + + [t([1, 2, 3]), t([4, 5, 6])] + |> Nx.concatenate() + |> assert_equal(t([1, 2, 3, 4, 5, 6])) + + t1 = Nx.iota({2, 2, 2}, names: [:x, :y, :z], type: :u8) + t2 = Nx.iota({1, 2, 2}, names: [:x, :y, :z], type: :u8) + t3 = Nx.iota({1, 2, 2}, names: [:x, :y, :z], type: :s64) + + [t1, t2, t3] + |> Nx.concatenate(axis: :x) + |> assert_equal(t( + [ + [ + [0, 1], + [2, 3] + ], + [ + [4, 5], + [6, 7] + ], + [ + [0, 1], + [2, 3] + ], + [ + [0, 1], + [2, 3] + ] + ] + )) + end end defp t(values, opts \\ []) do From d729dd6bb6997be218df239aab33d4540c9f9ffc Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 29 Aug 2023 11:37:55 -0300 Subject: [PATCH 048/185] cargo fmt --- candlex/native/candlex/src/error.rs | 2 +- candlex/native/candlex/src/tensors.rs | 141 ++++++++++++++------------ 2 files changed, 77 insertions(+), 66 deletions(-) diff --git a/candlex/native/candlex/src/error.rs b/candlex/native/candlex/src/error.rs index a4d84cfb8a..2dae6e6950 100644 --- a/candlex/native/candlex/src/error.rs +++ b/candlex/native/candlex/src/error.rs @@ -9,7 +9,7 @@ rustler::atoms! { #[derive(Error, Debug)] pub enum CandlexError { #[error("Candle Error: {0}")] - Candle(#[from] candle_core::Error) + Candle(#[from] candle_core::Error), } impl Encoder for CandlexError { diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 6dbf9fef78..19e3a06224 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -33,18 +33,14 @@ impl Deref for ExTensor { #[rustler::nif(schedule = "DirtyCpu")] pub fn from_binary(binary: Binary, dtype_str: &str, shape: Term) -> Result { - Ok( - ExTensor::new( - Tensor::from_raw_buffer( - binary.as_slice(), - // TODO: Handle DTypeParseError - DType::from_str(dtype_str).unwrap(), - // TODO: Handle rustler::Error - &tuple_to_vec(shape).unwrap(), - &Device::Cpu, - )? - ) - ) + Ok(ExTensor::new(Tensor::from_raw_buffer( + binary.as_slice(), + // TODO: Handle DTypeParseError + DType::from_str(dtype_str).unwrap(), + // TODO: Handle rustler::Error + &tuple_to_vec(shape).unwrap(), + &Device::Cpu, + )?)) } #[rustler::nif(schedule = "DirtyCpu")] @@ -92,7 +88,12 @@ pub fn subtract(left: ExTensor, right: ExTensor) -> Result Result { +pub fn narrow( + t: ExTensor, + dim: usize, + start: usize, + length: usize, +) -> Result { Ok(ExTensor::new(t.narrow(dim, start, length)?)) } @@ -103,7 +104,9 @@ pub fn squeeze(t: ExTensor, dim: usize) -> Result { #[rustler::nif(schedule = "DirtyCpu")] pub fn arange(start: i64, end: i64, shape: Term) -> Result { - Ok(ExTensor::new(Tensor::arange(start, end, &Device::Cpu)?.reshape(tuple_to_vec(shape).unwrap())?)) + Ok(ExTensor::new( + Tensor::arange(start, end, &Device::Cpu)?.reshape(tuple_to_vec(shape).unwrap())?, + )) } #[rustler::nif(schedule = "DirtyCpu")] @@ -114,11 +117,14 @@ pub fn all(ex_tensor: ExTensor) -> Result { let on_true = Tensor::ones(dims, DType::U8, device)?; let on_false = Tensor::zeros(dims, DType::U8, device)?; - let bool_scalar = - match t.where_cond(&on_true, &on_false)?.min(0)?.to_scalar::()? { - 0 => 0u8, - _ => 1u8 - }; + let bool_scalar = match t + .where_cond(&on_true, &on_false)? + .min(0)? + .to_scalar::()? + { + 0 => 0u8, + _ => 1u8, + }; Ok(ExTensor::new(Tensor::new(bool_scalar, device)?)) } @@ -134,68 +140,73 @@ pub fn reshape(t: ExTensor, shape: Term) -> Result { } #[rustler::nif(schedule = "DirtyCpu")] -pub fn where_cond(t: ExTensor, on_true: ExTensor, on_false: ExTensor) -> Result { +pub fn where_cond( + t: ExTensor, + on_true: ExTensor, + on_false: ExTensor, +) -> Result { Ok(ExTensor::new(t.where_cond(&on_true, &on_false)?)) } #[rustler::nif(schedule = "DirtyCpu")] pub fn to_type(t: ExTensor, dtype_str: &str) -> Result { - Ok(ExTensor::new(t.to_dtype(DType::from_str(dtype_str).unwrap())?)) + Ok(ExTensor::new( + t.to_dtype(DType::from_str(dtype_str).unwrap())?, + )) } #[rustler::nif(schedule = "DirtyCpu")] pub fn concatenate(ex_tensors: Vec, dim: usize) -> Result { - let tensors = ex_tensors.iter().map(|t| t.deref()).collect::>(); + let tensors = ex_tensors + .iter() + .map(|t| t.deref()) + .collect::>(); Ok(ExTensor::new(Tensor::cat(&tensors[..], dim)?)) } fn tuple_to_vec(term: Term) -> Result, rustler::Error> { - Ok( - rustler::types::tuple::get_tuple(term)? + Ok(rustler::types::tuple::get_tuple(term)? .iter() .map(|elem| elem.decode()) - .collect::>()? - ) + .collect::>()?) } fn tensor_bytes(tensor: Tensor) -> Result, CandlexError> { - Ok( - match tensor.dtype() { - DType::I64 => tensor - .to_vec1::()? - .iter() - .flat_map(|val| val.to_ne_bytes()) - .collect(), - DType::U8 => tensor - .to_vec1::()? - .iter() - .flat_map(|val| val.to_ne_bytes()) - .collect(), - DType::U32 => tensor - .to_vec1::()? - .iter() - .flat_map(|val| val.to_ne_bytes()) - .collect(), - DType::F16 => tensor - .to_vec1::()? - .iter() - .flat_map(|val| val.to_ne_bytes()) - .collect(), - DType::F32 => tensor - .to_vec1::()? - .iter() - .flat_map(|val| val.to_ne_bytes()) - .collect(), - DType::F64 => tensor - .to_vec1::()? - .iter() - .flat_map(|val| val.to_ne_bytes()) - .collect(), - DType::BF16 => tensor - .to_vec1::()? - .iter() - .flat_map(|val| val.to_ne_bytes()) - .collect(), - } - ) + Ok(match tensor.dtype() { + DType::I64 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::U8 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::U32 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::F16 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::F32 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::F64 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::BF16 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + }) } From bc5774f97567c48ca9e7e721f9334173de82c2f1 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 29 Aug 2023 11:38:55 -0300 Subject: [PATCH 049/185] support iota with float --- candlex/lib/candlex/backend.ex | 6 +-- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/tensors.rs | 11 ++++- candlex/test/candlex_test.exs | 58 +++++++++++++-------------- 4 files changed, 42 insertions(+), 35 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index f26188732e..37b477d3e3 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -44,9 +44,8 @@ defmodule Candlex.Backend do constant(out, 0, backend_options) end - def iota(%T{shape: shape} = out, nil, _backend_options) do - # TODO: Support different types - Native.arange(0, Nx.size(shape), shape) + def iota(%T{shape: shape, type: type} = out, nil, _backend_options) do + Native.arange(0, Nx.size(shape), to_candle_dtype(type), shape) |> unwrap!() |> to_nx(out) end @@ -155,6 +154,7 @@ defmodule Candlex.Backend do @impl true def concatenate(%T{} = out, tensors, axis) do + # TODO: Support concatenating tensors of different type tensors |> Enum.map(&from_nx/1) |> Native.concatenate(axis) diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index ad90f3cc4e..4dd19682dc 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -10,7 +10,7 @@ defmodule Candlex.Native do def where_cond(_tensor, _on_true, _on_false), do: error() def narrow(_tensor, _dim, _start, _length), do: error() def squeeze(_tensor, _dim), do: error() - def arange(_start, _end, _shape), do: error() + def arange(_start, _end, _dtype, _shape), do: error() def broadcast_to(_tensor, _shape), do: error() def reshape(_tensor, _shape), do: error() def to_type(_tensor, _dtype), do: error() diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 19e3a06224..80185766a9 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -103,9 +103,16 @@ pub fn squeeze(t: ExTensor, dim: usize) -> Result { } #[rustler::nif(schedule = "DirtyCpu")] -pub fn arange(start: i64, end: i64, shape: Term) -> Result { +pub fn arange( + start: i64, + end: i64, + dtype_str: &str, + shape: Term, +) -> Result { Ok(ExTensor::new( - Tensor::arange(start, end, &Device::Cpu)?.reshape(tuple_to_vec(shape).unwrap())?, + Tensor::arange(start, end, &Device::Cpu)? + .to_dtype(DType::from_str(dtype_str).unwrap())? + .reshape(tuple_to_vec(shape).unwrap())?, )) } diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 2329381a66..2b4e1030ec 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -61,9 +61,8 @@ defmodule CandlexTest do Nx.iota({5}) |> assert_equal(t([0, 1, 2, 3, 4])) - # TODO: Support iota with float - # Nx.iota({5}, type: :f32) - # |> assert_equal(t([0.0, 1.0, 2.0, 3.0, 4.0])) + Nx.iota({5}, type: :f32) + |> assert_equal(t([0.0, 1.0, 2.0, 3.0, 4.0])) Nx.iota({2, 3}) |> assert_equal(t([[0, 1, 2], [3, 4, 5]])) @@ -132,32 +131,33 @@ defmodule CandlexTest do |> Nx.concatenate() |> assert_equal(t([1, 2, 3, 4, 5, 6])) - t1 = Nx.iota({2, 2, 2}, names: [:x, :y, :z], type: :u8) - t2 = Nx.iota({1, 2, 2}, names: [:x, :y, :z], type: :u8) - t3 = Nx.iota({1, 2, 2}, names: [:x, :y, :z], type: :s64) - - [t1, t2, t3] - |> Nx.concatenate(axis: :x) - |> assert_equal(t( - [ - [ - [0, 1], - [2, 3] - ], - [ - [4, 5], - [6, 7] - ], - [ - [0, 1], - [2, 3] - ], - [ - [0, 1], - [2, 3] - ] - ] - )) + # TODO: Support concatenating tensors of different type + # t1 = Nx.iota({2, 2, 2}, names: [:x, :y, :z], type: :f32) + # t2 = Nx.iota({1, 2, 2}, names: [:x, :y, :z], type: :u8) + # t3 = Nx.iota({1, 2, 2}, names: [:x, :y, :z], type: :s64) + + # [t1, t2, t3] + # |> Nx.concatenate(axis: :x) + # |> assert_equal(t( + # [ + # [ + # [0.0, 1.0], + # [2.0, 3.0] + # ], + # [ + # [4.0, 5.0], + # [6.0, 7.0] + # ], + # [ + # [0.0, 1.0], + # [2.0, 3.0] + # ], + # [ + # [0.0, 1.0], + # [2.0, 3.0] + # ] + # ] + # )) end end From 90da90bd7944f9343ede5f7f7e175f28d2e38e38 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 29 Aug 2023 12:54:17 -0300 Subject: [PATCH 050/185] support concatenate of different type --- candlex/lib/candlex/backend.ex | 17 ++++++++++- candlex/test/candlex_test.exs | 53 +++++++++++++++++----------------- 2 files changed, 42 insertions(+), 28 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 37b477d3e3..cdc3174766 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -154,8 +154,8 @@ defmodule Candlex.Backend do @impl true def concatenate(%T{} = out, tensors, axis) do - # TODO: Support concatenating tensors of different type tensors + |> maybe_upcast() |> Enum.map(&from_nx/1) |> Native.concatenate(axis) |> unwrap!() @@ -279,6 +279,21 @@ defmodule Candlex.Backend do {Nx.as_type(left, type), Nx.as_type(right, type)} end + defp maybe_upcast([first | _] = tensors) do + type = + tensors + |> Enum.reduce( + first.type, + fn tensor, type -> + Nx.Type.merge(type, tensor.type) + end + ) + + tensors + |> Enum.map(fn tensor -> + Nx.as_type(tensor, type) + end) + end defp maybe_broadcast_bin_args(_out_shape, %{shape: {}} = l, r), do: {from_nx(l), from_nx(r)} defp maybe_broadcast_bin_args(_out_shape, l, %{shape: {}} = r), do: {from_nx(l), from_nx(r)} diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 2b4e1030ec..745b685120 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -131,33 +131,32 @@ defmodule CandlexTest do |> Nx.concatenate() |> assert_equal(t([1, 2, 3, 4, 5, 6])) - # TODO: Support concatenating tensors of different type - # t1 = Nx.iota({2, 2, 2}, names: [:x, :y, :z], type: :f32) - # t2 = Nx.iota({1, 2, 2}, names: [:x, :y, :z], type: :u8) - # t3 = Nx.iota({1, 2, 2}, names: [:x, :y, :z], type: :s64) - - # [t1, t2, t3] - # |> Nx.concatenate(axis: :x) - # |> assert_equal(t( - # [ - # [ - # [0.0, 1.0], - # [2.0, 3.0] - # ], - # [ - # [4.0, 5.0], - # [6.0, 7.0] - # ], - # [ - # [0.0, 1.0], - # [2.0, 3.0] - # ], - # [ - # [0.0, 1.0], - # [2.0, 3.0] - # ] - # ] - # )) + t1 = Nx.iota({2, 2, 2}, names: [:x, :y, :z], type: :f32) + t2 = Nx.iota({1, 2, 2}, names: [:x, :y, :z], type: :u8) + t3 = Nx.iota({1, 2, 2}, names: [:x, :y, :z], type: :s64) + + [t1, t2, t3] + |> Nx.concatenate(axis: :x) + |> assert_equal(t( + [ + [ + [0.0, 1.0], + [2.0, 3.0] + ], + [ + [4.0, 5.0], + [6.0, 7.0] + ], + [ + [0.0, 1.0], + [2.0, 3.0] + ], + [ + [0.0, 1.0], + [2.0, 3.0] + ] + ] + )) end end From de86886c8d36e1b3b1b968b2b3525b32fc2dbcac Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 29 Aug 2023 16:17:32 -0300 Subject: [PATCH 051/185] broadcast add --- candlex/native/candlex/src/tensors.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 80185766a9..3a62777473 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -54,7 +54,7 @@ pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Result #[rustler::nif(schedule = "DirtyCpu")] pub fn add(left: ExTensor, right: ExTensor) -> Result { - Ok(ExTensor::new(left.add(&right)?)) + Ok(ExTensor::new(left.broadcast_add(&right)?)) } #[rustler::nif(schedule = "DirtyCpu")] From 021b3c6e0b094460177e21610693fbda8a80f435 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 29 Aug 2023 16:36:28 -0300 Subject: [PATCH 052/185] non native bitwise ops --- candlex/lib/candlex/backend.ex | 21 +++++++++++++ candlex/test/candlex_test.exs | 56 ++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index cdc3174766..650dee2227 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -92,6 +92,19 @@ defmodule Candlex.Backend do # Element-wise + for op <- [:bitwise_and, :bitwise_or, :bitwise_xor] do + @impl true + def unquote(op)(out, l, r) do + # TODO: Implement in candle + out + |> Nx.BinaryBackend.unquote(op)( + backend_transfer(l, Nx.BinaryBackend, []), + backend_transfer(r, Nx.BinaryBackend, []) + ) + |> backend_transfer(__MODULE__, []) + end + end + @impl true def select(%T{shape: shape} = out, pred, on_true, on_false) do on_true = @@ -132,6 +145,14 @@ defmodule Candlex.Backend do # Unary ops + @impl true + def bitwise_not(%T{} = out, tensor) do + # TODO: Implement in candle + out + |> Nx.BinaryBackend.bitwise_not(backend_transfer(tensor, Nx.BinaryBackend, [])) + |> backend_transfer(__MODULE__, []) + end + # Indexed @impl true diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 745b685120..fd5fa1dceb 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -158,6 +158,62 @@ defmodule CandlexTest do ] )) end + + test "bitwise_and" do + Nx.bitwise_and(1, 0) + |> assert_equal(t(0)) + + t([0, 1, 2]) + |> Nx.bitwise_and(1) + |> assert_equal(t([0, 1, 0])) + + t([0, -1, -2]) + |> Nx.bitwise_and(-1) + |> assert_equal(t([0, -1, -2])) + + t([0, 0, 1, 1]) + |> Nx.bitwise_and(t([0, 1, 0, 1])) + |> assert_equal(t([0, 0, 0, 1])) + end + + test "bitwise_not" do + Nx.bitwise_not(1) + |> assert_equal(t(-2)) + end + + test "bitwise_or" do + Nx.bitwise_or(1, 0) + |> assert_equal(t(1)) + + t([0, 1, 2]) + |> Nx.bitwise_or(1) + |> assert_equal(t([1, 1, 3])) + + t([0, -1, -2]) + |> Nx.bitwise_or(-1) + |> assert_equal(t([-1, -1, -1])) + + t([0, 0, 1, 1]) + |> Nx.bitwise_or(t([0, 1, 0, 1])) + |> assert_equal(t([0, 1, 1, 1])) + end + + test "bitwise_xor" do + Nx.bitwise_xor(1, 0) + |> assert_equal(t(1)) + + t([1, 2, 3]) + |> Nx.bitwise_xor(2) + |> assert_equal(t([3, 0, 1])) + + t([-1, -2, -3]) + |> Nx.bitwise_xor(2) + |> assert_equal(t([-3, -4, -1])) + + t([0, 0, 1, 1]) + |> Nx.bitwise_xor(t([0, 1, 0, 1])) + |> assert_equal(t([0, 1, 1, 0])) + end end defp t(values, opts \\ []) do From 5918e85b9b9bb5ee9012db31c5dafebc19b246ed Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 29 Aug 2023 16:45:07 -0300 Subject: [PATCH 053/185] less --- candlex/lib/candlex/backend.ex | 5 +---- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 5 +++++ candlex/test/candlex_test.exs | 12 ++++++++++++ 5 files changed, 20 insertions(+), 5 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 650dee2227..b35719de9d 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -130,7 +130,7 @@ defmodule Candlex.Backend do # Binary ops - for op <- [:add, :equal, :greater_equal, :max, :min, :multiply, :subtract] do + for op <- [:add, :equal, :greater_equal, :less, :max, :min, :multiply, :subtract] do @impl true def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do {left, right} = maybe_upcast(left, right) @@ -316,9 +316,6 @@ defmodule Candlex.Backend do end) end - defp maybe_broadcast_bin_args(_out_shape, %{shape: {}} = l, r), do: {from_nx(l), from_nx(r)} - defp maybe_broadcast_bin_args(_out_shape, l, %{shape: {}} = r), do: {from_nx(l), from_nx(r)} - defp maybe_broadcast_bin_args(out_shape, l, r) do { case l.shape do diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 4dd19682dc..0e40b438ce 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -16,7 +16,7 @@ defmodule Candlex.Native do def to_type(_tensor, _dtype), do: error() def concatenate(_tensors, _axis), do: error() - for op <- [:add, :equal, :greater_equal, :max, :min, :multiply, :subtract] do + for op <- [:add, :equal, :greater_equal, :less, :max, :min, :multiply, :subtract] do def unquote(op)(_left, _right), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 59ac3b94d3..04df2dbcb9 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -20,6 +20,7 @@ rustler::init! { tensors::multiply, tensors::equal, tensors::greater_equal, + tensors::less, tensors::subtract, tensors::all, tensors::where_cond, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 3a62777473..c69ca56a6b 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -82,6 +82,11 @@ pub fn greater_equal(left: ExTensor, right: ExTensor) -> Result Result { + Ok(ExTensor::new(left.lt(&right)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn subtract(left: ExTensor, right: ExTensor) -> Result { Ok(ExTensor::new(left.broadcast_sub(&right)?)) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index fd5fa1dceb..4d1dd6a024 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -214,6 +214,18 @@ defmodule CandlexTest do |> Nx.bitwise_xor(t([0, 1, 0, 1])) |> assert_equal(t([0, 1, 1, 0])) end + + test "less" do + Nx.less(1, 2) + |> assert_equal(t(1)) + + Nx.less(1, t([1, 2, 3])) + |> assert_equal(t([0, 1, 1])) + + t([[1.0, 2.0, 3.0], [4.0, 2.0, 1.0]]) + |> Nx.less(t([1, 2, 3])) + |> assert_equal(t([[0, 0, 0], [0, 0, 1]])) + end end defp t(values, opts \\ []) do From 19c51a70efe4fb3f2ca5810c2586441e4dce41ee Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 29 Aug 2023 16:52:07 -0300 Subject: [PATCH 054/185] less_equal --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 5 +++++ candlex/test/candlex_test.exs | 12 ++++++++++++ 5 files changed, 20 insertions(+), 2 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index b35719de9d..1ea33b9d3e 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -130,7 +130,7 @@ defmodule Candlex.Backend do # Binary ops - for op <- [:add, :equal, :greater_equal, :less, :max, :min, :multiply, :subtract] do + for op <- [:add, :equal, :greater_equal, :less, :less_equal, :max, :min, :multiply, :subtract] do @impl true def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do {left, right} = maybe_upcast(left, right) diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 0e40b438ce..5ea58e30d6 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -16,7 +16,7 @@ defmodule Candlex.Native do def to_type(_tensor, _dtype), do: error() def concatenate(_tensors, _axis), do: error() - for op <- [:add, :equal, :greater_equal, :less, :max, :min, :multiply, :subtract] do + for op <- [:add, :equal, :greater_equal, :less, :less_equal, :max, :min, :multiply, :subtract] do def unquote(op)(_left, _right), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 04df2dbcb9..7d82978448 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -21,6 +21,7 @@ rustler::init! { tensors::equal, tensors::greater_equal, tensors::less, + tensors::less_equal, tensors::subtract, tensors::all, tensors::where_cond, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index c69ca56a6b..4412da1251 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -87,6 +87,11 @@ pub fn less(left: ExTensor, right: ExTensor) -> Result { Ok(ExTensor::new(left.lt(&right)?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn less_equal(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new(left.le(&right)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn subtract(left: ExTensor, right: ExTensor) -> Result { Ok(ExTensor::new(left.broadcast_sub(&right)?)) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 4d1dd6a024..b3abe93c54 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -226,6 +226,18 @@ defmodule CandlexTest do |> Nx.less(t([1, 2, 3])) |> assert_equal(t([[0, 0, 0], [0, 0, 1]])) end + + test "less_equal" do + Nx.less_equal(1, 2) + |> assert_equal(t(1)) + + Nx.less_equal(1, t([1, 2, 3])) + |> assert_equal(t([1, 1, 1])) + + t([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + |> Nx.less_equal(t([1, 2, 3])) + |> assert_equal(t([[1, 1, 1], [0, 0, 0]])) + end end defp t(values, opts \\ []) do From ad2398df7ed4f0e9b8e9a2398557be148c70dbd7 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 29 Aug 2023 16:55:26 -0300 Subject: [PATCH 055/185] non native left_shift --- candlex/lib/candlex/backend.ex | 2 +- candlex/test/candlex_test.exs | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 1ea33b9d3e..b39e066208 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -92,7 +92,7 @@ defmodule Candlex.Backend do # Element-wise - for op <- [:bitwise_and, :bitwise_or, :bitwise_xor] do + for op <- [:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift] do @impl true def unquote(op)(out, l, r) do # TODO: Implement in candle diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index b3abe93c54..f5259feb44 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -238,6 +238,19 @@ defmodule CandlexTest do |> Nx.less_equal(t([1, 2, 3])) |> assert_equal(t([[1, 1, 1], [0, 0, 0]])) end + + test "left_shift" do + Nx.left_shift(1, 0) + |> assert_equal(t(1)) + + t([1, 2, 3]) + |> Nx.left_shift(2) + |> assert_equal(t([4, 8, 12])) + + t([1, 1, -1, -1]) + |> Nx.left_shift(t([1, 2, 3, 4])) + |> assert_equal(t([2, 4, -8, -16])) + end end defp t(values, opts \\ []) do From f34799b45e6533fb51ae04fc71e665445376d089 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 29 Aug 2023 16:57:46 -0300 Subject: [PATCH 056/185] non native right_shift --- candlex/lib/candlex/backend.ex | 2 +- candlex/test/candlex_test.exs | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index b39e066208..7edef5c0a0 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -92,7 +92,7 @@ defmodule Candlex.Backend do # Element-wise - for op <- [:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift] do + for op <- [:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift, :right_shift] do @impl true def unquote(op)(out, l, r) do # TODO: Implement in candle diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index f5259feb44..68e141b9ba 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -251,6 +251,19 @@ defmodule CandlexTest do |> Nx.left_shift(t([1, 2, 3, 4])) |> assert_equal(t([2, 4, -8, -16])) end + + test "right_shift" do + Nx.right_shift(1, 0) + |> assert_equal(t(1)) + + t([2, 4, 8]) + |> Nx.right_shift(2) + |> assert_equal(t([0, 1, 2])) + + t([16, 32, -64, -128]) + |> Nx.right_shift(t([1, 2, 3, 4])) + |> assert_equal(t([8, 8, -8, -8])) + end end defp t(values, opts \\ []) do From bc0fa6c8d731b4bd7baf2dcfadfd7f2a741493c9 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 29 Aug 2023 17:09:32 -0300 Subject: [PATCH 057/185] bitcast --- candlex/lib/candlex/backend.ex | 24 ++++++++++++++++-------- candlex/test/candlex_test.exs | 6 ++++++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 7edef5c0a0..9465276358 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -203,14 +203,6 @@ defmodule Candlex.Backend do |> to_nx(out) end - @impl true - def as_type(%T{type: type} = out, %T{} = t) do - from_nx(t) - |> Native.to_type(to_candle_dtype(type)) - |> unwrap!() - |> to_nx(out) - end - @impl true def squeeze(%T{} = out, %T{} = t, axes) do # sort the axes desc so we don't have to decrease the axis numbers after each squeeze @@ -223,6 +215,22 @@ defmodule Candlex.Backend do |> to_nx(out) end + # Type + + @impl true + def as_type(%T{type: type} = out, %T{} = t) do + from_nx(t) + |> Native.to_type(to_candle_dtype(type)) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def bitcast(out, tensor) do + out + |> from_binary(to_binary(tensor), []) + end + # Inspect @impl true diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 68e141b9ba..d196348ee0 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -264,6 +264,12 @@ defmodule CandlexTest do |> Nx.right_shift(t([1, 2, 3, 4])) |> assert_equal(t([8, 8, -8, -8])) end + + test "bitcast" do + t([0, 0, 0], type: :s64) + |> Nx.bitcast(:f64) + |> assert_equal(t([0.0, 0.0, 0.0])) + end end defp t(values, opts \\ []) do From 8c1d8e8fbacbbe880cb7a366ebc915e40b227705 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 29 Aug 2023 17:36:57 -0300 Subject: [PATCH 058/185] non native erf_inv --- candlex/lib/candlex/backend.ex | 14 ++++++++------ candlex/test/candlex_test.exs | 9 +++++++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 9465276358..e8da54d0a3 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -145,12 +145,14 @@ defmodule Candlex.Backend do # Unary ops - @impl true - def bitwise_not(%T{} = out, tensor) do - # TODO: Implement in candle - out - |> Nx.BinaryBackend.bitwise_not(backend_transfer(tensor, Nx.BinaryBackend, [])) - |> backend_transfer(__MODULE__, []) + for op <- [:bitwise_not, :erf_inv] do + @impl true + def unquote(op)(out, t) do + # TODO: Implement in candle + out + |> Nx.BinaryBackend.unquote(op)(backend_transfer(t, Nx.BinaryBackend, [])) + |> backend_transfer(__MODULE__, []) + end end # Indexed diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index d196348ee0..4464e31696 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -270,6 +270,15 @@ defmodule CandlexTest do |> Nx.bitcast(:f64) |> assert_equal(t([0.0, 0.0, 0.0])) end + + test "erf_inv" do + Nx.erf_inv(0.10000000149011612) + |> assert_equal(t(0.08885598927736282)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.erf_inv() + |> assert_equal(t([0.08885598927736282, 0.4769362807273865, 1.163087010383606])) + end end defp t(values, opts \\ []) do From 22174733efd4213aeb0cf7b3038a3a6b70a7ea9d Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 29 Aug 2023 19:01:10 -0300 Subject: [PATCH 059/185] candlex lin reg example --- candlex/examples/linear_regression.exs | 77 ++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 candlex/examples/linear_regression.exs diff --git a/candlex/examples/linear_regression.exs b/candlex/examples/linear_regression.exs new file mode 100644 index 0000000000..28a5884bd1 --- /dev/null +++ b/candlex/examples/linear_regression.exs @@ -0,0 +1,77 @@ +defmodule LinearRegression do + import Nx.Defn + + # y = mx + b + defn init_random_params do + key = Nx.Random.key(42) + {m, new_key} = Nx.Random.normal(key, 0.0, 0.1, shape: {1, 1}) + {b, _new_key} = Nx.Random.normal(new_key, 0.0, 0.1, shape: {1}) + {m, b} + end + + defn predict({m, b}, inp) do + Nx.dot(inp, m) + b + end + + # MSE Loss + defn loss({m, b}, inp, tar) do + preds = predict({m, b}, inp) + Nx.mean(Nx.pow(tar - preds, 2)) + end + + defn update({m, b} = params, inp, tar, step) do + {grad_m, grad_b} = grad(params, &loss(&1, inp, tar)) + {m - grad_m * step, b - grad_b * step} + end + + def train(params, epochs, lin_fn) do + data = + Stream.repeatedly(fn -> for _ <- 1..32, do: :rand.uniform() * 10 end) + |> Stream.map(fn x -> Enum.zip(x, Enum.map(x, lin_fn)) end) + + for _ <- 1..epochs, reduce: params do + acc -> + data + |> Enum.take(200) + |> Enum.reduce( + acc, + fn batch, cur_params -> + {inp, tar} = Enum.unzip(batch) + x = Nx.reshape(Nx.tensor(inp), {32, 1}) + y = Nx.reshape(Nx.tensor(tar), {32, 1}) + update(cur_params, x, y, 0.001) + end + ) + end + end +end + +Nx.default_backend(Candlex.Backend) +# Nx.default_backend(Nx.BinaryBackend) + +params = LinearRegression.init_random_params() +m = :rand.normal(0.0, 10.0) +b = :rand.normal(0.0, 5.0) +IO.puts("Target m: #{m} Target b: #{b}\n") + +lin_fn = fn x -> m * x + b end +epochs = 100 + +# These will be very close to the above coefficients +{time, {trained_m, trained_b}} = :timer.tc(LinearRegression, :train, [params, epochs, lin_fn]) + +# trained_m = +# trained_m +# |> Nx.squeeze() +# |> Nx.backend_transfer() +# |> Nx.to_number() + +# trained_b = +# trained_b +# |> Nx.squeeze() +# |> Nx.backend_transfer() +# |> Nx.to_number() + +# IO.puts("Trained in #{time / 1_000_000} sec.") +# IO.puts("Trained m: #{trained_m} Trained b: #{trained_b}\n") +# IO.puts("Accuracy m: #{m - trained_m} Accuracy b: #{b - trained_b}") From 54b8db346fd56ebcc692c0ba97285d869e353bca Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 30 Aug 2023 09:55:35 -0300 Subject: [PATCH 060/185] use rust macro to define repetitive binary nifs --- candlex/native/candlex/src/tensors.rs | 64 ++++++++------------------- 1 file changed, 19 insertions(+), 45 deletions(-) diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 4412da1251..45e220b6d7 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -52,51 +52,6 @@ pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Result Ok(binary.into()) } -#[rustler::nif(schedule = "DirtyCpu")] -pub fn add(left: ExTensor, right: ExTensor) -> Result { - Ok(ExTensor::new(left.broadcast_add(&right)?)) -} - -#[rustler::nif(schedule = "DirtyCpu")] -pub fn max(left: ExTensor, right: ExTensor) -> Result { - Ok(ExTensor::new(left.broadcast_maximum(&right)?)) -} - -#[rustler::nif(schedule = "DirtyCpu")] -pub fn min(left: ExTensor, right: ExTensor) -> Result { - Ok(ExTensor::new(left.broadcast_minimum(&right)?)) -} - -#[rustler::nif(schedule = "DirtyCpu")] -pub fn multiply(left: ExTensor, right: ExTensor) -> Result { - Ok(ExTensor::new(left.broadcast_mul(&right)?)) -} - -#[rustler::nif(schedule = "DirtyCpu")] -pub fn equal(left: ExTensor, right: ExTensor) -> Result { - Ok(ExTensor::new(left.eq(&right)?)) -} - -#[rustler::nif(schedule = "DirtyCpu")] -pub fn greater_equal(left: ExTensor, right: ExTensor) -> Result { - Ok(ExTensor::new(left.ge(&right)?)) -} - -#[rustler::nif(schedule = "DirtyCpu")] -pub fn less(left: ExTensor, right: ExTensor) -> Result { - Ok(ExTensor::new(left.lt(&right)?)) -} - -#[rustler::nif(schedule = "DirtyCpu")] -pub fn less_equal(left: ExTensor, right: ExTensor) -> Result { - Ok(ExTensor::new(left.le(&right)?)) -} - -#[rustler::nif(schedule = "DirtyCpu")] -pub fn subtract(left: ExTensor, right: ExTensor) -> Result { - Ok(ExTensor::new(left.broadcast_sub(&right)?)) -} - #[rustler::nif(schedule = "DirtyCpu")] pub fn narrow( t: ExTensor, @@ -181,6 +136,25 @@ pub fn concatenate(ex_tensors: Vec, dim: usize) -> Result { + #[rustler::nif(schedule = "DirtyCpu")] + pub fn $nif_name(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new(left.$native_fn_name(&right)?)) + } + } +} + +binary_nif!(add, broadcast_add); +binary_nif!(subtract, broadcast_sub); +binary_nif!(multiply, broadcast_mul); +binary_nif!(max, broadcast_maximum); +binary_nif!(min, broadcast_minimum); +binary_nif!(equal, eq); +binary_nif!(greater_equal, ge); +binary_nif!(less, lt); +binary_nif!(less_equal, le); + fn tuple_to_vec(term: Term) -> Result, rustler::Error> { Ok(rustler::types::tuple::get_tuple(term)? .iter() From ba52ca1546fb13b1471cee90c542ea5948ecff65 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 30 Aug 2023 15:39:58 -0300 Subject: [PATCH 061/185] eye --- candlex/lib/candlex/backend.ex | 8 +++++ candlex/test/candlex_test.exs | 60 ++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index e8da54d0a3..02af4417f2 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -50,6 +50,14 @@ defmodule Candlex.Backend do |> to_nx(out) end + @impl true + def eye(%T{shape: shape, type: type} = out, backend_options) do + iota = Nx.iota(shape, backend_options) + + Nx.equal(Nx.tril(iota), Nx.triu(iota)) + |> Nx.as_type(type) + end + # Backend @impl true diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 4464e31696..f304869cf5 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -279,6 +279,66 @@ defmodule CandlexTest do |> Nx.erf_inv() |> assert_equal(t([0.08885598927736282, 0.4769362807273865, 1.163087010383606])) end + + test "eye" do + Nx.eye(2) + |> assert_equal(t([[1, 0], [0, 1]])) + + Nx.eye(3, type: :f32) + |> assert_equal(t( + [ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0] + ] + )) + + Nx.eye({1, 2}) + |> assert_equal(t([[1, 0]])) + + Nx.eye({2, 4, 3}) + |> assert_equal(t( + [ + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, 0, 0] + ], + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, 0, 0] + ] + ] + )) + + # assert_equal doesn't yet work with vectorized axes + # Nx.eye({3}, vectorized_axes: [x: 1, y: 2]) + # |> assert_equal(t( + # [ + # [ + # [1, 0, 0], + # [1, 0, 0] + # ] + # ] + # )) + + # Nx.eye({2, 3}, vectorized_axes: [x: 2]) + # |> assert_equal(t( + # [ + # [ + # [1, 0, 0], + # [0, 1, 0] + # ], + # [ + # [1, 0, 0], + # [0, 1, 0] + # ] + # ] + # )) + end end defp t(values, opts \\ []) do From 19f3215bd838bcf5010baa4546375f8374efc51c Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 30 Aug 2023 17:44:21 -0300 Subject: [PATCH 062/185] negate --- candlex/lib/candlex/backend.ex | 13 +++++++++++++ candlex/lib/candlex/native.ex | 4 ++++ candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 11 +++++++++++ candlex/test/candlex_test.exs | 13 +++++++++++++ 5 files changed, 42 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 02af4417f2..35323392f7 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -153,6 +153,19 @@ defmodule Candlex.Backend do # Unary ops + unary_ops = [:negate] + + for op <- unary_ops do + @impl true + def unquote(op)(%T{} = out, %T{} = tensor) do + tensor + |> from_nx() + |> Native.unquote(op)() + |> unwrap!() + |> to_nx(out) + end + end + for op <- [:bitwise_not, :erf_inv] do @impl true def unquote(op)(out, t) do diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 5ea58e30d6..3099ea62d8 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -16,6 +16,10 @@ defmodule Candlex.Native do def to_type(_tensor, _dtype), do: error() def concatenate(_tensors, _axis), do: error() + for op <- [:negate] do + def unquote(op)(_tensor), do: error() + end + for op <- [:add, :equal, :greater_equal, :less, :less_equal, :max, :min, :multiply, :subtract] do def unquote(op)(_left, _right), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 7d82978448..7f8d48af6f 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -24,6 +24,7 @@ rustler::init! { tensors::less_equal, tensors::subtract, tensors::all, + tensors::negate, tensors::where_cond, tensors::narrow, tensors::squeeze, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 45e220b6d7..143ca8653a 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -136,6 +136,15 @@ pub fn concatenate(ex_tensors: Vec, dim: usize) -> Result { + #[rustler::nif(schedule = "DirtyCpu")] + pub fn $nif_name(ex_tensor: ExTensor) -> Result { + Ok(ExTensor::new(ex_tensor.$native_fn_name()?)) + } + } +} + macro_rules! binary_nif { ($nif_name:ident, $native_fn_name:ident) => { #[rustler::nif(schedule = "DirtyCpu")] @@ -145,6 +154,8 @@ macro_rules! binary_nif { } } +unary_nif!(negate, neg); + binary_nif!(add, broadcast_add); binary_nif!(subtract, broadcast_sub); binary_nif!(multiply, broadcast_mul); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index f304869cf5..3c0b710d76 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -339,6 +339,19 @@ defmodule CandlexTest do # ] # )) end + + test "negate" do + # TODO: candle doesn't support unary functions for integers yet + # Nx.negate(1) + # |> assert_equal(t(-1)) + + Nx.negate(1.0) + |> assert_equal(t(-1.0)) + + t([1.0, 2.0, -3.0], type: :f32) + |> Nx.negate() + |> assert_equal(t([-1.0, -2.0, 3.0])) + end end defp t(values, opts \\ []) do From 446fc8e7004f1b6e9d5e603e549801bc8ac868ed Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 30 Aug 2023 17:47:20 -0300 Subject: [PATCH 063/185] wip dot --- candlex/lib/candlex/backend.ex | 18 +++++++++++++++++ candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/lib.rs | 3 ++- candlex/native/candlex/src/tensors.rs | 1 + candlex/test/candlex_test.exs | 29 +++++++++++++++++++++++++++ 5 files changed, 51 insertions(+), 2 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 35323392f7..9d7e03212e 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -206,6 +206,24 @@ defmodule Candlex.Backend do |> to_nx(out) end + @impl true + def dot( + %T{type: _out_type} = out, + %T{shape: left_shape, type: _left_type} = left, + _left_axes, + [] = _left_batched_axes, + %T{shape: right_shape, type: _right_type} = right, + _right_axes, + [] = _right_batched_axes + ) when tuple_size(left_shape) == 2 and tuple_size(right_shape) == 2 do + Native.matmul( + from_nx(left), + from_nx(right) + ) + |> unwrap!() + |> to_nx(out) + end + # Shape @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 3099ea62d8..4357fe44a4 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -20,7 +20,7 @@ defmodule Candlex.Native do def unquote(op)(_tensor), do: error() end - for op <- [:add, :equal, :greater_equal, :less, :less_equal, :max, :min, :multiply, :subtract] do + for op <- [:add, :equal, :greater_equal, :less, :less_equal, :matmul, :max, :min, :multiply, :subtract] do def unquote(op)(_left, _right), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 7f8d48af6f..eaeb978a13 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -32,7 +32,8 @@ rustler::init! { tensors::to_type, tensors::broadcast_to, tensors::reshape, - tensors::concatenate + tensors::concatenate, + tensors::matmul ], load = load } diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 143ca8653a..964d579d65 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -165,6 +165,7 @@ binary_nif!(equal, eq); binary_nif!(greater_equal, ge); binary_nif!(less, lt); binary_nif!(less_equal, le); +binary_nif!(matmul, broadcast_matmul); fn tuple_to_vec(term: Term) -> Result, rustler::Error> { Ok(rustler::types::tuple::get_tuple(term)? diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 3c0b710d76..f2d89c837b 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -340,6 +340,35 @@ defmodule CandlexTest do # )) end + test "dot" do + Nx.dot(5, 5) + |> assert_equal(t(25)) + + Nx.dot(-2.0, 5.0) + |> assert_equal(t(-10.0)) + + Nx.dot(2, 2.0) + |> assert_equal(t(4.0)) + + # TODO: + # t([1, 2, 3]) + # |> Nx.dot(t([4, 5, 6])) + # |> assert_equal(t(32)) + + # t([1.0, 2.0, 3.0]) + # |> Nx.dot(t([1, 2, 3])) + # |> assert_equal(t(14.0)) + + # t([[1, 2, 3], [4, 5, 6]]) + # |> Nx.dot(t([[7, 8], [9, 10], [11, 12]])) + # |> assert_equal(t( + # [ + # [58, 64], + # [139, 154] + # ] + # )) + end + test "negate" do # TODO: candle doesn't support unary functions for integers yet # Nx.negate(1) From 9f7c26d8115d97322cd2624436c41e44b3104d46 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 09:59:51 -0300 Subject: [PATCH 064/185] removes unnecessary elixir broadcasting for a few binary ops already using candle broadcast_* fn --- candlex/lib/candlex/backend.ex | 14 +++++++++++++- candlex/test/candlex_test.exs | 7 +++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 9d7e03212e..a9e7dd52cf 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -138,7 +138,19 @@ defmodule Candlex.Backend do # Binary ops - for op <- [:add, :equal, :greater_equal, :less, :less_equal, :max, :min, :multiply, :subtract] do + for op <- [:add, :max, :min, :multiply, :subtract] do + @impl true + def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do + {left, right} = maybe_upcast(left, right) + + from_nx(left) + |> Native.unquote(op)(from_nx(right)) + |> unwrap!() + |> to_nx(out) + end + end + + for op <- [:equal, :greater_equal, :less, :less_equal] do @impl true def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do {left, right} = maybe_upcast(left, right) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index f2d89c837b..d7163dc5c8 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -49,6 +49,13 @@ defmodule CandlexTest do t([1, 2, 3]) |> Nx.add(t([10, 20, 30])) |> assert_equal(t([11, 22, 33])) + + Nx.add(1, 2.2) + |> assert_equal(t(3.2)) + + t([1, 2, 3]) + |> Nx.add(1.0) + |> assert_equal(t([2.0, 3.0, 4.0])) end test "iota" do From 3694942adce1511640d6048b8dc55f10a041ac9b Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 10:32:25 -0300 Subject: [PATCH 065/185] prefer direct native call if possible --- candlex/lib/candlex/backend.ex | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index a9e7dd52cf..703e3543af 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -114,18 +114,20 @@ defmodule Candlex.Backend do end @impl true - def select(%T{shape: shape} = out, pred, on_true, on_false) do + def select(%T{shape: shape, type: type} = out, pred, on_true, on_false) do on_true = on_true - |> Nx.as_type(Nx.type(out)) |> from_nx() + |> Native.to_type(to_candle_dtype(type)) + |> unwrap!() |> Native.broadcast_to(shape) |> unwrap!() on_false = on_false - |> Nx.as_type(Nx.type(out)) |> from_nx() + |> Native.to_type(to_candle_dtype(type)) + |> unwrap!() |> Native.broadcast_to(shape) |> unwrap!() From b98c617d071836d044b513d66843ad50190f9dfb Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 10:59:47 -0300 Subject: [PATCH 066/185] prefer to flag as unsupported for now instaed of backend transfer --- candlex/lib/candlex/backend.ex | 17 +--- candlex/test/candlex_test.exs | 164 ++++++++++++++++----------------- 2 files changed, 86 insertions(+), 95 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 703e3543af..025bdab505 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -103,13 +103,7 @@ defmodule Candlex.Backend do for op <- [:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift, :right_shift] do @impl true def unquote(op)(out, l, r) do - # TODO: Implement in candle - out - |> Nx.BinaryBackend.unquote(op)( - backend_transfer(l, Nx.BinaryBackend, []), - backend_transfer(r, Nx.BinaryBackend, []) - ) - |> backend_transfer(__MODULE__, []) + unsupported_op(unquote(op)) end end @@ -183,10 +177,7 @@ defmodule Candlex.Backend do for op <- [:bitwise_not, :erf_inv] do @impl true def unquote(op)(out, t) do - # TODO: Implement in candle - out - |> Nx.BinaryBackend.unquote(op)(backend_transfer(t, Nx.BinaryBackend, [])) - |> backend_transfer(__MODULE__, []) + unsupported_op(unquote(op)) end end @@ -428,8 +419,8 @@ defmodule Candlex.Backend do raise("Unsupported candle dtype") end - defp unsupported_op do - raise("Unsupported candle op") + defp unsupported_op(op_name) do + raise("Unsupported candlex operation '#{op_name}'") end defp unwrap!({:ok, result}), do: result diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index d7163dc5c8..05c480865e 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -166,61 +166,61 @@ defmodule CandlexTest do )) end - test "bitwise_and" do - Nx.bitwise_and(1, 0) - |> assert_equal(t(0)) - - t([0, 1, 2]) - |> Nx.bitwise_and(1) - |> assert_equal(t([0, 1, 0])) - - t([0, -1, -2]) - |> Nx.bitwise_and(-1) - |> assert_equal(t([0, -1, -2])) - - t([0, 0, 1, 1]) - |> Nx.bitwise_and(t([0, 1, 0, 1])) - |> assert_equal(t([0, 0, 0, 1])) - end - - test "bitwise_not" do - Nx.bitwise_not(1) - |> assert_equal(t(-2)) - end - - test "bitwise_or" do - Nx.bitwise_or(1, 0) - |> assert_equal(t(1)) - - t([0, 1, 2]) - |> Nx.bitwise_or(1) - |> assert_equal(t([1, 1, 3])) - - t([0, -1, -2]) - |> Nx.bitwise_or(-1) - |> assert_equal(t([-1, -1, -1])) - - t([0, 0, 1, 1]) - |> Nx.bitwise_or(t([0, 1, 0, 1])) - |> assert_equal(t([0, 1, 1, 1])) - end - - test "bitwise_xor" do - Nx.bitwise_xor(1, 0) - |> assert_equal(t(1)) - - t([1, 2, 3]) - |> Nx.bitwise_xor(2) - |> assert_equal(t([3, 0, 1])) - - t([-1, -2, -3]) - |> Nx.bitwise_xor(2) - |> assert_equal(t([-3, -4, -1])) - - t([0, 0, 1, 1]) - |> Nx.bitwise_xor(t([0, 1, 0, 1])) - |> assert_equal(t([0, 1, 1, 0])) - end + # test "bitwise_and" do + # Nx.bitwise_and(1, 0) + # |> assert_equal(t(0)) + + # t([0, 1, 2]) + # |> Nx.bitwise_and(1) + # |> assert_equal(t([0, 1, 0])) + + # t([0, -1, -2]) + # |> Nx.bitwise_and(-1) + # |> assert_equal(t([0, -1, -2])) + + # t([0, 0, 1, 1]) + # |> Nx.bitwise_and(t([0, 1, 0, 1])) + # |> assert_equal(t([0, 0, 0, 1])) + # end + + # test "bitwise_not" do + # Nx.bitwise_not(1) + # |> assert_equal(t(-2)) + # end + + # test "bitwise_or" do + # Nx.bitwise_or(1, 0) + # |> assert_equal(t(1)) + + # t([0, 1, 2]) + # |> Nx.bitwise_or(1) + # |> assert_equal(t([1, 1, 3])) + + # t([0, -1, -2]) + # |> Nx.bitwise_or(-1) + # |> assert_equal(t([-1, -1, -1])) + + # t([0, 0, 1, 1]) + # |> Nx.bitwise_or(t([0, 1, 0, 1])) + # |> assert_equal(t([0, 1, 1, 1])) + # end + + # test "bitwise_xor" do + # Nx.bitwise_xor(1, 0) + # |> assert_equal(t(1)) + + # t([1, 2, 3]) + # |> Nx.bitwise_xor(2) + # |> assert_equal(t([3, 0, 1])) + + # t([-1, -2, -3]) + # |> Nx.bitwise_xor(2) + # |> assert_equal(t([-3, -4, -1])) + + # t([0, 0, 1, 1]) + # |> Nx.bitwise_xor(t([0, 1, 0, 1])) + # |> assert_equal(t([0, 1, 1, 0])) + # end test "less" do Nx.less(1, 2) @@ -246,31 +246,31 @@ defmodule CandlexTest do |> assert_equal(t([[1, 1, 1], [0, 0, 0]])) end - test "left_shift" do - Nx.left_shift(1, 0) - |> assert_equal(t(1)) + # test "left_shift" do + # Nx.left_shift(1, 0) + # |> assert_equal(t(1)) - t([1, 2, 3]) - |> Nx.left_shift(2) - |> assert_equal(t([4, 8, 12])) + # t([1, 2, 3]) + # |> Nx.left_shift(2) + # |> assert_equal(t([4, 8, 12])) - t([1, 1, -1, -1]) - |> Nx.left_shift(t([1, 2, 3, 4])) - |> assert_equal(t([2, 4, -8, -16])) - end + # t([1, 1, -1, -1]) + # |> Nx.left_shift(t([1, 2, 3, 4])) + # |> assert_equal(t([2, 4, -8, -16])) + # end - test "right_shift" do - Nx.right_shift(1, 0) - |> assert_equal(t(1)) + # test "right_shift" do + # Nx.right_shift(1, 0) + # |> assert_equal(t(1)) - t([2, 4, 8]) - |> Nx.right_shift(2) - |> assert_equal(t([0, 1, 2])) + # t([2, 4, 8]) + # |> Nx.right_shift(2) + # |> assert_equal(t([0, 1, 2])) - t([16, 32, -64, -128]) - |> Nx.right_shift(t([1, 2, 3, 4])) - |> assert_equal(t([8, 8, -8, -8])) - end + # t([16, 32, -64, -128]) + # |> Nx.right_shift(t([1, 2, 3, 4])) + # |> assert_equal(t([8, 8, -8, -8])) + # end test "bitcast" do t([0, 0, 0], type: :s64) @@ -278,14 +278,14 @@ defmodule CandlexTest do |> assert_equal(t([0.0, 0.0, 0.0])) end - test "erf_inv" do - Nx.erf_inv(0.10000000149011612) - |> assert_equal(t(0.08885598927736282)) + # test "erf_inv" do + # Nx.erf_inv(0.10000000149011612) + # |> assert_equal(t(0.08885598927736282)) - t([0.10000000149011612, 0.5, 0.8999999761581421]) - |> Nx.erf_inv() - |> assert_equal(t([0.08885598927736282, 0.4769362807273865, 1.163087010383606])) - end + # t([0.10000000149011612, 0.5, 0.8999999761581421]) + # |> Nx.erf_inv() + # |> assert_equal(t([0.08885598927736282, 0.4769362807273865, 1.163087010383606])) + # end test "eye" do Nx.eye(2) From a2285d7b62190373ca145cf3c23228aca2e555da Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 11:00:00 -0300 Subject: [PATCH 067/185] cargo fmt --- candlex/native/candlex/src/tensors.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 964d579d65..b34b93ddda 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -142,7 +142,7 @@ macro_rules! unary_nif { pub fn $nif_name(ex_tensor: ExTensor) -> Result { Ok(ExTensor::new(ex_tensor.$native_fn_name()?)) } - } + }; } macro_rules! binary_nif { @@ -151,7 +151,7 @@ macro_rules! binary_nif { pub fn $nif_name(left: ExTensor, right: ExTensor) -> Result { Ok(ExTensor::new(left.$native_fn_name(&right)?)) } - } + }; } unary_nif!(negate, neg); From 2de0a2345a7e3c2d50679f7157cc7287cf24eb73 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 11:00:30 -0300 Subject: [PATCH 068/185] mix format --- candlex/lib/candlex/backend.ex | 12 ++++++++---- candlex/lib/candlex/native.ex | 13 ++++++++++++- candlex/test/candlex_test.exs | 24 ++++++++++++------------ 3 files changed, 32 insertions(+), 17 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 025bdab505..6f0352a609 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -220,7 +220,8 @@ defmodule Candlex.Backend do %T{shape: right_shape, type: _right_type} = right, _right_axes, [] = _right_batched_axes - ) when tuple_size(left_shape) == 2 and tuple_size(right_shape) == 2 do + ) + when tuple_size(left_shape) == 2 and tuple_size(right_shape) == 2 do Native.matmul( from_nx(left), from_nx(right) @@ -254,9 +255,9 @@ defmodule Candlex.Backend do # sort the axes desc so we don't have to decrease the axis numbers after each squeeze for axis <- Enum.sort(axes, :desc), reduce: from_nx(t) do ref -> - ref - |> Native.squeeze(axis) - |> unwrap!() + ref + |> Native.squeeze(axis) + |> unwrap!() end |> to_nx(out) end @@ -349,11 +350,13 @@ defmodule Candlex.Backend do defp maybe_upcast(%T{type: t} = left, %T{type: t} = right) do {left, right} end + defp maybe_upcast(left, right) do type = Nx.Type.merge(left.type, right.type) {Nx.as_type(left, type), Nx.as_type(right, type)} end + defp maybe_upcast([first | _] = tensors) do type = tensors @@ -390,6 +393,7 @@ defmodule Candlex.Backend do @doc false defp from_nx(%T{data: %CB{} = data}), do: data + defp from_nx(%T{} = tensor) do tensor |> Nx.backend_transfer(CB) diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 4357fe44a4..548556313a 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -20,7 +20,18 @@ defmodule Candlex.Native do def unquote(op)(_tensor), do: error() end - for op <- [:add, :equal, :greater_equal, :less, :less_equal, :matmul, :max, :min, :multiply, :subtract] do + for op <- [ + :add, + :equal, + :greater_equal, + :less, + :less_equal, + :matmul, + :max, + :min, + :multiply, + :subtract + ] do def unquote(op)(_left, _right), do: error() end diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 05c480865e..03f376b486 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -144,8 +144,8 @@ defmodule CandlexTest do [t1, t2, t3] |> Nx.concatenate(axis: :x) - |> assert_equal(t( - [ + |> assert_equal( + t([ [ [0.0, 1.0], [2.0, 3.0] @@ -162,8 +162,8 @@ defmodule CandlexTest do [0.0, 1.0], [2.0, 3.0] ] - ] - )) + ]) + ) end # test "bitwise_and" do @@ -292,20 +292,20 @@ defmodule CandlexTest do |> assert_equal(t([[1, 0], [0, 1]])) Nx.eye(3, type: :f32) - |> assert_equal(t( - [ + |> assert_equal( + t([ [1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0] - ] - )) + ]) + ) Nx.eye({1, 2}) |> assert_equal(t([[1, 0]])) Nx.eye({2, 4, 3}) - |> assert_equal(t( - [ + |> assert_equal( + t([ [ [1, 0, 0], [0, 1, 0], @@ -318,8 +318,8 @@ defmodule CandlexTest do [0, 0, 1], [0, 0, 0] ] - ] - )) + ]) + ) # assert_equal doesn't yet work with vectorized axes # Nx.eye({3}, vectorized_axes: [x: 1, y: 2]) From 86027f5e77e25b102bafad33e24c1d7906a5a5ea Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 15:55:25 -0300 Subject: [PATCH 069/185] fix typo --- candlex/native/candlex/src/error.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candlex/native/candlex/src/error.rs b/candlex/native/candlex/src/error.rs index 2dae6e6950..f8d2eb9f76 100644 --- a/candlex/native/candlex/src/error.rs +++ b/candlex/native/candlex/src/error.rs @@ -1,7 +1,7 @@ use rustler::{Encoder, Env, Term}; use thiserror::Error; -// Defines the atoms for each value of ExplorerError. +// Defines the atoms for each value of CandlexError. rustler::atoms! { candle, } From 90980195f4a61a4dda9951df0d23953b78cbd8d0 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 16:13:02 -0300 Subject: [PATCH 070/185] handle device backend option but still only support CPU --- candlex/lib/candlex/backend.ex | 40 +++++++++++++++++++-------- candlex/lib/candlex/native.ex | 6 ++-- candlex/native/candlex/src/devices.rs | 4 +++ candlex/native/candlex/src/error.rs | 2 ++ candlex/native/candlex/src/lib.rs | 11 +++++++- candlex/native/candlex/src/tensors.rs | 22 +++++++++++---- candlex/test/candlex_test.exs | 5 ++++ 7 files changed, 71 insertions(+), 19 deletions(-) create mode 100644 candlex/native/candlex/src/devices.rs diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 6f0352a609..189732d6e2 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -11,6 +11,9 @@ defmodule Candlex.Backend do alias Candlex.Backend, as: CB alias Candlex.Native + @device_cuda :cuda + @device_cpu :cpu + @impl true def init(opts) do if opts != [] do @@ -30,11 +33,9 @@ defmodule Candlex.Backend do end @impl true - def from_binary(%T{shape: shape, type: type} = tensor, binary, _backend_options) do - # TODO: Don't ignore backend options - + def from_binary(%T{shape: shape, type: type} = tensor, binary, backend_options) do binary - |> Native.from_binary(to_candle_dtype(type), shape) + |> Native.from_binary(to_candle_dtype(type), shape, device_option(backend_options)) |> unwrap!() |> to_nx(tensor) end @@ -44,8 +45,8 @@ defmodule Candlex.Backend do constant(out, 0, backend_options) end - def iota(%T{shape: shape, type: type} = out, nil, _backend_options) do - Native.arange(0, Nx.size(shape), to_candle_dtype(type), shape) + def iota(%T{shape: shape, type: type} = out, nil, backend_options) do + Native.arange(0, Nx.size(shape), to_candle_dtype(type), shape, device_option(backend_options)) |> unwrap!() |> to_nx(out) end @@ -298,11 +299,6 @@ defmodule Candlex.Backend do ]) end - # defp device_option(_backend_options) do - # # TODO: Support CUDA - # :cpu - # end - defp narrow(t, [start | starts], [length | lengths], axis, shape) do dim = elem(shape, axis) start = min(start, dim - length) @@ -419,6 +415,28 @@ defmodule Candlex.Backend do defp to_candle_dtype({:c, 64}), do: unsupported_dtype() defp to_candle_dtype({:c, 128}), do: unsupported_dtype() + defp device_option(nil) do + default_device() + end + + defp device_option(backend_options) do + backend_options[:device] || default_device() + end + + defp default_device do + # TODO: Support CUDA + # if cuda_available?() do + # @device_cuda + # else + # @device_cpu + # end + @device_cpu + end + + defp cuda_available? do + Native.is_cuda_available() + end + defp unsupported_dtype do raise("Unsupported candle dtype") end diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 548556313a..30022c47d1 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -4,13 +4,13 @@ defmodule Candlex.Native do use Rustler, otp_app: :candlex, crate: "candlex" # Rustler will override all the below stub functions with real NIFs - def from_binary(_binary, _dtype, _shape), do: error() + def from_binary(_binary, _dtype, _shape, _device), do: error() def to_binary(_tensor), do: error() def all(_tensor), do: error() def where_cond(_tensor, _on_true, _on_false), do: error() def narrow(_tensor, _dim, _start, _length), do: error() def squeeze(_tensor, _dim), do: error() - def arange(_start, _end, _dtype, _shape), do: error() + def arange(_start, _end, _dtype, _shape, _device), do: error() def broadcast_to(_tensor, _shape), do: error() def reshape(_tensor, _shape), do: error() def to_type(_tensor, _dtype), do: error() @@ -35,5 +35,7 @@ defmodule Candlex.Native do def unquote(op)(_left, _right), do: error() end + def is_cuda_available(), do: error() + defp error(), do: :erlang.nif_error(:nif_not_loaded) end diff --git a/candlex/native/candlex/src/devices.rs b/candlex/native/candlex/src/devices.rs new file mode 100644 index 0000000000..43a615a328 --- /dev/null +++ b/candlex/native/candlex/src/devices.rs @@ -0,0 +1,4 @@ +#[rustler::nif(schedule = "DirtyCpu")] +pub fn is_cuda_available() -> bool { + candle_core::utils::cuda_is_available() +} diff --git a/candlex/native/candlex/src/error.rs b/candlex/native/candlex/src/error.rs index f8d2eb9f76..6cdccceb38 100644 --- a/candlex/native/candlex/src/error.rs +++ b/candlex/native/candlex/src/error.rs @@ -10,6 +10,8 @@ rustler::atoms! { pub enum CandlexError { #[error("Candle Error: {0}")] Candle(#[from] candle_core::Error), + #[error("Generic Error: {0}")] + Other(String), } impl Encoder for CandlexError { diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index eaeb978a13..6cbd5c9889 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -1,3 +1,11 @@ +mod atoms { + rustler::atoms! { + cpu, + cuda + } +} + +mod devices; mod error; mod tensors; @@ -33,7 +41,8 @@ rustler::init! { tensors::broadcast_to, tensors::reshape, tensors::concatenate, - tensors::matmul + tensors::matmul, + devices::is_cuda_available ], load = load } diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index b34b93ddda..778a76d822 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,7 +1,8 @@ +use crate::atoms; use crate::error::CandlexError; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; -use rustler::{Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; +use rustler::{Atom, Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; use std::ops::Deref; use std::result::Result; use std::str::FromStr; @@ -32,14 +33,14 @@ impl Deref for ExTensor { } #[rustler::nif(schedule = "DirtyCpu")] -pub fn from_binary(binary: Binary, dtype_str: &str, shape: Term) -> Result { +pub fn from_binary(binary: Binary, dtype_str: &str, shape: Term, device: Atom) -> Result { Ok(ExTensor::new(Tensor::from_raw_buffer( binary.as_slice(), // TODO: Handle DTypeParseError DType::from_str(dtype_str).unwrap(), // TODO: Handle rustler::Error &tuple_to_vec(shape).unwrap(), - &Device::Cpu, + &device_from_atom(device)? )?)) } @@ -73,9 +74,10 @@ pub fn arange( end: i64, dtype_str: &str, shape: Term, + device: Atom ) -> Result { Ok(ExTensor::new( - Tensor::arange(start, end, &Device::Cpu)? + Tensor::arange(start, end, &device_from_atom(device)?)? .to_dtype(DType::from_str(dtype_str).unwrap())? .reshape(tuple_to_vec(shape).unwrap())?, )) @@ -83,7 +85,7 @@ pub fn arange( #[rustler::nif(schedule = "DirtyCpu")] pub fn all(ex_tensor: ExTensor) -> Result { - let device = &Device::Cpu; + let device = ex_tensor.device(); let t = ex_tensor.flatten_all()?; let dims = t.shape().dims(); let on_true = Tensor::ones(dims, DType::U8, device)?; @@ -174,6 +176,16 @@ fn tuple_to_vec(term: Term) -> Result, rustler::Error> { .collect::>()?) } +fn device_from_atom(atom: Atom) -> Result { + if atom == atoms::cpu() { + Ok(Device::Cpu) + // } else if atom == atoms::cuda() { + // Ok(Device::new_cuda(0)?) + } else { + Err(CandlexError::Other(format!("unsupported device {:?}", atom))) + } +} + fn tensor_bytes(tensor: Tensor) -> Result, CandlexError> { Ok(match tensor.dtype() { DType::I64 => tensor diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 03f376b486..f70dab8417 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -21,6 +21,11 @@ defmodule CandlexTest do check(2.16, type: :bf16) end + # test "gpu" do + # t([1, 2, 3], backend: {Candlex.Backend, device: :cuda}) + # |> assert_equal(t([1, 2, 3])) + # end + test "named dimensions" do check([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) From 612ee71eaa6283f4533b58b288e524787871dca7 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 17:11:21 -0300 Subject: [PATCH 071/185] dont inspect in test --- candlex/test/candlex_test.exs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index f70dab8417..b8bfb2ea99 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -407,9 +407,9 @@ defmodule CandlexTest do tensor = t(value, opts) tensor - |> IO.inspect() + # |> IO.inspect() |> Nx.to_binary() - |> IO.inspect() + # |> IO.inspect() opts = [backend: Nx.BinaryBackend] From 9542d8c5ae8348987da3f304c09755635fd4be6f Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 17:18:52 -0300 Subject: [PATCH 072/185] sin --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 4 ++++ candlex/test/candlex_test.exs | 9 +++++++++ 5 files changed, 16 insertions(+), 2 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 189732d6e2..6cdcd720fb 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -162,7 +162,7 @@ defmodule Candlex.Backend do # Unary ops - unary_ops = [:negate] + unary_ops = [:negate, :sin] for op <- unary_ops do @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 30022c47d1..f9f3b87ecc 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -16,7 +16,7 @@ defmodule Candlex.Native do def to_type(_tensor, _dtype), do: error() def concatenate(_tensors, _axis), do: error() - for op <- [:negate] do + for op <- [:negate, :sin] do def unquote(op)(_tensor), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 6cbd5c9889..d1b86021e8 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -42,6 +42,7 @@ rustler::init! { tensors::reshape, tensors::concatenate, tensors::matmul, + tensors::sin, devices::is_cuda_available ], load = load diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 778a76d822..3fcf0ef1a0 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -145,6 +145,9 @@ macro_rules! unary_nif { Ok(ExTensor::new(ex_tensor.$native_fn_name()?)) } }; + ($nif_name:ident) => { + unary_nif!($nif_name, $nif_name); + }; } macro_rules! binary_nif { @@ -157,6 +160,7 @@ macro_rules! binary_nif { } unary_nif!(negate, neg); +unary_nif!(sin); binary_nif!(add, broadcast_add); binary_nif!(subtract, broadcast_sub); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index b8bfb2ea99..57835451f9 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -393,6 +393,15 @@ defmodule CandlexTest do |> Nx.negate() |> assert_equal(t([-1.0, -2.0, 3.0])) end + + test "sin" do + Nx.sin(1.0) + |> assert_equal(t(0.8414709568023682)) + + t([1.0, 2.0, 3.0]) + |> Nx.sin() + |> assert_equal(t([0.8414709568023682, 0.9092974066734314, 0.14112000167369843])) + end end defp t(values, opts \\ []) do From b170caea91e4cd88881d8089c95f6a78d133c6eb Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 17:25:10 -0300 Subject: [PATCH 073/185] exp --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 1 + candlex/test/candlex_test.exs | 9 +++++++++ 5 files changed, 13 insertions(+), 2 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 6cdcd720fb..100ac770fa 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -162,7 +162,7 @@ defmodule Candlex.Backend do # Unary ops - unary_ops = [:negate, :sin] + unary_ops = [:exp, :negate, :sin] for op <- unary_ops do @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index f9f3b87ecc..63088bc2eb 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -16,7 +16,7 @@ defmodule Candlex.Native do def to_type(_tensor, _dtype), do: error() def concatenate(_tensors, _axis), do: error() - for op <- [:negate, :sin] do + for op <- [:exp, :negate, :sin] do def unquote(op)(_tensor), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index d1b86021e8..71a25b92b5 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -43,6 +43,7 @@ rustler::init! { tensors::concatenate, tensors::matmul, tensors::sin, + tensors::exp, devices::is_cuda_available ], load = load diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 3fcf0ef1a0..1e53576eae 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -161,6 +161,7 @@ macro_rules! binary_nif { unary_nif!(negate, neg); unary_nif!(sin); +unary_nif!(exp); binary_nif!(add, broadcast_add); binary_nif!(subtract, broadcast_sub); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 57835451f9..5238ac13ab 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -402,6 +402,15 @@ defmodule CandlexTest do |> Nx.sin() |> assert_equal(t([0.8414709568023682, 0.9092974066734314, 0.14112000167369843])) end + + test "exp" do + Nx.exp(1.0) + |> assert_equal(t(2.7182817459106445)) + + t([1.0, 2, 3]) + |> Nx.exp() + |> assert_equal(t([2.7182817459106445, 7.389056205749512, 20.08553695678711])) + end end defp t(values, opts \\ []) do From 20b4ed312f2a39c188aa444b69fda6535a0e78d9 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 17:27:29 -0300 Subject: [PATCH 074/185] cos --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 3 ++- candlex/test/candlex_test.exs | 9 +++++++++ 5 files changed, 14 insertions(+), 3 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 100ac770fa..27306ccba2 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -162,7 +162,7 @@ defmodule Candlex.Backend do # Unary ops - unary_ops = [:exp, :negate, :sin] + unary_ops = [:cos, :exp, :negate, :sin] for op <- unary_ops do @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 63088bc2eb..5234e2071d 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -16,7 +16,7 @@ defmodule Candlex.Native do def to_type(_tensor, _dtype), do: error() def concatenate(_tensors, _axis), do: error() - for op <- [:exp, :negate, :sin] do + for op <- [:cos, :exp, :negate, :sin] do def unquote(op)(_tensor), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 71a25b92b5..4fd4faaa4f 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -42,6 +42,7 @@ rustler::init! { tensors::reshape, tensors::concatenate, tensors::matmul, + tensors::cos, tensors::sin, tensors::exp, devices::is_cuda_available diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 1e53576eae..33082fee98 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -160,8 +160,9 @@ macro_rules! binary_nif { } unary_nif!(negate, neg); -unary_nif!(sin); +unary_nif!(cos); unary_nif!(exp); +unary_nif!(sin); binary_nif!(add, broadcast_add); binary_nif!(subtract, broadcast_sub); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 5238ac13ab..ad71260146 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -411,6 +411,15 @@ defmodule CandlexTest do |> Nx.exp() |> assert_equal(t([2.7182817459106445, 7.389056205749512, 20.08553695678711])) end + + test "cos" do + Nx.cos(1.0) + |> assert_equal(t(0.5403022766113281)) + + t([1.0, 2, 3]) + |> Nx.cos() + |> assert_equal(t([0.5403022766113281, -0.416146844625473, -0.9899924993515015])) + end end defp t(values, opts \\ []) do From c69a2f18da41b9ed02ea91d010782bd2424aad8a Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 17:29:45 -0300 Subject: [PATCH 075/185] log --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 1 + candlex/test/candlex_test.exs | 9 +++++++++ 5 files changed, 13 insertions(+), 2 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 27306ccba2..ee15147b70 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -162,7 +162,7 @@ defmodule Candlex.Backend do # Unary ops - unary_ops = [:cos, :exp, :negate, :sin] + unary_ops = [:cos, :exp, :log, :negate, :sin] for op <- unary_ops do @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 5234e2071d..369e2cd7ed 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -16,7 +16,7 @@ defmodule Candlex.Native do def to_type(_tensor, _dtype), do: error() def concatenate(_tensors, _axis), do: error() - for op <- [:cos, :exp, :negate, :sin] do + for op <- [:cos, :exp, :log, :negate, :sin] do def unquote(op)(_tensor), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 4fd4faaa4f..354b7efb65 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -45,6 +45,7 @@ rustler::init! { tensors::cos, tensors::sin, tensors::exp, + tensors::log, devices::is_cuda_available ], load = load diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 33082fee98..f45b11c1fb 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -163,6 +163,7 @@ unary_nif!(negate, neg); unary_nif!(cos); unary_nif!(exp); unary_nif!(sin); +unary_nif!(log); binary_nif!(add, broadcast_add); binary_nif!(subtract, broadcast_sub); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index ad71260146..a1a8f156db 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -420,6 +420,15 @@ defmodule CandlexTest do |> Nx.cos() |> assert_equal(t([0.5403022766113281, -0.416146844625473, -0.9899924993515015])) end + + test "log" do + Nx.log(1.0) + |> assert_equal(t(0.0)) + + t([1.0, 2, 3]) + |> Nx.log() + |> assert_equal(t([0.0, 0.6931471824645996, 1.0986123085021973])) + end end defp t(values, opts \\ []) do From 2afff410bfbd0fe33f26a1b5d6af75bbdf1ffc35 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 17:40:26 -0300 Subject: [PATCH 076/185] tanh --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/Cargo.lock | 36 +++++++++++++-------------- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 1 + candlex/test/candlex_test.exs | 9 +++++++ 6 files changed, 31 insertions(+), 20 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index ee15147b70..4dfa7d5387 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -162,7 +162,7 @@ defmodule Candlex.Backend do # Unary ops - unary_ops = [:cos, :exp, :log, :negate, :sin] + unary_ops = [:cos, :exp, :log, :negate, :sin, :tanh] for op <- unary_ops do @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 369e2cd7ed..a699287ab8 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -16,7 +16,7 @@ defmodule Candlex.Native do def to_type(_tensor, _dtype), do: error() def concatenate(_tensors, _axis), do: error() - for op <- [:cos, :exp, :log, :negate, :sin] do + for op <- [:cos, :exp, :log, :negate, :sin, :tanh] do def unquote(op)(_tensor), do: error() end diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 7d8031fc20..fb062a12e4 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "aho-corasick" -version = "1.0.4" +version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6748e8def348ed4d14996fa801f4122cd763fff530258cdc03f64b25f89d3a5a" +checksum = "0c378d78423fdad8089616f827526ee33c19f2fddbd5de1629152c9593ba4783" dependencies = [ "memchr", ] @@ -37,8 +37,8 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" -version = "0.1.3" -source = "git+https://github.com/huggingface/candle#aba1e90797e430f28eec13b14b76dd5355876f9c" +version = "0.2.1" +source = "git+https://github.com/huggingface/candle#949f1eae6fbbda4dfa20743fe4214e5e1a4a5921" dependencies = [ "byteorder", "candle-gemm", @@ -331,9 +331,9 @@ checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "5486aed0026218e61b8a01d5fbd5a0a134649abb71a0e53b7bc088529dced86e" [[package]] name = "memmap2" @@ -491,9 +491,9 @@ checksum = "2962bf2e1f971c53ef59b2d7ca51d6a5e5c4a9d2be47eb1f661a321a4da85888" [[package]] name = "regex" -version = "1.9.3" +version = "1.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81bc1d4caf89fac26a70747fe603c130093b53c773888797a6329091246d651a" +checksum = "12de2eff854e5fa4b1295edd650e227e9d8fb0c9e90b12e7f36d6a6811791a29" dependencies = [ "aho-corasick", "memchr", @@ -503,9 +503,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.6" +version = "0.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fed1ceff11a1dddaee50c9dc8e4938bd106e9d89ae372f192311e7da498e3b69" +checksum = "49530408a136e16e5b486e883fbb6ba058e8e4e8ae6621a77b048b314336e629" dependencies = [ "aho-corasick", "memchr", @@ -514,9 +514,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.7.4" +version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" [[package]] name = "rustler" @@ -559,9 +559,9 @@ checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" [[package]] name = "safetensors" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ad8cbd90c388a0b028565d8ad22e090101599d951c6b5f105b4f7772721a9d5f" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" dependencies = [ "serde", "serde_json", @@ -581,18 +581,18 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.185" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be9b6f69f1dfd54c3b568ffa45c310d6973a5e5148fd40cf515acaf38cf5bc31" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.185" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc59dfdcbad1437773485e0367fea4b090a2e0a16d9ffc46af47764536a298ec" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 354b7efb65..feb0fa73d0 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -46,6 +46,7 @@ rustler::init! { tensors::sin, tensors::exp, tensors::log, + tensors::tanh, devices::is_cuda_available ], load = load diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index f45b11c1fb..aa0eb84aae 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -164,6 +164,7 @@ unary_nif!(cos); unary_nif!(exp); unary_nif!(sin); unary_nif!(log); +unary_nif!(tanh); binary_nif!(add, broadcast_add); binary_nif!(subtract, broadcast_sub); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index a1a8f156db..0fa8247400 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -429,6 +429,15 @@ defmodule CandlexTest do |> Nx.log() |> assert_equal(t([0.0, 0.6931471824645996, 1.0986123085021973])) end + + test "tanh" do + Nx.tanh(1.0) + |> assert_equal(t(0.7615941762924194)) + + t([1.0, 2, 3]) + |> Nx.tanh() + |> assert_equal(t([0.7615941762924194, 0.9640275835990906, 0.9950547814369202])) + end end defp t(values, opts \\ []) do From c6f858308de667c6ea52245e60e8e2bed5b3c3d7 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 17:43:00 -0300 Subject: [PATCH 077/185] abs --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 1 + candlex/test/candlex_test.exs | 6 ++++++ 5 files changed, 10 insertions(+), 2 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 4dfa7d5387..7aa95ffeb3 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -162,7 +162,7 @@ defmodule Candlex.Backend do # Unary ops - unary_ops = [:cos, :exp, :log, :negate, :sin, :tanh] + unary_ops = [:abs, :cos, :exp, :log, :negate, :sin, :tanh] for op <- unary_ops do @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index a699287ab8..df2709cc5e 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -16,7 +16,7 @@ defmodule Candlex.Native do def to_type(_tensor, _dtype), do: error() def concatenate(_tensors, _axis), do: error() - for op <- [:cos, :exp, :log, :negate, :sin, :tanh] do + for op <- [:abs, :cos, :exp, :log, :negate, :sin, :tanh] do def unquote(op)(_tensor), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index feb0fa73d0..20e1d614fd 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -42,6 +42,7 @@ rustler::init! { tensors::reshape, tensors::concatenate, tensors::matmul, + tensors::abs, tensors::cos, tensors::sin, tensors::exp, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index aa0eb84aae..18ee24a6e5 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -160,6 +160,7 @@ macro_rules! binary_nif { } unary_nif!(negate, neg); +unary_nif!(abs); unary_nif!(cos); unary_nif!(exp); unary_nif!(sin); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 0fa8247400..385c39a865 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -438,6 +438,12 @@ defmodule CandlexTest do |> Nx.tanh() |> assert_equal(t([0.7615941762924194, 0.9640275835990906, 0.9950547814369202])) end + + test "abs" do + t([-2.0, -1, 0, 1, 2]) + |> Nx.abs() + |> assert_equal(t([2, 1, 0, 1, 2])) + end end defp t(values, opts \\ []) do From 7ff257b8a6ba7b886e0a67e0ee91b2dbfaed963d Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 17:45:33 -0300 Subject: [PATCH 078/185] sqrt --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 1 + candlex/test/candlex_test.exs | 9 +++++++++ 5 files changed, 13 insertions(+), 2 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 7aa95ffeb3..7b1e889407 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -162,7 +162,7 @@ defmodule Candlex.Backend do # Unary ops - unary_ops = [:abs, :cos, :exp, :log, :negate, :sin, :tanh] + unary_ops = [:abs, :cos, :exp, :log, :negate, :sin, :sqrt, :tanh] for op <- unary_ops do @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index df2709cc5e..65ae231d5e 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -16,7 +16,7 @@ defmodule Candlex.Native do def to_type(_tensor, _dtype), do: error() def concatenate(_tensors, _axis), do: error() - for op <- [:abs, :cos, :exp, :log, :negate, :sin, :tanh] do + for op <- [:abs, :cos, :exp, :log, :negate, :sin, :sqrt, :tanh] do def unquote(op)(_tensor), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 20e1d614fd..494b75de67 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -47,6 +47,7 @@ rustler::init! { tensors::sin, tensors::exp, tensors::log, + tensors::sqrt, tensors::tanh, devices::is_cuda_available ], diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 18ee24a6e5..840cc41968 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -165,6 +165,7 @@ unary_nif!(cos); unary_nif!(exp); unary_nif!(sin); unary_nif!(log); +unary_nif!(sqrt); unary_nif!(tanh); binary_nif!(add, broadcast_add); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 385c39a865..9e3224be3c 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -444,6 +444,15 @@ defmodule CandlexTest do |> Nx.abs() |> assert_equal(t([2, 1, 0, 1, 2])) end + + test "sqrt" do + Nx.sqrt(1.0) + |> assert_equal(t(1.0)) + + t([1.0, 2, 3]) + |> Nx.sqrt() + |> assert_equal(t([1.0, 1.4142135381698608, 1.7320507764816284])) + end end defp t(values, opts \\ []) do From 3d5aba4a298347fda1f14487eb908554385c6cf7 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 1 Sep 2023 13:26:36 -0300 Subject: [PATCH 079/185] argmax --- candlex/lib/candlex/backend.ex | 19 +++++++++++ candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 12 +++++++ candlex/test/candlex_test.exs | 45 +++++++++++++++++++++++++++ 5 files changed, 78 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 7b1e889407..2383263b3a 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -99,6 +99,25 @@ defmodule Candlex.Backend do |> to_nx(out) end + @impl true + def argmax(%T{} = out, %T{shape: {}} = tensor, _opts) do + out + |> constant(0, []) + end + + def argmax(%T{type: type} = out, %T{} = tensor, opts) do + axis = opts[:axis] || -1 + keep_axis = opts[:keep_axis] || false + + from_nx(tensor) + |> Native.argmax(axis, keep_axis) + |> unwrap!() + # # candle argmax changes to u32 + |> Native.to_type(to_candle_dtype(type)) + |> unwrap!() + |> to_nx(out) + end + # Element-wise for op <- [:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift, :right_shift] do diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 65ae231d5e..5ef072e55f 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -7,6 +7,7 @@ defmodule Candlex.Native do def from_binary(_binary, _dtype, _shape, _device), do: error() def to_binary(_tensor), do: error() def all(_tensor), do: error() + def argmax(_tensor, _dim, _keep_dim), do: error() def where_cond(_tensor, _on_true, _on_false), do: error() def narrow(_tensor, _dim, _start, _length), do: error() def squeeze(_tensor, _dim), do: error() diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 494b75de67..84fbeaca11 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -32,6 +32,7 @@ rustler::init! { tensors::less_equal, tensors::subtract, tensors::all, + tensors::argmax, tensors::negate, tensors::where_cond, tensors::narrow, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 840cc41968..9a23b170f1 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -103,6 +103,18 @@ pub fn all(ex_tensor: ExTensor) -> Result { Ok(ExTensor::new(Tensor::new(bool_scalar, device)?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn argmax(ex_tensor: ExTensor, dim: usize, keep_dim: bool) -> Result { + let t = + if keep_dim { + ex_tensor.argmax_keepdim(dim)? + } else { + ex_tensor.argmax(dim)? + }; + + Ok(ExTensor::new(t)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn broadcast_to(t: ExTensor, shape: Term) -> Result { Ok(ExTensor::new(t.broadcast_as(tuple_to_vec(shape).unwrap())?)) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 9e3224be3c..c2ee1045a5 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -453,6 +453,51 @@ defmodule CandlexTest do |> Nx.sqrt() |> assert_equal(t([1.0, 1.4142135381698608, 1.7320507764816284])) end + + test "argmax" do + Nx.argmax(4) + |> assert_equal(t(0)) + + # TODO: Support total argmax + # t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) + # |> Nx.argmax() + # |> assert_equal(t(10)) + + # t([2.0, 4.0]) + # |> Nx.argmax() + # |> assert_equal(t(1)) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) + |> Nx.argmax(axis: 0) + |> assert_equal(t( + [ + [1, 0, 0], + [1, 1, 0] + ] + )) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) + |> Nx.argmax(axis: :z) + |> assert_equal(t( + [ + [0, 2], + [0, 1] + ] + )) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) + |> Nx.argmax(axis: :y, keep_axis: true) + |> assert_equal(t( + [ + [ + [0, 0, 0] + ], + [ + [0, 1, 0] + ] + ] + )) + end end defp t(values, opts \\ []) do From bfdfd85aea113d0d7ef3e919c4532104a0a44bdb Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 1 Sep 2023 14:02:22 -0300 Subject: [PATCH 080/185] argmin --- candlex/lib/candlex/backend.ex | 33 ++++++++++---------- candlex/lib/candlex/native.ex | 5 +++- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 12 ++++++++ candlex/test/candlex_test.exs | 43 ++++++++++++++++++++++++++- 5 files changed, 77 insertions(+), 17 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 2383263b3a..f1ed2efe85 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -99,23 +99,26 @@ defmodule Candlex.Backend do |> to_nx(out) end - @impl true - def argmax(%T{} = out, %T{shape: {}} = tensor, _opts) do - out - |> constant(0, []) - end - def argmax(%T{type: type} = out, %T{} = tensor, opts) do - axis = opts[:axis] || -1 - keep_axis = opts[:keep_axis] || false + for op <- [:argmax, :argmin] do + @impl true + def unquote(op)(%T{} = out, %T{shape: {}} = _tensor, _opts) do + out + |> constant(0, []) + end + def unquote(op)(%T{type: type} = out, %T{} = tensor, opts) do + axis = opts[:axis] || -1 + keep_axis = opts[:keep_axis] || false - from_nx(tensor) - |> Native.argmax(axis, keep_axis) - |> unwrap!() - # # candle argmax changes to u32 - |> Native.to_type(to_candle_dtype(type)) - |> unwrap!() - |> to_nx(out) + tensor + |> from_nx() + |> Native.unquote(op)(axis, keep_axis) + |> unwrap!() + # candle argmax/argmin changes to u32 + |> Native.to_type(to_candle_dtype(type)) + |> unwrap!() + |> to_nx(out) + end end # Element-wise diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 5ef072e55f..d5dde0859e 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -7,7 +7,6 @@ defmodule Candlex.Native do def from_binary(_binary, _dtype, _shape, _device), do: error() def to_binary(_tensor), do: error() def all(_tensor), do: error() - def argmax(_tensor, _dim, _keep_dim), do: error() def where_cond(_tensor, _on_true, _on_false), do: error() def narrow(_tensor, _dim, _start, _length), do: error() def squeeze(_tensor, _dim), do: error() @@ -36,6 +35,10 @@ defmodule Candlex.Native do def unquote(op)(_left, _right), do: error() end + for op <- [:argmax, :argmin] do + def unquote(op)(_tensor, _dim, _keep_dim), do: error() + end + def is_cuda_available(), do: error() defp error(), do: :erlang.nif_error(:nif_not_loaded) diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 84fbeaca11..f616f11303 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -33,6 +33,7 @@ rustler::init! { tensors::subtract, tensors::all, tensors::argmax, + tensors::argmin, tensors::negate, tensors::where_cond, tensors::narrow, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 9a23b170f1..0686159548 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -115,6 +115,18 @@ pub fn argmax(ex_tensor: ExTensor, dim: usize, keep_dim: bool) -> Result Result { + let t = + if keep_dim { + ex_tensor.argmin_keepdim(dim)? + } else { + ex_tensor.argmin(dim)? + }; + + Ok(ExTensor::new(t)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn broadcast_to(t: ExTensor, shape: Term) -> Result { Ok(ExTensor::new(t.broadcast_as(tuple_to_vec(shape).unwrap())?)) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index c2ee1045a5..c6b4f54cc1 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -458,7 +458,7 @@ defmodule CandlexTest do Nx.argmax(4) |> assert_equal(t(0)) - # TODO: Support total argmax + # TODO: Support argmax without specific axis # t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) # |> Nx.argmax() # |> assert_equal(t(10)) @@ -498,6 +498,47 @@ defmodule CandlexTest do ] )) end + + test "argmin" do + Nx.argmin(4) + |> assert_equal(t(0)) + + # TODO: Support argmin without specific axis + # t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) + # |> Nx.argmin() + # |> assert_equal(t(4)) + + # t([2.0, 4.0]) + # |> Nx.argmin() + # |> assert_equal(t(0)) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) + |> Nx.argmin(axis: 0) + |> assert_equal(t( + [ + [0, 0, 0], + [0, 0, 0] + ] + )) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) + |> Nx.argmin(axis: 1) + |> assert_equal(t( + [ + [1, 1, 0], + [1, 0, 0] + ] + )) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) + |> Nx.argmin(axis: :z) + |> assert_equal(t( + [ + [1, 1], + [1, 2] + ] + )) + end end defp t(values, opts \\ []) do From 6a68279f4d0f20f45138ddc48dfd3bcc827a59b8 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 1 Sep 2023 16:41:54 -0300 Subject: [PATCH 081/185] acos --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/Cargo.lock | 1 + candlex/native/candlex/Cargo.toml | 1 + candlex/native/candlex/src/lib.rs | 2 ++ candlex/native/candlex/src/ops.rs | 36 +++++++++++++++++++++++++++ candlex/native/candlex/src/tensors.rs | 12 +++++++++ candlex/test/candlex_test.exs | 9 +++++++ 8 files changed, 63 insertions(+), 2 deletions(-) create mode 100644 candlex/native/candlex/src/ops.rs diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index f1ed2efe85..55989227f5 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -184,7 +184,7 @@ defmodule Candlex.Backend do # Unary ops - unary_ops = [:abs, :cos, :exp, :log, :negate, :sin, :sqrt, :tanh] + unary_ops = [:abs, :acos, :cos, :exp, :log, :negate, :sin, :sqrt, :tanh] for op <- unary_ops do @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index d5dde0859e..f7349477a3 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -16,7 +16,7 @@ defmodule Candlex.Native do def to_type(_tensor, _dtype), do: error() def concatenate(_tensors, _axis), do: error() - for op <- [:abs, :cos, :exp, :log, :negate, :sin, :sqrt, :tanh] do + for op <- [:abs, :acos, :cos, :exp, :log, :negate, :sin, :sqrt, :tanh] do def unquote(op)(_tensor), do: error() end diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index fb062a12e4..163c1efa32 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -185,6 +185,7 @@ version = "0.1.0" dependencies = [ "candle-core", "half", + "num-traits", "rustler", "thiserror", ] diff --git a/candlex/native/candlex/Cargo.toml b/candlex/native/candlex/Cargo.toml index 4a0bd9dc07..91c2a71661 100644 --- a/candlex/native/candlex/Cargo.toml +++ b/candlex/native/candlex/Cargo.toml @@ -12,5 +12,6 @@ crate-type = ["cdylib"] [dependencies] candle-core = { git = "https://github.com/huggingface/candle" } half = "2.3.1" +num-traits = "0.2.16" rustler = "0.29.1" thiserror = "1.0.47" diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index f616f11303..5921c2190e 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -7,6 +7,7 @@ mod atoms { mod devices; mod error; +mod ops; mod tensors; use rustler::{Env, Term}; @@ -45,6 +46,7 @@ rustler::init! { tensors::concatenate, tensors::matmul, tensors::abs, + tensors::acos, tensors::cos, tensors::sin, tensors::exp, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs new file mode 100644 index 0000000000..e6a95158fc --- /dev/null +++ b/candlex/native/candlex/src/ops.rs @@ -0,0 +1,36 @@ +use candle_core::{CpuStorage, CustomOp1, Error, Layout, Shape}; +use num_traits::Float; + +macro_rules! custom_unary_op { + ($struct_name:ident, $name:expr, $fn_name:ident) => { + pub(crate) struct $struct_name; + + fn $fn_name(value: T) -> T { + value.$fn_name() + } + + impl CustomOp1 for $struct_name { + // Box does not support const yet, so use a function to get the name. + fn name(&self) -> &'static str { + $name + } + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape), candle_core::Error> { + use candle_core::backend::BackendStorage; + + let storage = candle_core::map_dtype!( + $name, + storage, + |vec| candle_core::cpu_backend::unary_map(vec, layout, |v| $fn_name(v)), + (BF16, F16, F32, F64) + ); + + Ok((storage, layout.shape().clone())) + } + } + } +} + +custom_unary_op!(Acos, "acos", acos); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 0686159548..3795aa50bd 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,4 +1,5 @@ use crate::atoms; +use crate::ops::{Acos}; use crate::error::CandlexError; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -183,6 +184,15 @@ macro_rules! binary_nif { }; } +macro_rules! custom_unary_nif { + ($nif_name:ident, $custom_op_name:ident) => { + #[rustler::nif(schedule = "DirtyCpu")] + pub fn $nif_name(ex_tensor: ExTensor) -> Result { + Ok(ExTensor::new(ex_tensor.apply_op1_no_bwd(&$custom_op_name)?)) + } + }; +} + unary_nif!(negate, neg); unary_nif!(abs); unary_nif!(cos); @@ -192,6 +202,8 @@ unary_nif!(log); unary_nif!(sqrt); unary_nif!(tanh); +custom_unary_nif!(acos, Acos); + binary_nif!(add, broadcast_add); binary_nif!(subtract, broadcast_sub); binary_nif!(multiply, broadcast_mul); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index c6b4f54cc1..c6bd2121e8 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -539,6 +539,15 @@ defmodule CandlexTest do ] )) end + + test "acos" do + Nx.acos(0.10000000149011612) + |> assert_equal(t(1.4706288576126099)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.acos() + |> assert_equal(t([1.4706288576126099, 1.0471975803375244, 0.4510268568992615])) + end end defp t(values, opts \\ []) do From 02fd594deb00e5b0968b9c97a443942a6d10fc4a Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 1 Sep 2023 17:23:59 -0300 Subject: [PATCH 082/185] asin --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 1 + candlex/native/candlex/src/tensors.rs | 3 ++- candlex/test/candlex_test.exs | 9 +++++++++ 6 files changed, 15 insertions(+), 3 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 55989227f5..129740f3ae 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -184,7 +184,7 @@ defmodule Candlex.Backend do # Unary ops - unary_ops = [:abs, :acos, :cos, :exp, :log, :negate, :sin, :sqrt, :tanh] + unary_ops = [:abs, :acos, :asin, :cos, :exp, :log, :negate, :sin, :sqrt, :tanh] for op <- unary_ops do @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index f7349477a3..84a7cdac2c 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -16,7 +16,7 @@ defmodule Candlex.Native do def to_type(_tensor, _dtype), do: error() def concatenate(_tensors, _axis), do: error() - for op <- [:abs, :acos, :cos, :exp, :log, :negate, :sin, :sqrt, :tanh] do + for op <- [:abs, :acos, :asin, :cos, :exp, :log, :negate, :sin, :sqrt, :tanh] do def unquote(op)(_tensor), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 5921c2190e..d36f6f1a16 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -47,6 +47,7 @@ rustler::init! { tensors::matmul, tensors::abs, tensors::acos, + tensors::asin, tensors::cos, tensors::sin, tensors::exp, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index e6a95158fc..f08eaa6f20 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -34,3 +34,4 @@ macro_rules! custom_unary_op { } custom_unary_op!(Acos, "acos", acos); +custom_unary_op!(Asin, "asin", asin); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 3795aa50bd..215a90ef9b 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,5 +1,5 @@ use crate::atoms; -use crate::ops::{Acos}; +use crate::ops::{Acos, Asin}; use crate::error::CandlexError; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -203,6 +203,7 @@ unary_nif!(sqrt); unary_nif!(tanh); custom_unary_nif!(acos, Acos); +custom_unary_nif!(asin, Asin); binary_nif!(add, broadcast_add); binary_nif!(subtract, broadcast_sub); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index c6bd2121e8..5a545ae53c 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -548,6 +548,15 @@ defmodule CandlexTest do |> Nx.acos() |> assert_equal(t([1.4706288576126099, 1.0471975803375244, 0.4510268568992615])) end + + test "asin" do + Nx.asin(0.10000000149011612) + |> assert_equal(t(0.1001674234867096)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.asin() + |> assert_equal(t([0.1001674234867096, 0.5235987901687622, 1.1197694540023804])) + end end defp t(values, opts \\ []) do From c636535627cf3969469e4d153fcf52dd9dd5100b Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 1 Sep 2023 17:25:22 -0300 Subject: [PATCH 083/185] cargo fmt --- candlex/native/candlex/src/ops.rs | 8 ++++-- candlex/native/candlex/src/tensors.rs | 40 +++++++++++++++------------ 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index f08eaa6f20..32bc6e25a8 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -17,7 +17,11 @@ macro_rules! custom_unary_op { /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, /// offsets etc so the associated layout should be used to access it. - fn cpu_fwd(&self, storage: &CpuStorage, layout: &Layout) -> Result<(CpuStorage, Shape), candle_core::Error> { + fn cpu_fwd( + &self, + storage: &CpuStorage, + layout: &Layout, + ) -> Result<(CpuStorage, Shape), candle_core::Error> { use candle_core::backend::BackendStorage; let storage = candle_core::map_dtype!( @@ -30,7 +34,7 @@ macro_rules! custom_unary_op { Ok((storage, layout.shape().clone())) } } - } + }; } custom_unary_op!(Acos, "acos", acos); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 215a90ef9b..65cdaeae9d 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,6 +1,6 @@ use crate::atoms; -use crate::ops::{Acos, Asin}; use crate::error::CandlexError; +use crate::ops::{Acos, Asin}; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; use rustler::{Atom, Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; @@ -34,14 +34,19 @@ impl Deref for ExTensor { } #[rustler::nif(schedule = "DirtyCpu")] -pub fn from_binary(binary: Binary, dtype_str: &str, shape: Term, device: Atom) -> Result { +pub fn from_binary( + binary: Binary, + dtype_str: &str, + shape: Term, + device: Atom, +) -> Result { Ok(ExTensor::new(Tensor::from_raw_buffer( binary.as_slice(), // TODO: Handle DTypeParseError DType::from_str(dtype_str).unwrap(), // TODO: Handle rustler::Error &tuple_to_vec(shape).unwrap(), - &device_from_atom(device)? + &device_from_atom(device)?, )?)) } @@ -75,7 +80,7 @@ pub fn arange( end: i64, dtype_str: &str, shape: Term, - device: Atom + device: Atom, ) -> Result { Ok(ExTensor::new( Tensor::arange(start, end, &device_from_atom(device)?)? @@ -106,24 +111,22 @@ pub fn all(ex_tensor: ExTensor) -> Result { #[rustler::nif(schedule = "DirtyCpu")] pub fn argmax(ex_tensor: ExTensor, dim: usize, keep_dim: bool) -> Result { - let t = - if keep_dim { - ex_tensor.argmax_keepdim(dim)? - } else { - ex_tensor.argmax(dim)? - }; + let t = if keep_dim { + ex_tensor.argmax_keepdim(dim)? + } else { + ex_tensor.argmax(dim)? + }; Ok(ExTensor::new(t)) } #[rustler::nif(schedule = "DirtyCpu")] pub fn argmin(ex_tensor: ExTensor, dim: usize, keep_dim: bool) -> Result { - let t = - if keep_dim { - ex_tensor.argmin_keepdim(dim)? - } else { - ex_tensor.argmin(dim)? - }; + let t = if keep_dim { + ex_tensor.argmin_keepdim(dim)? + } else { + ex_tensor.argmin(dim)? + }; Ok(ExTensor::new(t)) } @@ -229,7 +232,10 @@ fn device_from_atom(atom: Atom) -> Result { // } else if atom == atoms::cuda() { // Ok(Device::new_cuda(0)?) } else { - Err(CandlexError::Other(format!("unsupported device {:?}", atom))) + Err(CandlexError::Other(format!( + "unsupported device {:?}", + atom + ))) } } From 89f29b7a12163ed28a890f2d8841b644c7391001 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 1 Sep 2023 17:26:13 -0300 Subject: [PATCH 084/185] mix format --- candlex/lib/candlex/backend.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 129740f3ae..83fa15bb21 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -99,13 +99,13 @@ defmodule Candlex.Backend do |> to_nx(out) end - for op <- [:argmax, :argmin] do @impl true def unquote(op)(%T{} = out, %T{shape: {}} = _tensor, _opts) do out |> constant(0, []) end + def unquote(op)(%T{type: type} = out, %T{} = tensor, opts) do axis = opts[:axis] || -1 keep_axis = opts[:keep_axis] || false From 9d371eb06d3474fad0478bbc45026e6cf40f58e1 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 1 Sep 2023 17:27:56 -0300 Subject: [PATCH 085/185] cargo update --- candlex/native/candlex/Cargo.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 163c1efa32..2a78aca544 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -38,7 +38,7 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" version = "0.2.1" -source = "git+https://github.com/huggingface/candle#949f1eae6fbbda4dfa20743fe4214e5e1a4a5921" +source = "git+https://github.com/huggingface/candle#2c1df6bba1a2017b8b4aec87a725b5b06b48cdab" dependencies = [ "byteorder", "candle-gemm", From a254c1916f67595e02819c3ebdee87253e6c4126 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 1 Sep 2023 17:38:10 -0300 Subject: [PATCH 086/185] tan --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 1 + candlex/native/candlex/src/tensors.rs | 3 ++- candlex/test/candlex_test.exs | 9 +++++++++ 6 files changed, 15 insertions(+), 3 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 83fa15bb21..85afb5ebb4 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -184,7 +184,7 @@ defmodule Candlex.Backend do # Unary ops - unary_ops = [:abs, :acos, :asin, :cos, :exp, :log, :negate, :sin, :sqrt, :tanh] + unary_ops = [:abs, :acos, :asin, :cos, :exp, :log, :negate, :sin, :sqrt, :tan, :tanh] for op <- unary_ops do @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 84a7cdac2c..d77fb7b7f3 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -16,7 +16,7 @@ defmodule Candlex.Native do def to_type(_tensor, _dtype), do: error() def concatenate(_tensors, _axis), do: error() - for op <- [:abs, :acos, :asin, :cos, :exp, :log, :negate, :sin, :sqrt, :tanh] do + for op <- [:abs, :acos, :asin, :cos, :exp, :log, :negate, :sin, :sqrt, :tan, :tanh] do def unquote(op)(_tensor), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index d36f6f1a16..f741919f8b 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -53,6 +53,7 @@ rustler::init! { tensors::exp, tensors::log, tensors::sqrt, + tensors::tan, tensors::tanh, devices::is_cuda_available ], diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 32bc6e25a8..1dcf95a51b 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -39,3 +39,4 @@ macro_rules! custom_unary_op { custom_unary_op!(Acos, "acos", acos); custom_unary_op!(Asin, "asin", asin); +custom_unary_op!(Tan, "tan", tan); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 65cdaeae9d..8de7b9123c 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,6 +1,6 @@ use crate::atoms; use crate::error::CandlexError; -use crate::ops::{Acos, Asin}; +use crate::ops::{Acos, Asin, Tan}; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; use rustler::{Atom, Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; @@ -207,6 +207,7 @@ unary_nif!(tanh); custom_unary_nif!(acos, Acos); custom_unary_nif!(asin, Asin); +custom_unary_nif!(tan, Tan); binary_nif!(add, broadcast_add); binary_nif!(subtract, broadcast_sub); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 5a545ae53c..fc0655779e 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -557,6 +557,15 @@ defmodule CandlexTest do |> Nx.asin() |> assert_equal(t([0.1001674234867096, 0.5235987901687622, 1.1197694540023804])) end + + test "tan" do + Nx.tan(1.0) + |> assert_equal(t(1.5574077367782593)) + + t([1.0, 2, 3]) + |> Nx.tan() + |> assert_equal(t([1.5574077367782593, -2.185039758682251, -0.14254654943943024])) + end end defp t(values, opts \\ []) do From 3c041423e2d8315276578d560efd8e4835126f26 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 1 Sep 2023 17:47:23 -0300 Subject: [PATCH 087/185] atan --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 1 + candlex/native/candlex/src/tensors.rs | 3 ++- candlex/test/candlex_test.exs | 9 +++++++++ 6 files changed, 15 insertions(+), 3 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 85afb5ebb4..c0dd521bae 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -184,7 +184,7 @@ defmodule Candlex.Backend do # Unary ops - unary_ops = [:abs, :acos, :asin, :cos, :exp, :log, :negate, :sin, :sqrt, :tan, :tanh] + unary_ops = [:abs, :acos, :asin, :atan, :cos, :exp, :log, :negate, :sin, :sqrt, :tan, :tanh] for op <- unary_ops do @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index d77fb7b7f3..df06e7e610 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -16,7 +16,7 @@ defmodule Candlex.Native do def to_type(_tensor, _dtype), do: error() def concatenate(_tensors, _axis), do: error() - for op <- [:abs, :acos, :asin, :cos, :exp, :log, :negate, :sin, :sqrt, :tan, :tanh] do + for op <- [:abs, :acos, :asin, :atan, :cos, :exp, :log, :negate, :sin, :sqrt, :tan, :tanh] do def unquote(op)(_tensor), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index f741919f8b..26bfc4fd76 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -48,6 +48,7 @@ rustler::init! { tensors::abs, tensors::acos, tensors::asin, + tensors::atan, tensors::cos, tensors::sin, tensors::exp, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 1dcf95a51b..26c4ade0f8 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -39,4 +39,5 @@ macro_rules! custom_unary_op { custom_unary_op!(Acos, "acos", acos); custom_unary_op!(Asin, "asin", asin); +custom_unary_op!(Atan, "atan", atan); custom_unary_op!(Tan, "tan", tan); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 8de7b9123c..02c066267d 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,6 +1,6 @@ use crate::atoms; use crate::error::CandlexError; -use crate::ops::{Acos, Asin, Tan}; +use crate::ops::{Acos, Asin, Atan, Tan}; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; use rustler::{Atom, Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; @@ -207,6 +207,7 @@ unary_nif!(tanh); custom_unary_nif!(acos, Acos); custom_unary_nif!(asin, Asin); +custom_unary_nif!(atan, Atan); custom_unary_nif!(tan, Tan); binary_nif!(add, broadcast_add); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index fc0655779e..ea7c880724 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -566,6 +566,15 @@ defmodule CandlexTest do |> Nx.tan() |> assert_equal(t([1.5574077367782593, -2.185039758682251, -0.14254654943943024])) end + + test "atan" do + Nx.atan(0.10000000149011612) + |> assert_equal(t(0.09966865181922913)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.atan() + |> assert_equal(t([0.09966865181922913, 0.46364760398864746, 0.7328150868415833])) + end end defp t(values, opts \\ []) do From 9f7db4fbbebab3f8b22fa3f4ed9b17eec0050f61 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 1 Sep 2023 17:53:52 -0300 Subject: [PATCH 088/185] ceil/floor --- candlex/lib/candlex/backend.ex | 19 ++++++++++++++++--- candlex/lib/candlex/native.ex | 17 ++++++++++++++++- candlex/native/candlex/src/lib.rs | 2 ++ candlex/native/candlex/src/ops.rs | 2 ++ candlex/native/candlex/src/tensors.rs | 4 +++- candlex/test/candlex_test.exs | 20 ++++++++++++++++++++ 6 files changed, 59 insertions(+), 5 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index c0dd521bae..7f8c4d6ee7 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -184,9 +184,22 @@ defmodule Candlex.Backend do # Unary ops - unary_ops = [:abs, :acos, :asin, :atan, :cos, :exp, :log, :negate, :sin, :sqrt, :tan, :tanh] - - for op <- unary_ops do + for op <- [ + :abs, + :acos, + :asin, + :atan, + :ceil, + :cos, + :exp, + :floor, + :log, + :negate, + :sin, + :sqrt, + :tan, + :tanh + ] do @impl true def unquote(op)(%T{} = out, %T{} = tensor) do tensor diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index df06e7e610..04ced9f01c 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -16,7 +16,22 @@ defmodule Candlex.Native do def to_type(_tensor, _dtype), do: error() def concatenate(_tensors, _axis), do: error() - for op <- [:abs, :acos, :asin, :atan, :cos, :exp, :log, :negate, :sin, :sqrt, :tan, :tanh] do + for op <- [ + :abs, + :acos, + :asin, + :atan, + :ceil, + :cos, + :exp, + :floor, + :log, + :negate, + :sin, + :sqrt, + :tan, + :tanh + ] do def unquote(op)(_tensor), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 26bfc4fd76..f495c725cb 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -49,9 +49,11 @@ rustler::init! { tensors::acos, tensors::asin, tensors::atan, + tensors::ceil, tensors::cos, tensors::sin, tensors::exp, + tensors::floor, tensors::log, tensors::sqrt, tensors::tan, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 26c4ade0f8..2204eac8c3 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -40,4 +40,6 @@ macro_rules! custom_unary_op { custom_unary_op!(Acos, "acos", acos); custom_unary_op!(Asin, "asin", asin); custom_unary_op!(Atan, "atan", atan); +custom_unary_op!(Ceil, "ceil", ceil); +custom_unary_op!(Floor, "floor", floor); custom_unary_op!(Tan, "tan", tan); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 02c066267d..4ff921c4e1 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,6 +1,6 @@ use crate::atoms; use crate::error::CandlexError; -use crate::ops::{Acos, Asin, Atan, Tan}; +use crate::ops::{Acos, Asin, Atan, Ceil, Floor, Tan}; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; use rustler::{Atom, Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; @@ -208,6 +208,8 @@ unary_nif!(tanh); custom_unary_nif!(acos, Acos); custom_unary_nif!(asin, Asin); custom_unary_nif!(atan, Atan); +custom_unary_nif!(ceil, Ceil); +custom_unary_nif!(floor, Floor); custom_unary_nif!(tan, Tan); binary_nif!(add, broadcast_add); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index ea7c880724..72a49fe71f 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -575,6 +575,26 @@ defmodule CandlexTest do |> Nx.atan() |> assert_equal(t([0.09966865181922913, 0.46364760398864746, 0.7328150868415833])) end + + test "ceil" do + t([-1, 0, 1]) + |> Nx.ceil() + |> assert_equal(t([-1, 0, 1])) + + t([-1.5, -0.5, 0.5, 1.5]) + |> Nx.ceil() + |> assert_equal(t([-1.0, 0.0, 1.0, 2.0])) + end + + test "floor" do + t([-1, 0, 1]) + |> Nx.floor() + |> assert_equal(t([-1, 0, 1])) + + t([-1.5, -0.5, 0.5, 1.5]) + |> Nx.floor() + |> assert_equal(t([-2.0, -1.0, 0.0, 1.0])) + end end defp t(values, opts \\ []) do From 72929247a932d59918a1b7664ef94b20dc62b302 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 1 Sep 2023 17:58:44 -0300 Subject: [PATCH 089/185] cbrt --- candlex/lib/candlex/backend.ex | 1 + candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 1 + candlex/native/candlex/src/tensors.rs | 3 ++- candlex/test/candlex_test.exs | 9 +++++++++ 6 files changed, 15 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 7f8c4d6ee7..8db6445e97 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -189,6 +189,7 @@ defmodule Candlex.Backend do :acos, :asin, :atan, + :cbrt, :ceil, :cos, :exp, diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 04ced9f01c..f6ba8a1bd2 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -21,6 +21,7 @@ defmodule Candlex.Native do :acos, :asin, :atan, + :cbrt, :ceil, :cos, :exp, diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index f495c725cb..c090052b6b 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -49,6 +49,7 @@ rustler::init! { tensors::acos, tensors::asin, tensors::atan, + tensors::cbrt, tensors::ceil, tensors::cos, tensors::sin, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 2204eac8c3..3455f1cd09 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -40,6 +40,7 @@ macro_rules! custom_unary_op { custom_unary_op!(Acos, "acos", acos); custom_unary_op!(Asin, "asin", asin); custom_unary_op!(Atan, "atan", atan); +custom_unary_op!(Cbrt, "cbrt", cbrt); custom_unary_op!(Ceil, "ceil", ceil); custom_unary_op!(Floor, "floor", floor); custom_unary_op!(Tan, "tan", tan); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 4ff921c4e1..3e7c828cae 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,6 +1,6 @@ use crate::atoms; use crate::error::CandlexError; -use crate::ops::{Acos, Asin, Atan, Ceil, Floor, Tan}; +use crate::ops::{Acos, Asin, Atan, Cbrt, Ceil, Floor, Tan}; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; use rustler::{Atom, Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; @@ -208,6 +208,7 @@ unary_nif!(tanh); custom_unary_nif!(acos, Acos); custom_unary_nif!(asin, Asin); custom_unary_nif!(atan, Atan); +custom_unary_nif!(cbrt, Cbrt); custom_unary_nif!(ceil, Ceil); custom_unary_nif!(floor, Floor); custom_unary_nif!(tan, Tan); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 72a49fe71f..ff93325623 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -595,6 +595,15 @@ defmodule CandlexTest do |> Nx.floor() |> assert_equal(t([-2.0, -1.0, 0.0, 1.0])) end + + test "cbrt" do + Nx.cbrt(1.0) + |> assert_equal(t(1.0)) + + t([1.0, 2, 3]) + |> Nx.cbrt() + |> assert_equal(t([1.0, 1.2599210739135742, 1.4422495365142822])) + end end defp t(values, opts \\ []) do From ad5b8c46a29af401917b02b7a8789a04fcb7c671 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 1 Sep 2023 18:02:00 -0300 Subject: [PATCH 090/185] round --- candlex/lib/candlex/backend.ex | 1 + candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 1 + candlex/native/candlex/src/tensors.rs | 3 ++- candlex/test/candlex_test.exs | 10 ++++++++++ 6 files changed, 16 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 8db6445e97..0ebd7af756 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -196,6 +196,7 @@ defmodule Candlex.Backend do :floor, :log, :negate, + :round, :sin, :sqrt, :tan, diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index f6ba8a1bd2..9620942000 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -28,6 +28,7 @@ defmodule Candlex.Native do :floor, :log, :negate, + :round, :sin, :sqrt, :tan, diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index c090052b6b..3ca9844edc 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -55,6 +55,7 @@ rustler::init! { tensors::sin, tensors::exp, tensors::floor, + tensors::round, tensors::log, tensors::sqrt, tensors::tan, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 3455f1cd09..60f4653a91 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -43,4 +43,5 @@ custom_unary_op!(Atan, "atan", atan); custom_unary_op!(Cbrt, "cbrt", cbrt); custom_unary_op!(Ceil, "ceil", ceil); custom_unary_op!(Floor, "floor", floor); +custom_unary_op!(Round, "round", round); custom_unary_op!(Tan, "tan", tan); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 3e7c828cae..2a97cf2af6 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,6 +1,6 @@ use crate::atoms; use crate::error::CandlexError; -use crate::ops::{Acos, Asin, Atan, Cbrt, Ceil, Floor, Tan}; +use crate::ops::{Acos, Asin, Atan, Cbrt, Ceil, Floor, Round, Tan}; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; use rustler::{Atom, Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; @@ -211,6 +211,7 @@ custom_unary_nif!(atan, Atan); custom_unary_nif!(cbrt, Cbrt); custom_unary_nif!(ceil, Ceil); custom_unary_nif!(floor, Floor); +custom_unary_nif!(round, Round); custom_unary_nif!(tan, Tan); binary_nif!(add, broadcast_add); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index ff93325623..bf08c04242 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -596,6 +596,16 @@ defmodule CandlexTest do |> assert_equal(t([-2.0, -1.0, 0.0, 1.0])) end + test "round" do + t([-1, 0, 1]) + |> Nx.round() + |> assert_equal(t([-1, 0, 1])) + + t([-1.5, -0.5, 0.5, 1.5]) + |> Nx.round() + |> assert_equal(t([-2.0, -1.0, 1.0, 2.0])) + end + test "cbrt" do Nx.cbrt(1.0) |> assert_equal(t(1.0)) From 587a41008f057d0429206485f34ebe962f5f9b49 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 1 Sep 2023 18:06:22 -0300 Subject: [PATCH 091/185] log1p --- candlex/lib/candlex/backend.ex | 1 + candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 1 + candlex/native/candlex/src/tensors.rs | 3 ++- candlex/test/candlex_test.exs | 9 +++++++++ 6 files changed, 15 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 0ebd7af756..bb12bb90ea 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -195,6 +195,7 @@ defmodule Candlex.Backend do :exp, :floor, :log, + :log1p, :negate, :round, :sin, diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 9620942000..abff3401ef 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -27,6 +27,7 @@ defmodule Candlex.Native do :exp, :floor, :log, + :log1p, :negate, :round, :sin, diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 3ca9844edc..9a85bbc0ff 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -57,6 +57,7 @@ rustler::init! { tensors::floor, tensors::round, tensors::log, + tensors::log1p, tensors::sqrt, tensors::tan, tensors::tanh, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 60f4653a91..ed6348fcff 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -43,5 +43,6 @@ custom_unary_op!(Atan, "atan", atan); custom_unary_op!(Cbrt, "cbrt", cbrt); custom_unary_op!(Ceil, "ceil", ceil); custom_unary_op!(Floor, "floor", floor); +custom_unary_op!(Log1p, "ln_1p", ln_1p); custom_unary_op!(Round, "round", round); custom_unary_op!(Tan, "tan", tan); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 2a97cf2af6..484a5c4500 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,6 +1,6 @@ use crate::atoms; use crate::error::CandlexError; -use crate::ops::{Acos, Asin, Atan, Cbrt, Ceil, Floor, Round, Tan}; +use crate::ops::{Acos, Asin, Atan, Cbrt, Ceil, Floor, Log1p, Round, Tan}; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; use rustler::{Atom, Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; @@ -211,6 +211,7 @@ custom_unary_nif!(atan, Atan); custom_unary_nif!(cbrt, Cbrt); custom_unary_nif!(ceil, Ceil); custom_unary_nif!(floor, Floor); +custom_unary_nif!(log1p, Log1p); custom_unary_nif!(round, Round); custom_unary_nif!(tan, Tan); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index bf08c04242..d088a55bf9 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -614,6 +614,15 @@ defmodule CandlexTest do |> Nx.cbrt() |> assert_equal(t([1.0, 1.2599210739135742, 1.4422495365142822])) end + + test "log1p" do + Nx.log1p(1.0) + |> assert_equal(t(0.6931471824645996)) + + t([1.0, 2, 3]) + |> Nx.log1p() + |> assert_equal(t([0.6931471824645996, 1.0986123085021973, 1.3862943649291992])) + end end defp t(values, opts \\ []) do From 42ef7b7983358cc445a948be210dca406f03b241 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 4 Sep 2023 12:59:20 -0300 Subject: [PATCH 092/185] bitwise_and/or/xor as CustomOp2 --- candlex/lib/candlex/backend.ex | 4 +- candlex/lib/candlex/native.ex | 3 ++ candlex/native/candlex/Cargo.lock | 2 +- candlex/native/candlex/Cargo.toml | 4 +- candlex/native/candlex/src/lib.rs | 3 ++ candlex/native/candlex/src/ops.rs | 53 ++++++++++++++++++++++++++- candlex/native/candlex/src/tensors.rs | 15 +++++++- candlex/test/candlex_test.exs | 51 ++++++++++++++++++++++++++ 8 files changed, 129 insertions(+), 6 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index bb12bb90ea..aed6d86efa 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -123,7 +123,7 @@ defmodule Candlex.Backend do # Element-wise - for op <- [:bitwise_and, :bitwise_or, :bitwise_xor, :left_shift, :right_shift] do + for op <- [:left_shift, :right_shift] do @impl true def unquote(op)(out, l, r) do unsupported_op(unquote(op)) @@ -169,7 +169,7 @@ defmodule Candlex.Backend do end end - for op <- [:equal, :greater_equal, :less, :less_equal] do + for op <- [:bitwise_and, :bitwise_or, :bitwise_xor, :equal, :greater_equal, :less, :less_equal] do @impl true def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do {left, right} = maybe_upcast(left, right) diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index abff3401ef..46f303efa3 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -40,6 +40,9 @@ defmodule Candlex.Native do for op <- [ :add, + :bitwise_and, + :bitwise_or, + :bitwise_xor, :equal, :greater_equal, :less, diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 2a78aca544..e75351215b 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -38,7 +38,7 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" version = "0.2.1" -source = "git+https://github.com/huggingface/candle#2c1df6bba1a2017b8b4aec87a725b5b06b48cdab" +source = "git+https://github.com/mimiquate/candle?branch=custom-op-2#d3dbcc513c89dc6023cbbe558d3e77e9fce433d2" dependencies = [ "byteorder", "candle-gemm", diff --git a/candlex/native/candlex/Cargo.toml b/candlex/native/candlex/Cargo.toml index 91c2a71661..839a275879 100644 --- a/candlex/native/candlex/Cargo.toml +++ b/candlex/native/candlex/Cargo.toml @@ -10,7 +10,9 @@ path = "src/lib.rs" crate-type = ["cdylib"] [dependencies] -candle-core = { git = "https://github.com/huggingface/candle" } +# TODO: Uncomment back when https://github.com/huggingface/candle/pull/741 merged +# candle-core = { git = "https://github.com/huggingface/candle" } +candle-core = { git = "https://github.com/mimiquate/candle", branch = "custom-op-2" } half = "2.3.1" num-traits = "0.2.16" rustler = "0.29.1" diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 9a85bbc0ff..eda48236e1 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -61,6 +61,9 @@ rustler::init! { tensors::sqrt, tensors::tan, tensors::tanh, + tensors::bitwise_and, + tensors::bitwise_or, + tensors::bitwise_xor, devices::is_cuda_available ], load = load diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index ed6348fcff..153eef204f 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -1,4 +1,4 @@ -use candle_core::{CpuStorage, CustomOp1, Error, Layout, Shape}; +use candle_core::{CpuStorage, CustomOp1, CustomOp2, Error, Layout, Shape}; use num_traits::Float; macro_rules! custom_unary_op { @@ -37,6 +37,53 @@ macro_rules! custom_unary_op { }; } +macro_rules! custom_binary_op { + ($struct_name:ident, $name:literal, $closure:expr) => { + pub(crate) struct $struct_name; + + impl CustomOp2 for $struct_name { + fn name(&self) -> &'static str { + $name + } + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + ) -> Result<(CpuStorage, Shape), candle_core::Error> { + use candle_core::backend::BackendStorage; + + // let storage = candle_core::map_dtype!( + // $name, + // s1, + // |vec1| candle_core::cpu_backend::binary_map(l1, l2, vec1, s2, |v1, v2| op_wrapper(v1, v2)), + // (U8, U32, I64) + // ); + + match (s1, s2) { + (CpuStorage::I64(lhs), CpuStorage::I64(rhs)) => { + let data = candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, $closure); + + Ok((CpuStorage::I64(data), l1.shape().clone())) + } + _ => { + Err(Error::DTypeMismatchBinaryOp { + lhs: s1.dtype(), + rhs: s2.dtype(), + op: self.name(), + } + .bt()) + } + } + } + } + } +} + custom_unary_op!(Acos, "acos", acos); custom_unary_op!(Asin, "asin", asin); custom_unary_op!(Atan, "atan", atan); @@ -46,3 +93,7 @@ custom_unary_op!(Floor, "floor", floor); custom_unary_op!(Log1p, "ln_1p", ln_1p); custom_unary_op!(Round, "round", round); custom_unary_op!(Tan, "tan", tan); + +custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2); +custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2); +custom_binary_op!(BitXor, "bit_xor", |v1, v2| v1 ^ v2); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 484a5c4500..538816526c 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,6 +1,6 @@ use crate::atoms; use crate::error::CandlexError; -use crate::ops::{Acos, Asin, Atan, Cbrt, Ceil, Floor, Log1p, Round, Tan}; +use crate::ops::{Acos, Asin, Atan, BitAnd, BitOr, BitXor, Cbrt, Ceil, Floor, Log1p, Round, Tan}; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; use rustler::{Atom, Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; @@ -196,6 +196,15 @@ macro_rules! custom_unary_nif { }; } +macro_rules! custom_binary_nif { + ($nif_name:ident, $custom_op_name:ident) => { + #[rustler::nif(schedule = "DirtyCpu")] + pub fn $nif_name(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new(left.apply_op2_no_bwd(right.deref(), &$custom_op_name)?)) + } + }; +} + unary_nif!(negate, neg); unary_nif!(abs); unary_nif!(cos); @@ -226,6 +235,10 @@ binary_nif!(less, lt); binary_nif!(less_equal, le); binary_nif!(matmul, broadcast_matmul); +custom_binary_nif!(bitwise_and, BitAnd); +custom_binary_nif!(bitwise_or, BitOr); +custom_binary_nif!(bitwise_xor, BitXor); + fn tuple_to_vec(term: Term) -> Result, rustler::Error> { Ok(rustler::types::tuple::get_tuple(term)? .iter() diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index d088a55bf9..ca209b65da 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -623,6 +623,57 @@ defmodule CandlexTest do |> Nx.log1p() |> assert_equal(t([0.6931471824645996, 1.0986123085021973, 1.3862943649291992])) end + + test "bitwise_and" do + Nx.bitwise_and(1, 0) + |> assert_equal(t(0)) + + t([0, 1, 2]) + |> Nx.bitwise_and(1) + |> assert_equal(t([0, 1, 0])) + + t([0, -1, -2]) + |> Nx.bitwise_and(-1) + |> assert_equal(t([0, -1, -2])) + + t([0, 0, 1, 1]) + |> Nx.bitwise_and(t([0, 1, 0, 1])) + |> assert_equal(t([0, 0, 0, 1])) + end + + test "bitwise_or" do + Nx.bitwise_or(1, 0) + |> assert_equal(t(1)) + + t([0, 1, 2]) + |> Nx.bitwise_or(1) + |> assert_equal(t([1, 1, 3])) + + t([0, -1, -2]) + |> Nx.bitwise_or(-1) + |> assert_equal(t([-1, -1, -1])) + + t([0, 0, 1, 1]) + |> Nx.bitwise_or(t([0, 1, 0, 1])) + |> assert_equal(t([0, 1, 1, 1])) + end + + test "bitwise_xor" do + Nx.bitwise_xor(1, 0) + |> assert_equal(t(1)) + + t([1, 2, 3]) + |> Nx.bitwise_xor(2) + |> assert_equal(t([3, 0, 1])) + + t([-1, -2, -3]) + |> Nx.bitwise_xor(2) + |> assert_equal(t([-3, -4, -1])) + + t([0, 0, 1, 1]) + |> Nx.bitwise_xor(t([0, 1, 0, 1])) + |> assert_equal(t([0, 1, 1, 0])) + end end defp t(values, opts \\ []) do From c9987e01d113215811cae435cbb42bf9009f7f66 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 4 Sep 2023 16:59:37 -0300 Subject: [PATCH 093/185] candlex right/left_shift --- candlex/lib/candlex/backend.ex | 9 +-------- candlex/lib/candlex/native.ex | 2 ++ candlex/native/candlex/src/lib.rs | 2 ++ candlex/native/candlex/src/ops.rs | 2 ++ candlex/native/candlex/src/tensors.rs | 4 +++- candlex/test/candlex_test.exs | 26 ++++++++++++++++++++++++++ 6 files changed, 36 insertions(+), 9 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index aed6d86efa..4fa2708fb9 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -123,13 +123,6 @@ defmodule Candlex.Backend do # Element-wise - for op <- [:left_shift, :right_shift] do - @impl true - def unquote(op)(out, l, r) do - unsupported_op(unquote(op)) - end - end - @impl true def select(%T{shape: shape, type: type} = out, pred, on_true, on_false) do on_true = @@ -169,7 +162,7 @@ defmodule Candlex.Backend do end end - for op <- [:bitwise_and, :bitwise_or, :bitwise_xor, :equal, :greater_equal, :less, :less_equal] do + for op <- [:bitwise_and, :bitwise_or, :bitwise_xor, :equal, :greater_equal, :left_shift, :less, :less_equal, :right_shift] do @impl true def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do {left, right} = maybe_upcast(left, right) diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 46f303efa3..c71984f365 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -45,12 +45,14 @@ defmodule Candlex.Native do :bitwise_xor, :equal, :greater_equal, + :left_shift, :less, :less_equal, :matmul, :max, :min, :multiply, + :right_shift, :subtract ] do def unquote(op)(_left, _right), do: error() diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index eda48236e1..f15e4ec0e2 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -64,6 +64,8 @@ rustler::init! { tensors::bitwise_and, tensors::bitwise_or, tensors::bitwise_xor, + tensors::left_shift, + tensors::right_shift, devices::is_cuda_available ], load = load diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 153eef204f..f25b884d97 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -97,3 +97,5 @@ custom_unary_op!(Tan, "tan", tan); custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2); custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2); custom_binary_op!(BitXor, "bit_xor", |v1, v2| v1 ^ v2); +custom_binary_op!(Shl, "shl", |v1, v2| v1 << v2); +custom_binary_op!(Shr, "shr", |v1, v2| v1 >> v2); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 538816526c..4044285f7d 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,6 +1,6 @@ use crate::atoms; use crate::error::CandlexError; -use crate::ops::{Acos, Asin, Atan, BitAnd, BitOr, BitXor, Cbrt, Ceil, Floor, Log1p, Round, Tan}; +use crate::ops::{Acos, Asin, Atan, BitAnd, BitOr, BitXor, Cbrt, Ceil, Floor, Shl, Shr, Log1p, Round, Tan}; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; use rustler::{Atom, Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; @@ -238,6 +238,8 @@ binary_nif!(matmul, broadcast_matmul); custom_binary_nif!(bitwise_and, BitAnd); custom_binary_nif!(bitwise_or, BitOr); custom_binary_nif!(bitwise_xor, BitXor); +custom_binary_nif!(left_shift, Shl); +custom_binary_nif!(right_shift, Shr); fn tuple_to_vec(term: Term) -> Result, rustler::Error> { Ok(rustler::types::tuple::get_tuple(term)? diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index ca209b65da..66dfcc0b2e 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -674,6 +674,32 @@ defmodule CandlexTest do |> Nx.bitwise_xor(t([0, 1, 0, 1])) |> assert_equal(t([0, 1, 1, 0])) end + + test "left_shift" do + Nx.left_shift(1, 0) + |> assert_equal(t(1)) + + t([1, 2, 3]) + |> Nx.left_shift(2) + |> assert_equal(t([4, 8, 12])) + + t([1, 1, -1, -1]) + |> Nx.left_shift(t([1, 2, 3, 4])) + |> assert_equal(t([2, 4, -8, -16])) + end + + test "right_shift" do + Nx.right_shift(1, 0) + |> assert_equal(t(1)) + + t([2, 4, 8]) + |> Nx.right_shift(2) + |> assert_equal(t([0, 1, 2])) + + t([16, 32, -64, -128]) + |> Nx.right_shift(t([1, 2, 3, 4])) + |> assert_equal(t([8, 8, -8, -8])) + end end defp t(values, opts \\ []) do From 932375498a5877468ab6f68b41aa498a0c99e1fb Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 4 Sep 2023 17:14:24 -0300 Subject: [PATCH 094/185] support u32 for custom binary ops --- candlex/native/candlex/src/ops.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index f25b884d97..ffb786fd0e 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -65,6 +65,11 @@ macro_rules! custom_binary_op { // ); match (s1, s2) { + (CpuStorage::U32(lhs), CpuStorage::U32(rhs)) => { + let data = candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, $closure); + + Ok((CpuStorage::U32(data), l1.shape().clone())) + } (CpuStorage::I64(lhs), CpuStorage::I64(rhs)) => { let data = candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, $closure); From 23b8a068d9a4f3bec600d7f886fa229435642547 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 4 Sep 2023 17:36:29 -0300 Subject: [PATCH 095/185] removes repeated/commented tests --- candlex/test/candlex_test.exs | 56 ----------------------------------- 1 file changed, 56 deletions(-) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 66dfcc0b2e..f8057dcabd 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -171,62 +171,6 @@ defmodule CandlexTest do ) end - # test "bitwise_and" do - # Nx.bitwise_and(1, 0) - # |> assert_equal(t(0)) - - # t([0, 1, 2]) - # |> Nx.bitwise_and(1) - # |> assert_equal(t([0, 1, 0])) - - # t([0, -1, -2]) - # |> Nx.bitwise_and(-1) - # |> assert_equal(t([0, -1, -2])) - - # t([0, 0, 1, 1]) - # |> Nx.bitwise_and(t([0, 1, 0, 1])) - # |> assert_equal(t([0, 0, 0, 1])) - # end - - # test "bitwise_not" do - # Nx.bitwise_not(1) - # |> assert_equal(t(-2)) - # end - - # test "bitwise_or" do - # Nx.bitwise_or(1, 0) - # |> assert_equal(t(1)) - - # t([0, 1, 2]) - # |> Nx.bitwise_or(1) - # |> assert_equal(t([1, 1, 3])) - - # t([0, -1, -2]) - # |> Nx.bitwise_or(-1) - # |> assert_equal(t([-1, -1, -1])) - - # t([0, 0, 1, 1]) - # |> Nx.bitwise_or(t([0, 1, 0, 1])) - # |> assert_equal(t([0, 1, 1, 1])) - # end - - # test "bitwise_xor" do - # Nx.bitwise_xor(1, 0) - # |> assert_equal(t(1)) - - # t([1, 2, 3]) - # |> Nx.bitwise_xor(2) - # |> assert_equal(t([3, 0, 1])) - - # t([-1, -2, -3]) - # |> Nx.bitwise_xor(2) - # |> assert_equal(t([-3, -4, -1])) - - # t([0, 0, 1, 1]) - # |> Nx.bitwise_xor(t([0, 1, 0, 1])) - # |> assert_equal(t([0, 1, 1, 0])) - # end - test "less" do Nx.less(1, 2) |> assert_equal(t(1)) From 7941d81810d2342d379d0ab78fee1c90452f32ad Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 4 Sep 2023 17:38:46 -0300 Subject: [PATCH 096/185] candlex bitwise_not --- candlex/lib/candlex/backend.ex | 3 +- candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 52 +++++++++++++++++++++++++++ candlex/native/candlex/src/tensors.rs | 3 +- candlex/test/candlex_test.exs | 13 +++++++ 6 files changed, 71 insertions(+), 2 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 4fa2708fb9..711a6447aa 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -182,6 +182,7 @@ defmodule Candlex.Backend do :acos, :asin, :atan, + :bitwise_not, :cbrt, :ceil, :cos, @@ -206,7 +207,7 @@ defmodule Candlex.Backend do end end - for op <- [:bitwise_not, :erf_inv] do + for op <- [:erf_inv] do @impl true def unquote(op)(out, t) do unsupported_op(unquote(op)) diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index c71984f365..1825c08fb8 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -21,6 +21,7 @@ defmodule Candlex.Native do :acos, :asin, :atan, + :bitwise_not, :cbrt, :ceil, :cos, diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index f15e4ec0e2..eaa3050b96 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -61,6 +61,7 @@ rustler::init! { tensors::sqrt, tensors::tan, tensors::tanh, + tensors::bitwise_not, tensors::bitwise_and, tensors::bitwise_or, tensors::bitwise_xor, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index ffb786fd0e..1cdce39952 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -37,6 +37,57 @@ macro_rules! custom_unary_op { }; } +macro_rules! custom_unary_op_closure { + ($struct_name:ident, $name:expr, $closure:expr) => { + pub(crate) struct $struct_name; + + impl CustomOp1 for $struct_name { + // Box does not support const yet, so use a function to get the name. + fn name(&self) -> &'static str { + $name + } + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + storage: &CpuStorage, + layout: &Layout, + ) -> Result<(CpuStorage, Shape), candle_core::Error> { + use candle_core::backend::BackendStorage; + + // TODO: Find a way to make map_dtype! play well with inferred closure + // params types? + // + // let storage = candle_core::map_dtype!( + // $name, + // storage, + // |vec| candle_core::cpu_backend::unary_map(vec, layout, $closure), + // (U8, U32, I64) + // ); + + // Ok((storage, layout.shape().clone())) + + match storage { + CpuStorage::U8(vec) => { + let data = candle_core::cpu_backend::unary_map(vec, layout, $closure); + Ok((CpuStorage::U8(data), layout.shape().clone())) + } + CpuStorage::U32(vec) => { + let data = candle_core::cpu_backend::unary_map(vec, layout, $closure); + Ok((CpuStorage::U32(data), layout.shape().clone())) + } + CpuStorage::I64(vec) => { + let data = candle_core::cpu_backend::unary_map(vec, layout, $closure); + Ok((CpuStorage::I64(data), layout.shape().clone())) + } + s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())? + } + } + } + }; +} + macro_rules! custom_binary_op { ($struct_name:ident, $name:literal, $closure:expr) => { pub(crate) struct $struct_name; @@ -98,6 +149,7 @@ custom_unary_op!(Floor, "floor", floor); custom_unary_op!(Log1p, "ln_1p", ln_1p); custom_unary_op!(Round, "round", round); custom_unary_op!(Tan, "tan", tan); +custom_unary_op_closure!(BitNot, "bit_not", |v| !v); custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2); custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 4044285f7d..6040759ffa 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,6 +1,6 @@ use crate::atoms; use crate::error::CandlexError; -use crate::ops::{Acos, Asin, Atan, BitAnd, BitOr, BitXor, Cbrt, Ceil, Floor, Shl, Shr, Log1p, Round, Tan}; +use crate::ops::{Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, Ceil, Floor, Shl, Shr, Log1p, Round, Tan}; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; use rustler::{Atom, Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; @@ -217,6 +217,7 @@ unary_nif!(tanh); custom_unary_nif!(acos, Acos); custom_unary_nif!(asin, Asin); custom_unary_nif!(atan, Atan); +custom_unary_nif!(bitwise_not, BitNot); custom_unary_nif!(cbrt, Cbrt); custom_unary_nif!(ceil, Ceil); custom_unary_nif!(floor, Floor); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index f8057dcabd..6c4eb63801 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -619,6 +619,19 @@ defmodule CandlexTest do |> assert_equal(t([0, 1, 1, 0])) end + test "bitwise_not" do + Nx.bitwise_not(1) + |> assert_equal(t(-2)) + + t([-1, 0, 1]) + |> Nx.bitwise_not() + |> assert_equal(t([0, -1, -2])) + + t([0, 1, 254, 255], type: :u8) + |> Nx.bitwise_not() + |> assert_equal(t([255, 254, 1, 0])) + end + test "left_shift" do Nx.left_shift(1, 0) |> assert_equal(t(1)) From e2415f6def49039fa5087e1bcfe179eb453fa7ce Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 5 Sep 2023 09:24:42 -0300 Subject: [PATCH 097/185] point back to candle upstream --- candlex/native/candlex/Cargo.lock | 2 +- candlex/native/candlex/Cargo.toml | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index e75351215b..02b38b57b1 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -38,7 +38,7 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" version = "0.2.1" -source = "git+https://github.com/mimiquate/candle?branch=custom-op-2#d3dbcc513c89dc6023cbbe558d3e77e9fce433d2" +source = "git+https://github.com/huggingface/candle#a8410bf35ea3ad8eb973f48d301e65309d232377" dependencies = [ "byteorder", "candle-gemm", diff --git a/candlex/native/candlex/Cargo.toml b/candlex/native/candlex/Cargo.toml index 839a275879..91c2a71661 100644 --- a/candlex/native/candlex/Cargo.toml +++ b/candlex/native/candlex/Cargo.toml @@ -10,9 +10,7 @@ path = "src/lib.rs" crate-type = ["cdylib"] [dependencies] -# TODO: Uncomment back when https://github.com/huggingface/candle/pull/741 merged -# candle-core = { git = "https://github.com/huggingface/candle" } -candle-core = { git = "https://github.com/mimiquate/candle", branch = "custom-op-2" } +candle-core = { git = "https://github.com/huggingface/candle" } half = "2.3.1" num-traits = "0.2.16" rustler = "0.29.1" From 2c650b50aa451a8e1857ee8fb2324c3002e1ca7d Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 5 Sep 2023 17:45:03 -0300 Subject: [PATCH 098/185] candlex check nx vs candle dtype --- candlex/lib/candlex/backend.ex | 16 +++++++++++++++- candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 5 +++++ 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 711a6447aa..3e2c07e25c 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -428,7 +428,13 @@ defmodule Candlex.Backend do |> from_nx() end - defp to_nx(%{resource: ref} = backend_tensor, %T{} = t) when is_reference(ref) do + defp to_nx(%CB{resource: ref} = backend_tensor, %T{type: nx_type} = t) when is_reference(ref) do + {:ok, candle_dtype} = Native.dtype(backend_tensor) + + if nx_type != from_candle_dtype(candle_dtype) do + raise "tensor type mismatch, Nx (#{inspect(nx_type)}) and Candle (#{inspect(candle_dtype)})" + end + %{t | data: backend_tensor} end @@ -447,6 +453,14 @@ defmodule Candlex.Backend do defp to_candle_dtype({:c, 64}), do: unsupported_dtype() defp to_candle_dtype({:c, 128}), do: unsupported_dtype() + defp from_candle_dtype("i64"), do: {:s, 64} + defp from_candle_dtype("u8"), do: {:u, 8} + defp from_candle_dtype("u32"), do: {:u, 32} + defp from_candle_dtype("f16"), do: {:f, 16} + defp from_candle_dtype("bf16"), do: {:bf, 16} + defp from_candle_dtype("f32"), do: {:f, 32} + defp from_candle_dtype("f64"), do: {:f, 64} + defp device_option(nil) do default_device() end diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 1825c08fb8..0d687ed2bf 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -14,6 +14,7 @@ defmodule Candlex.Native do def broadcast_to(_tensor, _shape), do: error() def reshape(_tensor, _shape), do: error() def to_type(_tensor, _dtype), do: error() + def dtype(_tensor), do: error() def concatenate(_tensors, _axis), do: error() for op <- [ diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index eaa3050b96..17fb27a750 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -33,6 +33,7 @@ rustler::init! { tensors::less_equal, tensors::subtract, tensors::all, + tensors::dtype, tensors::argmax, tensors::argmin, tensors::negate, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 6040759ffa..d01091399a 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -157,6 +157,11 @@ pub fn to_type(t: ExTensor, dtype_str: &str) -> Result { )) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn dtype(t: ExTensor) -> Result<&'static str, CandlexError> { + Ok(t.dtype().as_str()) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn concatenate(ex_tensors: Vec, dim: usize) -> Result { let tensors = ex_tensors From b467c727c38695877e68fbfe0a6e2dd55a4341cb Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:24:15 -0300 Subject: [PATCH 099/185] candlex is_infinity --- candlex/lib/candlex/backend.ex | 1 + candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 34 +++++++++++++++++++++++++++ candlex/native/candlex/src/tensors.rs | 3 ++- candlex/test/candlex_test.exs | 15 ++++++++++++ 6 files changed, 54 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 3e2c07e25c..de7bf0238f 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -188,6 +188,7 @@ defmodule Candlex.Backend do :cos, :exp, :floor, + :is_infinity, :log, :log1p, :negate, diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 0d687ed2bf..20d271eb6a 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -28,6 +28,7 @@ defmodule Candlex.Native do :cos, :exp, :floor, + :is_infinity, :log, :log1p, :negate, diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 17fb27a750..497294bf78 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -56,6 +56,7 @@ rustler::init! { tensors::sin, tensors::exp, tensors::floor, + tensors::is_infinity, tensors::round, tensors::log, tensors::log1p, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 1cdce39952..7b8869f643 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -37,6 +37,39 @@ macro_rules! custom_unary_op { }; } +macro_rules! custom_unary_bool_op { + ($struct_name:ident, $name:expr, $fn_name:ident, ($($dtypes:ident),+)) => { + pub(crate) struct $struct_name; + + impl CustomOp1 for $struct_name { + // Box does not support const yet, so use a function to get the name. + fn name(&self) -> &'static str { + $name + } + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + storage: &CpuStorage, + layout: &Layout, + ) -> Result<(CpuStorage, Shape), candle_core::Error> { + use candle_core::backend::BackendStorage; + + match storage { + $( + CpuStorage::$dtypes(vec) => { + let data = candle_core::cpu_backend::unary_map(vec, layout, |v| u8::from(v.$fn_name())); + Ok((CpuStorage::U8(data), layout.shape().clone())) + } + )* + s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())? + } + } + } + }; +} + macro_rules! custom_unary_op_closure { ($struct_name:ident, $name:expr, $closure:expr) => { pub(crate) struct $struct_name; @@ -149,6 +182,7 @@ custom_unary_op!(Floor, "floor", floor); custom_unary_op!(Log1p, "ln_1p", ln_1p); custom_unary_op!(Round, "round", round); custom_unary_op!(Tan, "tan", tan); +custom_unary_bool_op!(IsInf, "is_inf", is_infinite, (F32, F64)); custom_unary_op_closure!(BitNot, "bit_not", |v| !v); custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index d01091399a..3d2d3282cf 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,6 +1,6 @@ use crate::atoms; use crate::error::CandlexError; -use crate::ops::{Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, Ceil, Floor, Shl, Shr, Log1p, Round, Tan}; +use crate::ops::{Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, Ceil, Floor, IsInf, Shl, Shr, Log1p, Round, Tan}; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; use rustler::{Atom, Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; @@ -226,6 +226,7 @@ custom_unary_nif!(bitwise_not, BitNot); custom_unary_nif!(cbrt, Cbrt); custom_unary_nif!(ceil, Ceil); custom_unary_nif!(floor, Floor); +custom_unary_nif!(is_infinity, IsInf); custom_unary_nif!(log1p, Log1p); custom_unary_nif!(round, Round); custom_unary_nif!(tan, Tan); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 6c4eb63801..a4629524b9 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -657,6 +657,21 @@ defmodule CandlexTest do |> Nx.right_shift(t([1, 2, 3, 4])) |> assert_equal(t([8, 8, -8, -8])) end + + test "is_infinity" do + t([:infinity, :nan, :neg_infinity, 1, 0]) + |> Nx.is_infinity() + |> assert_equal(t([1, 0, 1, 0, 0])) + + t([:infinity, 1, :neg_infinity]) + |> Nx.is_infinity() + |> assert_equal(t([1, 0, 1])) + + # TODO: Not supported for :s64 + # t([1, 0]) + # |> Nx.is_infinity() + # |> assert_equal(t([0, 0])) + end end defp t(values, opts \\ []) do From ab244b834af3f3a5a327330eb75f316649ac9528 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:30:28 -0300 Subject: [PATCH 100/185] more flexible op macros --- candlex/native/candlex/src/ops.rs | 109 +++++++++++------------------- 1 file changed, 39 insertions(+), 70 deletions(-) diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 7b8869f643..5f814ff978 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -2,13 +2,9 @@ use candle_core::{CpuStorage, CustomOp1, CustomOp2, Error, Layout, Shape}; use num_traits::Float; macro_rules! custom_unary_op { - ($struct_name:ident, $name:expr, $fn_name:ident) => { + ($struct_name:ident, $name:expr, $fn_name:ident, ($($dtypes:ident),+)) => { pub(crate) struct $struct_name; - fn $fn_name(value: T) -> T { - value.$fn_name() - } - impl CustomOp1 for $struct_name { // Box does not support const yet, so use a function to get the name. fn name(&self) -> &'static str { @@ -24,14 +20,15 @@ macro_rules! custom_unary_op { ) -> Result<(CpuStorage, Shape), candle_core::Error> { use candle_core::backend::BackendStorage; - let storage = candle_core::map_dtype!( - $name, - storage, - |vec| candle_core::cpu_backend::unary_map(vec, layout, |v| $fn_name(v)), - (BF16, F16, F32, F64) - ); - - Ok((storage, layout.shape().clone())) + match storage { + $( + CpuStorage::$dtypes(vec) => { + let data = candle_core::cpu_backend::unary_map(vec, layout, |v| v.$fn_name()); + Ok((CpuStorage::$dtypes(data), layout.shape().clone())) + } + )* + s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())? + } } } }; @@ -71,7 +68,7 @@ macro_rules! custom_unary_bool_op { } macro_rules! custom_unary_op_closure { - ($struct_name:ident, $name:expr, $closure:expr) => { + ($struct_name:ident, $name:expr, $closure:expr, ($($dtypes:ident),+)) => { pub(crate) struct $struct_name; impl CustomOp1 for $struct_name { @@ -89,31 +86,13 @@ macro_rules! custom_unary_op_closure { ) -> Result<(CpuStorage, Shape), candle_core::Error> { use candle_core::backend::BackendStorage; - // TODO: Find a way to make map_dtype! play well with inferred closure - // params types? - // - // let storage = candle_core::map_dtype!( - // $name, - // storage, - // |vec| candle_core::cpu_backend::unary_map(vec, layout, $closure), - // (U8, U32, I64) - // ); - - // Ok((storage, layout.shape().clone())) - match storage { - CpuStorage::U8(vec) => { - let data = candle_core::cpu_backend::unary_map(vec, layout, $closure); - Ok((CpuStorage::U8(data), layout.shape().clone())) - } - CpuStorage::U32(vec) => { - let data = candle_core::cpu_backend::unary_map(vec, layout, $closure); - Ok((CpuStorage::U32(data), layout.shape().clone())) - } - CpuStorage::I64(vec) => { - let data = candle_core::cpu_backend::unary_map(vec, layout, $closure); - Ok((CpuStorage::I64(data), layout.shape().clone())) - } + $( + CpuStorage::$dtypes(vec) => { + let data = candle_core::cpu_backend::unary_map(vec, layout, $closure); + Ok((CpuStorage::$dtypes(data), layout.shape().clone())) + } + )* s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())? } } @@ -122,7 +101,7 @@ macro_rules! custom_unary_op_closure { } macro_rules! custom_binary_op { - ($struct_name:ident, $name:literal, $closure:expr) => { + ($struct_name:ident, $name:literal, $closure:expr, ($($dtypes:ident),+)) => { pub(crate) struct $struct_name; impl CustomOp2 for $struct_name { @@ -141,24 +120,14 @@ macro_rules! custom_binary_op { ) -> Result<(CpuStorage, Shape), candle_core::Error> { use candle_core::backend::BackendStorage; - // let storage = candle_core::map_dtype!( - // $name, - // s1, - // |vec1| candle_core::cpu_backend::binary_map(l1, l2, vec1, s2, |v1, v2| op_wrapper(v1, v2)), - // (U8, U32, I64) - // ); - match (s1, s2) { - (CpuStorage::U32(lhs), CpuStorage::U32(rhs)) => { - let data = candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, $closure); - - Ok((CpuStorage::U32(data), l1.shape().clone())) - } - (CpuStorage::I64(lhs), CpuStorage::I64(rhs)) => { - let data = candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, $closure); + $( + (CpuStorage::$dtypes(lhs), CpuStorage::$dtypes(rhs)) => { + let data = candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, $closure); - Ok((CpuStorage::I64(data), l1.shape().clone())) - } + Ok((CpuStorage::$dtypes(data), l1.shape().clone())) + } + )* _ => { Err(Error::DTypeMismatchBinaryOp { lhs: s1.dtype(), @@ -173,20 +142,20 @@ macro_rules! custom_binary_op { } } -custom_unary_op!(Acos, "acos", acos); -custom_unary_op!(Asin, "asin", asin); -custom_unary_op!(Atan, "atan", atan); -custom_unary_op!(Cbrt, "cbrt", cbrt); -custom_unary_op!(Ceil, "ceil", ceil); -custom_unary_op!(Floor, "floor", floor); -custom_unary_op!(Log1p, "ln_1p", ln_1p); -custom_unary_op!(Round, "round", round); -custom_unary_op!(Tan, "tan", tan); +custom_unary_op!(Acos, "acos", acos, (BF16, F16, F32, F64)); +custom_unary_op!(Asin, "asin", asin, (BF16, F16, F32, F64)); +custom_unary_op!(Atan, "atan", atan, (BF16, F16, F32, F64)); +custom_unary_op!(Cbrt, "cbrt", cbrt, (BF16, F16, F32, F64)); +custom_unary_op!(Ceil, "ceil", ceil, (BF16, F16, F32, F64)); +custom_unary_op!(Floor, "floor", floor, (BF16, F16, F32, F64)); +custom_unary_op!(Log1p, "ln_1p", ln_1p, (BF16, F16, F32, F64)); +custom_unary_op!(Round, "round", round, (BF16, F16, F32, F64)); +custom_unary_op!(Tan, "tan", tan, (BF16, F16, F32, F64)); custom_unary_bool_op!(IsInf, "is_inf", is_infinite, (F32, F64)); -custom_unary_op_closure!(BitNot, "bit_not", |v| !v); +custom_unary_op_closure!(BitNot, "bit_not", |v| !v, (U8, U32, I64)); -custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2); -custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2); -custom_binary_op!(BitXor, "bit_xor", |v1, v2| v1 ^ v2); -custom_binary_op!(Shl, "shl", |v1, v2| v1 << v2); -custom_binary_op!(Shr, "shr", |v1, v2| v1 >> v2); +custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2, (U32, I64)); +custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2, (U32, I64)); +custom_binary_op!(BitXor, "bit_xor", |v1, v2| v1 ^ v2, (U32, I64)); +custom_binary_op!(Shl, "shl", |v1, v2| v1 << v2, (U32, I64)); +custom_binary_op!(Shr, "shr", |v1, v2| v1 >> v2, (U32, I64)); From c86fc512f4406a2e2ccc506879bd9880bcd59c21 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:46:29 -0300 Subject: [PATCH 101/185] mix format --- candlex/lib/candlex/backend.ex | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index de7bf0238f..c5d45e24f9 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -162,7 +162,17 @@ defmodule Candlex.Backend do end end - for op <- [:bitwise_and, :bitwise_or, :bitwise_xor, :equal, :greater_equal, :left_shift, :less, :less_equal, :right_shift] do + for op <- [ + :bitwise_and, + :bitwise_or, + :bitwise_xor, + :equal, + :greater_equal, + :left_shift, + :less, + :less_equal, + :right_shift + ] do @impl true def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do {left, right} = maybe_upcast(left, right) From 3d0b46bda899f33ca1ee5aedddc43b1ad5f3345b Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:46:49 -0300 Subject: [PATCH 102/185] cargo fmt --- candlex/native/candlex/src/tensors.rs | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 3d2d3282cf..04352f68ad 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,6 +1,9 @@ use crate::atoms; use crate::error::CandlexError; -use crate::ops::{Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, Ceil, Floor, IsInf, Shl, Shr, Log1p, Round, Tan}; +use crate::ops::{ + Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, Ceil, Floor, IsInf, Log1p, Round, Shl, + Shr, Tan, +}; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; use rustler::{Atom, Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; @@ -205,7 +208,9 @@ macro_rules! custom_binary_nif { ($nif_name:ident, $custom_op_name:ident) => { #[rustler::nif(schedule = "DirtyCpu")] pub fn $nif_name(left: ExTensor, right: ExTensor) -> Result { - Ok(ExTensor::new(left.apply_op2_no_bwd(right.deref(), &$custom_op_name)?)) + Ok(ExTensor::new( + left.apply_op2_no_bwd(right.deref(), &$custom_op_name)?, + )) } }; } From e9a56982dc74323a1c9cc8191cb024dbf20e2c70 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 6 Sep 2023 11:42:28 -0300 Subject: [PATCH 103/185] candlex logical_or --- candlex/lib/candlex/backend.ex | 1 + candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 48 +++++++++++++++++++++++++++ candlex/native/candlex/src/tensors.rs | 5 +-- candlex/test/candlex_test.exs | 25 ++++++++++++++ 6 files changed, 79 insertions(+), 2 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index c5d45e24f9..c45b131b71 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -171,6 +171,7 @@ defmodule Candlex.Backend do :left_shift, :less, :less_equal, + :logical_or, :right_shift ] do @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 20d271eb6a..8026e74bfb 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -51,6 +51,7 @@ defmodule Candlex.Native do :left_shift, :less, :less_equal, + :logical_or, :matmul, :max, :min, diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 497294bf78..05f07bfc63 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -67,6 +67,7 @@ rustler::init! { tensors::bitwise_and, tensors::bitwise_or, tensors::bitwise_xor, + tensors::logical_or, tensors::left_shift, tensors::right_shift, devices::is_cuda_available diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 5f814ff978..b6bf21383e 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -142,6 +142,48 @@ macro_rules! custom_binary_op { } } +macro_rules! custom_binary_bool_op { + ($struct_name:ident, $name:literal, $closure:expr, ($($dtypes:ident),+)) => { + pub(crate) struct $struct_name; + + impl CustomOp2 for $struct_name { + fn name(&self) -> &'static str { + $name + } + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + ) -> Result<(CpuStorage, Shape), candle_core::Error> { + use candle_core::backend::BackendStorage; + + match (s1, s2) { + $( + (CpuStorage::$dtypes(lhs), CpuStorage::$dtypes(rhs)) => { + let data = candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, |v1, v2| u8::from($closure(v1, v2))); + + Ok((CpuStorage::U8(data), l1.shape().clone())) + } + )* + _ => { + Err(Error::DTypeMismatchBinaryOp { + lhs: s1.dtype(), + rhs: s2.dtype(), + op: self.name(), + } + .bt()) + } + } + } + } + } +} + custom_unary_op!(Acos, "acos", acos, (BF16, F16, F32, F64)); custom_unary_op!(Asin, "asin", asin, (BF16, F16, F32, F64)); custom_unary_op!(Atan, "atan", atan, (BF16, F16, F32, F64)); @@ -159,3 +201,9 @@ custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2, (U32, I64)); custom_binary_op!(BitXor, "bit_xor", |v1, v2| v1 ^ v2, (U32, I64)); custom_binary_op!(Shl, "shl", |v1, v2| v1 << v2, (U32, I64)); custom_binary_op!(Shr, "shr", |v1, v2| v1 >> v2, (U32, I64)); +custom_binary_bool_op!( + LogicalOr, + "logical_or", + |v1, v2| if v1 as i8 == 0 && v2 as i8 == 0 { 0 } else { 1 }, + (U8, U32, I64, F32, F64) +); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 04352f68ad..bcd67512cb 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,8 +1,8 @@ use crate::atoms; use crate::error::CandlexError; use crate::ops::{ - Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, Ceil, Floor, IsInf, Log1p, Round, Shl, - Shr, Tan, + Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, Ceil, Floor, IsInf, Log1p, LogicalOr, + Round, Shl, Shr, Tan, }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -251,6 +251,7 @@ custom_binary_nif!(bitwise_and, BitAnd); custom_binary_nif!(bitwise_or, BitOr); custom_binary_nif!(bitwise_xor, BitXor); custom_binary_nif!(left_shift, Shl); +custom_binary_nif!(logical_or, LogicalOr); custom_binary_nif!(right_shift, Shr); fn tuple_to_vec(term: Term) -> Result, rustler::Error> { diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index a4629524b9..016dd160ba 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -672,6 +672,31 @@ defmodule CandlexTest do # |> Nx.is_infinity() # |> assert_equal(t([0, 0])) end + + test "logical_or" do + Nx.logical_or(0, t([-1, 0, 1])) + |> assert_equal(t([1, 0, 1])) + + t([-1, 0, 1]) + |> Nx.logical_or(t([[-1], [0], [1]])) + |> assert_equal(t( + [ + [1, 1, 1], + [1, 0, 1], + [1, 1, 1] + ] + )) + + t([-1.0, 0.0, 1.0]) + |> Nx.logical_or(t([[-1], [0], [1]])) + |> assert_equal(t( + [ + [1, 1, 1], + [1, 0, 1], + [1, 1, 1] + ] + )) + end end defp t(values, opts \\ []) do From b3cd3365fbe6460d95eb0b8938d4fe7f834a5e6e Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 6 Sep 2023 11:45:35 -0300 Subject: [PATCH 104/185] candlex erf_inv --- candlex/lib/candlex/backend.ex | 8 +- candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/Cargo.lock | 144 +++++++++++++++++++++++++- candlex/native/candlex/Cargo.toml | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 2 + candlex/native/candlex/src/tensors.rs | 5 +- candlex/test/candlex_test.exs | 10 ++ candlex/test/support/candlex_case.ex | 15 +++ 9 files changed, 175 insertions(+), 12 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index c45b131b71..465e76b9d1 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -197,6 +197,7 @@ defmodule Candlex.Backend do :cbrt, :ceil, :cos, + :erf_inv, :exp, :floor, :is_infinity, @@ -219,13 +220,6 @@ defmodule Candlex.Backend do end end - for op <- [:erf_inv] do - @impl true - def unquote(op)(out, t) do - unsupported_op(unquote(op)) - end - end - # Indexed @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 8026e74bfb..653c4aabfe 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -26,6 +26,7 @@ defmodule Candlex.Native do :cbrt, :ceil, :cos, + :erf_inv, :exp, :floor, :is_infinity, diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 02b38b57b1..37a5ca1406 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -11,6 +11,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -187,6 +196,7 @@ dependencies = [ "half", "num-traits", "rustler", + "statrs", "thiserror", ] @@ -330,6 +340,16 @@ version = "0.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" +[[package]] +name = "matrixmultiply" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "090126dc04f95dc0d1c1c91f61bdd474b3930ca064c1edc8a849da2c6cbe1e77" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.6.2" @@ -354,6 +374,35 @@ dependencies = [ "autocfg", ] +[[package]] +name = "nalgebra" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d506eb7e08d6329505faa8a3a00a5dcc6de9f76e0c77e4b75763ae3c770831ff" +dependencies = [ + "approx", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "rand", + "rand_distr", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01fcc0b8149b4632adc89ac3b7b31a12fb6099a0317a4eb2ebff574ef7de7218" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "num-complex" version = "0.4.4" @@ -363,6 +412,27 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.16" @@ -462,6 +532,12 @@ dependencies = [ "bitflags", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.7.0" @@ -539,7 +615,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn", + "syn 2.0.29", ] [[package]] @@ -558,6 +634,15 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +[[package]] +name = "safe_arch" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f398075ce1e6a179b46f51bd88d0598b92b00d3551f1a2d4ac49e771b56ac354" +dependencies = [ + "bytemuck", +] + [[package]] name = "safetensors" version = "0.3.3" @@ -597,7 +682,7 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.29", ] [[package]] @@ -611,6 +696,43 @@ dependencies = [ "serde", ] +[[package]] +name = "simba" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0b7840f121a46d63066ee7a99fc81dcabbc6105e437cae43528cea199b5a05f" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", + "wide", +] + +[[package]] +name = "statrs" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d08e5e1748192713cc281da8b16924fb46be7b0c2431854eadc785823e5696e" +dependencies = [ + "approx", + "lazy_static", + "nalgebra", + "num-traits", + "rand", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.29" @@ -639,9 +761,15 @@ checksum = "6bb623b56e39ab7dcd4b1b98bb6c8f8d907ed255b18de254088016b27a8ee19b" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.29", ] +[[package]] +name = "typenum" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" + [[package]] name = "unicode-ident" version = "1.0.11" @@ -669,6 +797,16 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wide" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa469ffa65ef7e0ba0f164183697b89b854253fd31aeb92358b7b6155177d62f" +dependencies = [ + "bytemuck", + "safe_arch", +] + [[package]] name = "zip" version = "0.6.6" diff --git a/candlex/native/candlex/Cargo.toml b/candlex/native/candlex/Cargo.toml index 91c2a71661..640adb2d92 100644 --- a/candlex/native/candlex/Cargo.toml +++ b/candlex/native/candlex/Cargo.toml @@ -14,4 +14,5 @@ candle-core = { git = "https://github.com/huggingface/candle" } half = "2.3.1" num-traits = "0.2.16" rustler = "0.29.1" +statrs = "0.16.0" thiserror = "1.0.47" diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 05f07bfc63..2966be9260 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -54,6 +54,7 @@ rustler::init! { tensors::ceil, tensors::cos, tensors::sin, + tensors::erf_inv, tensors::exp, tensors::floor, tensors::is_infinity, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index b6bf21383e..3c844a0bab 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -1,5 +1,6 @@ use candle_core::{CpuStorage, CustomOp1, CustomOp2, Error, Layout, Shape}; use num_traits::Float; +use statrs::function::erf::erf_inv; macro_rules! custom_unary_op { ($struct_name:ident, $name:expr, $fn_name:ident, ($($dtypes:ident),+)) => { @@ -195,6 +196,7 @@ custom_unary_op!(Round, "round", round, (BF16, F16, F32, F64)); custom_unary_op!(Tan, "tan", tan, (BF16, F16, F32, F64)); custom_unary_bool_op!(IsInf, "is_inf", is_infinite, (F32, F64)); custom_unary_op_closure!(BitNot, "bit_not", |v| !v, (U8, U32, I64)); +custom_unary_op_closure!(ErfInv, "erf_inv", |v| erf_inv(v), (F64)); custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2, (U32, I64)); custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2, (U32, I64)); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index bcd67512cb..19c4902877 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,8 +1,8 @@ use crate::atoms; use crate::error::CandlexError; use crate::ops::{ - Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, Ceil, Floor, IsInf, Log1p, LogicalOr, - Round, Shl, Shr, Tan, + Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, Ceil, ErfInv, Floor, IsInf, Log1p, + LogicalOr, Round, Shl, Shr, Tan, }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -230,6 +230,7 @@ custom_unary_nif!(atan, Atan); custom_unary_nif!(bitwise_not, BitNot); custom_unary_nif!(cbrt, Cbrt); custom_unary_nif!(ceil, Ceil); +custom_unary_nif!(erf_inv, ErfInv); custom_unary_nif!(floor, Floor); custom_unary_nif!(is_infinity, IsInf); custom_unary_nif!(log1p, Log1p); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 016dd160ba..8cdf6aa876 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -697,6 +697,16 @@ defmodule CandlexTest do ] )) end + + test "erf_inv" do + t(0.10000000149011612, type: :f64) + |> Nx.erf_inv() + |> assert_close(t(0.08885598927736282, type: :f64)) + + t([0.10000000149011612, 0.5, 0.8999999761581421], type: :f64) + |> Nx.erf_inv() + |> assert_close(t([0.0888559891358877, 0.47693629334671295, 1.1630870196442271], type: :f64)) + end end defp t(values, opts \\ []) do diff --git a/candlex/test/support/candlex_case.ex b/candlex/test/support/candlex_case.ex index d90b66eeef..b9b398aaae 100644 --- a/candlex/test/support/candlex_case.ex +++ b/candlex/test/support/candlex_case.ex @@ -27,4 +27,19 @@ defmodule Candlex.Case do """) end end + + def assert_close(left, right) do + equals = + left + |> Nx.all_close(right, atol: 1.0e-4, rtol: 1.0e-4) + |> Nx.backend_transfer(Nx.BinaryBackend) + + if equals != Nx.tensor(1, type: {:u, 8}, backend: Nx.BinaryBackend) do + flunk(""" + Tensor assertion failed. + left: #{inspect(left)} + right: #{inspect(right)} + """) + end + end end From db08bbd11a3794f73a706e682e539c7470303416 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 6 Sep 2023 13:39:02 -0300 Subject: [PATCH 105/185] support erf_inv for f32 --- candlex/native/candlex/src/ops.rs | 8 ++++++-- candlex/test/candlex_test.exs | 16 +++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 3c844a0bab..5c2931dccc 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -1,6 +1,10 @@ use candle_core::{CpuStorage, CustomOp1, CustomOp2, Error, Layout, Shape}; use num_traits::Float; -use statrs::function::erf::erf_inv; +use num_traits::cast::FromPrimitive; + +fn erf_inv(v: T) -> T { + FromPrimitive::from_f64(statrs::function::erf::erf_inv(v.to_f64().unwrap())).unwrap() +} macro_rules! custom_unary_op { ($struct_name:ident, $name:expr, $fn_name:ident, ($($dtypes:ident),+)) => { @@ -196,7 +200,7 @@ custom_unary_op!(Round, "round", round, (BF16, F16, F32, F64)); custom_unary_op!(Tan, "tan", tan, (BF16, F16, F32, F64)); custom_unary_bool_op!(IsInf, "is_inf", is_infinite, (F32, F64)); custom_unary_op_closure!(BitNot, "bit_not", |v| !v, (U8, U32, I64)); -custom_unary_op_closure!(ErfInv, "erf_inv", |v| erf_inv(v), (F64)); +custom_unary_op_closure!(ErfInv, "erf_inv", |v| erf_inv(v), (F32, F64)); custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2, (U32, I64)); custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2, (U32, I64)); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 8cdf6aa876..6a4a92086f 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -227,15 +227,6 @@ defmodule CandlexTest do |> assert_equal(t([0.0, 0.0, 0.0])) end - # test "erf_inv" do - # Nx.erf_inv(0.10000000149011612) - # |> assert_equal(t(0.08885598927736282)) - - # t([0.10000000149011612, 0.5, 0.8999999761581421]) - # |> Nx.erf_inv() - # |> assert_equal(t([0.08885598927736282, 0.4769362807273865, 1.163087010383606])) - # end - test "eye" do Nx.eye(2) |> assert_equal(t([[1, 0], [0, 1]])) @@ -699,6 +690,13 @@ defmodule CandlexTest do end test "erf_inv" do + Nx.erf_inv(0.10000000149011612) + |> assert_close(t(0.08885598927736282)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.erf_inv() + |> assert_close(t([0.08885598927736282, 0.4769362807273865, 1.163087010383606])) + t(0.10000000149011612, type: :f64) |> Nx.erf_inv() |> assert_close(t(0.08885598927736282, type: :f64)) From cd20b8f9dacf7ae79f82f0d6f6150cb3eeed5c78 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 6 Sep 2023 13:46:07 -0300 Subject: [PATCH 106/185] support f16/bf16 erf_inv --- candlex/native/candlex/src/ops.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 5c2931dccc..bf7b396ce4 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -200,7 +200,7 @@ custom_unary_op!(Round, "round", round, (BF16, F16, F32, F64)); custom_unary_op!(Tan, "tan", tan, (BF16, F16, F32, F64)); custom_unary_bool_op!(IsInf, "is_inf", is_infinite, (F32, F64)); custom_unary_op_closure!(BitNot, "bit_not", |v| !v, (U8, U32, I64)); -custom_unary_op_closure!(ErfInv, "erf_inv", |v| erf_inv(v), (F32, F64)); +custom_unary_op_closure!(ErfInv, "erf_inv", |v| erf_inv(v), (BF16, F16, F32, F64)); custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2, (U32, I64)); custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2, (U32, I64)); From 6faae6f968df2a3c71aba46332f0b3ad7694318d Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 6 Sep 2023 16:57:28 -0300 Subject: [PATCH 107/185] better error message --- candlex/lib/candlex/backend.ex | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 465e76b9d1..62cd6d36af 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -444,20 +444,20 @@ defmodule Candlex.Backend do %{t | data: backend_tensor} end - defp to_candle_dtype({:s, 8}), do: unsupported_dtype() - defp to_candle_dtype({:s, 16}), do: unsupported_dtype() - defp to_candle_dtype({:s, 32}), do: unsupported_dtype() + defp to_candle_dtype({:s, 8} = t), do: unsupported_dtype(t) + defp to_candle_dtype({:s, 16} = t), do: unsupported_dtype(t) + defp to_candle_dtype({:s, 32} = t), do: unsupported_dtype(t) defp to_candle_dtype({:s, 64}), do: "i64" defp to_candle_dtype({:u, 8}), do: "u8" - defp to_candle_dtype({:u, 16}), do: unsupported_dtype() + defp to_candle_dtype({:u, 16} = t), do: unsupported_dtype(t) defp to_candle_dtype({:u, 32}), do: "u32" - defp to_candle_dtype({:u, 64}), do: unsupported_dtype() + defp to_candle_dtype({:u, 64} = t), do: unsupported_dtype(t) defp to_candle_dtype({:f, 16}), do: "f16" defp to_candle_dtype({:f, 32}), do: "f32" defp to_candle_dtype({:f, 64}), do: "f64" defp to_candle_dtype({:bf, 16}), do: "bf16" - defp to_candle_dtype({:c, 64}), do: unsupported_dtype() - defp to_candle_dtype({:c, 128}), do: unsupported_dtype() + defp to_candle_dtype({:c, 64} = t), do: unsupported_dtype(t) + defp to_candle_dtype({:c, 128} = t), do: unsupported_dtype(t) defp from_candle_dtype("i64"), do: {:s, 64} defp from_candle_dtype("u8"), do: {:u, 8} @@ -489,8 +489,8 @@ defmodule Candlex.Backend do Native.is_cuda_available() end - defp unsupported_dtype do - raise("Unsupported candle dtype") + defp unsupported_dtype(t) do + raise("Unsupported candle dtype for #{inspect(t)}") end defp unsupported_op(op_name) do From a8b0288abe209ee613cd2b81e2c5fdac5cdaa5d0 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 6 Sep 2023 18:12:41 -0300 Subject: [PATCH 108/185] ensure Candlex tensor matches type of native candle Tensor --- candlex/lib/candlex/backend.ex | 3 +++ candlex/test/candlex_test.exs | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 62cd6d36af..51376803fe 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -182,6 +182,9 @@ defmodule Candlex.Backend do left |> Native.unquote(op)(right) |> unwrap!() + # TODO: Do this conditionally or as part of native op + |> Native.to_type(to_candle_dtype(out.type)) + |> unwrap!() |> to_nx(out) end end diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 6a4a92086f..bc04fc6991 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -601,6 +601,10 @@ defmodule CandlexTest do |> Nx.bitwise_xor(2) |> assert_equal(t([3, 0, 1])) + t([1, 2, 3], type: :u32) + |> Nx.bitwise_xor(2) + |> assert_equal(t([3, 0, 1])) + t([-1, -2, -3]) |> Nx.bitwise_xor(2) |> assert_equal(t([-3, -4, -1])) From 6ebcc813b3178e2fd8a8a52c01a5071647b54368 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 7 Sep 2023 11:40:29 -0300 Subject: [PATCH 109/185] a few more test cases --- candlex/test/candlex_test.exs | 58 +++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index bc04fc6991..695b56d47e 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -195,36 +195,18 @@ defmodule CandlexTest do |> assert_equal(t([[1, 1, 1], [0, 0, 0]])) end - # test "left_shift" do - # Nx.left_shift(1, 0) - # |> assert_equal(t(1)) - - # t([1, 2, 3]) - # |> Nx.left_shift(2) - # |> assert_equal(t([4, 8, 12])) - - # t([1, 1, -1, -1]) - # |> Nx.left_shift(t([1, 2, 3, 4])) - # |> assert_equal(t([2, 4, -8, -16])) - # end - - # test "right_shift" do - # Nx.right_shift(1, 0) - # |> assert_equal(t(1)) - - # t([2, 4, 8]) - # |> Nx.right_shift(2) - # |> assert_equal(t([0, 1, 2])) - - # t([16, 32, -64, -128]) - # |> Nx.right_shift(t([1, 2, 3, 4])) - # |> assert_equal(t([8, 8, -8, -8])) - # end - test "bitcast" do t([0, 0, 0], type: :s64) |> Nx.bitcast(:f64) |> assert_equal(t([0.0, 0.0, 0.0])) + + t([0, 0, 0], type: :u32) + |> Nx.bitcast(:f32) + |> assert_equal(t([0.0, 0.0, 0.0])) + + t([0, 0, 0], type: :u32) + |> Nx.bitcast(:u32) + |> assert_equal(t([0, 0, 0])) end test "eye" do @@ -638,6 +620,22 @@ defmodule CandlexTest do t([1, 1, -1, -1]) |> Nx.left_shift(t([1, 2, 3, 4])) |> assert_equal(t([2, 4, -8, -16])) + + t([1, 2, 3], type: :u32) + |> Nx.left_shift(2) + |> assert_equal(t([4, 8, 12])) + + t([1, 2, 3], type: :u32) + |> Nx.left_shift(t(2, type: :u8)) + |> assert_equal(t([4, 8, 12])) + + t([1, 1, 0, 0], type: :u32) + |> Nx.left_shift(t([1, 2, 3, 4])) + |> assert_equal(t([2, 4, 0, 0])) + + t([1, 1, 0, 0], type: :u32) + |> Nx.left_shift(t([1, 2, 3, 4], type: :u8)) + |> assert_equal(t([2, 4, 0, 0])) end test "right_shift" do @@ -651,6 +649,14 @@ defmodule CandlexTest do t([16, 32, -64, -128]) |> Nx.right_shift(t([1, 2, 3, 4])) |> assert_equal(t([8, 8, -8, -8])) + + t([2, 4, 8], type: :u32) + |> Nx.right_shift(2) + |> assert_equal(t([0, 1, 2])) + + t([16, 32, -64, -128], type: :u32) + |> Nx.right_shift(t([1, 2, 3, 4])) + |> assert_equal(t([8, 8, 536870904, 268435448])) end test "is_infinity" do From 8f69e1a8613bea90f8300f0d00bc19fbadff4916 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 8 Sep 2023 10:36:19 -0300 Subject: [PATCH 110/185] test 2-D tensors dot product --- candlex/test/candlex_test.exs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 695b56d47e..b9882d61be 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -288,6 +288,7 @@ defmodule CandlexTest do # |> Nx.dot(t([1, 2, 3])) # |> assert_equal(t(14.0)) + # TODO: Candle matmul doesn't support integers yet # t([[1, 2, 3], [4, 5, 6]]) # |> Nx.dot(t([[7, 8], [9, 10], [11, 12]])) # |> assert_equal(t( @@ -296,6 +297,15 @@ defmodule CandlexTest do # [139, 154] # ] # )) + + t([[1.0, 2, 3], [4, 5, 6]]) + |> Nx.dot(t([[7.0, 8], [9, 10], [11, 12]])) + |> assert_equal(t( + [ + [58.0, 64], + [139, 154] + ] + )) end test "negate" do From 713f20df8bc861d4c96cb1e20ec041bfadbaf0d8 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 8 Sep 2023 11:05:10 -0300 Subject: [PATCH 111/185] commented tests for later --- candlex/test/candlex_test.exs | 117 +++++++++++++++++++++++++++++++++- 1 file changed, 116 insertions(+), 1 deletion(-) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index b9882d61be..cfa991f1db 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -269,7 +269,9 @@ defmodule CandlexTest do # )) end - test "dot" do + test "dot/2" do + # Dot product of scalars + Nx.dot(5, 5) |> assert_equal(t(25)) @@ -279,6 +281,8 @@ defmodule CandlexTest do Nx.dot(2, 2.0) |> assert_equal(t(4.0)) + # Dot product of vectors + # TODO: # t([1, 2, 3]) # |> Nx.dot(t([4, 5, 6])) @@ -288,6 +292,8 @@ defmodule CandlexTest do # |> Nx.dot(t([1, 2, 3])) # |> assert_equal(t(14.0)) + # Dot product of matrices (2-D tensors) + # TODO: Candle matmul doesn't support integers yet # t([[1, 2, 3], [4, 5, 6]]) # |> Nx.dot(t([[7, 8], [9, 10], [11, 12]])) @@ -306,8 +312,117 @@ defmodule CandlexTest do [139, 154] ] )) + + # Dot product of vector and n-D tensor + + # t([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]], names: [:i, :j, :k]) + # |> Nx.dot(t([5.0, 10], names: [:x])) + # |> assert_equal(t( + # [ + # [25, 55], + # [85, 115] + # ] + # )) + + # t([5.0, 10], names: [:x]) + # |> Nx.dot(t([[1.0, 2, 3], [4, 5, 6]], names: [:i, :j])) + # |> assert_equal(t( + # [45, 60, 75] + # )) + + # t([[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]], names: [:shard, :batch, :x, :y, :z]) + # |> Nx.dot(t([2.0, 2.0], names: [:data])) + # |> assert_equal(t( + # [ + # [ + # [ + # [6.0, 14.0], + # [22.0, 30.0] + # ] + # ] + # ] + # )) + + # Dot product of n-D and m-D tensors + + # t([[[1.0, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:x, :y, :z]) + # |> Nx.dot(t([[[1.0, 2, 3], [3, 4, 5], [5, 6, 7]]], names: [:i, :j, :k])) + # |> assert_equal(t( + # [ + # [ + # [ + # [22, 28, 34] + # ], + # [ + # [49, 64, 79] + # ], + # [ + # [76, 100, 124] + # ] + # ], + # [ + # [ + # [22, 28, 34] + # ], + # [ + # [49, 64, 79] + # ], + # [ + # [76, 100, 124] + # ] + # ] + # ] + # )) end + # TODO: + # test "dot/6" do + # Contracting along axes + + # t1 = t([[1.0, 2], [3, 4]], names: [:x, :y]) + # t2 = t([[10.0, 20], [30, 40]], names: [:height, :width]) + + # t1 + # |> Nx.dot([0], [], t2, [0], []) + # |> assert_equal(t( + # [ + # [100, 140], + # [140, 200] + # ] + # )) + + # t1 + # |> Nx.dot([0], [], t2, [1], []) + # |> assert_equal(t( + # [ + # [70, 150], + # [100, 220] + # ] + # )) + + # t1 + # |> Nx.dot([1], [], t2, [0], []) + # |> assert_equal(t( + # [ + # [70, 100], + # [150, 220] + # ] + # )) + + # t1 + # |> Nx.dot([1], [], t2, [1], []) + # |> assert_equal(t( + # [ + # [50, 110], + # [110, 250] + # ] + # )) + + # t1 + # |> Nx.dot([0, 1], [], t2, [0, 1], []) + # |> assert_equal(t(300)) + # end + test "negate" do # TODO: candle doesn't support unary functions for integers yet # Nx.negate(1) From 5814191c20037d67812167c82d9473884e254282 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 8 Sep 2023 11:30:05 -0300 Subject: [PATCH 112/185] more clear about what dot is currently supported --- candlex/lib/candlex/backend.ex | 4 ++-- candlex/test/candlex_test.exs | 26 +++++++++++++------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 51376803fe..c0fbb66bd4 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -257,10 +257,10 @@ defmodule Candlex.Backend do def dot( %T{type: _out_type} = out, %T{shape: left_shape, type: _left_type} = left, - _left_axes, + [1] = _left_axes, [] = _left_batched_axes, %T{shape: right_shape, type: _right_type} = right, - _right_axes, + [0] = _right_axes, [] = _right_batched_axes ) when tuple_size(left_shape) == 2 and tuple_size(right_shape) == 2 do diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index cfa991f1db..e7e750535b 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -375,13 +375,13 @@ defmodule CandlexTest do # )) end - # TODO: - # test "dot/6" do + test "dot/6" do # Contracting along axes - # t1 = t([[1.0, 2], [3, 4]], names: [:x, :y]) - # t2 = t([[10.0, 20], [30, 40]], names: [:height, :width]) + t1 = t([[1.0, 2], [3, 4]], names: [:x, :y]) + t2 = t([[10.0, 20], [30, 40]], names: [:height, :width]) + # TODO: # t1 # |> Nx.dot([0], [], t2, [0], []) # |> assert_equal(t( @@ -400,14 +400,14 @@ defmodule CandlexTest do # ] # )) - # t1 - # |> Nx.dot([1], [], t2, [0], []) - # |> assert_equal(t( - # [ - # [70, 100], - # [150, 220] - # ] - # )) + t1 + |> Nx.dot([1], [], t2, [0], []) + |> assert_equal(t( + [ + [70, 100], + [150, 220] + ] + )) # t1 # |> Nx.dot([1], [], t2, [1], []) @@ -421,7 +421,7 @@ defmodule CandlexTest do # t1 # |> Nx.dot([0, 1], [], t2, [0, 1], []) # |> assert_equal(t(300)) - # end + end test "negate" do # TODO: candle doesn't support unary functions for integers yet From e13bdb9108e5fdfb565dc985151470e72c44a000 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 8 Sep 2023 11:45:00 -0300 Subject: [PATCH 113/185] support dot/6 with axes [0] and [0] --- candlex/lib/candlex/backend.ex | 20 ++++++++++++++++++++ candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 5 +++++ candlex/test/candlex_test.exs | 18 +++++++++--------- 5 files changed, 36 insertions(+), 9 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index c0fbb66bd4..62bdf6e1b9 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -272,6 +272,18 @@ defmodule Candlex.Backend do |> to_nx(out) end + def dot(out, %T{shape: left_shape} = left, [0], [], right, [0], []) when tuple_size(left_shape) == 2 do + dot( + out, + left |> Nx.transpose(axes: [1, 0]), + [1], + [], + right, + [0], + [] + ) + end + # Shape @impl true @@ -304,6 +316,14 @@ defmodule Candlex.Backend do |> to_nx(out) end + @impl true + def transpose(out, %T{} = t, [dim1, dim2] = axes) do + from_nx(t) + |> Native.transpose(dim1, dim2) + |> unwrap!() + |> to_nx(out) + end + # Type @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 653c4aabfe..e2293782c9 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -13,6 +13,7 @@ defmodule Candlex.Native do def arange(_start, _end, _dtype, _shape, _device), do: error() def broadcast_to(_tensor, _shape), do: error() def reshape(_tensor, _shape), do: error() + def transpose(_tensor, _dim1, _dim2), do: error() def to_type(_tensor, _dtype), do: error() def dtype(_tensor), do: error() def concatenate(_tensors, _axis), do: error() diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 2966be9260..2fcf19bf6c 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -40,6 +40,7 @@ rustler::init! { tensors::where_cond, tensors::narrow, tensors::squeeze, + tensors::transpose, tensors::arange, tensors::to_type, tensors::broadcast_to, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 19c4902877..d1d7ac2643 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -77,6 +77,11 @@ pub fn squeeze(t: ExTensor, dim: usize) -> Result { Ok(ExTensor::new(t.squeeze(dim)?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn transpose(t: ExTensor, dim1: usize, dim2: usize) -> Result { + Ok(ExTensor::new(t.transpose(dim1, dim2)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn arange( start: i64, diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index e7e750535b..be8866192e 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -381,16 +381,16 @@ defmodule CandlexTest do t1 = t([[1.0, 2], [3, 4]], names: [:x, :y]) t2 = t([[10.0, 20], [30, 40]], names: [:height, :width]) - # TODO: - # t1 - # |> Nx.dot([0], [], t2, [0], []) - # |> assert_equal(t( - # [ - # [100, 140], - # [140, 200] - # ] - # )) + t1 + |> Nx.dot([0], [], t2, [0], []) + |> assert_equal(t( + [ + [100, 140], + [140, 200] + ] + )) + # TODO: # t1 # |> Nx.dot([0], [], t2, [1], []) # |> assert_equal(t( From 33798993a39bc993cf91d274d20111aaab571fd1 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 8 Sep 2023 11:52:49 -0300 Subject: [PATCH 114/185] support dot/6 with axes [0] and [1] --- candlex/lib/candlex/backend.ex | 36 +++++++++++++++++++++++++++++++--- candlex/test/candlex_test.exs | 16 +++++++-------- 2 files changed, 41 insertions(+), 11 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 62bdf6e1b9..e3310ac17e 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -272,15 +272,45 @@ defmodule Candlex.Backend do |> to_nx(out) end - def dot(out, %T{shape: left_shape} = left, [0], [], right, [0], []) when tuple_size(left_shape) == 2 do + def dot( + out, + %T{shape: left_shape} = left, + [0], + left_batched_axes, + right, + right_axes, + right_batched_axes + ) + when tuple_size(left_shape) == 2 do dot( out, left |> Nx.transpose(axes: [1, 0]), [1], - [], + left_batched_axes, right, + right_axes, + right_batched_axes + ) + end + + def dot( + out, + left, + left_axes, + left_batched_axes, + %T{shape: right_shape} = right, + [1], + right_batched_axes + ) + when tuple_size(right_shape) == 2 do + dot( + out, + left, + left_axes, + left_batched_axes, + right |> Nx.transpose(axes: [1, 0]), [0], - [] + right_batched_axes ) end diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index be8866192e..22bb0ac24b 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -391,14 +391,14 @@ defmodule CandlexTest do )) # TODO: - # t1 - # |> Nx.dot([0], [], t2, [1], []) - # |> assert_equal(t( - # [ - # [70, 150], - # [100, 220] - # ] - # )) + t1 + |> Nx.dot([0], [], t2, [1], []) + |> assert_equal(t( + [ + [70, 150], + [100, 220] + ] + )) t1 |> Nx.dot([1], [], t2, [0], []) From 547daf991472bdcd1e75b3d3d6953d3e93aaf58e Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 8 Sep 2023 13:13:12 -0300 Subject: [PATCH 115/185] sum/2 --- candlex/lib/candlex/backend.ex | 12 ++++++ candlex/lib/candlex/native.ex | 2 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 11 +++++ candlex/test/candlex_test.exs | 58 +++++++++++++++++++++++++++ candlex/test/support/candlex_case.ex | 2 +- 6 files changed, 85 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index e3310ac17e..864e1cf232 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -99,6 +99,18 @@ defmodule Candlex.Backend do |> to_nx(out) end + @impl true + def sum(%T{} = out, %T{} = t, opts) do + axes = opts[:axes] || Nx.axes(t) + keep_axes = opts[:keep_axes] || false + + t + |> from_nx() + |> Native.sum(axes, keep_axes) + |> unwrap!() + |> to_nx(out) + end + for op <- [:argmax, :argmin] do @impl true def unquote(op)(%T{} = out, %T{shape: {}} = _tensor, _opts) do diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index e2293782c9..9e79b67594 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -64,6 +64,8 @@ defmodule Candlex.Native do def unquote(op)(_left, _right), do: error() end + def sum(_tensor, _dims, _keep_dims), do: error() + for op <- [:argmax, :argmin] do def unquote(op)(_tensor, _dim, _keep_dim), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 2fcf19bf6c..72fd963b38 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -33,6 +33,7 @@ rustler::init! { tensors::less_equal, tensors::subtract, tensors::all, + tensors::sum, tensors::dtype, tensors::argmax, tensors::argmin, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index d1d7ac2643..900c6a4b44 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -139,6 +139,17 @@ pub fn argmin(ex_tensor: ExTensor, dim: usize, keep_dim: bool) -> Result, keep_dims: bool) -> Result { + let t = if keep_dims { + ex_tensor.sum_keepdim(dims)? + } else { + ex_tensor.sum(dims)? + }; + + Ok(ExTensor::new(t)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn broadcast_to(t: ExTensor, shape: Term) -> Result { Ok(ExTensor::new(t.broadcast_as(tuple_to_vec(shape).unwrap())?)) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 22bb0ac24b..ffb7405b3d 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -840,6 +840,64 @@ defmodule CandlexTest do |> Nx.erf_inv() |> assert_close(t([0.0888559891358877, 0.47693629334671295, 1.1630870196442271], type: :f64)) end + + test "sum/2" do + t(42) + |> Nx.sum() + |> assert_equal(t(42)) + + t([1, 2, 3]) + |> Nx.sum() + |> assert_equal(t(6)) + + t([[1.0, 2.0], [3.0, 4.0]]) + |> Nx.sum() + |> assert_equal(t(10.0)) + + t = Nx.iota({2, 2, 3}, names: [:x, :y, :z]) + Nx.sum(t, axes: [:x]) + |> assert_equal(t( + [ + [6, 8, 10], + [12, 14, 16] + ] + )) + + Nx.sum(t, axes: [:y]) + |> assert_equal(t( + [ + [3, 5, 7], + [15, 17, 19] + ] + )) + + Nx.sum(t, axes: [:z]) + |> assert_equal(t( + [ + [3, 12], + [21, 30] + ] + )) + + Nx.sum(t, axes: [:x, :z]) + |> assert_equal(t([24, 42])) + + Nx.sum(t, axes: [-3]) + |> assert_equal(t( + [ + [6, 8, 10], + [12, 14, 16] + ] + )) + + t([[1, 2], [3, 4]], names: [:x, :y]) + |> Nx.sum(axes: [:x], keep_axes: true) + |> assert_equal(t( + [ + [4, 6] + ] + )) + end end defp t(values, opts \\ []) do diff --git a/candlex/test/support/candlex_case.ex b/candlex/test/support/candlex_case.ex index b9b398aaae..3f11eae98d 100644 --- a/candlex/test/support/candlex_case.ex +++ b/candlex/test/support/candlex_case.ex @@ -19,7 +19,7 @@ defmodule Candlex.Case do |> Nx.all() |> Nx.to_number() - if equals != 1 do + if equals != 1 || Nx.shape(left) != Nx.shape(right) do flunk(""" Tensor assertion failed. left: #{inspect(left)} From 19dc54f726df83357e830e1b9c3a9f7fa4153530 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 8 Sep 2023 13:15:28 -0300 Subject: [PATCH 116/185] uncomment linear regression --- candlex/examples/linear_regression.exs | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/candlex/examples/linear_regression.exs b/candlex/examples/linear_regression.exs index 28a5884bd1..e80949a8fe 100644 --- a/candlex/examples/linear_regression.exs +++ b/candlex/examples/linear_regression.exs @@ -60,18 +60,18 @@ epochs = 100 # These will be very close to the above coefficients {time, {trained_m, trained_b}} = :timer.tc(LinearRegression, :train, [params, epochs, lin_fn]) -# trained_m = -# trained_m -# |> Nx.squeeze() -# |> Nx.backend_transfer() -# |> Nx.to_number() +trained_m = + trained_m + |> Nx.squeeze() + |> Nx.backend_transfer() + |> Nx.to_number() -# trained_b = -# trained_b -# |> Nx.squeeze() -# |> Nx.backend_transfer() -# |> Nx.to_number() +trained_b = + trained_b + |> Nx.squeeze() + |> Nx.backend_transfer() + |> Nx.to_number() -# IO.puts("Trained in #{time / 1_000_000} sec.") -# IO.puts("Trained m: #{trained_m} Trained b: #{trained_b}\n") -# IO.puts("Accuracy m: #{m - trained_m} Accuracy b: #{b - trained_b}") +IO.puts("Trained in #{time / 1_000_000} sec.") +IO.puts("Trained m: #{trained_m} Trained b: #{trained_b}\n") +IO.puts("Accuracy m: #{m - trained_m} Accuracy b: #{b - trained_b}") From afa4aad8b8b70396e4c53be450c2d03bcf23ab3a Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 8 Sep 2023 14:00:44 -0300 Subject: [PATCH 117/185] divide/2 --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 5 ++-- candlex/native/candlex/src/tensors.rs | 1 + candlex/test/candlex_test.exs | 37 +++++++++++++++++++++++++++ 5 files changed, 43 insertions(+), 3 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 864e1cf232..132fd63479 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -162,7 +162,7 @@ defmodule Candlex.Backend do # Binary ops - for op <- [:add, :max, :min, :multiply, :subtract] do + for op <- [:add, :divide, :max, :min, :multiply, :subtract] do @impl true def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do {left, right} = maybe_upcast(left, right) diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 9e79b67594..e36c291047 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -48,6 +48,7 @@ defmodule Candlex.Native do :bitwise_and, :bitwise_or, :bitwise_xor, + :divide, :equal, :greater_equal, :left_shift, diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 72fd963b38..5c161efec4 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -24,14 +24,15 @@ rustler::init! { tensors::from_binary, tensors::to_binary, tensors::add, + tensors::subtract, + tensors::multiply, + tensors::divide, tensors::max, tensors::min, - tensors::multiply, tensors::equal, tensors::greater_equal, tensors::less, tensors::less_equal, - tensors::subtract, tensors::all, tensors::sum, tensors::dtype, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 900c6a4b44..5fdca18023 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -256,6 +256,7 @@ custom_unary_nif!(tan, Tan); binary_nif!(add, broadcast_add); binary_nif!(subtract, broadcast_sub); binary_nif!(multiply, broadcast_mul); +binary_nif!(divide, broadcast_div); binary_nif!(max, broadcast_maximum); binary_nif!(min, broadcast_minimum); binary_nif!(equal, eq); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index ffb7405b3d..855c8140ee 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -118,6 +118,43 @@ defmodule CandlexTest do |> assert_equal(t([[3, 6], [4, 8]])) end + test "divide/2" do + 1.0 + |> Nx.divide(2) + |> assert_equal(t(0.5)) + + t([1.0, 2, 3]) + |> Nx.divide(1) + |> assert_equal(t([1.0, 2.0, 3.0])) + + t([[1.0], [2]]) + |> Nx.divide(t([[10, 20]])) + |> assert_equal(t( + [ + [0.10000000149011612, 0.05000000074505806], + [0.20000000298023224, 0.10000000149011612] + ] + )) + + # TODO: Support integers + # 1 + # |> Nx.divide(2) + # |> assert_equal(t(0.5)) + + # t([1, 2, 3]) + # |> Nx.divide(1) + # |> assert_equal(t([1.0, 2.0, 3.0])) + + # t([[1], [2]]) + # |> Nx.divide(t([[10, 20]])) + # |> assert_equal(t( + # [ + # [0.10000000149011612, 0.05000000074505806], + # [0.20000000298023224, 0.10000000149011612] + # ] + # )) + end + test "broadcast" do Nx.broadcast(1, {1, 2, 3}) |> assert_equal(t([[[1, 1, 1], [1, 1, 1]]])) From f35d6485de5f64f0cf5a524abaf85d29cf3e17ae Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 8 Sep 2023 16:06:40 -0300 Subject: [PATCH 118/185] to_batched/2 --- candlex/lib/candlex/backend.ex | 21 ++++++++++++++++++++- candlex/lib/candlex/native.ex | 2 ++ candlex/native/candlex/src/lib.rs | 2 ++ candlex/native/candlex/src/tensors.rs | 16 +++++++++++++++- candlex/test/candlex_test.exs | 27 +++++++++++++++++++++++++++ 5 files changed, 66 insertions(+), 2 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 132fd63479..665202989f 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -490,6 +490,20 @@ defmodule Candlex.Backend do ## Conversions + @impl true + def to_batched(%T{shape: out_shape} = out, %T{shape: shape} = t, _opts) do + # TODO: dont ignore opts + batch_size = elem(out_shape, 0) + t_axis_0 = elem(shape, 0) + num_batches = div(t_axis_0, batch_size) + + t + |> from_nx() + |> Native.chunk(num_batches) + |> unwrap!() + |> Stream.map(&to_nx(&1, out)) + end + @doc false defp from_nx(%T{data: %CB{} = data}), do: data @@ -499,13 +513,18 @@ defmodule Candlex.Backend do |> from_nx() end - defp to_nx(%CB{resource: ref} = backend_tensor, %T{type: nx_type} = t) when is_reference(ref) do + defp to_nx(%CB{resource: ref} = backend_tensor, %T{type: nx_type, shape: nx_shape} = t) when is_reference(ref) do {:ok, candle_dtype} = Native.dtype(backend_tensor) + {:ok, candle_shape} = Native.t_shape(backend_tensor) if nx_type != from_candle_dtype(candle_dtype) do raise "tensor type mismatch, Nx (#{inspect(nx_type)}) and Candle (#{inspect(candle_dtype)})" end + if nx_shape != candle_shape do + raise "tensor shape mismatch, Nx (#{inspect(nx_shape)}) and Candle (#{inspect(candle_shape)})" + end + %{t | data: backend_tensor} end diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index e36c291047..2c3e8a81c8 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -9,6 +9,7 @@ defmodule Candlex.Native do def all(_tensor), do: error() def where_cond(_tensor, _on_true, _on_false), do: error() def narrow(_tensor, _dim, _start, _length), do: error() + def chunk(_tensor, _num_chunks), do: error() def squeeze(_tensor, _dim), do: error() def arange(_start, _end, _dtype, _shape, _device), do: error() def broadcast_to(_tensor, _shape), do: error() @@ -16,6 +17,7 @@ defmodule Candlex.Native do def transpose(_tensor, _dim1, _dim2), do: error() def to_type(_tensor, _dtype), do: error() def dtype(_tensor), do: error() + def t_shape(_tensor), do: error() def concatenate(_tensors, _axis), do: error() for op <- [ diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 5c161efec4..06a62d8e0d 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -36,11 +36,13 @@ rustler::init! { tensors::all, tensors::sum, tensors::dtype, + tensors::t_shape, tensors::argmax, tensors::argmin, tensors::negate, tensors::where_cond, tensors::narrow, + tensors::chunk, tensors::squeeze, tensors::transpose, tensors::arange, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 5fdca18023..196b56aabc 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -6,7 +6,7 @@ use crate::ops::{ }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; -use rustler::{Atom, Binary, Env, NewBinary, NifStruct, ResourceArc, Term}; +use rustler::{Atom, Binary, Env, Encoder, NewBinary, NifStruct, ResourceArc, Term}; use std::ops::Deref; use std::result::Result; use std::str::FromStr; @@ -72,6 +72,11 @@ pub fn narrow( Ok(ExTensor::new(t.narrow(dim, start, length)?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn chunk(t: ExTensor, num_chunks: usize) -> Result, CandlexError> { + Ok(t.chunk(num_chunks, 0)?.into_iter().map(|t| ExTensor::new(t)).collect()) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn squeeze(t: ExTensor, dim: usize) -> Result { Ok(ExTensor::new(t.squeeze(dim)?)) @@ -181,6 +186,11 @@ pub fn dtype(t: ExTensor) -> Result<&'static str, CandlexError> { Ok(t.dtype().as_str()) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn t_shape(env: Env, t: ExTensor) -> Result { + Ok(vec_to_tuple(env, t.shape().clone().into_dims()).unwrap()) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn concatenate(ex_tensors: Vec, dim: usize) -> Result { let tensors = ex_tensors @@ -279,6 +289,10 @@ fn tuple_to_vec(term: Term) -> Result, rustler::Error> { .collect::>()?) } +fn vec_to_tuple(env: Env, vec: Vec) -> Result { + Ok(rustler::types::tuple::make_tuple(env, &vec.into_iter().map(|elem| elem.encode(env)).collect::>())) +} + fn device_from_atom(atom: Atom) -> Result { if atom == atoms::cpu() { Ok(Device::Cpu) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 855c8140ee..453afa02c5 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -935,6 +935,33 @@ defmodule CandlexTest do ] )) end + + test "to_batched/2" do + [first, second] = + Nx.iota({2, 2, 2}) + |> Nx.to_batched(1) + |> Enum.to_list() + + first + |> assert_equal(t( + [ + [ + [0, 1], + [2, 3] + ] + ] + )) + + second + |> assert_equal(t( + [ + [ + [4, 5], + [6, 7] + ] + ] + )) + end end defp t(values, opts \\ []) do From 289ebb7bdee6bfa32de1415e549bdae0d64bbc1a Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 8 Sep 2023 16:53:26 -0300 Subject: [PATCH 119/185] sigmoid/1 --- candlex/lib/candlex/backend.ex | 1 + candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 1 + candlex/native/candlex/src/tensors.rs | 3 ++- candlex/test/candlex_test.exs | 9 +++++++++ 6 files changed, 15 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 665202989f..c9de4ecc14 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -220,6 +220,7 @@ defmodule Candlex.Backend do :log1p, :negate, :round, + :sigmoid, :sin, :sqrt, :tan, diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 2c3e8a81c8..d5ada054c8 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -37,6 +37,7 @@ defmodule Candlex.Native do :log1p, :negate, :round, + :sigmoid, :sin, :sqrt, :tan, diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 06a62d8e0d..99837ef810 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -58,6 +58,7 @@ rustler::init! { tensors::cbrt, tensors::ceil, tensors::cos, + tensors::sigmoid, tensors::sin, tensors::erf_inv, tensors::exp, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index bf7b396ce4..489be44495 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -201,6 +201,7 @@ custom_unary_op!(Tan, "tan", tan, (BF16, F16, F32, F64)); custom_unary_bool_op!(IsInf, "is_inf", is_infinite, (F32, F64)); custom_unary_op_closure!(BitNot, "bit_not", |v| !v, (U8, U32, I64)); custom_unary_op_closure!(ErfInv, "erf_inv", |v| erf_inv(v), (BF16, F16, F32, F64)); +custom_unary_op_closure!(Sigmoid, "sigmoid", |v| 1. / (1. + (-v).exp()), (F32, F64)); custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2, (U32, I64)); custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2, (U32, I64)); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 196b56aabc..135f7eeb0c 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -2,7 +2,7 @@ use crate::atoms; use crate::error::CandlexError; use crate::ops::{ Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, Ceil, ErfInv, Floor, IsInf, Log1p, - LogicalOr, Round, Shl, Shr, Tan, + LogicalOr, Round, Shl, Shr, Sigmoid, Tan, }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -261,6 +261,7 @@ custom_unary_nif!(floor, Floor); custom_unary_nif!(is_infinity, IsInf); custom_unary_nif!(log1p, Log1p); custom_unary_nif!(round, Round); +custom_unary_nif!(sigmoid, Sigmoid); custom_unary_nif!(tan, Tan); binary_nif!(add, broadcast_add); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 453afa02c5..dd0401baf3 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -962,6 +962,15 @@ defmodule CandlexTest do ] )) end + + test "sigmoid/1" do + Nx.sigmoid(1.0) + |> assert_close(t(0.7310585975646973)) + + t([1.0, 2, 3]) + |> Nx.sigmoid() + |> assert_close(t([0.7310585975646973, 0.8807970881462097, 0.9525741338729858])) + end end defp t(values, opts \\ []) do From 93597dbfe7483e26eca95060b7c3bbcf00eec1c9 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 11 Sep 2023 14:43:26 -0300 Subject: [PATCH 120/185] candlex support u64 using candle i64 --- candlex/lib/candlex/backend.ex | 8 +++++--- candlex/test/candlex_test.exs | 12 ++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index c9de4ecc14..900323b307 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -518,8 +518,10 @@ defmodule Candlex.Backend do {:ok, candle_dtype} = Native.dtype(backend_tensor) {:ok, candle_shape} = Native.t_shape(backend_tensor) - if nx_type != from_candle_dtype(candle_dtype) do - raise "tensor type mismatch, Nx (#{inspect(nx_type)}) and Candle (#{inspect(candle_dtype)})" + case {nx_type, from_candle_dtype(candle_dtype)} do + {{:u, 64}, {:s, 64}} -> :ok + {type, type} -> :ok + {type, other_type} -> raise "tensor type mismatch, Nx (#{inspect(type)}) and Candle (#{inspect(other_type)})" end if nx_shape != candle_shape do @@ -536,7 +538,7 @@ defmodule Candlex.Backend do defp to_candle_dtype({:u, 8}), do: "u8" defp to_candle_dtype({:u, 16} = t), do: unsupported_dtype(t) defp to_candle_dtype({:u, 32}), do: "u32" - defp to_candle_dtype({:u, 64} = t), do: unsupported_dtype(t) + defp to_candle_dtype({:u, 64} = t), do: "i64" defp to_candle_dtype({:f, 16}), do: "f16" defp to_candle_dtype({:f, 32}), do: "f32" defp to_candle_dtype({:f, 64}), do: "f64" diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index dd0401baf3..d56d613bd0 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -6,6 +6,7 @@ defmodule CandlexTest do test "tensor" do check(255, type: :u8) check(100_002, type: :u32) + check(100_102, type: :u64) check(-101, type: :s64) check(1.16, type: :f16) check(1.32, type: :f32) @@ -55,6 +56,10 @@ defmodule CandlexTest do |> Nx.add(t([10, 20, 30])) |> assert_equal(t([11, 22, 33])) + t([1, 2, 3], type: :u64) + |> Nx.add(t([10, 20, 30], type: :u64)) + |> assert_equal(t([11, 22, 33])) + Nx.add(1, 2.2) |> assert_equal(t(3.2)) @@ -73,6 +78,9 @@ defmodule CandlexTest do Nx.iota({5}) |> assert_equal(t([0, 1, 2, 3, 4])) + Nx.iota({5}, type: :u64) + |> assert_equal(t([0, 1, 2, 3, 4])) + Nx.iota({5}, type: :f32) |> assert_equal(t([0.0, 1.0, 2.0, 3.0, 4.0])) @@ -109,6 +117,10 @@ defmodule CandlexTest do |> Nx.multiply(t([3, 4])) |> assert_equal(t([3, 8])) + t([1, 2], type: :u64) + |> Nx.multiply(t([3, 4], type: :u64)) + |> assert_equal(t([3, 8])) + t([[1], [2]]) |> Nx.multiply(t([3, 4])) |> assert_equal(t([[3, 4], [6, 8]])) From 4de34e3f96eeae510723031cfdc7c99a206d51ee Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 11 Sep 2023 15:12:02 -0300 Subject: [PATCH 121/185] candlex show device on inspect --- candlex/lib/candlex/backend.ex | 6 +++--- candlex/native/candlex/src/tensors.rs | 7 +++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 900323b307..a0471e8196 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -3,7 +3,7 @@ defmodule Candlex.Backend do An opaque Nx backend with bindings to candle. """ - defstruct [:resource] + defstruct [:device, :resource] @behaviour Nx.Backend @@ -395,9 +395,9 @@ defmodule Candlex.Backend do |> maybe_add_signature(tensor) end - defp maybe_add_signature(result, %T{data: %CB{resource: ref}}) when is_reference(ref) do + defp maybe_add_signature(result, %T{data: %CB{device: device, resource: ref}}) when is_reference(ref) do Inspect.Algebra.concat([ - "Candlex.Backend(#{:erlang.ref_to_list(ref)})", + "Candlex.Backend(#{device})", Inspect.Algebra.line(), result ]) diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 135f7eeb0c..ed7c0a49af 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -16,12 +16,19 @@ pub(crate) struct TensorRef(Tensor); #[derive(NifStruct)] #[module = "Candlex.Backend"] pub struct ExTensor { + device: String, resource: ResourceArc, } impl ExTensor { pub fn new(tensor: Tensor) -> Self { + let dev_string = match tensor.device() { + Device::Cpu => String::from("cpu"), + Device::Cuda(_) => String::from("cuda") + }; + Self { + device: dev_string, resource: ResourceArc::new(TensorRef(tensor)), } } From 5329b4253ed1513c0acbdd3da515a3536a336feb Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 11 Sep 2023 15:13:14 -0300 Subject: [PATCH 122/185] candlex mnist example --- candlex/examples/mnist.exs | 170 +++++++++++++++++++++++++++++++++++++ 1 file changed, 170 insertions(+) create mode 100644 candlex/examples/mnist.exs diff --git a/candlex/examples/mnist.exs b/candlex/examples/mnist.exs new file mode 100644 index 0000000000..9e0254a509 --- /dev/null +++ b/candlex/examples/mnist.exs @@ -0,0 +1,170 @@ +defmodule MNIST do + import Nx.Defn + + defn init_random_params do + key = Nx.Random.key(42) + {w1, new_key} = Nx.Random.normal(key, 0.0, 0.1, shape: {784, 128}, names: [:input, :layer]) + {b1, new_key} = Nx.Random.normal(new_key, 0.0, 0.1, shape: {128}, names: [:layer]) + + {w2, new_key} = + Nx.Random.normal(new_key, 0.0, 0.1, shape: {128, 10}, names: [:layer, :output]) + + {b2, _new_key} = Nx.Random.normal(new_key, 0.0, 0.1, shape: {10}, names: [:output]) + {w1, b1, w2, b2} + end + + defn softmax(logits) do + Nx.exp(logits) / Nx.sum(Nx.exp(logits), axes: [:output], keep_axes: true) + end + + defn predict({w1, b1, w2, b2}, batch) do + batch + |> Nx.dot(w1) + |> Nx.add(b1) + |> Nx.sigmoid() + |> Nx.dot(w2) + |> Nx.add(b2) + |> softmax() + end + + defn accuracy({w1, b1, w2, b2}, batch_images, batch_labels) do + Nx.mean( + Nx.equal( + Nx.argmax(batch_labels, axis: :output), + Nx.argmax(predict({w1, b1, w2, b2}, batch_images), axis: :output) + ) + ) + end + + defn loss({w1, b1, w2, b2}, batch_images, batch_labels) do + preds = predict({w1, b1, w2, b2}, batch_images) + -Nx.sum(Nx.mean(Nx.log(preds) * batch_labels, axes: [:output])) + end + + defn update({w1, b1, w2, b2} = params, batch_images, batch_labels, step) do + {grad_w1, grad_b1, grad_w2, grad_b2} = grad(params, &loss(&1, batch_images, batch_labels)) + + { + w1 - grad_w1 * step, + b1 - grad_b1 * step, + w2 - grad_w2 * step, + b2 - grad_b2 * step + } + end + + defn update_with_averages({_, _, _, _} = cur_params, imgs, tar, avg_loss, avg_accuracy, total) do + batch_loss = loss(cur_params, imgs, tar) + batch_accuracy = accuracy(cur_params, imgs, tar) + avg_loss = avg_loss + batch_loss / total + avg_accuracy = avg_accuracy + batch_accuracy / total + {update(cur_params, imgs, tar, 0.01), avg_loss, avg_accuracy} + end + + defp unzip_cache_or_download(zip) do + base_url = ~c"https://storage.googleapis.com/cvdf-datasets/mnist/" + path = Path.join("tmp", zip) + + data = + if File.exists?(path) do + IO.puts("Using #{zip} from tmp/\n") + File.read!(path) + else + IO.puts("Fetching #{zip} from https://storage.googleapis.com/cvdf-datasets/mnist/\n") + :inets.start() + :ssl.start() + + {:ok, {_status, _response, data}} = :httpc.request(base_url ++ zip) + File.mkdir_p!("tmp") + File.write!(path, data) + + data + end + + :zlib.gunzip(data) + end + + def download(images, labels) do + <<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = + unzip_cache_or_download(images) + + train_images = + images + |> Nx.from_binary({:u, 8}) + |> Nx.reshape({n_images, n_rows * n_cols}, names: [:batch, :input]) + |> Nx.divide(255.0) + |> Nx.to_batched(30) + + IO.puts("#{n_images} #{n_rows}x#{n_cols} images\n") + + <<_::32, n_labels::32, labels::binary>> = unzip_cache_or_download(labels) + + train_labels = + labels + |> Nx.from_binary({:u, 8}) + |> Nx.reshape({n_labels, 1}, names: [:batch, :output]) + |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) + |> Nx.to_batched(30) + + IO.puts("#{n_labels} labels\n") + + {train_images, train_labels} + end + + def train_epoch(cur_params, imgs, labels) do + total_batches = Enum.count(imgs) + + imgs + |> Stream.zip(labels) + |> Enum.reduce({cur_params, Nx.tensor(0.0), Nx.tensor(0.0)}, fn + {imgs, tar}, {cur_params, avg_loss, avg_accuracy} -> + update_with_averages(cur_params, imgs, tar, avg_loss, avg_accuracy, total_batches) + end) + end + + def train(imgs, labels, params, opts \\ []) do + epochs = opts[:epochs] || 5 + + for epoch <- 1..epochs, reduce: params do + cur_params -> + {time, {new_params, epoch_avg_loss, epoch_avg_acc}} = + :timer.tc(__MODULE__, :train_epoch, [cur_params, imgs, labels]) + + epoch_avg_loss = + epoch_avg_loss + |> Nx.backend_transfer() + |> Nx.to_number() + + epoch_avg_acc = + epoch_avg_acc + |> Nx.backend_transfer() + |> Nx.to_number() + + IO.puts("Epoch #{epoch} Time: #{time / 1_000_000}s") + IO.puts("Epoch #{epoch} average loss: #{inspect(epoch_avg_loss)}") + IO.puts("Epoch #{epoch} average accuracy: #{inspect(epoch_avg_acc)}") + IO.puts("\n") + new_params + end + end +end + +Nx.global_default_backend(Candlex.Backend) + +{train_images, train_labels} = + MNIST.download(~c"train-images-idx3-ubyte.gz", ~c"train-labels-idx1-ubyte.gz") + +IO.puts("Initializing parameters...\n") +params = MNIST.init_random_params() + +IO.puts("Training MNIST for 10 epochs...\n\n") +final_params = MNIST.train(train_images, train_labels, params, epochs: 10) + +IO.puts("The result of the first batch") + +IO.inspect( + MNIST.predict(final_params, hd(Enum.to_list(train_images))) + |> Nx.argmax(axis: :output) +) + +IO.puts("Labels for the first batch") +IO.inspect(hd(Enum.to_list(train_labels)) |> Nx.argmax(axis: :output)) From ceea256046541bd6fb023840ce4373c4090c43d9 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 11 Sep 2023 15:27:56 -0300 Subject: [PATCH 123/185] cargo fmt --- candlex/native/candlex/src/ops.rs | 2 +- candlex/native/candlex/src/tensors.rs | 22 +++++++++++++++++----- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 489be44495..6b9881d773 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -1,6 +1,6 @@ use candle_core::{CpuStorage, CustomOp1, CustomOp2, Error, Layout, Shape}; -use num_traits::Float; use num_traits::cast::FromPrimitive; +use num_traits::Float; fn erf_inv(v: T) -> T { FromPrimitive::from_f64(statrs::function::erf::erf_inv(v.to_f64().unwrap())).unwrap() diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index ed7c0a49af..8a9b12a470 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -6,7 +6,7 @@ use crate::ops::{ }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; -use rustler::{Atom, Binary, Env, Encoder, NewBinary, NifStruct, ResourceArc, Term}; +use rustler::{Atom, Binary, Encoder, Env, NewBinary, NifStruct, ResourceArc, Term}; use std::ops::Deref; use std::result::Result; use std::str::FromStr; @@ -24,7 +24,7 @@ impl ExTensor { pub fn new(tensor: Tensor) -> Self { let dev_string = match tensor.device() { Device::Cpu => String::from("cpu"), - Device::Cuda(_) => String::from("cuda") + Device::Cuda(_) => String::from("cuda"), }; Self { @@ -81,7 +81,10 @@ pub fn narrow( #[rustler::nif(schedule = "DirtyCpu")] pub fn chunk(t: ExTensor, num_chunks: usize) -> Result, CandlexError> { - Ok(t.chunk(num_chunks, 0)?.into_iter().map(|t| ExTensor::new(t)).collect()) + Ok(t.chunk(num_chunks, 0)? + .into_iter() + .map(|t| ExTensor::new(t)) + .collect()) } #[rustler::nif(schedule = "DirtyCpu")] @@ -152,7 +155,11 @@ pub fn argmin(ex_tensor: ExTensor, dim: usize, keep_dim: bool) -> Result, keep_dims: bool) -> Result { +pub fn sum( + ex_tensor: ExTensor, + dims: Vec, + keep_dims: bool, +) -> Result { let t = if keep_dims { ex_tensor.sum_keepdim(dims)? } else { @@ -298,7 +305,12 @@ fn tuple_to_vec(term: Term) -> Result, rustler::Error> { } fn vec_to_tuple(env: Env, vec: Vec) -> Result { - Ok(rustler::types::tuple::make_tuple(env, &vec.into_iter().map(|elem| elem.encode(env)).collect::>())) + Ok(rustler::types::tuple::make_tuple( + env, + &vec.into_iter() + .map(|elem| elem.encode(env)) + .collect::>(), + )) } fn device_from_atom(atom: Atom) -> Result { From 5bf1d16c02c17f487b583b000c9f068dce8cb0e2 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 12 Sep 2023 12:54:06 -0300 Subject: [PATCH 124/185] mean --- candlex/test/candlex_test.exs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index d56d613bd0..23dba0f784 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -983,6 +983,20 @@ defmodule CandlexTest do |> Nx.sigmoid() |> assert_close(t([0.7310585975646973, 0.8807970881462097, 0.9525741338729858])) end + + test "mean/1" do + t(42) + |> Nx.mean() + |> assert_equal(t(42.0)) + + t([1, 2, 3]) + |> Nx.mean() + |> assert_equal(t(2.0)) + + t([0.1, 0.2, 0.3]) + |> Nx.mean() + |> assert_equal(t(0.2)) + end end defp t(values, opts \\ []) do From 2779dfeab111688e7330d76df1a1c2181e64925a Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 12 Sep 2023 14:57:08 -0300 Subject: [PATCH 125/185] divide integer --- candlex/native/candlex/src/tensors.rs | 11 +++++++++- candlex/test/candlex_test.exs | 29 +++++++++++++-------------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 8a9b12a470..75b22fa66e 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -214,6 +214,16 @@ pub fn concatenate(ex_tensors: Vec, dim: usize) -> Result Result { + Ok(ExTensor::new( + // Need to force float in case we receive integers, given + // candle rounds down integer division. + left.to_dtype(DType::F32)? + .broadcast_div(&right.to_dtype(DType::F32)?)?, + )) +} + macro_rules! unary_nif { ($nif_name:ident, $native_fn_name:ident) => { #[rustler::nif(schedule = "DirtyCpu")] @@ -281,7 +291,6 @@ custom_unary_nif!(tan, Tan); binary_nif!(add, broadcast_add); binary_nif!(subtract, broadcast_sub); binary_nif!(multiply, broadcast_mul); -binary_nif!(divide, broadcast_div); binary_nif!(max, broadcast_maximum); binary_nif!(min, broadcast_minimum); binary_nif!(equal, eq); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 23dba0f784..890c325727 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -148,23 +148,22 @@ defmodule CandlexTest do ] )) - # TODO: Support integers - # 1 - # |> Nx.divide(2) - # |> assert_equal(t(0.5)) + 1 + |> Nx.divide(2) + |> assert_equal(t(0.5)) - # t([1, 2, 3]) - # |> Nx.divide(1) - # |> assert_equal(t([1.0, 2.0, 3.0])) + t([1, 2, 3]) + |> Nx.divide(2) + |> assert_equal(t([0.5, 1.0, 1.5])) - # t([[1], [2]]) - # |> Nx.divide(t([[10, 20]])) - # |> assert_equal(t( - # [ - # [0.10000000149011612, 0.05000000074505806], - # [0.20000000298023224, 0.10000000149011612] - # ] - # )) + t([[1], [2]]) + |> Nx.divide(t([[10, 20]])) + |> assert_equal(t( + [ + [0.10000000149011612, 0.05000000074505806], + [0.20000000298023224, 0.10000000149011612] + ] + )) end test "broadcast" do From 155f6eba0967000a32a27ffc28b707820ffad941 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 12 Sep 2023 15:04:24 -0300 Subject: [PATCH 126/185] fix sum out tensor type --- candlex/lib/candlex/backend.ex | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index a0471e8196..12a22c27c5 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -100,7 +100,7 @@ defmodule Candlex.Backend do end @impl true - def sum(%T{} = out, %T{} = t, opts) do + def sum(%T{type: out_type} = out, %T{} = t, opts) do axes = opts[:axes] || Nx.axes(t) keep_axes = opts[:keep_axes] || false @@ -108,6 +108,8 @@ defmodule Candlex.Backend do |> from_nx() |> Native.sum(axes, keep_axes) |> unwrap!() + |> Native.to_type(to_candle_dtype(out_type)) + |> unwrap!() |> to_nx(out) end From 79f02c1bd1c12f3af8073e828e67577794801024 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 12 Sep 2023 15:20:21 -0300 Subject: [PATCH 127/185] cargo update --- candlex/native/candlex/Cargo.lock | 41 ++++++++++++++------------- candlex/native/candlex/Cargo.toml | 2 +- candlex/native/candlex/src/tensors.rs | 2 +- 3 files changed, 23 insertions(+), 22 deletions(-) diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 37a5ca1406..42c353d099 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -34,9 +34,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bytemuck" -version = "1.13.1" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" +checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" [[package]] name = "byteorder" @@ -47,7 +47,8 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" version = "0.2.1" -source = "git+https://github.com/huggingface/candle#a8410bf35ea3ad8eb973f48d301e65309d232377" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d54bfcd86ef61f5ceb73aa43a49d08e395503e6c79f12ae8f5ed4695dd33b84e" dependencies = [ "byteorder", "candle-gemm", @@ -352,9 +353,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.6.2" +version = "2.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5486aed0026218e61b8a01d5fbd5a0a134649abb71a0e53b7bc088529dced86e" +checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" [[package]] name = "memmap2" @@ -568,9 +569,9 @@ checksum = "2962bf2e1f971c53ef59b2d7ca51d6a5e5c4a9d2be47eb1f661a321a4da85888" [[package]] name = "regex" -version = "1.9.4" +version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "12de2eff854e5fa4b1295edd650e227e9d8fb0c9e90b12e7f36d6a6811791a29" +checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" dependencies = [ "aho-corasick", "memchr", @@ -580,9 +581,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49530408a136e16e5b486e883fbb6ba058e8e4e8ae6621a77b048b314336e629" +checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" dependencies = [ "aho-corasick", "memchr", @@ -615,7 +616,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.32", ] [[package]] @@ -682,14 +683,14 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.32", ] [[package]] name = "serde_json" -version = "1.0.105" +version = "1.0.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +checksum = "2cc66a619ed80bf7a0f6b17dd063a84b88f6dea1813737cf469aef1d081142c2" dependencies = [ "itoa", "ryu", @@ -735,9 +736,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.29" +version = "2.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a" +checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" dependencies = [ "proc-macro2", "quote", @@ -746,22 +747,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.47" +version = "1.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97a802ec30afc17eee47b2855fc72e0c4cd62be9b4efe6591edde0ec5bd68d8f" +checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.47" +version = "1.0.48" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6bb623b56e39ab7dcd4b1b98bb6c8f8d907ed255b18de254088016b27a8ee19b" +checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" dependencies = [ "proc-macro2", "quote", - "syn 2.0.29", + "syn 2.0.32", ] [[package]] diff --git a/candlex/native/candlex/Cargo.toml b/candlex/native/candlex/Cargo.toml index 640adb2d92..ddc8bebabf 100644 --- a/candlex/native/candlex/Cargo.toml +++ b/candlex/native/candlex/Cargo.toml @@ -10,7 +10,7 @@ path = "src/lib.rs" crate-type = ["cdylib"] [dependencies] -candle-core = { git = "https://github.com/huggingface/candle" } +candle-core = { version = "0.2.1" } half = "2.3.1" num-traits = "0.2.16" rustler = "0.29.1" diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 75b22fa66e..77f4d35ee3 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -240,7 +240,7 @@ macro_rules! binary_nif { ($nif_name:ident, $native_fn_name:ident) => { #[rustler::nif(schedule = "DirtyCpu")] pub fn $nif_name(left: ExTensor, right: ExTensor) -> Result { - Ok(ExTensor::new(left.$native_fn_name(&right)?)) + Ok(ExTensor::new(left.$native_fn_name(right.deref())?)) } }; } From 6ea8675ea0de2f6a9ff926ef46f236c768b904ae Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 12 Sep 2023 15:32:12 -0300 Subject: [PATCH 128/185] mix format --- candlex/lib/candlex/backend.ex | 17 ++++++++++++----- candlex/test/candlex_test.exs | 6 ++++-- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 12a22c27c5..1c23923cb8 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -397,7 +397,8 @@ defmodule Candlex.Backend do |> maybe_add_signature(tensor) end - defp maybe_add_signature(result, %T{data: %CB{device: device, resource: ref}}) when is_reference(ref) do + defp maybe_add_signature(result, %T{data: %CB{device: device, resource: ref}}) + when is_reference(ref) do Inspect.Algebra.concat([ "Candlex.Backend(#{device})", Inspect.Algebra.line(), @@ -516,14 +517,20 @@ defmodule Candlex.Backend do |> from_nx() end - defp to_nx(%CB{resource: ref} = backend_tensor, %T{type: nx_type, shape: nx_shape} = t) when is_reference(ref) do + defp to_nx(%CB{resource: ref} = backend_tensor, %T{type: nx_type, shape: nx_shape} = t) + when is_reference(ref) do {:ok, candle_dtype} = Native.dtype(backend_tensor) {:ok, candle_shape} = Native.t_shape(backend_tensor) case {nx_type, from_candle_dtype(candle_dtype)} do - {{:u, 64}, {:s, 64}} -> :ok - {type, type} -> :ok - {type, other_type} -> raise "tensor type mismatch, Nx (#{inspect(type)}) and Candle (#{inspect(other_type)})" + {{:u, 64}, {:s, 64}} -> + :ok + + {type, type} -> + :ok + + {type, other_type} -> + raise "tensor type mismatch, Nx (#{inspect(type)}) and Candle (#{inspect(other_type)})" end if nx_shape != candle_shape do diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 890c325727..54fd7093f8 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -829,7 +829,7 @@ defmodule CandlexTest do t([16, 32, -64, -128], type: :u32) |> Nx.right_shift(t([1, 2, 3, 4])) - |> assert_equal(t([8, 8, 536870904, 268435448])) + |> assert_equal(t([8, 8, 536_870_904, 268_435_448])) end test "is_infinity" do @@ -886,7 +886,9 @@ defmodule CandlexTest do t([0.10000000149011612, 0.5, 0.8999999761581421], type: :f64) |> Nx.erf_inv() - |> assert_close(t([0.0888559891358877, 0.47693629334671295, 1.1630870196442271], type: :f64)) + |> assert_close( + t([0.0888559891358877, 0.47693629334671295, 1.1630870196442271], type: :f64) + ) end test "sum/2" do From c209f5273c88406245aed3f4774dc1eb9e2f3f58 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 12 Sep 2023 16:25:30 -0300 Subject: [PATCH 129/185] candlex logical_xor --- candlex/lib/candlex/backend.ex | 1 + candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 6 ++++++ candlex/native/candlex/src/tensors.rs | 3 ++- candlex/test/candlex_test.exs | 26 ++++++++++++++++++++++++++ 6 files changed, 37 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 1c23923cb8..4310c27eb6 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -186,6 +186,7 @@ defmodule Candlex.Backend do :less, :less_equal, :logical_or, + :logical_xor, :right_shift ] do @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index d5ada054c8..4740ab032b 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -58,6 +58,7 @@ defmodule Candlex.Native do :less, :less_equal, :logical_or, + :logical_xor, :matmul, :max, :min, diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 99837ef810..fbf000ce9b 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -75,6 +75,7 @@ rustler::init! { tensors::bitwise_or, tensors::bitwise_xor, tensors::logical_or, + tensors::logical_xor, tensors::left_shift, tensors::right_shift, devices::is_cuda_available diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 6b9881d773..55b22a9b39 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -214,3 +214,9 @@ custom_binary_bool_op!( |v1, v2| if v1 as i8 == 0 && v2 as i8 == 0 { 0 } else { 1 }, (U8, U32, I64, F32, F64) ); +custom_binary_bool_op!( + LogicalXor, + "logical_xor", + |v1, v2| if (v1 as i8 != 0) == (v2 as i8 != 0) { 0 } else { 1 }, + (U8, U32, I64, F32, F64) +); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 77f4d35ee3..8844916654 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -2,7 +2,7 @@ use crate::atoms; use crate::error::CandlexError; use crate::ops::{ Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, Ceil, ErfInv, Floor, IsInf, Log1p, - LogicalOr, Round, Shl, Shr, Sigmoid, Tan, + LogicalOr, LogicalXor, Round, Shl, Shr, Sigmoid, Tan, }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -304,6 +304,7 @@ custom_binary_nif!(bitwise_or, BitOr); custom_binary_nif!(bitwise_xor, BitXor); custom_binary_nif!(left_shift, Shl); custom_binary_nif!(logical_or, LogicalOr); +custom_binary_nif!(logical_xor, LogicalXor); custom_binary_nif!(right_shift, Shr); fn tuple_to_vec(term: Term) -> Result, rustler::Error> { diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 54fd7093f8..b01349fd6e 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -872,6 +872,32 @@ defmodule CandlexTest do )) end + test "logical_xor" do + 0 + |> Nx.logical_xor(t([-1, 0, 1])) + |> assert_equal(t([1, 0, 1])) + + t([-1, 0, 1]) + |> Nx.logical_xor(t([[-1], [0], [1]])) + |> assert_equal(t( + [ + [0, 1, 0], + [1, 0, 1], + [0, 1, 0] + ] + )) + + t([-1.0, 0.0, 1.0]) + |> Nx.logical_xor(t([[-1], [0], [1]])) + |> assert_equal(t( + [ + [0, 1, 0], + [1, 0, 1], + [0, 1, 0] + ] + )) + end + test "erf_inv" do Nx.erf_inv(0.10000000149011612) |> assert_close(t(0.08885598927736282)) From c8b8a2c342a1a247ede544929ee93baea83eda1b Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 14 Sep 2023 13:24:13 -0300 Subject: [PATCH 130/185] candlex float pow --- candlex/lib/candlex/backend.ex | 13 +++++++++++++ candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 1 + candlex/native/candlex/src/tensors.rs | 3 ++- candlex/test/candlex_test.exs | 26 ++++++++++++++++++++++++++ 6 files changed, 44 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 4310c27eb6..42fd24e44d 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -176,6 +176,19 @@ defmodule Candlex.Backend do end end + for op <- [:pow] do + @impl true + def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do + {left, right} = maybe_upcast(left, right) + {left, right} = maybe_broadcast_bin_args(out.shape, left, right) + + left + |> Native.unquote(op)(right) + |> unwrap!() + |> to_nx(out) + end + end + for op <- [ :bitwise_and, :bitwise_or, diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 4740ab032b..c072f997f5 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -63,6 +63,7 @@ defmodule Candlex.Native do :max, :min, :multiply, + :pow, :right_shift, :subtract ] do diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index fbf000ce9b..df9df4857b 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -27,6 +27,7 @@ rustler::init! { tensors::subtract, tensors::multiply, tensors::divide, + tensors::pow, tensors::max, tensors::min, tensors::equal, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 55b22a9b39..8635b014dc 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -206,6 +206,7 @@ custom_unary_op_closure!(Sigmoid, "sigmoid", |v| 1. / (1. + (-v).exp()), (F32, F custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2, (U32, I64)); custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2, (U32, I64)); custom_binary_op!(BitXor, "bit_xor", |v1, v2| v1 ^ v2, (U32, I64)); +custom_binary_op!(Pow, "pow", |v1, v2| v1.powf(v2), (F32, F64)); custom_binary_op!(Shl, "shl", |v1, v2| v1 << v2, (U32, I64)); custom_binary_op!(Shr, "shr", |v1, v2| v1 >> v2, (U32, I64)); custom_binary_bool_op!( diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 8844916654..0a73073531 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -2,7 +2,7 @@ use crate::atoms; use crate::error::CandlexError; use crate::ops::{ Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, Ceil, ErfInv, Floor, IsInf, Log1p, - LogicalOr, LogicalXor, Round, Shl, Shr, Sigmoid, Tan, + LogicalOr, LogicalXor, Pow, Round, Shl, Shr, Sigmoid, Tan, }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -305,6 +305,7 @@ custom_binary_nif!(bitwise_xor, BitXor); custom_binary_nif!(left_shift, Shl); custom_binary_nif!(logical_or, LogicalOr); custom_binary_nif!(logical_xor, LogicalXor); +custom_binary_nif!(pow, Pow); custom_binary_nif!(right_shift, Shr); fn tuple_to_vec(term: Term) -> Result, rustler::Error> { diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index b01349fd6e..a0f6640b17 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1024,6 +1024,32 @@ defmodule CandlexTest do |> Nx.mean() |> assert_equal(t(0.2)) end + + test "pow" do + # Nx.pow(2, 4) + # |> assert_equal(t(16)) + + # t([1, 2, 3], type: :u32) + # |> Nx.pow(t(2, type: :u32)) + # |> assert_equal(t([1, 4, 9])) + + t([1.0, 2.0, 3.0]) + |> Nx.pow(2) + |> assert_equal(t([1.0, 4.0, 9.0])) + + 2 + |> Nx.pow(t([1.0, 2.0, 3.0])) + |> assert_equal(t([2.0, 4.0, 8.0])) + + # t([[2], [3]]) + # |> Nx.pow(t([[4, 5]])) + # |> assert_equal(t( + # [ + # [16, 32], + # [81, 243] + # ] + # )) + end end defp t(values, opts \\ []) do From 719a382b9345c57da4c5e31b40e527d2ee4835ac Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 14 Sep 2023 13:30:55 -0300 Subject: [PATCH 131/185] candlex greater --- candlex/lib/candlex/backend.ex | 1 + candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 1 + candlex/test/candlex_test.exs | 21 +++++++++++++++++++++ 5 files changed, 25 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 42fd24e44d..866b0224a9 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -194,6 +194,7 @@ defmodule Candlex.Backend do :bitwise_or, :bitwise_xor, :equal, + :greater, :greater_equal, :left_shift, :less, diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index c072f997f5..de5057bdda 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -53,6 +53,7 @@ defmodule Candlex.Native do :bitwise_xor, :divide, :equal, + :greater, :greater_equal, :left_shift, :less, diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index df9df4857b..25f7feb981 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -31,6 +31,7 @@ rustler::init! { tensors::max, tensors::min, tensors::equal, + tensors::greater, tensors::greater_equal, tensors::less, tensors::less_equal, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 0a73073531..99077a6e29 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -294,6 +294,7 @@ binary_nif!(multiply, broadcast_mul); binary_nif!(max, broadcast_maximum); binary_nif!(min, broadcast_minimum); binary_nif!(equal, eq); +binary_nif!(greater, gt); binary_nif!(greater_equal, ge); binary_nif!(less, lt); binary_nif!(less_equal, le); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index a0f6640b17..99c5eadebf 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -219,6 +219,27 @@ defmodule CandlexTest do ) end + test "greater" do + Nx.greater(1, 2) + |> assert_equal(t(0)) + + Nx.greater(1, t([1, 2, 3])) + |> assert_equal(t([0, 0, 0])) + + t([1, 2, 3]) + |> Nx.greater(t([1, 2, 2])) + |> assert_equal(t([0, 0, 1])) + + t([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + |> Nx.greater(t([1, 2, 3])) + |> assert_equal(t( + [ + [0, 0, 0], + [1, 1, 1] + ] + )) + end + test "less" do Nx.less(1, 2) |> assert_equal(t(1)) From b32e4faa37cb96b4b446b9a7a2ee829dd94ebf2e Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 15 Sep 2023 10:25:53 -0300 Subject: [PATCH 132/185] candlex to_batched (repeat) --- candlex/lib/candlex/backend.ex | 45 ++++++++++++++++++++++++++-------- candlex/test/candlex_test.exs | 40 ++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 10 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 866b0224a9..9bba22b23e 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -510,17 +510,42 @@ defmodule Candlex.Backend do ## Conversions @impl true - def to_batched(%T{shape: out_shape} = out, %T{shape: shape} = t, _opts) do - # TODO: dont ignore opts - batch_size = elem(out_shape, 0) - t_axis_0 = elem(shape, 0) - num_batches = div(t_axis_0, batch_size) + def to_batched(%T{shape: out_shape} = out, %T{shape: shape} = t, opts) do + leftover = opts[:leftover] + first_dimension = 0 + batch_size = elem(out_shape, first_dimension) + axis_total = elem(shape, first_dimension) + remainder = rem(axis_total, batch_size) + num_batches = div(axis_total, batch_size) + native_tensor = from_nx(t) + + native_batches = + cond do + remainder == 0 -> + native_tensor + |> Native.chunk(num_batches) + |> unwrap!() + remainder > 0 && leftover == :repeat -> + slice_shape = + shape + |> Tuple.delete_at(first_dimension) + |> Tuple.insert_at(first_dimension, remainder) + + [ + native_tensor, + Native.narrow(native_tensor, first_dimension, 0, batch_size - remainder) + |> unwrap!() + ] + |> Native.concatenate(first_dimension) + |> unwrap!() + |> Native.chunk(num_batches + 1) + |> unwrap!() + true -> + raise "not implemented" + end - t - |> from_nx() - |> Native.chunk(num_batches) - |> unwrap!() - |> Stream.map(&to_nx(&1, out)) + native_batches + |> Stream.map(&to_nx(&1, out)) end @doc false diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 99c5eadebf..5c80588474 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1021,6 +1021,46 @@ defmodule CandlexTest do ] ] )) + + [first, second] = + Nx.iota({10}) + |> Nx.to_batched(5) + |> Enum.to_list() + + first + |> assert_equal(Nx.tensor([0, 1, 2, 3, 4])) + + second + |> assert_equal(Nx.tensor([5, 6, 7, 8, 9])) + + [first, second, third, fourth] = + Nx.iota({10}) + |> Nx.to_batched(3) + |> Enum.to_list() + + first + |> assert_equal(Nx.tensor([0, 1, 2])) + + second + |> assert_equal(Nx.tensor([3, 4, 5])) + + third + |> assert_equal(Nx.tensor([6, 7, 8])) + + fourth + |> assert_equal(Nx.tensor([9, 0, 1])) + + # TODO: Implement with discard + # [first, second] = + # Nx.iota({10}) + # |> Nx.to_batched(4, leftover: :discard) + # |> Enum.to_list() + + # first + # |> assert_equal(Nx.tensor([0, 1, 2, 3])) + + # second + # |> assert_equal(Nx.tensor([4, 5, 6, 7])) end test "sigmoid/1" do From c92c6b8edfc53067a695588720537f2c32225829 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 18 Sep 2023 16:05:45 -0300 Subject: [PATCH 133/185] candlex conv (partially) --- candlex/lib/candlex/backend.ex | 65 +++++++++++++++++ candlex/lib/candlex/native.ex | 3 + candlex/native/candlex/src/lib.rs | 3 + candlex/native/candlex/src/tensors.rs | 25 +++++++ candlex/test/candlex_test.exs | 101 ++++++++++++++++++++++++++ 5 files changed, 197 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 9bba22b23e..5d4fc13349 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -283,6 +283,59 @@ defmodule Candlex.Backend do |> to_nx(out) end + @impl true + def conv(%T{type: out_type} = out, %T{shape: shape} = tensor, %T{} = kernel, opts) do + # TODO: Support more opts + unsupported_option!(opts, :batch_group_size, 1) + unsupported_option!(opts, :feature_group_size, 1) + + # For now we assume: + # strides = opts[:strides] # [1, 1] + # padding = opts[:padding] # [{0, 0}, {0, 0}] + # input_dilation = opts[:input_dilation] # [1, 1] + # kernel_dilation = opts[:kernel_dilation] # [1, 1] + + input_permutation = opts[:input_permutation] + kernel_permutation = opts[:kernel_permutation] + + output_permutation = + case opts[:output_permutation] do + nil -> + nil + + l -> + # The permutation that Nx.Shape expects is actually the reverse permutation + # for the given input + l |> Enum.with_index() |> Enum.sort() |> Enum.map(&elem(&1, 1)) + end + + native_tensor = + tensor + |> from_nx() + |> permute(input_permutation) + |> Native.to_type(to_candle_dtype(out_type)) + |> unwrap!() + + native_kernel = + kernel + |> from_nx() + |> permute(kernel_permutation) + |> Native.to_type(to_candle_dtype(out_type)) + |> unwrap!() + + native_result = + case Nx.rank(shape) do + 3 -> Native.conv1d(native_tensor, native_kernel) + 4 -> Native.conv2d(native_tensor, native_kernel) + rank -> raise("unsupported conv for tensor of rank #{rank}, only 3 or 4 supported") + end + + native_result + |> unwrap!() + |> permute(output_permutation) + |> to_nx(out) + end + @impl true def dot( %T{type: _out_type} = out, @@ -548,6 +601,12 @@ defmodule Candlex.Backend do |> Stream.map(&to_nx(&1, out)) end + defp permute(native_tensor, permutation) do + native_tensor + |> Native.permute(permutation) + |> unwrap!() + end + @doc false defp from_nx(%T{data: %CB{} = data}), do: data @@ -633,6 +692,12 @@ defmodule Candlex.Backend do raise("Unsupported candlex operation '#{op_name}'") end + defp unsupported_option!(opts, key, acceptable_default) do + if opts[key] != nil and opts[key] != acceptable_default do + raise "#{inspect(key)} option with #{inspect(opts[key])} is not supported" + end + end + defp unwrap!({:ok, result}), do: result defp unwrap!({:error, error}), do: raise("Candlex: #{error}") end diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index de5057bdda..d920672554 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -19,6 +19,8 @@ defmodule Candlex.Native do def dtype(_tensor), do: error() def t_shape(_tensor), do: error() def concatenate(_tensors, _axis), do: error() + def conv1d(_tensor, _kernel), do: error() + def conv2d(_tensor, _kernel), do: error() for op <- [ :abs, @@ -72,6 +74,7 @@ defmodule Candlex.Native do end def sum(_tensor, _dims, _keep_dims), do: error() + def permute(_tensor, _dims), do: error() for op <- [:argmax, :argmin] do def unquote(op)(_tensor, _dim, _keep_dim), do: error() diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 25f7feb981..0d7478b72f 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -52,6 +52,9 @@ rustler::init! { tensors::broadcast_to, tensors::reshape, tensors::concatenate, + tensors::conv1d, + tensors::conv2d, + tensors::permute, tensors::matmul, tensors::abs, tensors::acos, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 99077a6e29..970bc102a3 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -169,6 +169,11 @@ pub fn sum( Ok(ExTensor::new(t)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn permute(ex_tensor: ExTensor, dims: Vec) -> Result { + Ok(ExTensor::new(ex_tensor.permute(dims)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn broadcast_to(t: ExTensor, shape: Term) -> Result { Ok(ExTensor::new(t.broadcast_as(tuple_to_vec(shape).unwrap())?)) @@ -214,6 +219,26 @@ pub fn concatenate(ex_tensors: Vec, dim: usize) -> Result Result { + let padding = 0; + let stride = 1; + let dilation = 1; + let groups = 1; + + Ok(ExTensor::new(tensor.conv1d(kernel.deref(), padding, stride, dilation, groups)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn conv2d(tensor: ExTensor, kernel: ExTensor) -> Result { + let padding = 0; + let stride = 1; + let dilation = 1; + let groups = 1; + + Ok(ExTensor::new(tensor.conv2d(kernel.deref(), padding, stride, dilation, groups)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn divide(left: ExTensor, right: ExTensor) -> Result { Ok(ExTensor::new( diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 5c80588474..31576a91f6 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1111,6 +1111,107 @@ defmodule CandlexTest do # ] # )) end + + test "conv" do + Nx.iota({9}) + |> Nx.reshape({1, 1, 3, 3}) + |> Nx.conv( + Nx.iota({4}) + |> Nx.reshape({4, 1, 1, 1}), + strides: [1, 1] + ) + |> assert_equal(t( + [ + [ + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0] + ], + [ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0] + ], + [ + [0.0, 2.0, 4.0], + [6.0, 8.0, 10.0], + [12.0, 14.0, 16.0] + ], + [ + [0.0, 3.0, 6.0], + [9.0, 12.0, 15.0], + [18.0, 21.0, 24.0] + ] + ] + ] + )) + + # input/output permutation + + result = + Nx.iota({1, 3, 3, 6}) + |> Nx.conv( + 1 |> Nx.broadcast({2, 6, 1, 1}), + input_permutation: [0, 3, 1, 2], + output_permutation: [0, 3, 1, 2] + ) + + assert result.shape == {1, 3, 3, 2} + + result + |> assert_close(t( + [ + [ + [15.0, 15.0], + [51.0, 51.0], + [87.0, 87.0] + ], + [ + [123.0, 123.0], + [159.0, 159.0], + [195.0, 195.0] + ], + [ + [231.0, 231.0], + [267.0, 267.0], + [303.0, 303.0] + ] + ] + )) + + # Nx.iota({9}) + # |> Nx.reshape({1, 1, 3, 3}) + # |> Nx.conv( + # Nx.iota({8}) + # |> Nx.reshape({4, 1, 2, 1}), + # strides: 2, + # padding: :same, + # kernel_dilation: [2, 1] + # ) + # |> assert_equal(t( + # [ + # [ + # [ + # [3.0, 5.0], + # [0.0, 0.0] + # ], + # [ + # [9.0, 15.0], + # [6.0, 10.0] + # ], + # [ + # [15.0, 25.0], + # [12.0, 20.0] + # ], + # [ + # [21.0, 35.0], + # [18.0, 30.0] + # ] + # ] + # ] + # )) + end end defp t(values, opts \\ []) do From 89bb0a3d5efa602462b7c266fc9f795381bbc6e6 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Mon, 18 Sep 2023 16:21:29 -0300 Subject: [PATCH 134/185] candlex rsqrt --- candlex/lib/candlex/backend.ex | 1 + candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 5 +++++ candlex/test/candlex_test.exs | 9 +++++++++ 5 files changed, 17 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 5d4fc13349..58a5b136f7 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -237,6 +237,7 @@ defmodule Candlex.Backend do :log1p, :negate, :round, + :rsqrt, :sigmoid, :sin, :sqrt, diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index d920672554..8a26b037e3 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -39,6 +39,7 @@ defmodule Candlex.Native do :log1p, :negate, :round, + :rsqrt, :sigmoid, :sin, :sqrt, diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 0d7478b72f..507f1fdce1 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -72,6 +72,7 @@ rustler::init! { tensors::round, tensors::log, tensors::log1p, + tensors::rsqrt, tensors::sqrt, tensors::tan, tensors::tanh, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 970bc102a3..eb9b4c6a91 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -97,6 +97,11 @@ pub fn transpose(t: ExTensor, dim1: usize, dim2: usize) -> Result Result { + Ok(ExTensor::new(t.sqrt()?.recip()?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn arange( start: i64, diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 31576a91f6..4e2078f5c9 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -565,6 +565,15 @@ defmodule CandlexTest do |> assert_equal(t([1.0, 1.4142135381698608, 1.7320507764816284])) end + test "rsqrt" do + Nx.rsqrt(1.0) + |> assert_equal(t(1.0)) + + t([1.0, 2, 3]) + |> Nx.rsqrt() + |> assert_equal(t([1.0, 0.7071067690849304, 0.5773502588272095])) + end + test "argmax" do Nx.argmax(4) |> assert_equal(t(0)) From f506c311f22f19b2591c0c13e5f3b2f5ce393f77 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 19 Sep 2023 15:21:40 -0300 Subject: [PATCH 135/185] candlex reduce_max --- candlex/lib/candlex/backend.ex | 23 +++++++++++++++++ candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 11 ++++++++ candlex/test/candlex_test.exs | 37 +++++++++++++++++++++++++++ 5 files changed, 73 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 58a5b136f7..8a088761bf 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -135,6 +135,29 @@ defmodule Candlex.Backend do end end + @impl true + def reduce_max(%T{} = out, %T{shape: {}} = tensor, _opts) do + out + |> from_binary(to_binary(tensor), []) + end + def reduce_max(%T{} = out, %T{} = tensor, opts) do + axis = + case opts[:axes] do + nil -> 0 + [] -> 0 + [axis] -> axis + axes -> raise "doesn't support axes option with more than 1 axis, '#{inspect(axes)}'" + end + + keep_axis = opts[:keep_axes] || false + + tensor + |> from_nx() + |> Native.reduce_max(axis, keep_axis) + |> unwrap!() + |> to_nx(out) + end + # Element-wise @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 8a26b037e3..447da3d27a 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -77,7 +77,7 @@ defmodule Candlex.Native do def sum(_tensor, _dims, _keep_dims), do: error() def permute(_tensor, _dims), do: error() - for op <- [:argmax, :argmin] do + for op <- [:argmax, :argmin, :reduce_max] do def unquote(op)(_tensor, _dim, _keep_dim), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 507f1fdce1..e4005f1475 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -41,6 +41,7 @@ rustler::init! { tensors::t_shape, tensors::argmax, tensors::argmin, + tensors::reduce_max, tensors::negate, tensors::where_cond, tensors::narrow, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index eb9b4c6a91..ac64695711 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -159,6 +159,17 @@ pub fn argmin(ex_tensor: ExTensor, dim: usize, keep_dim: bool) -> Result Result { + let t = if keep_dim { + ex_tensor.max_keepdim(dim)? + } else { + ex_tensor.max(dim)? + }; + + Ok(ExTensor::new(t)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn sum( ex_tensor: ExTensor, diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 4e2078f5c9..7508555bdd 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1221,6 +1221,43 @@ defmodule CandlexTest do # ] # )) end + + test "reduce_max" do + t(42) + |> Nx.reduce_max() + |> assert_equal(t(42)) + + t(42.0) + |> Nx.reduce_max() + |> assert_equal(t(42.0)) + + t([1, 2, 3]) + |> Nx.reduce_max() + |> assert_equal(t(3)) + + t([[3, 1, 4], [2, 1, 1]], names: [:x, :y]) + |> Nx.reduce_max(axes: [:x]) + |> assert_equal(t([3, 1, 4])) + + t([[3, 1, 4], [2, 1, 1]], names: [:x, :y]) + |> Nx.reduce_max(axes: [:y]) + |> assert_equal(t([4, 2])) + + # t([[[1, 2], [4, 5]], [[2, 4], [3, 8]]], names: [:x, :y, :z]) + # |> Nx.reduce_max(axes: [:x, :z]) + # |> assert_equal(t([4, 8])) + + # t([[[1, 2], [4, 5]], [[2, 4], [3, 8]]], names: [:x, :y, :z]) + # |> Nx.reduce_max(axes: [:x, :z], keep_axes: true) + # |> assert_equal(t( + # [ + # [ + # [4], + # [8] + # ] + # ] + # )) + end end defp t(values, opts \\ []) do From d21ee6559406f2b6c3668f3033ee66ba6dfc776a Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 20 Sep 2023 11:15:34 -0300 Subject: [PATCH 136/185] candlex take_along_axis --- candlex/lib/candlex/backend.ex | 9 +++++++ candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 5 ++++ candlex/test/candlex_test.exs | 38 +++++++++++++++++++++++++++ 5 files changed, 54 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 8a088761bf..3bef224868 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -295,6 +295,15 @@ defmodule Candlex.Backend do |> to_nx(out) end + @impl true + def take_along_axis(%T{} = out, %T{} = tensor, %T{} = indexes, axis) do + tensor + |> from_nx() + |> Native.gather(from_nx(indexes), axis) + |> unwrap!() + |> to_nx(out) + end + # N-dim @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 447da3d27a..a09143e63d 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -9,6 +9,7 @@ defmodule Candlex.Native do def all(_tensor), do: error() def where_cond(_tensor, _on_true, _on_false), do: error() def narrow(_tensor, _dim, _start, _length), do: error() + def gather(_tensor, _indexes, _dim), do: error() def chunk(_tensor, _num_chunks), do: error() def squeeze(_tensor, _dim), do: error() def arange(_start, _end, _dtype, _shape, _device), do: error() diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index e4005f1475..e121bdbb6e 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -45,6 +45,7 @@ rustler::init! { tensors::negate, tensors::where_cond, tensors::narrow, + tensors::gather, tensors::chunk, tensors::squeeze, tensors::transpose, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index ac64695711..91f8c521d9 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -79,6 +79,11 @@ pub fn narrow( Ok(ExTensor::new(t.narrow(dim, start, length)?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn gather(t: ExTensor, indexes: ExTensor, dim: usize) -> Result { + Ok(ExTensor::new(t.gather(indexes.deref(), dim)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn chunk(t: ExTensor, num_chunks: usize) -> Result, CandlexError> { Ok(t.chunk(num_chunks, 0)? diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 7508555bdd..ff042e264c 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1258,6 +1258,44 @@ defmodule CandlexTest do # ] # )) end + + test "take_along_axis" do + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.take_along_axis( + t( + [ + [0, 0, 2, 2, 1, 1], + [2, 2, 1, 1, 0, 0] + ] + ), + axis: 1 + ) + |> assert_equal(t( + [ + [1, 1, 3, 3, 2, 2], + [6, 6, 5, 5, 4, 4] + ] + )) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.take_along_axis( + t( + [ + [0, 1, 1], + [1, 0, 0], + [0, 1, 0] + ] + ), + axis: 0 + ) + |> assert_equal(t( + [ + [1, 5, 6], + [4, 2, 3], + [1, 5, 3] + ] + )) + end end defp t(values, opts \\ []) do From b61be28fbea1eeaf76996d97c37b5cc08b0fd884 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 20 Sep 2023 11:48:39 -0300 Subject: [PATCH 137/185] candlex iota with axis argument --- candlex/lib/candlex/backend.ex | 17 +++++++++++ candlex/test/candlex_test.exs | 53 ++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 3bef224868..3be5945bdf 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -51,6 +51,23 @@ defmodule Candlex.Backend do |> to_nx(out) end + def iota(%T{shape: shape, type: type} = out, axis, backend_options) do + # Build in one dimension, then broadcast + axis_size = elem(shape, axis) + + Native.arange( + 0, + axis_size, + to_candle_dtype(type), + Tuple.duplicate(1, Nx.rank(shape)) |> put_elem(axis, axis_size), + device_option(backend_options) + ) + |> unwrap!() + |> Native.broadcast_to(shape) + |> unwrap!() + |> to_nx(out) + end + @impl true def eye(%T{shape: shape, type: type} = out, backend_options) do iota = Nx.iota(shape, backend_options) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index ff042e264c..e4bf2e3feb 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -86,6 +86,59 @@ defmodule CandlexTest do Nx.iota({2, 3}) |> assert_equal(t([[0, 1, 2], [3, 4, 5]])) + + Nx.iota({3, 3}, axis: 1) + |> assert_equal(t( + [ + [0, 1, 2], + [0, 1, 2], + [0, 1, 2] + ] + )) + + Nx.iota({3, 3}, axis: -1) + |> assert_equal(t( + [ + [0, 1, 2], + [0, 1, 2], + [0, 1, 2] + ] + )) + + Nx.iota({3, 4, 3}, axis: 0, type: :f64) + |> assert_equal(t( + [ + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0] + ], + [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0] + ], + [ + [2.0, 2.0, 2.0], + [2.0, 2.0, 2.0], + [2.0, 2.0, 2.0], + [2.0, 2.0, 2.0] + ] + ] + )) + + Nx.iota({1, 3, 2}, axis: 2) + |> assert_equal(t( + [ + [ + [0, 1], + [0, 1], + [0, 1] + ] + ] + )) end test "max" do From 4be4f1a8aa84018e1347df26cc7659af98260d72 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 20 Sep 2023 12:20:14 -0300 Subject: [PATCH 138/185] mix format --- candlex/lib/candlex/backend.ex | 50 +++++++++++++++++----------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 3be5945bdf..768acef9db 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -157,6 +157,7 @@ defmodule Candlex.Backend do out |> from_binary(to_binary(tensor), []) end + def reduce_max(%T{} = out, %T{} = tensor, opts) do axis = case opts[:axes] do @@ -622,33 +623,32 @@ defmodule Candlex.Backend do num_batches = div(axis_total, batch_size) native_tensor = from_nx(t) - native_batches = - cond do - remainder == 0 -> - native_tensor - |> Native.chunk(num_batches) - |> unwrap!() - remainder > 0 && leftover == :repeat -> - slice_shape = - shape - |> Tuple.delete_at(first_dimension) - |> Tuple.insert_at(first_dimension, remainder) - - [ - native_tensor, - Native.narrow(native_tensor, first_dimension, 0, batch_size - remainder) - |> unwrap!() - ] - |> Native.concatenate(first_dimension) - |> unwrap!() - |> Native.chunk(num_batches + 1) + cond do + remainder == 0 -> + native_tensor + |> Native.chunk(num_batches) + |> unwrap!() + + remainder > 0 && leftover == :repeat -> + slice_shape = + shape + |> Tuple.delete_at(first_dimension) + |> Tuple.insert_at(first_dimension, remainder) + + [ + native_tensor, + Native.narrow(native_tensor, first_dimension, 0, batch_size - remainder) |> unwrap!() - true -> - raise "not implemented" - end + ] + |> Native.concatenate(first_dimension) + |> unwrap!() + |> Native.chunk(num_batches + 1) + |> unwrap!() - native_batches - |> Stream.map(&to_nx(&1, out)) + true -> + raise "not implemented" + end + |> Stream.map(&to_nx(&1, out)) end defp permute(native_tensor, permutation) do From ff833d1dcf3463b0124ef07ec34594e633b2e375 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 20 Sep 2023 12:20:57 -0300 Subject: [PATCH 139/185] cargo fmt --- candlex/native/candlex/src/ops.rs | 6 +++++- candlex/native/candlex/src/tensors.rs | 22 +++++++++++++++++++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 8635b014dc..05ba8a09fd 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -218,6 +218,10 @@ custom_binary_bool_op!( custom_binary_bool_op!( LogicalXor, "logical_xor", - |v1, v2| if (v1 as i8 != 0) == (v2 as i8 != 0) { 0 } else { 1 }, + |v1, v2| if (v1 as i8 != 0) == (v2 as i8 != 0) { + 0 + } else { + 1 + }, (U8, U32, I64, F32, F64) ); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 91f8c521d9..bfddb65ece 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -165,7 +165,11 @@ pub fn argmin(ex_tensor: ExTensor, dim: usize, keep_dim: bool) -> Result Result { +pub fn reduce_max( + ex_tensor: ExTensor, + dim: usize, + keep_dim: bool, +) -> Result { let t = if keep_dim { ex_tensor.max_keepdim(dim)? } else { @@ -247,7 +251,13 @@ pub fn conv1d(tensor: ExTensor, kernel: ExTensor) -> Result Result Date: Wed, 20 Sep 2023 12:32:11 -0300 Subject: [PATCH 140/185] cargo update --- candlex/native/candlex/Cargo.lock | 70 +++++++++++++++---------------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 42c353d099..285227e772 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -4,9 +4,9 @@ version = 3 [[package]] name = "aho-corasick" -version = "1.0.5" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c378d78423fdad8089616f827526ee33c19f2fddbd5de1629152c9593ba4783" +checksum = "ea5d730647d4fadd988536d06fecce94b7b4f2a7efdae548f1cf4b63205518ab" dependencies = [ "memchr", ] @@ -46,9 +46,9 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54bfcd86ef61f5ceb73aa43a49d08e395503e6c79f12ae8f5ed4695dd33b84e" +checksum = "dde6b117b2e56ee68959aad0cb5fdfe65f39328c8c82e5bd356a6b95ade471c7" dependencies = [ "byteorder", "candle-gemm", @@ -66,9 +66,9 @@ dependencies = [ [[package]] name = "candle-gemm" -version = "0.15.6" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b726a1f6cdd7ff080e95e3d91694701b1e04a58acd198e4a78c39428b2274e" +checksum = "ef9b07a4b0ba1a304b44432006580980ddff9748c201261c279437e7b11bba68" dependencies = [ "candle-gemm-c32", "candle-gemm-c64", @@ -88,9 +88,9 @@ dependencies = [ [[package]] name = "candle-gemm-c32" -version = "0.15.6" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "661470663389f0c99fd8449e620bfae630a662739f830a323eda4dcf80888843" +checksum = "f595241dad99811de285e029889f57c29dd98e33de7a8a6b881867b1488d7d4a" dependencies = [ "candle-gemm-common", "dyn-stack", @@ -105,9 +105,9 @@ dependencies = [ [[package]] name = "candle-gemm-c64" -version = "0.15.6" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4a111ddf61db562854a6d2ff4dfe1e8a84066431b7bc68d3afae4bf60874fda0" +checksum = "648f22fd8f5a4f330e29d791845b514966421308a6a2b5fedb949ee07e54c77f" dependencies = [ "candle-gemm-common", "dyn-stack", @@ -122,9 +122,9 @@ dependencies = [ [[package]] name = "candle-gemm-common" -version = "0.15.6" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a6dd93783ead7eeef14361667ea32014dc6f716a2fc956b075fe78729e10dd5" +checksum = "e03c01b4ca3b9d71e4eb89e42946a08f8b0d2f1b861f7fa2ea0966233f1e0b08" dependencies = [ "dyn-stack", "lazy_static", @@ -138,9 +138,9 @@ dependencies = [ [[package]] name = "candle-gemm-f16" -version = "0.15.6" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b76499bf4b858cacc526c5c8f948bc7152774247dce8568f174b743ab1363fa4" +checksum = "97f8af2a482131713d28a337abff6debf26c529afa1837caf2ba190909b2107c" dependencies = [ "candle-gemm-common", "candle-gemm-f32", @@ -157,9 +157,9 @@ dependencies = [ [[package]] name = "candle-gemm-f32" -version = "0.15.6" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bec152e7d36339d3785e0d746d75ee94a4e92968fbb12ddcc91b536b938d016" +checksum = "938927961e2f0c0a6064fcf3524ea3f7f455fe5708419532a6fea9aea1ab45ae" dependencies = [ "candle-gemm-common", "dyn-stack", @@ -174,9 +174,9 @@ dependencies = [ [[package]] name = "candle-gemm-f64" -version = "0.15.6" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00f59ac68a5521e2ff71431bb7f1b22126ff0b60c5e66599b1f4676433da6e69" +checksum = "d192d7126e59b81ef4cf13cd9f194e6dbdc09171f65d0074d059dc009ac06775" dependencies = [ "candle-gemm-common", "dyn-stack", @@ -313,9 +313,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" [[package]] name = "itoa" @@ -331,9 +331,9 @@ checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "libc" -version = "0.2.147" +version = "0.2.148" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "9cdc71e17332e86d2e1d38c1f99edcb6288ee11b815fb1a4b049eaa2114d369b" [[package]] name = "libm" @@ -468,9 +468,9 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "proc-macro2" -version = "1.0.66" +version = "1.0.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" +checksum = "3d433d9f1a3e8c1263d9456598b16fec66f4acc9a74dacffd35c7bb09b3a1328" dependencies = [ "unicode-ident", ] @@ -616,7 +616,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.37", ] [[package]] @@ -683,14 +683,14 @@ checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.37", ] [[package]] name = "serde_json" -version = "1.0.106" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cc66a619ed80bf7a0f6b17dd063a84b88f6dea1813737cf469aef1d081142c2" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" dependencies = [ "itoa", "ryu", @@ -736,9 +736,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.32" +version = "2.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" +checksum = "7303ef2c05cd654186cb250d29049a24840ca25d2747c25c0381c8d9e2f582e8" dependencies = [ "proc-macro2", "quote", @@ -762,20 +762,20 @@ checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" dependencies = [ "proc-macro2", "quote", - "syn 2.0.32", + "syn 2.0.37", ] [[package]] name = "typenum" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicode-ident" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unreachable" From 1dcdedc7255c75ac0a00ac1d55414475b932ab21 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 31 Aug 2023 14:12:39 -0300 Subject: [PATCH 141/185] decent CUDA support of some backend functions --- candlex/config/config.exs | 17 +++++++++ candlex/lib/candlex/backend.ex | 26 +++++++------ candlex/lib/candlex/native.ex | 3 +- candlex/native/candlex/Cargo.lock | 55 +++++++++++++++++---------- candlex/native/candlex/Cargo.toml | 5 ++- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 9 ++++- candlex/test/candlex_test.exs | 13 ++----- 8 files changed, 83 insertions(+), 46 deletions(-) create mode 100644 candlex/config/config.exs diff --git a/candlex/config/config.exs b/candlex/config/config.exs new file mode 100644 index 0000000000..aecd8eacf6 --- /dev/null +++ b/candlex/config/config.exs @@ -0,0 +1,17 @@ +import Config + +enable_cuda = + case System.get_env("CUDA") do + nil -> System.find_executable("nvcc") && System.find_executable("nvidia-smi") + "false" -> false + _ -> true + end + +crate_features = + if enable_cuda do + [:cuda] + else + [] + end + +config :candlex, crate_features: crate_features diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 768acef9db..f3809c5bdd 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -16,11 +16,7 @@ defmodule Candlex.Backend do @impl true def init(opts) do - if opts != [] do - raise ArgumentError, "Candlex.Backend accepts no options" - end - - opts + Keyword.validate!(opts, [:device]) end # Creation @@ -79,6 +75,14 @@ defmodule Candlex.Backend do # Backend @impl true + def backend_copy(%T{} = tensor, Candlex.Backend, backend_options) do + tensor + |> from_nx() + |> Native.to_device(device_option(backend_options)) + |> unwrap!() + |> to_nx(tensor) + end + def backend_copy(%T{} = tensor, backend, backend_options) do backend.from_binary(tensor, to_binary(tensor), backend_options) end @@ -721,13 +725,11 @@ defmodule Candlex.Backend do end defp default_device do - # TODO: Support CUDA - # if cuda_available?() do - # @device_cuda - # else - # @device_cpu - # end - @device_cpu + if cuda_available?() do + @device_cuda + else + @device_cpu + end end defp cuda_available? do diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index a09143e63d..0d4af778db 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -1,7 +1,7 @@ defmodule Candlex.Native do @moduledoc false - use Rustler, otp_app: :candlex, crate: "candlex" + use Rustler, otp_app: :candlex, features: Application.compile_env(:candlex, :crate_features, []) # Rustler will override all the below stub functions with real NIFs def from_binary(_binary, _dtype, _shape, _device), do: error() @@ -83,6 +83,7 @@ defmodule Candlex.Native do end def is_cuda_available(), do: error() + def to_device(_tensor, _device), do: error() defp error(), do: :erlang.nif_error(:nif_not_loaded) end diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 285227e772..440d8a35fe 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -46,12 +46,13 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dde6b117b2e56ee68959aad0cb5fdfe65f39328c8c82e5bd356a6b95ade471c7" +version = "0.2.3" +source = "git+https://github.com/mimiquate/candle?branch=cuda#de3f2be78379ffdaf404bb3d46b3578e6a005957" dependencies = [ "byteorder", "candle-gemm", + "candle-kernels", + "cudarc", "half", "memmap2", "num-traits", @@ -189,6 +190,15 @@ dependencies = [ "seq-macro", ] +[[package]] +name = "candle-kernels" +version = "0.2.3" +source = "git+https://github.com/mimiquate/candle?branch=cuda#de3f2be78379ffdaf404bb3d46b3578e6a005957" +dependencies = [ + "glob", + "rayon", +] + [[package]] name = "candlex" version = "0.1.0" @@ -216,16 +226,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "crossbeam-channel" -version = "0.5.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" -dependencies = [ - "cfg-if", - "crossbeam-utils", -] - [[package]] name = "crossbeam-deque" version = "0.8.3" @@ -265,6 +265,15 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +[[package]] +name = "cudarc" +version = "0.9.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc4cab390f4a32340211f015292a4551742a63e528e9ade9e0bde0d1a989d2a1" +dependencies = [ + "half", +] + [[package]] name = "dyn-stack" version = "0.9.0" @@ -292,6 +301,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "half" version = "2.3.1" @@ -343,9 +358,9 @@ checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" [[package]] name = "matrixmultiply" -version = "0.3.7" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "090126dc04f95dc0d1c1c91f61bdd474b3930ca064c1edc8a849da2c6cbe1e77" +checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" dependencies = [ "autocfg", "rawpointer", @@ -541,9 +556,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" [[package]] name = "rayon" -version = "1.7.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" dependencies = [ "either", "rayon-core", @@ -551,14 +566,12 @@ dependencies = [ [[package]] name = "rayon-core" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" dependencies = [ - "crossbeam-channel", "crossbeam-deque", "crossbeam-utils", - "num_cpus", ] [[package]] diff --git a/candlex/native/candlex/Cargo.toml b/candlex/native/candlex/Cargo.toml index ddc8bebabf..576eaeb042 100644 --- a/candlex/native/candlex/Cargo.toml +++ b/candlex/native/candlex/Cargo.toml @@ -10,9 +10,12 @@ path = "src/lib.rs" crate-type = ["cdylib"] [dependencies] -candle-core = { version = "0.2.1" } +candle-core = { git = "https://github.com/mimiquate/candle", branch = "cuda" } half = "2.3.1" num-traits = "0.2.16" rustler = "0.29.1" statrs = "0.16.0" thiserror = "1.0.47" + +[features] +cuda = ["candle-core/cuda"] diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index e121bdbb6e..18ea43b5a9 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -86,6 +86,7 @@ rustler::init! { tensors::logical_xor, tensors::left_shift, tensors::right_shift, + tensors::to_device, devices::is_cuda_available ], load = load diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index bfddb65ece..9ce39cbb34 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -60,6 +60,11 @@ pub fn from_binary( )?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn to_device(ex_tensor: ExTensor, device: Atom) -> Result { + Ok(ExTensor::new(ex_tensor.to_device(&device_from_atom(device)?)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Result { let bytes = tensor_bytes(ex_tensor.flatten_all()?)?; @@ -390,8 +395,8 @@ fn vec_to_tuple(env: Env, vec: Vec) -> Result { fn device_from_atom(atom: Atom) -> Result { if atom == atoms::cpu() { Ok(Device::Cpu) - // } else if atom == atoms::cuda() { - // Ok(Device::new_cuda(0)?) + } else if atom == atoms::cuda() { + Ok(Device::new_cuda(0)?) } else { Err(CandlexError::Other(format!( "unsupported device {:?}", diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index e4bf2e3feb..a1aedc0e5e 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -22,11 +22,6 @@ defmodule CandlexTest do check(2.16, type: :bf16) end - # test "gpu" do - # t([1, 2, 3], backend: {Candlex.Backend, device: :cuda}) - # |> assert_equal(t([1, 2, 3])) - # end - test "named dimensions" do check([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) @@ -560,11 +555,11 @@ defmodule CandlexTest do test "sin" do Nx.sin(1.0) - |> assert_equal(t(0.8414709568023682)) + |> assert_close(t(0.8414709568023682)) t([1.0, 2.0, 3.0]) |> Nx.sin() - |> assert_equal(t([0.8414709568023682, 0.9092974066734314, 0.14112000167369843])) + |> assert_close(t([0.8414709568023682, 0.9092974066734314, 0.14112000167369843])) end test "exp" do @@ -578,11 +573,11 @@ defmodule CandlexTest do test "cos" do Nx.cos(1.0) - |> assert_equal(t(0.5403022766113281)) + |> assert_close(t(0.5403022766113281)) t([1.0, 2, 3]) |> Nx.cos() - |> assert_equal(t([0.5403022766113281, -0.416146844625473, -0.9899924993515015])) + |> assert_close(t([0.5403022766113281, -0.416146844625473, -0.9899924993515015])) end test "log" do From bbae52cf43f0b6c92da1a5d419fe7220c8a81167 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 22 Sep 2023 14:19:20 -0300 Subject: [PATCH 142/185] cargo fmt --- candlex/native/candlex/src/tensors.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 9ce39cbb34..875df40ae2 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -62,7 +62,9 @@ pub fn from_binary( #[rustler::nif(schedule = "DirtyCpu")] pub fn to_device(ex_tensor: ExTensor, device: Atom) -> Result { - Ok(ExTensor::new(ex_tensor.to_device(&device_from_atom(device)?)?)) + Ok(ExTensor::new( + ex_tensor.to_device(&device_from_atom(device)?)?, + )) } #[rustler::nif(schedule = "DirtyCpu")] From 55c76d46d05ef82db11c50b93c458b716110e0f2 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 22 Sep 2023 14:20:09 -0300 Subject: [PATCH 143/185] cuda: implement custom ops --- candlex/native/candlex/Cargo.lock | 11 +- candlex/native/candlex/Cargo.toml | 3 + candlex/native/candlex/build.rs | 251 ++++++++++++++++ candlex/native/candlex/src/kernels.rs | 4 + .../candlex/src/kernels/custom_binary.cu | 89 ++++++ .../candlex/src/kernels/custom_unary.cu | 69 +++++ .../native/candlex/src/kernels/strides.cuh | 34 +++ candlex/native/candlex/src/lib.rs | 2 + candlex/native/candlex/src/ops.rs | 276 +++++++++++++++++- candlex/test/candlex_test.exs | 8 +- 10 files changed, 735 insertions(+), 12 deletions(-) create mode 100644 candlex/native/candlex/build.rs create mode 100644 candlex/native/candlex/src/kernels.rs create mode 100644 candlex/native/candlex/src/kernels/custom_binary.cu create mode 100644 candlex/native/candlex/src/kernels/custom_unary.cu create mode 100644 candlex/native/candlex/src/kernels/strides.cuh diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 440d8a35fe..c1cafadd16 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -11,6 +11,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "anyhow" +version = "1.0.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" + [[package]] name = "approx" version = "0.5.1" @@ -47,7 +53,7 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" version = "0.2.3" -source = "git+https://github.com/mimiquate/candle?branch=cuda#de3f2be78379ffdaf404bb3d46b3578e6a005957" +source = "git+https://github.com/mimiquate/candle?branch=cuda#18ba60c1de589473abd53d4757e4e87484856b4e" dependencies = [ "byteorder", "candle-gemm", @@ -193,7 +199,7 @@ dependencies = [ [[package]] name = "candle-kernels" version = "0.2.3" -source = "git+https://github.com/mimiquate/candle?branch=cuda#de3f2be78379ffdaf404bb3d46b3578e6a005957" +source = "git+https://github.com/mimiquate/candle?branch=cuda#18ba60c1de589473abd53d4757e4e87484856b4e" dependencies = [ "glob", "rayon", @@ -203,6 +209,7 @@ dependencies = [ name = "candlex" version = "0.1.0" dependencies = [ + "anyhow", "candle-core", "half", "num-traits", diff --git a/candlex/native/candlex/Cargo.toml b/candlex/native/candlex/Cargo.toml index 576eaeb042..d5d273b237 100644 --- a/candlex/native/candlex/Cargo.toml +++ b/candlex/native/candlex/Cargo.toml @@ -17,5 +17,8 @@ rustler = "0.29.1" statrs = "0.16.0" thiserror = "1.0.47" +[build-dependencies] +anyhow = "1.0.75" + [features] cuda = ["candle-core/cuda"] diff --git a/candlex/native/candlex/build.rs b/candlex/native/candlex/build.rs new file mode 100644 index 0000000000..138bf992c0 --- /dev/null +++ b/candlex/native/candlex/build.rs @@ -0,0 +1,251 @@ +#![allow(unused)] +use anyhow::{Context, Result}; +use std::io::Write; +use std::path::PathBuf; + +struct KernelDirectories { + kernel_dir: &'static str, + rust_target: &'static str, + include_dirs: &'static [&'static str], +} + +const DIRS: [KernelDirectories; 1] = [KernelDirectories { + kernel_dir: "src/kernels/", + rust_target: "src/kernels.rs", + include_dirs: &[], +}]; + +impl KernelDirectories { + fn maybe_build_ptx( + &self, + cu_file: &std::path::Path, + ptx_file: &std::path::Path, + compute_cap: usize, + ) -> Result<()> { + let should_compile = + if ptx_file.exists() { + cu_file + .metadata()? + .modified()? + .duration_since(ptx_file.metadata()?.modified()?) + .is_ok() + } else { + true + }; + + if should_compile { + #[cfg(feature = "cuda")] + { + let mut command = std::process::Command::new("nvcc"); + let out_dir = ptx_file.parent().context("no parent for ptx file")?; + let include_dirs: Vec = self.include_dirs.iter().map(|c| format!("-I{c}")).collect(); + + command + .arg(format!("--gpu-architecture=sm_{compute_cap}")) + .arg("--ptx") + .args(["--default-stream", "per-thread"]) + .args(["--output-directory", out_dir.to_str().unwrap()]) + .arg(format!("-I/{}", self.kernel_dir)) + .args(include_dirs) + .arg(cu_file); + + let output = + command + .spawn() + .context("failed spawning nvcc")? + .wait_with_output()?; + + if !output.status.success() { + anyhow::bail!( + "nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ) + } + } + + #[cfg(not(feature = "cuda"))] + std::fs::OpenOptions::new() + .create(true) + .write(true) + .open(ptx_file)?; + } + + Ok(()) + } + + fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> { + println!("cargo:rerun-if-changed={}", self.kernel_dir); + let kernel_dir = PathBuf::from(self.kernel_dir); + let out_dir = out_dir.join(self.kernel_dir); + if !out_dir.exists() { + std::fs::create_dir_all(&out_dir)?; + } + let mut cu_files = vec![]; + let mut cuh_files = vec![]; + for file in std::fs::read_dir(kernel_dir)?.flatten() { + let file = file.path(); + match file.extension().and_then(|v| v.to_str()) { + Some("cu") => cu_files.push(file), + Some("cuh") => cuh_files.push(file), + _ => {} + } + } + + let mut ptx_paths = vec![]; + for cu_file in cu_files.iter() { + let file_stem = cu_file + .file_stem() + .with_context(|| format!("no stem {cu_file:?}"))?; + let file_stem = file_stem.to_string_lossy().into_owned(); + let ptx_file = out_dir.join(&format!("{file_stem}.ptx")); + self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?; + ptx_paths.push(ptx_file); + } + + let regenerate_rs_file = true; + if regenerate_rs_file { + let mut file = std::fs::File::create(self.rust_target)?; + for ptx_path in ptx_paths { + let name = ptx_path + .file_stem() + .context("empty stem")? + .to_string_lossy(); + file.write_all(b"#[rustfmt::skip]\n")?; + let const_definition = format!( + r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#, + name.to_uppercase().replace('.', "_"), + self.kernel_dir, + ); + file.write_all(const_definition.as_bytes())?; + file.write_all(b"\n")?; + } + } + + Ok(()) + } +} + +fn main() -> Result<()> { + println!("cargo:rerun-if-changed=build.rs"); + + let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?; + let out_dir = PathBuf::from(out_dir); + #[cfg(feature = "cuda")] + set_cuda_include_dir()?; + #[cfg(feature = "cuda")] + let compute_cap = compute_cap()?; + #[cfg(not(feature = "cuda"))] + let compute_cap = 0; + + for dir in DIRS { + dir.process(&out_dir, compute_cap)? + } + + Ok(()) +} + +fn set_cuda_include_dir() -> Result<()> { + // NOTE: copied from cudarc build.rs. + let env_vars = [ + "CUDA_PATH", + "CUDA_ROOT", + "CUDA_TOOLKIT_ROOT_DIR", + "CUDNN_LIB", + ]; + let env_vars = env_vars + .into_iter() + .map(std::env::var) + .filter_map(Result::ok) + .map(Into::::into); + + let roots = [ + "/usr", + "/usr/local/cuda", + "/opt/cuda", + "/usr/lib/cuda", + "C:/Program Files/NVIDIA GPU Computing Toolkit", + "C:/CUDA", + ]; + let roots = roots.into_iter().map(Into::::into); + let root = env_vars + .chain(roots) + .find(|path| path.join("include").join("cuda.h").is_file()) + .context("cannot find include/cuda.h")?; + println!( + "cargo:rustc-env=CUDA_INCLUDE_DIR={}", + root.join("include").display() + ); + Ok(()) +} + +#[allow(unused)] +fn compute_cap() -> Result { + // Grab compute code from nvidia-smi + let mut compute_cap = { + let out = std::process::Command::new("nvidia-smi") + .arg("--query-gpu=compute_cap") + .arg("--format=csv") + .output() + .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; + let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; + let mut lines = out.lines(); + assert_eq!( + lines.next().context("missing line in stdout")?, + "compute_cap" + ); + let cap = lines + .next() + .context("missing line in stdout")? + .replace('.', ""); + cap.parse::() + .with_context(|| format!("cannot parse as int {cap}"))? + }; + + // Grab available GPU codes from nvcc and select the highest one + let max_nvcc_code = { + let out = std::process::Command::new("nvcc") + .arg("--list-gpu-code") + .output() + .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); + let out = std::str::from_utf8(&out.stdout).unwrap(); + + let out = out.lines().collect::>(); + let mut codes = Vec::with_capacity(out.len()); + for code in out { + let code = code.split('_').collect::>(); + if !code.is_empty() && code.contains(&"sm") { + if let Ok(num) = code[1].parse::() { + codes.push(num); + } + } + } + codes.sort(); + if !codes.contains(&compute_cap) { + anyhow::bail!( + "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}." + ); + } + *codes.last().unwrap() + }; + + // If nvidia-smi compute_cap is higher than the highest gpu code from nvcc, + // then choose the highest gpu code in nvcc + if compute_cap > max_nvcc_code { + println!( + "cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}." + ); + compute_cap = max_nvcc_code; + } + + println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); + + if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { + compute_cap = compute_cap_str + .parse::() + .with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?; + println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP"); + } + println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}"); + Ok(compute_cap) +} diff --git a/candlex/native/candlex/src/kernels.rs b/candlex/native/candlex/src/kernels.rs new file mode 100644 index 0000000000..13317b35c7 --- /dev/null +++ b/candlex/native/candlex/src/kernels.rs @@ -0,0 +1,4 @@ +#[rustfmt::skip] +pub const CUSTOM_BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_binary.ptx")); +#[rustfmt::skip] +pub const CUSTOM_UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_unary.ptx")); diff --git a/candlex/native/candlex/src/kernels/custom_binary.cu b/candlex/native/candlex/src/kernels/custom_binary.cu new file mode 100644 index 0000000000..b6e9bb5047 --- /dev/null +++ b/candlex/native/candlex/src/kernels/custom_binary.cu @@ -0,0 +1,89 @@ +#include +#include +#include "strides.cuh" + +#define CUSTOM_BINARY_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, FUNC) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *dims_and_strides, \ + const TYPENAME *lhs, \ + const TYPENAME *rhs, \ + OUT_TYPENAME *out \ +) { \ + const size_t *dims = dims_and_strides; \ + const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \ + const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \ + bool lhs_cont = is_contiguous(num_dims, dims, lhs_strides); \ + bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \ + if (lhs_cont && rhs_cont) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + TYPENAME x = lhs[i]; \ + TYPENAME y = rhs[i]; \ + out[i] = FUNC; \ + } \ + } else if (lhs_cont) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int rhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + rhs_i += i_dim * rhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs[i]; \ + TYPENAME y = rhs[rhs_i]; \ + out[i] = FUNC; \ + } \ + } else if (rhs_cont) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int lhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + lhs_i += i_dim * lhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs[lhs_i]; \ + TYPENAME y = rhs[i]; \ + out[i] = FUNC; \ + } \ + } else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int lhs_i = 0; \ + unsigned int rhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + lhs_i += i_dim * lhs_strides[d]; \ + rhs_i += i_dim * rhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs[lhs_i]; \ + TYPENAME y = rhs[rhs_i]; \ + out[i] = FUNC; \ + } \ + } \ +} \ + +#define CUSTOM_BINARY_OP(TYPENAME, FN_NAME, FUNC) \ + CUSTOM_BINARY_OP_OUT(TYPENAME, TYPENAME, FN_NAME, FUNC) + +CUSTOM_BINARY_OP(uint32_t, bit_and_u32, x & y) +CUSTOM_BINARY_OP(int64_t, bit_and_i64, x & y) +CUSTOM_BINARY_OP(uint32_t, bit_or_u32, x | y) +CUSTOM_BINARY_OP(int64_t, bit_or_i64, x | y) +CUSTOM_BINARY_OP(uint32_t, bit_xor_u32, x ^ y) +CUSTOM_BINARY_OP(int64_t, bit_xor_i64, x ^ y) +CUSTOM_BINARY_OP(float, pow_f32, pow(x, y)) +CUSTOM_BINARY_OP(double, pow_f64, pow(x, y)) +CUSTOM_BINARY_OP(uint32_t, shl_u32, x << y) +CUSTOM_BINARY_OP(int64_t, shl_i64, x << y) +CUSTOM_BINARY_OP(uint32_t, shr_u32, x >> y) +CUSTOM_BINARY_OP(int64_t, shr_i64, x >> y) + +CUSTOM_BINARY_OP_OUT(uint8_t, uint8_t, logical_or_u8, x || y) +CUSTOM_BINARY_OP_OUT(int64_t, uint8_t, logical_or_i64, x || y) +CUSTOM_BINARY_OP_OUT(float, uint8_t, logical_or_f32, x || y) +CUSTOM_BINARY_OP_OUT(int64_t, uint8_t, logical_xor_i64, !x != !y) +CUSTOM_BINARY_OP_OUT(float, uint8_t, logical_xor_f32, !x != !y) diff --git a/candlex/native/candlex/src/kernels/custom_unary.cu b/candlex/native/candlex/src/kernels/custom_unary.cu new file mode 100644 index 0000000000..29b36de614 --- /dev/null +++ b/candlex/native/candlex/src/kernels/custom_unary.cu @@ -0,0 +1,69 @@ +#define _USE_MATH_DEFINES +#include +#include +#include +#include "strides.cuh" + +__device__ __forceinline__ float atang(float a) { return atanf(a); } +__device__ __forceinline__ double atang(double a) { return atan(a); } +__device__ __forceinline__ float erfinvg(float a) { return erfinvf(a); } +__device__ __forceinline__ double erfinvg(double a) { return erfinv(a); } +__device__ __forceinline__ float tang(float a) { return tanf(a); } +__device__ __forceinline__ double tang(double a) { return tan(a); } + +#define CUSTOM_UNARY_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, FUNC) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const TYPENAME *inp, \ + OUT_TYPENAME *out \ +) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + if (is_contiguous(num_dims, dims, strides)) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + TYPENAME x = inp ? inp[i] : out[i]; \ + out[i] = FUNC; \ + } \ + } \ + else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + TYPENAME x = inp ? inp[strided_i] : out[i]; \ + out[i] = FUNC; \ + } \ + } \ +} \ + +#define CUSTOM_UNARY_OP(TYPENAME, FN_NAME, FUNC) \ + CUSTOM_UNARY_OP_OUT(TYPENAME, TYPENAME, FN_NAME, FUNC) + +CUSTOM_UNARY_OP(float, acos_f32, acos(x)) +CUSTOM_UNARY_OP(double, acos_f64, acos(x)) +CUSTOM_UNARY_OP(float, asin_f32, asin(x)) +CUSTOM_UNARY_OP(double, asin_f64, asin(x)) +CUSTOM_UNARY_OP(float, atan_f32, atang(x)) +CUSTOM_UNARY_OP(double, atan_f64, atang(x)) +CUSTOM_UNARY_OP(uint8_t, bit_not_u8, ~x) +CUSTOM_UNARY_OP(uint32_t, bit_not_u32, ~x) +CUSTOM_UNARY_OP(int64_t, bit_not_i64, ~x) +CUSTOM_UNARY_OP(float, cbrt_f32, cbrt(x)) +CUSTOM_UNARY_OP(double, cbrt_f64, cbrt(x)) +CUSTOM_UNARY_OP(float, ceil_f32, ceil(x)) +CUSTOM_UNARY_OP(double, ceil_f64, ceil(x)) +CUSTOM_UNARY_OP(float, erf_inv_f32, erfinvg(x)) +CUSTOM_UNARY_OP(double, erf_inv_f64, erfinvg(x)) +CUSTOM_UNARY_OP(float, floor_f32, floor(x)) +CUSTOM_UNARY_OP(double, floor_f64, floor(x)) +CUSTOM_UNARY_OP(float, ln_1p_f32, log1p(x)) +CUSTOM_UNARY_OP(double, ln_1p_f64, log1p(x)) +CUSTOM_UNARY_OP(float, round_f32, round(x)) +CUSTOM_UNARY_OP(double, round_f64, round(x)) +CUSTOM_UNARY_OP(float, sigmoid_f32, 1.0 / (1.0 + exp(-x))) +CUSTOM_UNARY_OP(double, sigmoid_f64, 1.0 / (1.0 + exp(-x))) +CUSTOM_UNARY_OP(float, tan_f32, tang(x)) +CUSTOM_UNARY_OP(double, tan_f64, tang(x)) + +CUSTOM_UNARY_OP_OUT(float, uint8_t, is_inf_f32, isinf(x) ? 1 : 0) +CUSTOM_UNARY_OP_OUT(double, uint8_t, is_inf_f64, isinf(x) ? 1 : 0) diff --git a/candlex/native/candlex/src/kernels/strides.cuh b/candlex/native/candlex/src/kernels/strides.cuh new file mode 100644 index 0000000000..c95123de51 --- /dev/null +++ b/candlex/native/candlex/src/kernels/strides.cuh @@ -0,0 +1,34 @@ +// TODO: This is often used to check that the data is contiguous so that +// kernels can be easily mapped. However this only returns true for row +// major, if all the inputs are column major, we could apply the fast path +// too (but we wouldn't if some of them are row major and some column major). +__device__ bool is_contiguous( + const size_t num_dims, + const size_t *dims, + const size_t *strides +) { + size_t acc = 1; + for (unsigned int d = 0; d < num_dims; d++) { + unsigned int dim_idx = num_dims - 1 - d; + if (acc != strides[dim_idx]) { + return false; + } + acc *= dims[dim_idx]; + } + return true; +} + +__device__ unsigned int get_strided_index( + unsigned int idx, + const size_t num_dims, + const size_t *dims, + const size_t *strides +) { + unsigned int strided_i = 0; + for (unsigned int d = 0; d < num_dims; d++) { + unsigned int dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 18ea43b5a9..b09ab68eb7 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -7,6 +7,8 @@ mod atoms { mod devices; mod error; +#[cfg(feature = "cuda")] +mod kernels; mod ops; mod tensors; diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 05ba8a09fd..eee1fa4afb 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "cuda")] +use candle_core::CudaStorage; use candle_core::{CpuStorage, CustomOp1, CustomOp2, Error, Layout, Shape}; use num_traits::cast::FromPrimitive; use num_traits::Float; @@ -35,6 +37,55 @@ macro_rules! custom_unary_op { s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())? } } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + storage: &CudaStorage, + layout: &Layout, + ) -> Result<(CudaStorage, Shape), candle_core::Error> { + use crate::kernels; + use candle_core::cuda_backend::cudarc::driver::{CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig}; + use candle_core::cuda_backend::{kernel_name, Map1, WrapErr}; + use candle_core::{CudaDevice, WithDType}; + + impl Map1 for $struct_name { + fn f( + &self, + src: &CudaSlice, + device: &CudaDevice, + layout: &Layout, + ) -> Result, candle_core::Error> { + let src = src.slice(layout.start_offset()..); + let func = device.get_or_load_func(&kernel_name::($name), kernels::CUSTOM_UNARY)?; + let dims = layout.shape().dims(); + let elem_count = layout.shape().elem_count(); + let launch_config = LaunchConfig::for_num_elems(elem_count as u32); + let dims_and_strides = device.htod_copy([dims, layout.stride()].concat()).w()?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { device.alloc::(elem_count) }.w()?; + let params = (elem_count, dims.len(), &dims_and_strides, &src, &dst); + // SAFETY: ffi. + unsafe { func.launch(launch_config, params) }.w()?; + + Ok(dst) + } + } + + use candle_core::backend::BackendStorage; + let device = storage.device(); + let slice = $struct_name.map(&storage.slice, device, layout)?; + + Ok( + ( + CudaStorage { + slice, + device: device.clone(), + }, + layout.shape().clone() + ) + ) + } } }; } @@ -68,12 +119,62 @@ macro_rules! custom_unary_bool_op { s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())? } } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + storage: &CudaStorage, + layout: &Layout, + ) -> Result<(CudaStorage, Shape), candle_core::Error> { + use crate::kernels; + use candle_core::cuda_backend::cudarc::driver::{CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits}; + use candle_core::cuda_backend::{kernel_name, CudaStorageSlice, Map1Any, WrapErr}; + use candle_core::{CudaDevice, WithDType}; + + impl Map1Any for $struct_name { + fn f) -> CudaStorageSlice>( + &self, + src: &CudaSlice, + device: &CudaDevice, + layout: &Layout, + _wrap: W, + ) -> Result { + let src = src.slice(layout.start_offset()..); + let func = device.get_or_load_func(&kernel_name::($name), kernels::CUSTOM_UNARY)?; + let dims = layout.shape().dims(); + let elem_count = layout.shape().elem_count(); + let launch_config = LaunchConfig::for_num_elems(elem_count as u32); + let dims_and_strides = device.htod_copy([dims, layout.stride()].concat()).w()?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { device.alloc::(elem_count) }.w()?; + let params = (elem_count, dims.len(), &dims_and_strides, &src, &dst); + // SAFETY: ffi. + unsafe { func.launch(launch_config, params) }.w()?; + + Ok(CudaStorageSlice::U8(dst)) + } + } + + use candle_core::backend::BackendStorage; + let device = storage.device(); + let slice = $struct_name.map(&storage.slice, device, layout)?; + + Ok( + ( + CudaStorage { + slice, + device: device.clone(), + }, + layout.shape().clone() + ) + ) + } } }; } macro_rules! custom_unary_op_closure { - ($struct_name:ident, $name:expr, $closure:expr, ($($dtypes:ident),+)) => { + ($struct_name:ident, $name:expr, $cpu_closure:expr, ($($dtypes:ident),+)) => { pub(crate) struct $struct_name; impl CustomOp1 for $struct_name { @@ -94,19 +195,68 @@ macro_rules! custom_unary_op_closure { match storage { $( CpuStorage::$dtypes(vec) => { - let data = candle_core::cpu_backend::unary_map(vec, layout, $closure); + let data = candle_core::cpu_backend::unary_map(vec, layout, $cpu_closure); Ok((CpuStorage::$dtypes(data), layout.shape().clone())) } )* s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())? } } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + storage: &CudaStorage, + layout: &Layout, + ) -> Result<(CudaStorage, Shape), candle_core::Error> { + use crate::kernels; + use candle_core::cuda_backend::cudarc::driver::{CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig}; + use candle_core::cuda_backend::{kernel_name, Map1, WrapErr}; + use candle_core::{CudaDevice, WithDType}; + + impl Map1 for $struct_name { + fn f( + &self, + src: &CudaSlice, + device: &CudaDevice, + layout: &Layout, + ) -> Result, candle_core::Error> { + let src = src.slice(layout.start_offset()..); + let func = device.get_or_load_func(&kernel_name::($name), kernels::CUSTOM_UNARY)?; + let dims = layout.shape().dims(); + let elem_count = layout.shape().elem_count(); + let launch_config = LaunchConfig::for_num_elems(elem_count as u32); + let dims_and_strides = device.htod_copy([dims, layout.stride()].concat()).w()?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { device.alloc::(elem_count) }.w()?; + let params = (elem_count, dims.len(), &dims_and_strides, &src, &dst); + // SAFETY: ffi. + unsafe { func.launch(launch_config, params) }.w()?; + + Ok(dst) + } + } + + use candle_core::backend::BackendStorage; + let device = storage.device(); + let slice = $struct_name.map(&storage.slice, device, layout)?; + + Ok( + ( + CudaStorage { + slice, + device: device.clone(), + }, + layout.shape().clone() + ) + ) + } } }; } macro_rules! custom_binary_op { - ($struct_name:ident, $name:literal, $closure:expr, ($($dtypes:ident),+)) => { + ($struct_name:ident, $name:literal, $cpu_closure:expr, ($($dtypes:ident),+)) => { pub(crate) struct $struct_name; impl CustomOp2 for $struct_name { @@ -128,7 +278,7 @@ macro_rules! custom_binary_op { match (s1, s2) { $( (CpuStorage::$dtypes(lhs), CpuStorage::$dtypes(rhs)) => { - let data = candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, $closure); + let data = candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, $cpu_closure); Ok((CpuStorage::$dtypes(data), l1.shape().clone())) } @@ -143,12 +293,69 @@ macro_rules! custom_binary_op { } } } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &CudaStorage, + l1: &Layout, + s2: &CudaStorage, + l2: &Layout, + ) -> Result<(CudaStorage, Shape), candle_core::Error> { + use crate::kernels; + use candle_core::cuda_backend::cudarc::driver::{CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits}; + use candle_core::cuda_backend::{kernel_name, Map2, WrapErr}; + use candle_core::{CudaDevice, WithDType}; + + impl Map2 for $struct_name { + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + device: &CudaDevice, + ) -> Result, candle_core::Error> { + let shape1 = layout1.shape(); + let dims1 = shape1.dims(); + let elem_count1 = shape1.elem_count(); + let launch_config = LaunchConfig::for_num_elems(elem_count1 as u32); + let dims_and_strides = device + .htod_copy([dims1, layout1.stride(), layout2.stride()].concat()) + .w()?; + let src1 = src1.slice(layout1.start_offset()..); + let src2 = src2.slice(layout2.start_offset()..); + let func = device.get_or_load_func(&kernel_name::($name), kernels::CUSTOM_BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { device.alloc::(elem_count1) }.w()?; + let params = (elem_count1, dims1.len(), &dims_and_strides, &src1, &src2, &out); + // SAFETY: ffi + unsafe { func.launch(launch_config, params) }.w()?; + + Ok(out) + } + } + + use candle_core::backend::BackendStorage; + let device = s1.device(); + let slice = $struct_name.map(&s1.slice, l1, &s2.slice, l2, device)?; + + Ok( + ( + CudaStorage { + slice, + device: device.clone(), + }, + l1.shape().clone() + ) + ) + } } } } macro_rules! custom_binary_bool_op { - ($struct_name:ident, $name:literal, $closure:expr, ($($dtypes:ident),+)) => { + ($struct_name:ident, $name:literal, $cpu_closure:expr, ($($dtypes:ident),+)) => { pub(crate) struct $struct_name; impl CustomOp2 for $struct_name { @@ -170,7 +377,7 @@ macro_rules! custom_binary_bool_op { match (s1, s2) { $( (CpuStorage::$dtypes(lhs), CpuStorage::$dtypes(rhs)) => { - let data = candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, |v1, v2| u8::from($closure(v1, v2))); + let data = candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, |v1, v2| u8::from($cpu_closure(v1, v2))); Ok((CpuStorage::U8(data), l1.shape().clone())) } @@ -185,6 +392,63 @@ macro_rules! custom_binary_bool_op { } } } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &CudaStorage, + l1: &Layout, + s2: &CudaStorage, + l2: &Layout, + ) -> Result<(CudaStorage, Shape), candle_core::Error> { + use crate::kernels; + use candle_core::cuda_backend::cudarc::driver::{CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits}; + use candle_core::cuda_backend::{kernel_name, CudaStorageSlice, Map2Any, WrapErr}; + use candle_core::{CudaDevice, WithDType}; + + impl Map2Any for $struct_name { + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + device: &CudaDevice, + ) -> Result { + let shape1 = layout1.shape(); + let dims1 = shape1.dims(); + let elem_count1 = shape1.elem_count(); + let launch_config = LaunchConfig::for_num_elems(elem_count1 as u32); + let dims_and_strides = device + .htod_copy([dims1, layout1.stride(), layout2.stride()].concat()) + .w()?; + let src1 = src1.slice(layout1.start_offset()..); + let src2 = src2.slice(layout2.start_offset()..); + let func = device.get_or_load_func(&kernel_name::($name), kernels::CUSTOM_BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { device.alloc::(elem_count1) }.w()?; + let params = (elem_count1, dims1.len(), &dims_and_strides, &src1, &src2, &out); + // SAFETY: ffi + unsafe { func.launch(launch_config, params) }.w()?; + + Ok(CudaStorageSlice::U8(out)) + } + } + + use candle_core::backend::BackendStorage; + let device = s1.device(); + let slice = $struct_name.map(&s1.slice, l1, &s2.slice, l2, device)?; + + Ok( + ( + CudaStorage { + slice, + device: device.clone(), + }, + l1.shape().clone() + ) + ) + } } } } diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index a1aedc0e5e..057cda6fe9 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -728,20 +728,20 @@ defmodule CandlexTest do test "tan" do Nx.tan(1.0) - |> assert_equal(t(1.5574077367782593)) + |> assert_close(t(1.5574077367782593)) t([1.0, 2, 3]) |> Nx.tan() - |> assert_equal(t([1.5574077367782593, -2.185039758682251, -0.14254654943943024])) + |> assert_close(t([1.5574077367782593, -2.185039758682251, -0.14254654943943024])) end test "atan" do Nx.atan(0.10000000149011612) - |> assert_equal(t(0.09966865181922913)) + |> assert_close(t(0.09966865181922913)) t([0.10000000149011612, 0.5, 0.8999999761581421]) |> Nx.atan() - |> assert_equal(t([0.09966865181922913, 0.46364760398864746, 0.7328150868415833])) + |> assert_close(t([0.09966865181922913, 0.46364760398864746, 0.7328150868415833])) end test "ceil" do From 3dd0a4aeb3167a5a86c84f625ab9ffbc681bda46 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 27 Sep 2023 16:39:53 -0300 Subject: [PATCH 144/185] cuda use as much device funs as possible --- .../candlex/src/kernels/custom_binary.cu | 13 +++- .../candlex/src/kernels/custom_unary.cu | 66 ++++++++++++------- 2 files changed, 55 insertions(+), 24 deletions(-) diff --git a/candlex/native/candlex/src/kernels/custom_binary.cu b/candlex/native/candlex/src/kernels/custom_binary.cu index b6e9bb5047..87248122e6 100644 --- a/candlex/native/candlex/src/kernels/custom_binary.cu +++ b/candlex/native/candlex/src/kernels/custom_binary.cu @@ -2,6 +2,15 @@ #include #include "strides.cuh" +#define DEVICE_FN_FLOAT_WRAPPER(FN_NAME) \ + __device__ __forceinline__ float FN_NAME##g(float a, float b) { return FN_NAME##f(a, b); } + +#define DEVICE_FN_DOUBLE_WRAPPER(FN_NAME) \ + __device__ __forceinline__ double FN_NAME##g(double a, double b) { return FN_NAME(a, b); } + +DEVICE_FN_FLOAT_WRAPPER(pow) +DEVICE_FN_DOUBLE_WRAPPER(pow) + #define CUSTOM_BINARY_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, FUNC) \ extern "C" __global__ void FN_NAME( \ const size_t numel, \ @@ -75,8 +84,8 @@ CUSTOM_BINARY_OP(uint32_t, bit_or_u32, x | y) CUSTOM_BINARY_OP(int64_t, bit_or_i64, x | y) CUSTOM_BINARY_OP(uint32_t, bit_xor_u32, x ^ y) CUSTOM_BINARY_OP(int64_t, bit_xor_i64, x ^ y) -CUSTOM_BINARY_OP(float, pow_f32, pow(x, y)) -CUSTOM_BINARY_OP(double, pow_f64, pow(x, y)) +CUSTOM_BINARY_OP(float, pow_f32, powg(x, y)) +CUSTOM_BINARY_OP(double, pow_f64, powg(x, y)) CUSTOM_BINARY_OP(uint32_t, shl_u32, x << y) CUSTOM_BINARY_OP(int64_t, shl_i64, x << y) CUSTOM_BINARY_OP(uint32_t, shr_u32, x >> y) diff --git a/candlex/native/candlex/src/kernels/custom_unary.cu b/candlex/native/candlex/src/kernels/custom_unary.cu index 29b36de614..441661066c 100644 --- a/candlex/native/candlex/src/kernels/custom_unary.cu +++ b/candlex/native/candlex/src/kernels/custom_unary.cu @@ -4,12 +4,34 @@ #include #include "strides.cuh" -__device__ __forceinline__ float atang(float a) { return atanf(a); } -__device__ __forceinline__ double atang(double a) { return atan(a); } -__device__ __forceinline__ float erfinvg(float a) { return erfinvf(a); } -__device__ __forceinline__ double erfinvg(double a) { return erfinv(a); } -__device__ __forceinline__ float tang(float a) { return tanf(a); } -__device__ __forceinline__ double tang(double a) { return tan(a); } +#define DEVICE_FN_FLOAT_WRAPPER(FN_NAME) \ + __device__ __forceinline__ float FN_NAME##g(float a) { return FN_NAME##f(a); } + +#define DEVICE_FN_DOUBLE_WRAPPER(FN_NAME) \ + __device__ __forceinline__ double FN_NAME##g(double a) { return FN_NAME(a); } + +DEVICE_FN_FLOAT_WRAPPER(acos) +DEVICE_FN_DOUBLE_WRAPPER(acos) +DEVICE_FN_FLOAT_WRAPPER(asin) +DEVICE_FN_DOUBLE_WRAPPER(asin) +DEVICE_FN_FLOAT_WRAPPER(atan) +DEVICE_FN_DOUBLE_WRAPPER(atan) +DEVICE_FN_FLOAT_WRAPPER(cbrt) +DEVICE_FN_DOUBLE_WRAPPER(cbrt) +DEVICE_FN_FLOAT_WRAPPER(ceil) +DEVICE_FN_DOUBLE_WRAPPER(ceil) +DEVICE_FN_FLOAT_WRAPPER(erfinv) +DEVICE_FN_DOUBLE_WRAPPER(erfinv) +DEVICE_FN_FLOAT_WRAPPER(exp) +DEVICE_FN_DOUBLE_WRAPPER(exp) +DEVICE_FN_FLOAT_WRAPPER(floor) +DEVICE_FN_DOUBLE_WRAPPER(floor) +DEVICE_FN_FLOAT_WRAPPER(round) +DEVICE_FN_DOUBLE_WRAPPER(round) +DEVICE_FN_FLOAT_WRAPPER(log1p) +DEVICE_FN_DOUBLE_WRAPPER(log1p) +DEVICE_FN_FLOAT_WRAPPER(tan) +DEVICE_FN_DOUBLE_WRAPPER(tan) #define CUSTOM_UNARY_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, FUNC) \ extern "C" __global__ void FN_NAME( \ @@ -39,29 +61,29 @@ extern "C" __global__ void FN_NAME( \ #define CUSTOM_UNARY_OP(TYPENAME, FN_NAME, FUNC) \ CUSTOM_UNARY_OP_OUT(TYPENAME, TYPENAME, FN_NAME, FUNC) -CUSTOM_UNARY_OP(float, acos_f32, acos(x)) -CUSTOM_UNARY_OP(double, acos_f64, acos(x)) -CUSTOM_UNARY_OP(float, asin_f32, asin(x)) -CUSTOM_UNARY_OP(double, asin_f64, asin(x)) +CUSTOM_UNARY_OP(float, acos_f32, acosg(x)) +CUSTOM_UNARY_OP(double, acos_f64, acosg(x)) +CUSTOM_UNARY_OP(float, asin_f32, asing(x)) +CUSTOM_UNARY_OP(double, asin_f64, asing(x)) CUSTOM_UNARY_OP(float, atan_f32, atang(x)) CUSTOM_UNARY_OP(double, atan_f64, atang(x)) CUSTOM_UNARY_OP(uint8_t, bit_not_u8, ~x) CUSTOM_UNARY_OP(uint32_t, bit_not_u32, ~x) CUSTOM_UNARY_OP(int64_t, bit_not_i64, ~x) -CUSTOM_UNARY_OP(float, cbrt_f32, cbrt(x)) -CUSTOM_UNARY_OP(double, cbrt_f64, cbrt(x)) -CUSTOM_UNARY_OP(float, ceil_f32, ceil(x)) -CUSTOM_UNARY_OP(double, ceil_f64, ceil(x)) +CUSTOM_UNARY_OP(float, cbrt_f32, cbrtg(x)) +CUSTOM_UNARY_OP(double, cbrt_f64, cbrtg(x)) +CUSTOM_UNARY_OP(float, ceil_f32, ceilg(x)) +CUSTOM_UNARY_OP(double, ceil_f64, ceilg(x)) CUSTOM_UNARY_OP(float, erf_inv_f32, erfinvg(x)) CUSTOM_UNARY_OP(double, erf_inv_f64, erfinvg(x)) -CUSTOM_UNARY_OP(float, floor_f32, floor(x)) -CUSTOM_UNARY_OP(double, floor_f64, floor(x)) -CUSTOM_UNARY_OP(float, ln_1p_f32, log1p(x)) -CUSTOM_UNARY_OP(double, ln_1p_f64, log1p(x)) -CUSTOM_UNARY_OP(float, round_f32, round(x)) -CUSTOM_UNARY_OP(double, round_f64, round(x)) -CUSTOM_UNARY_OP(float, sigmoid_f32, 1.0 / (1.0 + exp(-x))) -CUSTOM_UNARY_OP(double, sigmoid_f64, 1.0 / (1.0 + exp(-x))) +CUSTOM_UNARY_OP(float, floor_f32, floorg(x)) +CUSTOM_UNARY_OP(double, floor_f64, floorg(x)) +CUSTOM_UNARY_OP(float, ln_1p_f32, log1pg(x)) +CUSTOM_UNARY_OP(double, ln_1p_f64, log1pg(x)) +CUSTOM_UNARY_OP(float, round_f32, roundg(x)) +CUSTOM_UNARY_OP(double, round_f64, roundg(x)) +CUSTOM_UNARY_OP(float, sigmoid_f32, 1.0 / (1.0 + expg(-x))) +CUSTOM_UNARY_OP(double, sigmoid_f64, 1.0 / (1.0 + expg(-x))) CUSTOM_UNARY_OP(float, tan_f32, tang(x)) CUSTOM_UNARY_OP(double, tan_f64, tang(x)) From 9cb36e512affc097b182241b015b46b9fdba2e42 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 28 Sep 2023 10:49:55 -0300 Subject: [PATCH 145/185] merge 2 custom_unary_op macros into 1 --- candlex/native/candlex/src/ops.rs | 61 +++++++------------------------ 1 file changed, 14 insertions(+), 47 deletions(-) diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 05ba8a09fd..670391a3e2 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -7,7 +7,7 @@ fn erf_inv(v: T) -> T { } macro_rules! custom_unary_op { - ($struct_name:ident, $name:expr, $fn_name:ident, ($($dtypes:ident),+)) => { + ($struct_name:ident, $name:expr, $closure:expr, ($($dtypes:ident),+)) => { pub(crate) struct $struct_name; impl CustomOp1 for $struct_name { @@ -28,7 +28,7 @@ macro_rules! custom_unary_op { match storage { $( CpuStorage::$dtypes(vec) => { - let data = candle_core::cpu_backend::unary_map(vec, layout, |v| v.$fn_name()); + let data = candle_core::cpu_backend::unary_map(vec, layout, $closure); Ok((CpuStorage::$dtypes(data), layout.shape().clone())) } )* @@ -72,39 +72,6 @@ macro_rules! custom_unary_bool_op { }; } -macro_rules! custom_unary_op_closure { - ($struct_name:ident, $name:expr, $closure:expr, ($($dtypes:ident),+)) => { - pub(crate) struct $struct_name; - - impl CustomOp1 for $struct_name { - // Box does not support const yet, so use a function to get the name. - fn name(&self) -> &'static str { - $name - } - - /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, - /// offsets etc so the associated layout should be used to access it. - fn cpu_fwd( - &self, - storage: &CpuStorage, - layout: &Layout, - ) -> Result<(CpuStorage, Shape), candle_core::Error> { - use candle_core::backend::BackendStorage; - - match storage { - $( - CpuStorage::$dtypes(vec) => { - let data = candle_core::cpu_backend::unary_map(vec, layout, $closure); - Ok((CpuStorage::$dtypes(data), layout.shape().clone())) - } - )* - s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())? - } - } - } - }; -} - macro_rules! custom_binary_op { ($struct_name:ident, $name:literal, $closure:expr, ($($dtypes:ident),+)) => { pub(crate) struct $struct_name; @@ -189,19 +156,19 @@ macro_rules! custom_binary_bool_op { } } -custom_unary_op!(Acos, "acos", acos, (BF16, F16, F32, F64)); -custom_unary_op!(Asin, "asin", asin, (BF16, F16, F32, F64)); -custom_unary_op!(Atan, "atan", atan, (BF16, F16, F32, F64)); -custom_unary_op!(Cbrt, "cbrt", cbrt, (BF16, F16, F32, F64)); -custom_unary_op!(Ceil, "ceil", ceil, (BF16, F16, F32, F64)); -custom_unary_op!(Floor, "floor", floor, (BF16, F16, F32, F64)); -custom_unary_op!(Log1p, "ln_1p", ln_1p, (BF16, F16, F32, F64)); -custom_unary_op!(Round, "round", round, (BF16, F16, F32, F64)); -custom_unary_op!(Tan, "tan", tan, (BF16, F16, F32, F64)); +custom_unary_op!(Acos, "acos", |v| v.acos(), (BF16, F16, F32, F64)); +custom_unary_op!(Asin, "asin", |v| v.asin(), (BF16, F16, F32, F64)); +custom_unary_op!(Atan, "atan", |v| v.atan(), (BF16, F16, F32, F64)); +custom_unary_op!(BitNot, "bit_not", |v| !v, (U8, U32, I64)); +custom_unary_op!(Cbrt, "cbrt", |v| v.cbrt(), (BF16, F16, F32, F64)); +custom_unary_op!(Ceil, "ceil", |v| v.ceil(), (BF16, F16, F32, F64)); +custom_unary_op!(ErfInv, "erf_inv", |v| erf_inv(v), (BF16, F16, F32, F64)); +custom_unary_op!(Floor, "floor", |v| v.floor(), (BF16, F16, F32, F64)); +custom_unary_op!(Log1p, "ln_1p", |v| v.ln_1p(), (BF16, F16, F32, F64)); +custom_unary_op!(Round, "round", |v| v.round(), (BF16, F16, F32, F64)); +custom_unary_op!(Sigmoid, "sigmoid", |v| 1. / (1. + (-v).exp()), (F32, F64)); +custom_unary_op!(Tan, "tan", |v| v.tan(), (BF16, F16, F32, F64)); custom_unary_bool_op!(IsInf, "is_inf", is_infinite, (F32, F64)); -custom_unary_op_closure!(BitNot, "bit_not", |v| !v, (U8, U32, I64)); -custom_unary_op_closure!(ErfInv, "erf_inv", |v| erf_inv(v), (BF16, F16, F32, F64)); -custom_unary_op_closure!(Sigmoid, "sigmoid", |v| 1. / (1. + (-v).exp()), (F32, F64)); custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2, (U32, I64)); custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2, (U32, I64)); From 4547f348c737174efe23be39c712dfb8206c0c40 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 28 Sep 2023 11:46:31 -0300 Subject: [PATCH 146/185] removes unnecessary data local var --- candlex/native/candlex/src/ops.rs | 46 ++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 670391a3e2..5bcb2e7aeb 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -28,8 +28,12 @@ macro_rules! custom_unary_op { match storage { $( CpuStorage::$dtypes(vec) => { - let data = candle_core::cpu_backend::unary_map(vec, layout, $closure); - Ok((CpuStorage::$dtypes(data), layout.shape().clone())) + Ok( + ( + CpuStorage::$dtypes(candle_core::cpu_backend::unary_map(vec, layout, $closure)), + layout.shape().clone() + ) + ) } )* s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())? @@ -61,8 +65,14 @@ macro_rules! custom_unary_bool_op { match storage { $( CpuStorage::$dtypes(vec) => { - let data = candle_core::cpu_backend::unary_map(vec, layout, |v| u8::from(v.$fn_name())); - Ok((CpuStorage::U8(data), layout.shape().clone())) + Ok( + ( + CpuStorage::U8( + candle_core::cpu_backend::unary_map(vec, layout, |v| u8::from(v.$fn_name())) + ), + layout.shape().clone() + ) + ) } )* s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())? @@ -95,9 +105,14 @@ macro_rules! custom_binary_op { match (s1, s2) { $( (CpuStorage::$dtypes(lhs), CpuStorage::$dtypes(rhs)) => { - let data = candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, $closure); - - Ok((CpuStorage::$dtypes(data), l1.shape().clone())) + Ok( + ( + CpuStorage::$dtypes( + candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, $closure) + ), + l1.shape().clone() + ) + ) } )* _ => { @@ -137,9 +152,20 @@ macro_rules! custom_binary_bool_op { match (s1, s2) { $( (CpuStorage::$dtypes(lhs), CpuStorage::$dtypes(rhs)) => { - let data = candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, |v1, v2| u8::from($closure(v1, v2))); - - Ok((CpuStorage::U8(data), l1.shape().clone())) + Ok( + ( + CpuStorage::U8( + candle_core::cpu_backend::binary_map( + l1, + l2, + lhs, + rhs, + |v1, v2| u8::from($closure(v1, v2)) + ) + ), + l1.shape().clone() + ) + ) } )* _ => { From 001bd81e4dec6fa200c919cdf0b05346c8ee5c16 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 28 Sep 2023 15:48:52 -0300 Subject: [PATCH 147/185] candlex transpose --- candlex/lib/candlex/backend.ex | 4 +- candlex/lib/candlex/native.ex | 1 - candlex/native/candlex/src/lib.rs | 1 - candlex/native/candlex/src/tensors.rs | 5 -- candlex/test/candlex_test.exs | 106 ++++++++++++++++++++++++++ 5 files changed, 108 insertions(+), 9 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index f3809c5bdd..fa6c039390 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -485,9 +485,9 @@ defmodule Candlex.Backend do end @impl true - def transpose(out, %T{} = t, [dim1, dim2] = axes) do + def transpose(out, %T{} = t, axes) do from_nx(t) - |> Native.transpose(dim1, dim2) + |> Native.permute(axes) |> unwrap!() |> to_nx(out) end diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 0d4af778db..92d7210d1b 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -15,7 +15,6 @@ defmodule Candlex.Native do def arange(_start, _end, _dtype, _shape, _device), do: error() def broadcast_to(_tensor, _shape), do: error() def reshape(_tensor, _shape), do: error() - def transpose(_tensor, _dim1, _dim2), do: error() def to_type(_tensor, _dtype), do: error() def dtype(_tensor), do: error() def t_shape(_tensor), do: error() diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 18ea43b5a9..75021333ea 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -48,7 +48,6 @@ rustler::init! { tensors::gather, tensors::chunk, tensors::squeeze, - tensors::transpose, tensors::arange, tensors::to_type, tensors::broadcast_to, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 875df40ae2..623bc329f0 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -104,11 +104,6 @@ pub fn squeeze(t: ExTensor, dim: usize) -> Result { Ok(ExTensor::new(t.squeeze(dim)?)) } -#[rustler::nif(schedule = "DirtyCpu")] -pub fn transpose(t: ExTensor, dim1: usize, dim2: usize) -> Result { - Ok(ExTensor::new(t.transpose(dim1, dim2)?)) -} - #[rustler::nif(schedule = "DirtyCpu")] pub fn rsqrt(t: ExTensor) -> Result { Ok(ExTensor::new(t.sqrt()?.recip()?)) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index a1aedc0e5e..ad58f5b4dc 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1344,6 +1344,112 @@ defmodule CandlexTest do ] )) end + + test "transpose" do + t(1) + |> Nx.transpose() + |> assert_equal(t(1)) + + Nx.iota({2, 3, 4}, names: [:x, :y, :z]) + |> Nx.transpose() + |> assert_equal(t( + [ + [ + [0, 12], + [4, 16], + [8, 20] + ], + [ + [1, 13], + [5, 17], + [9, 21] + ], + [ + [2, 14], + [6, 18], + [10, 22] + ], + [ + [3, 15], + [7, 19], + [11, 23] + ] + ] + )) + + t(1) + |> Nx.transpose(axes: []) + |> assert_equal(t(1)) + + Nx.iota({2, 3, 4}, names: [:batch, :x, :y]) + |> Nx.transpose(axes: [2, 1, :batch]) + |> assert_equal(t( + [ + [ + [0, 12], + [4, 16], + [8, 20] + ], + [ + [1, 13], + [5, 17], + [9, 21] + ], + [ + [2, 14], + [6, 18], + [10, 22] + ], + [ + [3, 15], + [7, 19], + [11, 23] + ] + ] + )) + + Nx.iota({2, 3, 4}, names: [:batch, :x, :y]) + |> Nx.transpose(axes: [:y, :batch, :x]) + |> assert_equal(t( + [ + [ + [0, 4, 8], + [12, 16, 20] + ], + [ + [1, 5, 9], + [13, 17, 21] + ], + [ + [2, 6, 10], + [14, 18, 22] + ], + [ + [3, 7, 11], + [15, 19, 23] + ] + ] + )) + + Nx.iota({2, 3, 4}, names: [:batch, :x, :y]) + |> Nx.transpose(axes: [:batch, :y, :x]) + |> assert_equal(t( + [ + [ + [0, 4, 8], + [1, 5, 9], + [2, 6, 10], + [3, 7, 11] + ], + [ + [12, 16, 20], + [13, 17, 21], + [14, 18, 22], + [15, 19, 23] + ] + ] + )) + end end defp t(values, opts \\ []) do From 4f40b594c077b035a1a434fa7c7ebbd054c42eb5 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 28 Sep 2023 17:00:04 -0300 Subject: [PATCH 148/185] cargo update --- candlex/native/candlex/Cargo.lock | 83 ++++++++++++++++++++++++++++--- 1 file changed, 77 insertions(+), 6 deletions(-) diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 440d8a35fe..337e40ce83 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -47,7 +47,7 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" version = "0.2.3" -source = "git+https://github.com/mimiquate/candle?branch=cuda#de3f2be78379ffdaf404bb3d46b3578e6a005957" +source = "git+https://github.com/mimiquate/candle?branch=cuda#35b062c9cce74545e3c382367709e50fdc0b562f" dependencies = [ "byteorder", "candle-gemm", @@ -62,6 +62,7 @@ dependencies = [ "rayon", "safetensors", "thiserror", + "yoke", "zip", ] @@ -193,7 +194,7 @@ dependencies = [ [[package]] name = "candle-kernels" version = "0.2.3" -source = "git+https://github.com/mimiquate/candle?branch=cuda#de3f2be78379ffdaf404bb3d46b3578e6a005957" +source = "git+https://github.com/mimiquate/candle?branch=cuda#35b062c9cce74545e3c382367709e50fdc0b562f" dependencies = [ "glob", "rayon", @@ -379,6 +380,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" dependencies = [ "libc", + "stable_deref_trait", ] [[package]] @@ -723,6 +725,12 @@ dependencies = [ "wide", ] +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + [[package]] name = "statrs" version = "0.16.0" @@ -758,20 +766,32 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "synstructure" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "285ba80e733fac80aa4270fbcdf83772a79b80aa35c97075320abfee4a915b06" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", + "unicode-xid", +] + [[package]] name = "thiserror" -version = "1.0.48" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d6d7a740b8a666a7e828dd00da9c0dc290dff53154ea77ac109281de90589b7" +checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.48" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49922ecae66cc8a249b77e68d1d0623c1b2c514f0060c27cdc68bd62a1219d35" +checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc" dependencies = [ "proc-macro2", "quote", @@ -790,6 +810,12 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" +[[package]] +name = "unicode-xid" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" + [[package]] name = "unreachable" version = "1.0.0" @@ -821,6 +847,51 @@ dependencies = [ "safe_arch", ] +[[package]] +name = "yoke" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e38c508604d6bbbd292dadb3c02559aa7fff6b654a078a36217cad871636e4" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5e19fb6ed40002bab5403ffa37e53e0e56f914a4450c8765f533018db1db35f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", + "synstructure", +] + +[[package]] +name = "zerofrom" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "655b0814c5c0b19ade497851070c640773304939a6c0fd5f5fb43da0696d05b7" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6a647510471d372f2e6c2e6b7219e44d8c574d24fdc11c610a61455782f18c3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", + "synstructure", +] + [[package]] name = "zip" version = "0.6.6" From 2170f04856f684821855b45541dca4b0a6efc68b Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 20 Sep 2023 12:18:15 -0300 Subject: [PATCH 149/185] candlex partial put_slice --- candlex/lib/candlex/backend.ex | 21 ++++++++++++++ candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/Cargo.lock | 4 +-- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 5 ++++ candlex/test/candlex_test.exs | 42 +++++++++++++++++++++++++++ 6 files changed, 72 insertions(+), 2 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index fa6c039390..729739f9b3 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -301,6 +301,27 @@ defmodule Candlex.Backend do # Indexed + @impl true + def put_slice(%T{} = out, %T{} = t, [start] = _start_indices, slice) do + t + |> from_nx() + |> Native.slice_scatter(from_nx(slice), 0, Nx.to_number(start)) + |> unwrap!() + |> to_nx(out) + end + + def put_slice(%T{} = out, %T{} = t, [start1, start2] = _start_indices, slice) do + if Nx.equal(start1, 0) do + t + |> from_nx() + |> Native.slice_scatter(from_nx(slice), 1, Nx.to_number(start2)) + |> unwrap!() + |> to_nx(out) + else + raise "unsupported" + end + end + @impl true def slice( %T{shape: _output_shape} = out, diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 92d7210d1b..bf81273798 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -21,6 +21,7 @@ defmodule Candlex.Native do def concatenate(_tensors, _axis), do: error() def conv1d(_tensor, _kernel), do: error() def conv2d(_tensor, _kernel), do: error() + def slice_scatter(_tensor, _src, _dim, _start), do: error() for op <- [ :abs, diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 337e40ce83..23849227f0 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -47,7 +47,7 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" version = "0.2.3" -source = "git+https://github.com/mimiquate/candle?branch=cuda#35b062c9cce74545e3c382367709e50fdc0b562f" +source = "git+https://github.com/mimiquate/candle?branch=cuda#cf9094e400bb143202f644a7490e21ca7b3b091b" dependencies = [ "byteorder", "candle-gemm", @@ -194,7 +194,7 @@ dependencies = [ [[package]] name = "candle-kernels" version = "0.2.3" -source = "git+https://github.com/mimiquate/candle?branch=cuda#35b062c9cce74545e3c382367709e50fdc0b562f" +source = "git+https://github.com/mimiquate/candle?branch=cuda#cf9094e400bb143202f644a7490e21ca7b3b091b" dependencies = [ "glob", "rayon", diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 75021333ea..7c3ca60186 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -56,6 +56,7 @@ rustler::init! { tensors::conv1d, tensors::conv2d, tensors::permute, + tensors::slice_scatter, tensors::matmul, tensors::abs, tensors::acos, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 623bc329f0..f39a758c4c 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -211,6 +211,11 @@ pub fn reshape(t: ExTensor, shape: Term) -> Result { Ok(ExTensor::new(t.reshape(tuple_to_vec(shape).unwrap())?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn slice_scatter(t: ExTensor, src: ExTensor, dim: usize, start: usize) -> Result { + Ok(ExTensor::new(t.slice_scatter(src.deref(), dim, start)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn where_cond( t: ExTensor, diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index ad58f5b4dc..58a80e174d 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1450,6 +1450,48 @@ defmodule CandlexTest do ] )) end + + test "put_slice" do + t([0, 1, 2, 3, 4]) + |> Nx.put_slice([2], Nx.tensor([5, 6])) + |> assert_equal(t([0, 1, 5, 6, 4])) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.put_slice([0, 0], t([[7, 8, 9], [10, 11, 12]])) + |> assert_equal(t( + [ + [7, 8, 9], + [10, 11, 12] + ] + )) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.put_slice([0, 1], t([[7, 8], [9, 10]])) + |> assert_equal(t( + [ + [1, 7, 8], + [4, 9, 10] + ] + )) + + # t([[1, 2, 3], [4, 5, 6]]) + # |> Nx.put_slice([t(0), t(1)], t([[10.0, 11.0]])) + # |> assert_equal(t( + # [ + # [1.0, 10.0, 11.0], + # [4.0, 5.0, 6.0] + # ] + # )) + + # t([[1, 2, 3], [4, 5, 6]]) + # |> Nx.put_slice([1, 1], t([[7, 8], [9, 10]])) + # |> assert_equal(t( + # [ + # [1, 7, 8], + # [4, 9, 10] + # ] + # )) + end end defp t(values, opts \\ []) do From 3641d94e8ac82f5df3dfc88fa88837fb24b5cc3f Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 29 Sep 2023 12:02:30 -0300 Subject: [PATCH 150/185] back to candle-core upstream after PRs fixes merged --- candlex/native/candlex/Cargo.lock | 4 ++-- candlex/native/candlex/Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 23849227f0..b0cea90c76 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -47,7 +47,7 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" version = "0.2.3" -source = "git+https://github.com/mimiquate/candle?branch=cuda#cf9094e400bb143202f644a7490e21ca7b3b091b" +source = "git+https://github.com/huggingface/candle#fc59bc31bf49e707d64ed25fd29a8803f9a12fb4" dependencies = [ "byteorder", "candle-gemm", @@ -194,7 +194,7 @@ dependencies = [ [[package]] name = "candle-kernels" version = "0.2.3" -source = "git+https://github.com/mimiquate/candle?branch=cuda#cf9094e400bb143202f644a7490e21ca7b3b091b" +source = "git+https://github.com/huggingface/candle#fc59bc31bf49e707d64ed25fd29a8803f9a12fb4" dependencies = [ "glob", "rayon", diff --git a/candlex/native/candlex/Cargo.toml b/candlex/native/candlex/Cargo.toml index 576eaeb042..5600a16472 100644 --- a/candlex/native/candlex/Cargo.toml +++ b/candlex/native/candlex/Cargo.toml @@ -10,7 +10,7 @@ path = "src/lib.rs" crate-type = ["cdylib"] [dependencies] -candle-core = { git = "https://github.com/mimiquate/candle", branch = "cuda" } +candle-core = { git = "https://github.com/huggingface/candle" } half = "2.3.1" num-traits = "0.2.16" rustler = "0.29.1" From ccac324a655e937768ad0c473e5de773225eadbd Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 28 Sep 2023 16:15:38 -0300 Subject: [PATCH 151/185] candlex pad --- candlex/lib/candlex/backend.ex | 16 +++ candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 5 + candlex/test/candlex_test.exs | 150 ++++++++++++++++++++++++++ 5 files changed, 173 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 729739f9b3..74c2c5ec5f 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -485,6 +485,22 @@ defmodule Candlex.Backend do |> to_nx(out) end + @impl true + def pad(%T{} = out, %T{} = t, pad_value, []) do + out + end + def pad(%T{} = out, %T{} = t, %T{shape: {}} = pad_value, [{low, high, 0 = _inner}]) do + if !Nx.equal(pad_value, 0) do + raise "only pad_value=0 supported for now" + end + + t + |> from_nx() + |> Native.pad_with_zeros(low, high) + |> unwrap!() + |> to_nx(out) + end + @impl true def reshape(%T{shape: shape} = out, %T{} = t) do from_nx(t) diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index bf81273798..9b7118d9e2 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -22,6 +22,7 @@ defmodule Candlex.Native do def conv1d(_tensor, _kernel), do: error() def conv2d(_tensor, _kernel), do: error() def slice_scatter(_tensor, _src, _dim, _start), do: error() + def pad_with_zeros(_tensor, _left, _right), do: error() for op <- [ :abs, diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 7c3ca60186..dcd6cd32e2 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -57,6 +57,7 @@ rustler::init! { tensors::conv2d, tensors::permute, tensors::slice_scatter, + tensors::pad_with_zeros, tensors::matmul, tensors::abs, tensors::acos, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index f39a758c4c..52549e9f33 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -216,6 +216,11 @@ pub fn slice_scatter(t: ExTensor, src: ExTensor, dim: usize, start: usize) -> Re Ok(ExTensor::new(t.slice_scatter(src.deref(), dim, start)?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn pad_with_zeros(t: ExTensor, left: usize, right: usize) -> Result { + Ok(ExTensor::new(t.pad_with_zeros(0, left, right)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn where_cond( t: ExTensor, diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 58a80e174d..5889b6f3cc 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1492,6 +1492,156 @@ defmodule CandlexTest do # ] # )) end + + test "pad" do + t(1) + |> Nx.pad(0, []) + |> assert_equal(t(1)) + + t([1, 2, 3], names: [:data]) + |> Nx.pad(0, [{1, 1, 0}]) + |> assert_equal(t([0, 1, 2, 3, 0])) + + # t([[1, 2, 3], [4, 5, 6]]) + # |> Nx.pad(0, [{0, 0, 1}, {0, 0, 1}]) + # |> assert_equal(t( + # [ + # [1, 0, 2, 0, 3], + # [0, 0, 0, 0, 0], + # [4, 0, 5, 0, 6] + # ] + # )) + + # Nx.pad(Nx.tensor([[1, 2, 3], [4, 5, 6]]), 0, [{1, 1, 0}, {1, 1, 0}]) + # [ + # [0, 0, 0, 0, 0], + # [0, 1, 2, 3, 0], + # [0, 4, 5, 6, 0], + # [0, 0, 0, 0, 0] + # ] + # > + + # tensor = Nx.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + # Nx.pad(tensor, 0, [{0, 2, 0}, {1, 1, 0}, {1, 0, 0}]) + # [ + # [ + # [0, 0, 0], + # [0, 1, 2], + # [0, 3, 4], + # [0, 0, 0] + # ], + # [ + # [0, 0, 0], + # [0, 5, 6], + # [0, 7, 8], + # [0, 0, 0] + # ], + # [ + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0] + # ], + # [ + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0] + # ] + # ] + + # tensor = Nx.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + # Nx.pad(tensor, 0, [{1, 0, 0}, {1, 1, 0}, {0, 1, 0}]) + # [ + # [ + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0] + # ], + # [ + # [0, 0, 0], + # [1, 2, 0], + # [3, 4, 0], + # [0, 0, 0] + # ], + # [ + # [0, 0, 0], + # [5, 6, 0], + # [7, 8, 0], + # [0, 0, 0] + # ] + # ] + + # tensor = Nx.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]) + # Nx.pad(tensor, 0.0, [{1, 2, 0}, {1, 0, 0}, {0, 1, 0}]) + # [ + # [ + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0] + # ], + # [ + # [0.0, 0.0, 0.0], + # [1.0, 2.0, 0.0], + # [3.0, 4.0, 0.0] + # ], + # [ + # [0.0, 0.0, 0.0], + # [5.0, 6.0, 0.0], + # [7.0, 8.0, 0.0] + # ], + # [ + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0] + # ], + # [ + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0] + # ] + # ] + + # Nx.pad(Nx.tensor([0, 1, 2, 3, 0]), 0, [{-1, -1, 0}]) + # [1, 2, 3] + + # tensor = Nx.tensor([ + # [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + # [[0, 0, 0], [1, 2, 0], [3, 4, 0], [0, 0, 0]], + # [[0, 0, 0], [5, 6, 0], [7, 8, 0], [0, 0, 0]] + # ]) + # Nx.pad(tensor, 0, [{-1, 0, 0}, {-1, -1, 0}, {0, -1, 0}]) + # [ + # [ + # [1, 2], + # [3, 4] + # ], + # [ + # [5, 6], + # [7, 8] + # ] + # ] + + # t([[0, 1, 2, 3], [0, 4, 5, 6]]) + # |> Nx.pad(0, [{0, 0, 0}, {-1, 1, 0}]) + # |> assert_equal(t( + # [ + # [1, 2, 3, 0], + # [4, 5, 6, 0] + # ] + # )) + + # t([[0, 1, 2], [3, 4, 5]], type: :f32) + # |> Nx.pad(0, [{-1, 2, 0}, {1, -1, 0}]) + # |> assert_equal(t( + # [ + # [0.0, 3.0, 4.0], + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0] + # ] + # ) + end end defp t(values, opts \\ []) do From cde65ce65e193aff93cf70de6617e544c1dd19fd Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 29 Sep 2023 17:15:46 -0300 Subject: [PATCH 152/185] candlex take --- candlex/lib/candlex/backend.ex | 13 ++++ candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 5 ++ candlex/test/candlex_test.exs | 96 +++++++++++++++++++++++++++ 5 files changed, 116 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 74c2c5ec5f..36d371c912 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -338,6 +338,19 @@ defmodule Candlex.Backend do |> to_nx(out) end + @impl true + def take(%T{} = out, %T{} = tensor, %T{} = indexes, axis) do + if Nx.rank(indexes) > 1 do + raise "only indexes of rank=1 supported for now" + end + + tensor + |> from_nx() + |> Native.index_select(from_nx(indexes), axis) + |> unwrap!() + |> to_nx(out) + end + @impl true def take_along_axis(%T{} = out, %T{} = tensor, %T{} = indexes, axis) do tensor diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 9b7118d9e2..55ef667b19 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -10,6 +10,7 @@ defmodule Candlex.Native do def where_cond(_tensor, _on_true, _on_false), do: error() def narrow(_tensor, _dim, _start, _length), do: error() def gather(_tensor, _indexes, _dim), do: error() + def index_select(_tensor, _indexes, _dim), do: error() def chunk(_tensor, _num_chunks), do: error() def squeeze(_tensor, _dim), do: error() def arange(_start, _end, _dtype, _shape, _device), do: error() diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index dcd6cd32e2..14c91ac79d 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -46,6 +46,7 @@ rustler::init! { tensors::where_cond, tensors::narrow, tensors::gather, + tensors::index_select, tensors::chunk, tensors::squeeze, tensors::arange, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 52549e9f33..225e67794c 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -91,6 +91,11 @@ pub fn gather(t: ExTensor, indexes: ExTensor, dim: usize) -> Result Result { + Ok(ExTensor::new(t.index_select(indexes.deref(), dim)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn chunk(t: ExTensor, num_chunks: usize) -> Result, CandlexError> { Ok(t.chunk(num_chunks, 0)? diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 5889b6f3cc..e0baab29c5 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1642,6 +1642,102 @@ defmodule CandlexTest do # ] # ) end + + test "take" do + t([[1, 2], [3, 4]]) + |> Nx.take(t([1, 0, 1])) + |> assert_equal(t( + [ + [3, 4], + [1, 2], + [3, 4] + ] + )) + + t([[1, 2], [3, 4]]) + |> Nx.take(t([1, 0, 1]), axis: 1) + |> assert_equal(t( + [ + [2, 1, 2], + [4, 3, 4] + ] + )) + + t([[[1, 2], [11, 12]], [[101, 102], [111, 112]]]) + |> Nx.take(t([1, 0, 1]), axis: 1) + |> assert_equal(t( + [ + [ + [11, 12], + [1, 2], + [11, 12] + ], + [ + [111, 112], + [101, 102], + [111, 112] + ] + ] + )) + + # t([[1, 2], [11, 12]]) + # |> Nx.take(t([[0, 0], [1, 1], [0, 0]]), axis: 1) + # |> assert_equal(t( + # [ + # [ + # [1, 1], + # [2, 2], + # [1, 1] + # ], + # [ + # [11, 11], + # [12, 12], + # [11, 11] + # ] + # ] + # )) + + # t([[[1, 2], [11, 12]], [[101, 102], [111, 112]]]) + # |> Nx.take(t([[0, 0, 0], [1, 1, 1], [0, 0, 0]]), axis: 1) + # |> assert_equal(t( + # [ + # [ + # [ + # [1, 2], + # [1, 2], + # [1, 2] + # ], + # [ + # [11, 12], + # [11, 12], + # [11, 12] + # ], + # [ + # [1, 2], + # [1, 2], + # [1, 2] + # ] + # ], + # [ + # [ + # [101, 102], + # [101, 102], + # [101, 102] + # ], + # [ + # [111, 112], + # [111, 112], + # [111, 112] + # ], + # [ + # [101, 102], + # [101, 102], + # [101, 102] + # ] + # ] + # ] + # )) + end end defp t(values, opts \\ []) do From 947aefaae048256e098eee4a014aec7fbb78bc28 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 3 Oct 2023 10:46:59 -0300 Subject: [PATCH 153/185] candlex clip --- candlex/lib/candlex/backend.ex | 11 +++++++ candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 5 +++ candlex/test/candlex_test.exs | 47 +++++++++++++++++++++++++++ 5 files changed, 65 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 36d371c912..4f305f6b6a 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -182,6 +182,17 @@ defmodule Candlex.Backend do # Element-wise + @impl true + def clip(%T{} = out, %T{} = t, %T{} = min, %T{} = max) do + [t, min, max] = maybe_upcast([t, min, max]) + + t + |> from_nx() + |> Native.clamp(from_nx(min), from_nx(max)) + |> unwrap!() + |> to_nx(out) + end + @impl true def select(%T{shape: shape, type: type} = out, pred, on_true, on_false) do on_true = diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 55ef667b19..388be2469f 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -24,6 +24,7 @@ defmodule Candlex.Native do def conv2d(_tensor, _kernel), do: error() def slice_scatter(_tensor, _src, _dim, _start), do: error() def pad_with_zeros(_tensor, _left, _right), do: error() + def clamp(_tensor, _min, _max), do: error() for op <- [ :abs, diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 14c91ac79d..0bbf0ccae4 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -49,6 +49,7 @@ rustler::init! { tensors::index_select, tensors::chunk, tensors::squeeze, + tensors::clamp, tensors::arange, tensors::to_type, tensors::broadcast_to, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 225e67794c..f4a400ee67 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -109,6 +109,11 @@ pub fn squeeze(t: ExTensor, dim: usize) -> Result { Ok(ExTensor::new(t.squeeze(dim)?)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn clamp(t: ExTensor, min_val: ExTensor, max_val: ExTensor) -> Result { + Ok(ExTensor::new(t.clamp(&min_val.broadcast_as(t.shape())?, &max_val.broadcast_as(t.shape())?)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn rsqrt(t: ExTensor) -> Result { Ok(ExTensor::new(t.sqrt()?.recip()?)) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index e0baab29c5..6800e7615c 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1738,6 +1738,53 @@ defmodule CandlexTest do # ] # )) end + + test "clip" do + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.clip(2, 4) + |> assert_equal(t( + [ + [2, 2, 3], + [4, 4, 4] + ] + )) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.clip(2.0, 3) + |> assert_equal(t( + [ + [2.0, 2.0, 3.0], + [3.0, 3.0, 3.0] + ] + )) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.clip(t(2.0), Nx.max(1.0, 3.0)) + |> assert_equal(t( + [ + [2.0, 2.0, 3.0], + [3.0, 3.0, 3.0] + ] + )) + + t([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + |> Nx.clip(2, 6.0) + |> assert_equal(t( + [ + [2.0, 2.0, 3.0], + [4.0, 5.0, 6.0] + ] + )) + + t([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], type: :f32) + |> Nx.clip(1, 4) + |> assert_equal(t( + [ + [1.0, 2.0, 3.0], + [4.0, 4.0, 4.0] + ] + )) + end end defp t(values, opts \\ []) do From c91d326995f06f0f421da7190ec4e36cdab5069d Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 3 Oct 2023 11:18:05 -0300 Subject: [PATCH 154/185] candlex not_equal --- candlex/lib/candlex/backend.ex | 1 + candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 1 + candlex/test/candlex_test.exs | 22 ++++++++++++++++++++++ 5 files changed, 26 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 4f305f6b6a..d831185fb7 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -257,6 +257,7 @@ defmodule Candlex.Backend do :less_equal, :logical_or, :logical_xor, + :not_equal, :right_shift ] do @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 388be2469f..3b594bfb9f 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -71,6 +71,7 @@ defmodule Candlex.Native do :max, :min, :multiply, + :not_equal, :pow, :right_shift, :subtract diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 0bbf0ccae4..33b597b35e 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -31,6 +31,7 @@ rustler::init! { tensors::max, tensors::min, tensors::equal, + tensors::not_equal, tensors::greater, tensors::greater_equal, tensors::less, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index f4a400ee67..50812359e6 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -378,6 +378,7 @@ binary_nif!(multiply, broadcast_mul); binary_nif!(max, broadcast_maximum); binary_nif!(min, broadcast_minimum); binary_nif!(equal, eq); +binary_nif!(not_equal, ne); binary_nif!(greater, gt); binary_nif!(greater_equal, ge); binary_nif!(less, lt); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 6800e7615c..ef1de8bb92 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1785,6 +1785,28 @@ defmodule CandlexTest do ] )) end + + test "not_equal" do + Nx.not_equal(1, 2) + |> assert_equal(t(1)) + + t([1, 2, 3]) + |> Nx.not_equal(t(1)) + |> assert_equal(t([0, 1, 1])) + + t([1, 1, 2]) + |> Nx.not_equal(t([1, 2, 3])) + |> assert_equal(t([0, 1, 1])) + + t([[1, 4, 2], [4, 5, 6]]) + |> Nx.not_equal(t([[1, 3, 2], [4, 2, 1]])) + |> assert_equal(t( + [ + [0, 1, 0], + [0, 1, 1] + ] + )) + end end defp t(values, opts \\ []) do From df34c1e68e6feee5606f8a12e4c70bfa016b9fe8 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 3 Oct 2023 11:26:13 -0300 Subject: [PATCH 155/185] fix automatic order after compile --- candlex/native/candlex/src/kernels.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/candlex/native/candlex/src/kernels.rs b/candlex/native/candlex/src/kernels.rs index 13317b35c7..c6627efde7 100644 --- a/candlex/native/candlex/src/kernels.rs +++ b/candlex/native/candlex/src/kernels.rs @@ -1,4 +1,4 @@ #[rustfmt::skip] -pub const CUSTOM_BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_binary.ptx")); -#[rustfmt::skip] pub const CUSTOM_UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_unary.ptx")); +#[rustfmt::skip] +pub const CUSTOM_BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_binary.ptx")); From 6307474f2ec8f1956e79b343033f0a2aef7e055c Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 3 Oct 2023 13:41:07 -0300 Subject: [PATCH 156/185] fix eye when backend options present --- candlex/lib/candlex/backend.ex | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index d831185fb7..c5bc661426 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -66,7 +66,7 @@ defmodule Candlex.Backend do @impl true def eye(%T{shape: shape, type: type} = out, backend_options) do - iota = Nx.iota(shape, backend_options) + iota = Nx.iota(shape, backend: {__MODULE__, backend_options}) Nx.equal(Nx.tril(iota), Nx.triu(iota)) |> Nx.as_type(type) From 02b57e2fa0a5ba83513389cf32751282f1bfe241 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Tue, 3 Oct 2023 16:54:52 -0300 Subject: [PATCH 157/185] memoize Device::new_cuda --- candlex/native/candlex/src/tensors.rs | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 50812359e6..296595f115 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -410,11 +410,22 @@ fn vec_to_tuple(env: Env, vec: Vec) -> Result { )) } +static CUDA_DEVICE: std::sync::Mutex> = std::sync::Mutex::new(None); + fn device_from_atom(atom: Atom) -> Result { if atom == atoms::cpu() { Ok(Device::Cpu) } else if atom == atoms::cuda() { - Ok(Device::new_cuda(0)?) + let mut cuda_device = CUDA_DEVICE.lock().unwrap(); + + if let Some(device) = cuda_device.as_ref() { + Ok(device.clone()) + } else { + let new_cuda_device = Device::new_cuda(0)?; + *cuda_device = Some(new_cuda_device.clone()); + + Ok(new_cuda_device) + } } else { Err(CandlexError::Other(format!( "unsupported device {:?}", From 97a91c1939a7469e4f324772b6f412393fb4a516 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 4 Oct 2023 15:02:05 -0300 Subject: [PATCH 158/185] cargo update --- candlex/native/candlex/Cargo.lock | 32 +++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index c592921c83..9e6e4d35ec 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -52,8 +52,8 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" -version = "0.2.3" -source = "git+https://github.com/huggingface/candle#fc59bc31bf49e707d64ed25fd29a8803f9a12fb4" +version = "0.3.0" +source = "git+https://github.com/huggingface/candle#c18a856e76cad9626406c3c483a53fb5b7eeef7b" dependencies = [ "byteorder", "candle-gemm", @@ -199,8 +199,8 @@ dependencies = [ [[package]] name = "candle-kernels" -version = "0.2.3" -source = "git+https://github.com/huggingface/candle#fc59bc31bf49e707d64ed25fd29a8803f9a12fb4" +version = "0.3.0" +source = "git+https://github.com/huggingface/candle#c18a856e76cad9626406c3c483a53fb5b7eeef7b" dependencies = [ "glob", "rayon", @@ -284,9 +284,9 @@ dependencies = [ [[package]] name = "dyn-stack" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24269739c7c175bc12130622ef1a60b9ab2d5b30c0b9ce5110cd406d7fd497bc" +checksum = "7fe7f8d7bcc523381d3c437b82cf74805de3931de0da69309ae0fe1bdf7a256e" dependencies = [ "bytemuck", "reborrow", @@ -376,9 +376,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.6.3" +version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f232d6ef707e1956a43342693d2a31e72989554d58299d7a88738cc95b0d35c" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" [[package]] name = "memmap2" @@ -585,15 +585,15 @@ dependencies = [ [[package]] name = "reborrow" -version = "0.5.4" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2962bf2e1f971c53ef59b2d7ca51d6a5e5c4a9d2be47eb1f661a321a4da85888" +checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" [[package]] name = "regex" -version = "1.9.5" +version = "1.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "697061221ea1b4a94a624f67d0ae2bfe4e22b8a17b6a192afb11046542cc8c47" +checksum = "ebee201405406dbf528b8b672104ae6d6d63e6d118cb10e4d51abbc7b58044ff" dependencies = [ "aho-corasick", "memchr", @@ -603,9 +603,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" +checksum = "59b23e92ee4318893fa3fe3e6fb365258efbfe6ac6ab30f090cdcbb7aa37efa9" dependencies = [ "aho-corasick", "memchr", @@ -846,9 +846,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wide" -version = "0.7.11" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa469ffa65ef7e0ba0f164183697b89b854253fd31aeb92358b7b6155177d62f" +checksum = "ebecebefc38ff1860b4bc47550bbfa63af5746061cf0d29fcd7fa63171602598" dependencies = [ "bytemuck", "safe_arch", From 0aa97d1ba283d06d0d6dec13461ee24102086514 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 4 Oct 2023 15:03:27 -0300 Subject: [PATCH 159/185] replace custom rounding ops with recently-added candle-native rounding ops --- candlex/native/candlex/src/kernels/custom_unary.cu | 12 ------------ candlex/native/candlex/src/ops.rs | 3 --- candlex/native/candlex/src/tensors.rs | 10 +++++----- 3 files changed, 5 insertions(+), 20 deletions(-) diff --git a/candlex/native/candlex/src/kernels/custom_unary.cu b/candlex/native/candlex/src/kernels/custom_unary.cu index 441661066c..a1ded98c67 100644 --- a/candlex/native/candlex/src/kernels/custom_unary.cu +++ b/candlex/native/candlex/src/kernels/custom_unary.cu @@ -18,16 +18,10 @@ DEVICE_FN_FLOAT_WRAPPER(atan) DEVICE_FN_DOUBLE_WRAPPER(atan) DEVICE_FN_FLOAT_WRAPPER(cbrt) DEVICE_FN_DOUBLE_WRAPPER(cbrt) -DEVICE_FN_FLOAT_WRAPPER(ceil) -DEVICE_FN_DOUBLE_WRAPPER(ceil) DEVICE_FN_FLOAT_WRAPPER(erfinv) DEVICE_FN_DOUBLE_WRAPPER(erfinv) DEVICE_FN_FLOAT_WRAPPER(exp) DEVICE_FN_DOUBLE_WRAPPER(exp) -DEVICE_FN_FLOAT_WRAPPER(floor) -DEVICE_FN_DOUBLE_WRAPPER(floor) -DEVICE_FN_FLOAT_WRAPPER(round) -DEVICE_FN_DOUBLE_WRAPPER(round) DEVICE_FN_FLOAT_WRAPPER(log1p) DEVICE_FN_DOUBLE_WRAPPER(log1p) DEVICE_FN_FLOAT_WRAPPER(tan) @@ -72,16 +66,10 @@ CUSTOM_UNARY_OP(uint32_t, bit_not_u32, ~x) CUSTOM_UNARY_OP(int64_t, bit_not_i64, ~x) CUSTOM_UNARY_OP(float, cbrt_f32, cbrtg(x)) CUSTOM_UNARY_OP(double, cbrt_f64, cbrtg(x)) -CUSTOM_UNARY_OP(float, ceil_f32, ceilg(x)) -CUSTOM_UNARY_OP(double, ceil_f64, ceilg(x)) CUSTOM_UNARY_OP(float, erf_inv_f32, erfinvg(x)) CUSTOM_UNARY_OP(double, erf_inv_f64, erfinvg(x)) -CUSTOM_UNARY_OP(float, floor_f32, floorg(x)) -CUSTOM_UNARY_OP(double, floor_f64, floorg(x)) CUSTOM_UNARY_OP(float, ln_1p_f32, log1pg(x)) CUSTOM_UNARY_OP(double, ln_1p_f64, log1pg(x)) -CUSTOM_UNARY_OP(float, round_f32, roundg(x)) -CUSTOM_UNARY_OP(double, round_f64, roundg(x)) CUSTOM_UNARY_OP(float, sigmoid_f32, 1.0 / (1.0 + expg(-x))) CUSTOM_UNARY_OP(double, sigmoid_f64, 1.0 / (1.0 + expg(-x))) CUSTOM_UNARY_OP(float, tan_f32, tang(x)) diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index b2ca0aff7d..b70ffb9e09 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -402,11 +402,8 @@ custom_unary_op!(Asin, "asin", |v| v.asin(), (BF16, F16, F32, F64)); custom_unary_op!(Atan, "atan", |v| v.atan(), (BF16, F16, F32, F64)); custom_unary_op!(BitNot, "bit_not", |v| !v, (U8, U32, I64)); custom_unary_op!(Cbrt, "cbrt", |v| v.cbrt(), (BF16, F16, F32, F64)); -custom_unary_op!(Ceil, "ceil", |v| v.ceil(), (BF16, F16, F32, F64)); custom_unary_op!(ErfInv, "erf_inv", |v| erf_inv(v), (BF16, F16, F32, F64)); -custom_unary_op!(Floor, "floor", |v| v.floor(), (BF16, F16, F32, F64)); custom_unary_op!(Log1p, "ln_1p", |v| v.ln_1p(), (BF16, F16, F32, F64)); -custom_unary_op!(Round, "round", |v| v.round(), (BF16, F16, F32, F64)); custom_unary_op!(Sigmoid, "sigmoid", |v| 1. / (1. + (-v).exp()), (F32, F64)); custom_unary_op!(Tan, "tan", |v| v.tan(), (BF16, F16, F32, F64)); custom_unary_bool_op!(IsInf, "is_inf", is_infinite, (F32, F64)); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 296595f115..6236b35104 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,8 +1,8 @@ use crate::atoms; use crate::error::CandlexError; use crate::ops::{ - Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, Ceil, ErfInv, Floor, IsInf, Log1p, - LogicalOr, LogicalXor, Pow, Round, Shl, Shr, Sigmoid, Tan, + Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, ErfInv, IsInf, Log1p, + LogicalOr, LogicalXor, Pow, Shl, Shr, Sigmoid, Tan, }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -351,8 +351,11 @@ macro_rules! custom_binary_nif { unary_nif!(negate, neg); unary_nif!(abs); +unary_nif!(ceil); unary_nif!(cos); unary_nif!(exp); +unary_nif!(floor); +unary_nif!(round); unary_nif!(sin); unary_nif!(log); unary_nif!(sqrt); @@ -363,12 +366,9 @@ custom_unary_nif!(asin, Asin); custom_unary_nif!(atan, Atan); custom_unary_nif!(bitwise_not, BitNot); custom_unary_nif!(cbrt, Cbrt); -custom_unary_nif!(ceil, Ceil); custom_unary_nif!(erf_inv, ErfInv); -custom_unary_nif!(floor, Floor); custom_unary_nif!(is_infinity, IsInf); custom_unary_nif!(log1p, Log1p); -custom_unary_nif!(round, Round); custom_unary_nif!(sigmoid, Sigmoid); custom_unary_nif!(tan, Tan); From 4f3c2af4055297425040c3a425d643a9a5013ab6 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 4 Oct 2023 15:04:26 -0300 Subject: [PATCH 160/185] cargo fmt --- candlex/native/candlex/build.rs | 25 ++++++++++++------------- candlex/native/candlex/src/tensors.rs | 16 ++++++++++++---- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/candlex/native/candlex/build.rs b/candlex/native/candlex/build.rs index 138bf992c0..86a9c6183c 100644 --- a/candlex/native/candlex/build.rs +++ b/candlex/native/candlex/build.rs @@ -22,23 +22,23 @@ impl KernelDirectories { ptx_file: &std::path::Path, compute_cap: usize, ) -> Result<()> { - let should_compile = - if ptx_file.exists() { - cu_file - .metadata()? - .modified()? - .duration_since(ptx_file.metadata()?.modified()?) - .is_ok() - } else { - true - }; + let should_compile = if ptx_file.exists() { + cu_file + .metadata()? + .modified()? + .duration_since(ptx_file.metadata()?.modified()?) + .is_ok() + } else { + true + }; if should_compile { #[cfg(feature = "cuda")] { let mut command = std::process::Command::new("nvcc"); let out_dir = ptx_file.parent().context("no parent for ptx file")?; - let include_dirs: Vec = self.include_dirs.iter().map(|c| format!("-I{c}")).collect(); + let include_dirs: Vec = + self.include_dirs.iter().map(|c| format!("-I{c}")).collect(); command .arg(format!("--gpu-architecture=sm_{compute_cap}")) @@ -49,8 +49,7 @@ impl KernelDirectories { .args(include_dirs) .arg(cu_file); - let output = - command + let output = command .spawn() .context("failed spawning nvcc")? .wait_with_output()?; diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 6236b35104..cef527f429 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,8 +1,8 @@ use crate::atoms; use crate::error::CandlexError; use crate::ops::{ - Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, ErfInv, IsInf, Log1p, - LogicalOr, LogicalXor, Pow, Shl, Shr, Sigmoid, Tan, + Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, ErfInv, IsInf, Log1p, LogicalOr, + LogicalXor, Pow, Shl, Shr, Sigmoid, Tan, }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -111,7 +111,10 @@ pub fn squeeze(t: ExTensor, dim: usize) -> Result { #[rustler::nif(schedule = "DirtyCpu")] pub fn clamp(t: ExTensor, min_val: ExTensor, max_val: ExTensor) -> Result { - Ok(ExTensor::new(t.clamp(&min_val.broadcast_as(t.shape())?, &max_val.broadcast_as(t.shape())?)?)) + Ok(ExTensor::new(t.clamp( + &min_val.broadcast_as(t.shape())?, + &max_val.broadcast_as(t.shape())?, + )?)) } #[rustler::nif(schedule = "DirtyCpu")] @@ -222,7 +225,12 @@ pub fn reshape(t: ExTensor, shape: Term) -> Result { } #[rustler::nif(schedule = "DirtyCpu")] -pub fn slice_scatter(t: ExTensor, src: ExTensor, dim: usize, start: usize) -> Result { +pub fn slice_scatter( + t: ExTensor, + src: ExTensor, + dim: usize, + start: usize, +) -> Result { Ok(ExTensor::new(t.slice_scatter(src.deref(), dim, start)?)) } From 0c954178245815ea1c8d07336f358e243b931d33 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 4 Oct 2023 15:50:10 -0300 Subject: [PATCH 161/185] mix format --- candlex/lib/candlex/backend.ex | 1 + 1 file changed, 1 insertion(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index c5bc661426..7bcde68682 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -514,6 +514,7 @@ defmodule Candlex.Backend do def pad(%T{} = out, %T{} = t, pad_value, []) do out end + def pad(%T{} = out, %T{} = t, %T{shape: {}} = pad_value, [{low, high, 0 = _inner}]) do if !Nx.equal(pad_value, 0) do raise "only pad_value=0 supported for now" From b4af44adb499c42b8d8d254376e1bc92b279316b Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 4 Oct 2023 16:12:09 -0300 Subject: [PATCH 162/185] removes unnecessary alias --- candlex/lib/candlex/backend.ex | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 7bcde68682..5b1a27d90e 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -8,7 +8,6 @@ defmodule Candlex.Backend do @behaviour Nx.Backend alias Nx.Tensor, as: T - alias Candlex.Backend, as: CB alias Candlex.Native @device_cuda :cuda @@ -583,7 +582,7 @@ defmodule Candlex.Backend do |> maybe_add_signature(tensor) end - defp maybe_add_signature(result, %T{data: %CB{device: device, resource: ref}}) + defp maybe_add_signature(result, %T{data: %__MODULE__{device: device, resource: ref}}) when is_reference(ref) do Inspect.Algebra.concat([ "Candlex.Backend(#{device})", @@ -725,15 +724,15 @@ defmodule Candlex.Backend do end @doc false - defp from_nx(%T{data: %CB{} = data}), do: data + defp from_nx(%T{data: %__MODULE__{} = data}), do: data defp from_nx(%T{} = tensor) do tensor - |> Nx.backend_transfer(CB) + |> Nx.backend_transfer(__MODULE__) |> from_nx() end - defp to_nx(%CB{resource: ref} = backend_tensor, %T{type: nx_type, shape: nx_shape} = t) + defp to_nx(%__MODULE__{resource: ref} = backend_tensor, %T{type: nx_type, shape: nx_shape} = t) when is_reference(ref) do {:ok, candle_dtype} = Native.dtype(backend_tensor) {:ok, candle_shape} = Native.t_shape(backend_tensor) From 1937cdd26c87abd7a20b764febf07f8e71c8fb95 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 4 Oct 2023 16:15:20 -0300 Subject: [PATCH 163/185] fix warnings --- candlex/lib/candlex/backend.ex | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 5b1a27d90e..d15bfc36f2 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -64,7 +64,7 @@ defmodule Candlex.Backend do end @impl true - def eye(%T{shape: shape, type: type} = out, backend_options) do + def eye(%T{shape: shape, type: type} = _out, backend_options) do iota = Nx.iota(shape, backend: {__MODULE__, backend_options}) Nx.equal(Nx.tril(iota), Nx.triu(iota)) @@ -510,7 +510,7 @@ defmodule Candlex.Backend do end @impl true - def pad(%T{} = out, %T{} = t, pad_value, []) do + def pad(%T{} = out, %T{} = _t, _pad_value, []) do out end @@ -696,11 +696,6 @@ defmodule Candlex.Backend do |> unwrap!() remainder > 0 && leftover == :repeat -> - slice_shape = - shape - |> Tuple.delete_at(first_dimension) - |> Tuple.insert_at(first_dimension, remainder) - [ native_tensor, Native.narrow(native_tensor, first_dimension, 0, batch_size - remainder) @@ -762,7 +757,7 @@ defmodule Candlex.Backend do defp to_candle_dtype({:u, 8}), do: "u8" defp to_candle_dtype({:u, 16} = t), do: unsupported_dtype(t) defp to_candle_dtype({:u, 32}), do: "u32" - defp to_candle_dtype({:u, 64} = t), do: "i64" + defp to_candle_dtype({:u, 64}), do: "i64" defp to_candle_dtype({:f, 16}), do: "f16" defp to_candle_dtype({:f, 32}), do: "f32" defp to_candle_dtype({:f, 64}), do: "f64" @@ -802,10 +797,6 @@ defmodule Candlex.Backend do raise("Unsupported candle dtype for #{inspect(t)}") end - defp unsupported_op(op_name) do - raise("Unsupported candlex operation '#{op_name}'") - end - defp unsupported_option!(opts, key, acceptable_default) do if opts[key] != nil and opts[key] != acceptable_default do raise "#{inspect(key)} option with #{inspect(opts[key])} is not supported" From 27b34d5aca2f0ce376b724a4af28a98419c48835 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 5 Oct 2023 14:52:52 -0300 Subject: [PATCH 164/185] cargo update --- candlex/native/candlex/Cargo.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 9e6e4d35ec..78a18543a6 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -53,7 +53,7 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" version = "0.3.0" -source = "git+https://github.com/huggingface/candle#c18a856e76cad9626406c3c483a53fb5b7eeef7b" +source = "git+https://github.com/huggingface/candle#8f7973958c55324a24f0c514e7ac6ded6681980f" dependencies = [ "byteorder", "candle-gemm", @@ -200,7 +200,7 @@ dependencies = [ [[package]] name = "candle-kernels" version = "0.3.0" -source = "git+https://github.com/huggingface/candle#c18a856e76cad9626406c3c483a53fb5b7eeef7b" +source = "git+https://github.com/huggingface/candle#8f7973958c55324a24f0c514e7ac6ded6681980f" dependencies = [ "glob", "rayon", From 085e4e431a7ee6f4fdaa7c9a2d2341b1b764116a Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Wed, 4 Oct 2023 10:39:28 -0300 Subject: [PATCH 165/185] test binary ops between tensors in different devices --- candlex/lib/candlex/backend.ex | 63 +++++++++++++++++++++++---- candlex/native/candlex/src/tensors.rs | 6 +-- candlex/test/candlex_test.exs | 26 +++++++++++ 3 files changed, 84 insertions(+), 11 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index d15bfc36f2..b324d5668f 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -73,6 +73,19 @@ defmodule Candlex.Backend do # Backend + @impl true + def backend_transfer(tensor, backend, backend_options) do + if backend == __MODULE__ && same_device?(tensor, device_option(backend_options)) do + tensor + else + try do + backend_copy(tensor, backend, backend_options) + after + backend_deallocate(tensor) + end + end + end + @impl true def backend_copy(%T{} = tensor, Candlex.Backend, backend_options) do tensor @@ -86,13 +99,6 @@ defmodule Candlex.Backend do backend.from_binary(tensor, to_binary(tensor), backend_options) end - @impl true - def backend_transfer(tensor, backend, backend_options) do - backend_copy(tensor, backend, backend_options) - after - backend_deallocate(tensor) - end - @impl true def backend_deallocate(%T{} = _tensor) do true @@ -222,6 +228,7 @@ defmodule Candlex.Backend do for op <- [:add, :divide, :max, :min, :multiply, :subtract] do @impl true def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do + {left, right} = maybe_transfer_device(left, right) {left, right} = maybe_upcast(left, right) from_nx(left) @@ -261,6 +268,7 @@ defmodule Candlex.Backend do ] do @impl true def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do + {left, right} = maybe_transfer_device(left, right) {left, right} = maybe_upcast(left, right) {left, right} = maybe_broadcast_bin_args(out.shape, left, right) @@ -677,6 +685,37 @@ defmodule Candlex.Backend do } end + defp maybe_transfer_device( + %T{data: %__MODULE__{device: device}} = l, + %T{data: %__MODULE__{device: device}} = r + ) do + {l, r} + end + + defp maybe_transfer_device( + %T{data: %__MODULE__{device: device}} = l, + %T{data: %__MODULE__{device: _other_device}} = r + ) do + { + l, + r |> Nx.backend_transfer({__MODULE__, device: device}) + } + end + + defp maybe_transfer_device(%T{} = l, %T{data: %__MODULE__{device: device}} = r) do + { + l |> Nx.backend_transfer({__MODULE__, device: device}), + r + } + end + + defp maybe_transfer_device(%T{data: %__MODULE__{device: device}} = l, %T{} = r) do + { + l, + r |> Nx.backend_transfer({__MODULE__, device: device}) + } + end + ## Conversions @impl true @@ -789,7 +828,15 @@ defmodule Candlex.Backend do end end - defp cuda_available? do + defp same_device?(%T{data: %__MODULE__{device: device}}, device) do + true + end + + defp same_device?(_t, _d) do + false + end + + def cuda_available? do Native.is_cuda_available() end diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index cef527f429..3d8a518104 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -16,15 +16,15 @@ pub(crate) struct TensorRef(Tensor); #[derive(NifStruct)] #[module = "Candlex.Backend"] pub struct ExTensor { - device: String, + device: Atom, resource: ResourceArc, } impl ExTensor { pub fn new(tensor: Tensor) -> Self { let dev_string = match tensor.device() { - Device::Cpu => String::from("cpu"), - Device::Cuda(_) => String::from("cuda"), + Device::Cpu => atoms::cpu(), + Device::Cuda(_) => atoms::cuda(), }; Self { diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index d1cd2d2a06..b9c790d808 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1807,6 +1807,32 @@ defmodule CandlexTest do ] )) end + + if Candlex.Backend.cuda_available? do + test "different devices" do + t([1, 2, 3], backend: {Candlex.Backend, device: :cpu}) + |> Nx.add(t([10, 20, 30], backend: {Candlex.Backend, device: :cuda})) + |> assert_equal(t([11, 22, 33])) + + t([1, 2, 3], backend: {Candlex.Backend, device: :cuda}) + |> Nx.add(t([10, 20, 30], backend: {Candlex.Backend, device: :cpu})) + |> assert_equal(t([11, 22, 33])) + end + end + + test "backend_transfer" do + t([1, 2, 3], backend: Nx.BinaryBackend) + |> Nx.backend_transfer({Candlex.Backend, device: :cpu}) + |> assert_equal(t([1, 2, 3])) + + t([1, 2, 3], backend: {Candlex.Backend, device: :cpu}) + |> Nx.backend_transfer(Nx.BinaryBackend) + |> assert_equal(t([1, 2, 3])) + + t([1, 2, 3], backend: {Candlex.Backend, device: :cpu}) + |> Nx.backend_transfer({Candlex.Backend, device: :cpu}) + |> assert_equal(t([1, 2, 3])) + end end defp t(values, opts \\ []) do From e5e8b9645d203d17489da1495cfd2b64570ce962 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 5 Oct 2023 16:10:21 -0300 Subject: [PATCH 166/185] add acosh/asinh/atanh/cosh/sinh --- candlex/lib/candlex/backend.ex | 5 +++ candlex/lib/candlex/native.ex | 5 +++ .../candlex/src/kernels/custom_unary.cu | 20 +++++++++ candlex/native/candlex/src/lib.rs | 5 +++ candlex/native/candlex/src/ops.rs | 5 +++ candlex/native/candlex/src/tensors.rs | 9 +++- candlex/test/candlex_test.exs | 45 +++++++++++++++++++ 7 files changed, 92 insertions(+), 2 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index b324d5668f..42272a27ec 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -287,12 +287,16 @@ defmodule Candlex.Backend do for op <- [ :abs, :acos, + :acosh, :asin, + :asinh, :atan, + :atanh, :bitwise_not, :cbrt, :ceil, :cos, + :cosh, :erf_inv, :exp, :floor, @@ -304,6 +308,7 @@ defmodule Candlex.Backend do :rsqrt, :sigmoid, :sin, + :sinh, :sqrt, :tan, :tanh diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 3b594bfb9f..b8516a77a2 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -29,12 +29,16 @@ defmodule Candlex.Native do for op <- [ :abs, :acos, + :acosh, :asin, + :asinh, :atan, + :atanh, :bitwise_not, :cbrt, :ceil, :cos, + :cosh, :erf_inv, :exp, :floor, @@ -46,6 +50,7 @@ defmodule Candlex.Native do :rsqrt, :sigmoid, :sin, + :sinh, :sqrt, :tan, :tanh diff --git a/candlex/native/candlex/src/kernels/custom_unary.cu b/candlex/native/candlex/src/kernels/custom_unary.cu index a1ded98c67..a8fe8f34b5 100644 --- a/candlex/native/candlex/src/kernels/custom_unary.cu +++ b/candlex/native/candlex/src/kernels/custom_unary.cu @@ -12,18 +12,28 @@ DEVICE_FN_FLOAT_WRAPPER(acos) DEVICE_FN_DOUBLE_WRAPPER(acos) +DEVICE_FN_FLOAT_WRAPPER(acosh) +DEVICE_FN_DOUBLE_WRAPPER(acosh) DEVICE_FN_FLOAT_WRAPPER(asin) DEVICE_FN_DOUBLE_WRAPPER(asin) +DEVICE_FN_FLOAT_WRAPPER(asinh) +DEVICE_FN_DOUBLE_WRAPPER(asinh) DEVICE_FN_FLOAT_WRAPPER(atan) DEVICE_FN_DOUBLE_WRAPPER(atan) +DEVICE_FN_FLOAT_WRAPPER(atanh) +DEVICE_FN_DOUBLE_WRAPPER(atanh) DEVICE_FN_FLOAT_WRAPPER(cbrt) DEVICE_FN_DOUBLE_WRAPPER(cbrt) +DEVICE_FN_FLOAT_WRAPPER(cosh) +DEVICE_FN_DOUBLE_WRAPPER(cosh) DEVICE_FN_FLOAT_WRAPPER(erfinv) DEVICE_FN_DOUBLE_WRAPPER(erfinv) DEVICE_FN_FLOAT_WRAPPER(exp) DEVICE_FN_DOUBLE_WRAPPER(exp) DEVICE_FN_FLOAT_WRAPPER(log1p) DEVICE_FN_DOUBLE_WRAPPER(log1p) +DEVICE_FN_FLOAT_WRAPPER(sinh) +DEVICE_FN_DOUBLE_WRAPPER(sinh) DEVICE_FN_FLOAT_WRAPPER(tan) DEVICE_FN_DOUBLE_WRAPPER(tan) @@ -57,21 +67,31 @@ extern "C" __global__ void FN_NAME( \ CUSTOM_UNARY_OP(float, acos_f32, acosg(x)) CUSTOM_UNARY_OP(double, acos_f64, acosg(x)) +CUSTOM_UNARY_OP(float, acosh_f32, acoshg(x)) +CUSTOM_UNARY_OP(double, acosh_f64, acoshg(x)) CUSTOM_UNARY_OP(float, asin_f32, asing(x)) CUSTOM_UNARY_OP(double, asin_f64, asing(x)) +CUSTOM_UNARY_OP(float, asinh_f32, asinhg(x)) +CUSTOM_UNARY_OP(double, asinh_f64, asinhg(x)) CUSTOM_UNARY_OP(float, atan_f32, atang(x)) CUSTOM_UNARY_OP(double, atan_f64, atang(x)) +CUSTOM_UNARY_OP(float, atanh_f32, atanhg(x)) +CUSTOM_UNARY_OP(double, atanh_f64, atanhg(x)) CUSTOM_UNARY_OP(uint8_t, bit_not_u8, ~x) CUSTOM_UNARY_OP(uint32_t, bit_not_u32, ~x) CUSTOM_UNARY_OP(int64_t, bit_not_i64, ~x) CUSTOM_UNARY_OP(float, cbrt_f32, cbrtg(x)) CUSTOM_UNARY_OP(double, cbrt_f64, cbrtg(x)) +CUSTOM_UNARY_OP(float, cosh_f32, coshg(x)) +CUSTOM_UNARY_OP(double, cosh_f64, coshg(x)) CUSTOM_UNARY_OP(float, erf_inv_f32, erfinvg(x)) CUSTOM_UNARY_OP(double, erf_inv_f64, erfinvg(x)) CUSTOM_UNARY_OP(float, ln_1p_f32, log1pg(x)) CUSTOM_UNARY_OP(double, ln_1p_f64, log1pg(x)) CUSTOM_UNARY_OP(float, sigmoid_f32, 1.0 / (1.0 + expg(-x))) CUSTOM_UNARY_OP(double, sigmoid_f64, 1.0 / (1.0 + expg(-x))) +CUSTOM_UNARY_OP(float, sinh_f32, sinhg(x)) +CUSTOM_UNARY_OP(double, sinh_f64, sinhg(x)) CUSTOM_UNARY_OP(float, tan_f32, tang(x)) CUSTOM_UNARY_OP(double, tan_f64, tang(x)) diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index c1c85edb5e..8fc8695962 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -66,13 +66,18 @@ rustler::init! { tensors::matmul, tensors::abs, tensors::acos, + tensors::acosh, tensors::asin, + tensors::asinh, tensors::atan, + tensors::atanh, tensors::cbrt, tensors::ceil, tensors::cos, + tensors::cosh, tensors::sigmoid, tensors::sin, + tensors::sinh, tensors::erf_inv, tensors::exp, tensors::floor, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index b70ffb9e09..ef3c655c58 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -398,13 +398,18 @@ macro_rules! custom_binary_bool_op { } custom_unary_op!(Acos, "acos", |v| v.acos(), (BF16, F16, F32, F64)); +custom_unary_op!(Acosh, "acosh", |v| v.acosh(), (BF16, F16, F32, F64)); custom_unary_op!(Asin, "asin", |v| v.asin(), (BF16, F16, F32, F64)); +custom_unary_op!(Asinh, "asinh", |v| v.asinh(), (BF16, F16, F32, F64)); custom_unary_op!(Atan, "atan", |v| v.atan(), (BF16, F16, F32, F64)); +custom_unary_op!(Atanh, "atanh", |v| v.atanh(), (BF16, F16, F32, F64)); custom_unary_op!(BitNot, "bit_not", |v| !v, (U8, U32, I64)); custom_unary_op!(Cbrt, "cbrt", |v| v.cbrt(), (BF16, F16, F32, F64)); +custom_unary_op!(Cosh, "cosh", |v| v.cosh(), (BF16, F16, F32, F64)); custom_unary_op!(ErfInv, "erf_inv", |v| erf_inv(v), (BF16, F16, F32, F64)); custom_unary_op!(Log1p, "ln_1p", |v| v.ln_1p(), (BF16, F16, F32, F64)); custom_unary_op!(Sigmoid, "sigmoid", |v| 1. / (1. + (-v).exp()), (F32, F64)); +custom_unary_op!(Sinh, "sinh", |v| v.sinh(), (BF16, F16, F32, F64)); custom_unary_op!(Tan, "tan", |v| v.tan(), (BF16, F16, F32, F64)); custom_unary_bool_op!(IsInf, "is_inf", is_infinite, (F32, F64)); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 3d8a518104..f659c7afaa 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,8 +1,8 @@ use crate::atoms; use crate::error::CandlexError; use crate::ops::{ - Acos, Asin, Atan, BitAnd, BitNot, BitOr, BitXor, Cbrt, ErfInv, IsInf, Log1p, LogicalOr, - LogicalXor, Pow, Shl, Shr, Sigmoid, Tan, + Acos, Acosh, Asin, Asinh, Atan, Atanh, BitAnd, BitNot, BitOr, BitXor, Cbrt, Cosh, ErfInv, + IsInf, Log1p, LogicalOr, LogicalXor, Pow, Shl, Shr, Sigmoid, Sinh, Tan, }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -370,14 +370,19 @@ unary_nif!(sqrt); unary_nif!(tanh); custom_unary_nif!(acos, Acos); +custom_unary_nif!(acosh, Acosh); custom_unary_nif!(asin, Asin); +custom_unary_nif!(asinh, Asinh); custom_unary_nif!(atan, Atan); +custom_unary_nif!(atanh, Atanh); custom_unary_nif!(bitwise_not, BitNot); custom_unary_nif!(cbrt, Cbrt); +custom_unary_nif!(cosh, Cosh); custom_unary_nif!(erf_inv, ErfInv); custom_unary_nif!(is_infinity, IsInf); custom_unary_nif!(log1p, Log1p); custom_unary_nif!(sigmoid, Sigmoid); +custom_unary_nif!(sinh, Sinh); custom_unary_nif!(tan, Tan); binary_nif!(add, broadcast_add); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index b9c790d808..125ac47e8b 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -562,6 +562,15 @@ defmodule CandlexTest do |> assert_close(t([0.8414709568023682, 0.9092974066734314, 0.14112000167369843])) end + test "sinh" do + Nx.sinh(1.0) + |> assert_close(t(1.175201177597046)) + + t([1.0, 2, 3]) + |> Nx.sinh() + |> assert_close(t([1.175201177597046, 3.6268603801727295, 10.017874717712402])) + end + test "exp" do Nx.exp(1.0) |> assert_equal(t(2.7182817459106445)) @@ -580,6 +589,15 @@ defmodule CandlexTest do |> assert_close(t([0.5403022766113281, -0.416146844625473, -0.9899924993515015])) end + test "cosh" do + Nx.cosh(1.0) + |> assert_close(t(1.5430806875228882)) + + t([1.0, 2, 3]) + |> Nx.cosh() + |> assert_close(t([1.5430806875228882, 3.762195587158203, 10.067662239074707])) + end + test "log" do Nx.log(1.0) |> assert_equal(t(0.0)) @@ -717,6 +735,15 @@ defmodule CandlexTest do |> assert_equal(t([1.4706288576126099, 1.0471975803375244, 0.4510268568992615])) end + test "acosh" do + Nx.acosh(1.0) + |> assert_equal(t(0.0)) + + t([1.0, 2, 3]) + |> Nx.acosh() + |> assert_close(t([0.0, 1.316957950592041, 1.7627471685409546])) + end + test "asin" do Nx.asin(0.10000000149011612) |> assert_equal(t(0.1001674234867096)) @@ -726,6 +753,15 @@ defmodule CandlexTest do |> assert_equal(t([0.1001674234867096, 0.5235987901687622, 1.1197694540023804])) end + test "asinh" do + Nx.asinh(1.0) + |> assert_close(t(0.8813735842704773)) + + t([1.0, 2, 3]) + |> Nx.asinh() + |> assert_close(t([0.8813735842704773, 1.4436354637145996, 1.8184465169906616])) + end + test "tan" do Nx.tan(1.0) |> assert_close(t(1.5574077367782593)) @@ -744,6 +780,15 @@ defmodule CandlexTest do |> assert_close(t([0.09966865181922913, 0.46364760398864746, 0.7328150868415833])) end + test "atanh" do + Nx.atanh(0.10000000149011612) + |> assert_close(t(0.10033535212278366)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.atanh() + |> assert_close(t([0.10033535212278366, 0.5493061542510986, 1.4722193479537964])) + end + test "ceil" do t([-1, 0, 1]) |> Nx.ceil() From 4d607cb3e81c624b48b3dd7149f8552eb2ef7b61 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 5 Oct 2023 16:46:31 -0300 Subject: [PATCH 167/185] candlex is_nan --- candlex/lib/candlex/backend.ex | 1 + candlex/lib/candlex/native.ex | 1 + .../candlex/src/kernels/custom_unary.cu | 2 ++ candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 1 + candlex/native/candlex/src/tensors.rs | 3 ++- candlex/test/candlex_test.exs | 19 +++++++++++++++++++ 7 files changed, 27 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 42272a27ec..d2fc618b01 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -301,6 +301,7 @@ defmodule Candlex.Backend do :exp, :floor, :is_infinity, + :is_nan, :log, :log1p, :negate, diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index b8516a77a2..a453d5c9a8 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -43,6 +43,7 @@ defmodule Candlex.Native do :exp, :floor, :is_infinity, + :is_nan, :log, :log1p, :negate, diff --git a/candlex/native/candlex/src/kernels/custom_unary.cu b/candlex/native/candlex/src/kernels/custom_unary.cu index a8fe8f34b5..19f6da8b36 100644 --- a/candlex/native/candlex/src/kernels/custom_unary.cu +++ b/candlex/native/candlex/src/kernels/custom_unary.cu @@ -97,3 +97,5 @@ CUSTOM_UNARY_OP(double, tan_f64, tang(x)) CUSTOM_UNARY_OP_OUT(float, uint8_t, is_inf_f32, isinf(x) ? 1 : 0) CUSTOM_UNARY_OP_OUT(double, uint8_t, is_inf_f64, isinf(x) ? 1 : 0) +CUSTOM_UNARY_OP_OUT(float, uint8_t, is_nan_f32, isnan(x) ? 1 : 0) +CUSTOM_UNARY_OP_OUT(double, uint8_t, is_nan_f64, isnan(x) ? 1 : 0) diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 8fc8695962..b01ddb11da 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -82,6 +82,7 @@ rustler::init! { tensors::exp, tensors::floor, tensors::is_infinity, + tensors::is_nan, tensors::round, tensors::log, tensors::log1p, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index ef3c655c58..5e2add742b 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -412,6 +412,7 @@ custom_unary_op!(Sigmoid, "sigmoid", |v| 1. / (1. + (-v).exp()), (F32, F64)); custom_unary_op!(Sinh, "sinh", |v| v.sinh(), (BF16, F16, F32, F64)); custom_unary_op!(Tan, "tan", |v| v.tan(), (BF16, F16, F32, F64)); custom_unary_bool_op!(IsInf, "is_inf", is_infinite, (F32, F64)); +custom_unary_bool_op!(IsNan, "is_nan", is_nan, (F32, F64)); custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2, (U32, I64)); custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2, (U32, I64)); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index f659c7afaa..93d685f24a 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -2,7 +2,7 @@ use crate::atoms; use crate::error::CandlexError; use crate::ops::{ Acos, Acosh, Asin, Asinh, Atan, Atanh, BitAnd, BitNot, BitOr, BitXor, Cbrt, Cosh, ErfInv, - IsInf, Log1p, LogicalOr, LogicalXor, Pow, Shl, Shr, Sigmoid, Sinh, Tan, + IsInf, IsNan, Log1p, LogicalOr, LogicalXor, Pow, Shl, Shr, Sigmoid, Sinh, Tan, }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -380,6 +380,7 @@ custom_unary_nif!(cbrt, Cbrt); custom_unary_nif!(cosh, Cosh); custom_unary_nif!(erf_inv, ErfInv); custom_unary_nif!(is_infinity, IsInf); +custom_unary_nif!(is_nan, IsNan); custom_unary_nif!(log1p, Log1p); custom_unary_nif!(sigmoid, Sigmoid); custom_unary_nif!(sinh, Sinh); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 125ac47e8b..2849b8352f 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -970,6 +970,25 @@ defmodule CandlexTest do # |> assert_equal(t([0, 0])) end + test "is_nan" do + t([:nan, 1.0, 0.0]) + |> Nx.is_nan() + |> assert_equal(t([1, 0, 0])) + + t([:nan, :infinity]) + |> Nx.is_nan() + |> assert_equal(t([1, 0])) + + # Complex not yet supported + # t(Complex.new(0, :nan)) + # |> Nx.is_nan() + # |> assert_equal(t(1)) + + t([1.0, 0.0]) + |> Nx.is_nan() + |> assert_equal(t([0, 0])) + end + test "logical_or" do Nx.logical_or(0, t([-1, 0, 1])) |> assert_equal(t([1, 0, 1])) From efbb099d03c74d7cbd72b69e65591c4e460a39ea Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 5 Oct 2023 16:53:57 -0300 Subject: [PATCH 168/185] candlex logical_and --- candlex/lib/candlex/backend.ex | 1 + candlex/lib/candlex/native.ex | 1 + .../candlex/src/kernels/custom_binary.cu | 3 +++ candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 6 +++++ candlex/native/candlex/src/tensors.rs | 3 ++- candlex/test/candlex_test.exs | 25 +++++++++++++++++++ 7 files changed, 39 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index d2fc618b01..840a529f81 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -261,6 +261,7 @@ defmodule Candlex.Backend do :left_shift, :less, :less_equal, + :logical_and, :logical_or, :logical_xor, :not_equal, diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index a453d5c9a8..3bfc4bb5b9 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -71,6 +71,7 @@ defmodule Candlex.Native do :left_shift, :less, :less_equal, + :logical_and, :logical_or, :logical_xor, :matmul, diff --git a/candlex/native/candlex/src/kernels/custom_binary.cu b/candlex/native/candlex/src/kernels/custom_binary.cu index 87248122e6..06ddaa099e 100644 --- a/candlex/native/candlex/src/kernels/custom_binary.cu +++ b/candlex/native/candlex/src/kernels/custom_binary.cu @@ -91,6 +91,9 @@ CUSTOM_BINARY_OP(int64_t, shl_i64, x << y) CUSTOM_BINARY_OP(uint32_t, shr_u32, x >> y) CUSTOM_BINARY_OP(int64_t, shr_i64, x >> y) +CUSTOM_BINARY_OP_OUT(uint8_t, uint8_t, logical_and_u8, x && y) +CUSTOM_BINARY_OP_OUT(int64_t, uint8_t, logical_and_i64, x && y) +CUSTOM_BINARY_OP_OUT(float, uint8_t, logical_and_f32, x && y) CUSTOM_BINARY_OP_OUT(uint8_t, uint8_t, logical_or_u8, x || y) CUSTOM_BINARY_OP_OUT(int64_t, uint8_t, logical_or_i64, x || y) CUSTOM_BINARY_OP_OUT(float, uint8_t, logical_or_f32, x || y) diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index b01ddb11da..b0bb705c4c 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -94,6 +94,7 @@ rustler::init! { tensors::bitwise_and, tensors::bitwise_or, tensors::bitwise_xor, + tensors::logical_and, tensors::logical_or, tensors::logical_xor, tensors::left_shift, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 5e2add742b..668832d475 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -420,6 +420,12 @@ custom_binary_op!(BitXor, "bit_xor", |v1, v2| v1 ^ v2, (U32, I64)); custom_binary_op!(Pow, "pow", |v1, v2| v1.powf(v2), (F32, F64)); custom_binary_op!(Shl, "shl", |v1, v2| v1 << v2, (U32, I64)); custom_binary_op!(Shr, "shr", |v1, v2| v1 >> v2, (U32, I64)); +custom_binary_bool_op!( + LogicalAnd, + "logical_and", + |v1, v2| if v1 as i8 != 0 && v2 as i8 != 0 { 1 } else { 0 }, + (U8, U32, I64, F32, F64) +); custom_binary_bool_op!( LogicalOr, "logical_or", diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 93d685f24a..611ce55e90 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -2,7 +2,7 @@ use crate::atoms; use crate::error::CandlexError; use crate::ops::{ Acos, Acosh, Asin, Asinh, Atan, Atanh, BitAnd, BitNot, BitOr, BitXor, Cbrt, Cosh, ErfInv, - IsInf, IsNan, Log1p, LogicalOr, LogicalXor, Pow, Shl, Shr, Sigmoid, Sinh, Tan, + IsInf, IsNan, Log1p, LogicalAnd, LogicalOr, LogicalXor, Pow, Shl, Shr, Sigmoid, Sinh, Tan, }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -403,6 +403,7 @@ custom_binary_nif!(bitwise_and, BitAnd); custom_binary_nif!(bitwise_or, BitOr); custom_binary_nif!(bitwise_xor, BitXor); custom_binary_nif!(left_shift, Shl); +custom_binary_nif!(logical_and, LogicalAnd); custom_binary_nif!(logical_or, LogicalOr); custom_binary_nif!(logical_xor, LogicalXor); custom_binary_nif!(pow, Pow); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 2849b8352f..a2476cacca 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -989,6 +989,31 @@ defmodule CandlexTest do |> assert_equal(t([0, 0])) end + test "logical_and" do + Nx.logical_and(1, t([-1, 0, 1])) + |> assert_equal(t([1, 0, 1])) + + t([-1, 0, 1]) + |> Nx.logical_and(t([[-1], [0], [1]])) + |> assert_equal(t( + [ + [1, 0, 1], + [0, 0, 0], + [1, 0, 1] + ] + )) + + t([-1.0, 0.0, 1.0]) + |> Nx.logical_and(t([[-1], [0], [1]]))\ + |> assert_equal(t( + [ + [1, 0, 1], + [0, 0, 0], + [1, 0, 1] + ] + )) + end + test "logical_or" do Nx.logical_or(0, t([-1, 0, 1])) |> assert_equal(t([1, 0, 1])) From 8ea2d03d0d4ae886268d8d066861a902ea70f5dc Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 5 Oct 2023 17:04:04 -0300 Subject: [PATCH 169/185] candlex erf/erfc --- candlex/lib/candlex/backend.ex | 2 ++ candlex/lib/candlex/native.ex | 2 ++ .../native/candlex/src/kernels/custom_unary.cu | 4 ++++ candlex/native/candlex/src/lib.rs | 2 ++ candlex/native/candlex/src/ops.rs | 5 +++++ candlex/native/candlex/src/tensors.rs | 4 +++- candlex/test/candlex_test.exs | 16 ++++++++++++++++ 7 files changed, 34 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 840a529f81..7f98dd757a 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -298,6 +298,8 @@ defmodule Candlex.Backend do :ceil, :cos, :cosh, + :erf, + :erfc, :erf_inv, :exp, :floor, diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 3bfc4bb5b9..256716e3da 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -39,6 +39,8 @@ defmodule Candlex.Native do :ceil, :cos, :cosh, + :erf, + :erfc, :erf_inv, :exp, :floor, diff --git a/candlex/native/candlex/src/kernels/custom_unary.cu b/candlex/native/candlex/src/kernels/custom_unary.cu index 19f6da8b36..24d101b408 100644 --- a/candlex/native/candlex/src/kernels/custom_unary.cu +++ b/candlex/native/candlex/src/kernels/custom_unary.cu @@ -26,6 +26,8 @@ DEVICE_FN_FLOAT_WRAPPER(cbrt) DEVICE_FN_DOUBLE_WRAPPER(cbrt) DEVICE_FN_FLOAT_WRAPPER(cosh) DEVICE_FN_DOUBLE_WRAPPER(cosh) +DEVICE_FN_FLOAT_WRAPPER(erfc) +DEVICE_FN_DOUBLE_WRAPPER(erfc) DEVICE_FN_FLOAT_WRAPPER(erfinv) DEVICE_FN_DOUBLE_WRAPPER(erfinv) DEVICE_FN_FLOAT_WRAPPER(exp) @@ -84,6 +86,8 @@ CUSTOM_UNARY_OP(float, cbrt_f32, cbrtg(x)) CUSTOM_UNARY_OP(double, cbrt_f64, cbrtg(x)) CUSTOM_UNARY_OP(float, cosh_f32, coshg(x)) CUSTOM_UNARY_OP(double, cosh_f64, coshg(x)) +CUSTOM_UNARY_OP(float, erfc_f32, erfcg(x)) +CUSTOM_UNARY_OP(double, erfc_f64, erfcg(x)) CUSTOM_UNARY_OP(float, erf_inv_f32, erfinvg(x)) CUSTOM_UNARY_OP(double, erf_inv_f64, erfinvg(x)) CUSTOM_UNARY_OP(float, ln_1p_f32, log1pg(x)) diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index b0bb705c4c..72cdca7b1c 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -78,6 +78,8 @@ rustler::init! { tensors::sigmoid, tensors::sin, tensors::sinh, + tensors::erf, + tensors::erfc, tensors::erf_inv, tensors::exp, tensors::floor, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 668832d475..389326edd2 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -4,6 +4,10 @@ use candle_core::{CpuStorage, CustomOp1, CustomOp2, Error, Layout, Shape}; use num_traits::cast::FromPrimitive; use num_traits::Float; +fn erfc(v: T) -> T { + FromPrimitive::from_f64(statrs::function::erf::erfc(v.to_f64().unwrap())).unwrap() +} + fn erf_inv(v: T) -> T { FromPrimitive::from_f64(statrs::function::erf::erf_inv(v.to_f64().unwrap())).unwrap() } @@ -406,6 +410,7 @@ custom_unary_op!(Atanh, "atanh", |v| v.atanh(), (BF16, F16, F32, F64)); custom_unary_op!(BitNot, "bit_not", |v| !v, (U8, U32, I64)); custom_unary_op!(Cbrt, "cbrt", |v| v.cbrt(), (BF16, F16, F32, F64)); custom_unary_op!(Cosh, "cosh", |v| v.cosh(), (BF16, F16, F32, F64)); +custom_unary_op!(Erfc, "erfc", |v| erfc(v), (BF16, F16, F32, F64)); custom_unary_op!(ErfInv, "erf_inv", |v| erf_inv(v), (BF16, F16, F32, F64)); custom_unary_op!(Log1p, "ln_1p", |v| v.ln_1p(), (BF16, F16, F32, F64)); custom_unary_op!(Sigmoid, "sigmoid", |v| 1. / (1. + (-v).exp()), (F32, F64)); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 611ce55e90..86d338290c 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,7 +1,7 @@ use crate::atoms; use crate::error::CandlexError; use crate::ops::{ - Acos, Acosh, Asin, Asinh, Atan, Atanh, BitAnd, BitNot, BitOr, BitXor, Cbrt, Cosh, ErfInv, + Acos, Acosh, Asin, Asinh, Atan, Atanh, BitAnd, BitNot, BitOr, BitXor, Cbrt, Cosh, ErfInv, Erfc, IsInf, IsNan, Log1p, LogicalAnd, LogicalOr, LogicalXor, Pow, Shl, Shr, Sigmoid, Sinh, Tan, }; use candle_core::{DType, Device, Tensor}; @@ -361,6 +361,7 @@ unary_nif!(negate, neg); unary_nif!(abs); unary_nif!(ceil); unary_nif!(cos); +unary_nif!(erf); unary_nif!(exp); unary_nif!(floor); unary_nif!(round); @@ -378,6 +379,7 @@ custom_unary_nif!(atanh, Atanh); custom_unary_nif!(bitwise_not, BitNot); custom_unary_nif!(cbrt, Cbrt); custom_unary_nif!(cosh, Cosh); +custom_unary_nif!(erfc, Erfc); custom_unary_nif!(erf_inv, ErfInv); custom_unary_nif!(is_infinity, IsInf); custom_unary_nif!(is_nan, IsNan); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index a2476cacca..c418783341 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1065,6 +1065,22 @@ defmodule CandlexTest do )) end + test "erf" do + Nx.erf(1.0) + |> assert_close(t(0.8427007794380188)) + + Nx.erf(t([1.0, 2, 3])) + |> assert_close(t([0.8427007794380188, 0.9953222870826721, 0.9999778866767883])) + end + + test "erfc" do + Nx.erfc(1.0) + |> assert_close(t(0.15729920566082)) + + Nx.erfc(t([1.0, 2, 3])) + |> assert_close(t([0.15729920566082, 0.004677734803408384, 2.2090496713644825e-5])) + end + test "erf_inv" do Nx.erf_inv(0.10000000149011612) |> assert_close(t(0.08885598927736282)) From b274ae3182ec0f0539c9449b865a126be5db564a Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Thu, 5 Oct 2023 17:20:38 -0300 Subject: [PATCH 170/185] candlex expm1 --- candlex/lib/candlex/backend.ex | 1 + candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/kernels/custom_unary.cu | 4 ++++ candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 1 + candlex/native/candlex/src/tensors.rs | 4 +++- candlex/test/candlex_test.exs | 9 +++++++++ 7 files changed, 20 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 7f98dd757a..131c03ac98 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -302,6 +302,7 @@ defmodule Candlex.Backend do :erfc, :erf_inv, :exp, + :expm1, :floor, :is_infinity, :is_nan, diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 256716e3da..5f6a93002f 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -43,6 +43,7 @@ defmodule Candlex.Native do :erfc, :erf_inv, :exp, + :expm1, :floor, :is_infinity, :is_nan, diff --git a/candlex/native/candlex/src/kernels/custom_unary.cu b/candlex/native/candlex/src/kernels/custom_unary.cu index 24d101b408..bc5b737598 100644 --- a/candlex/native/candlex/src/kernels/custom_unary.cu +++ b/candlex/native/candlex/src/kernels/custom_unary.cu @@ -32,6 +32,8 @@ DEVICE_FN_FLOAT_WRAPPER(erfinv) DEVICE_FN_DOUBLE_WRAPPER(erfinv) DEVICE_FN_FLOAT_WRAPPER(exp) DEVICE_FN_DOUBLE_WRAPPER(exp) +DEVICE_FN_FLOAT_WRAPPER(expm1) +DEVICE_FN_DOUBLE_WRAPPER(expm1) DEVICE_FN_FLOAT_WRAPPER(log1p) DEVICE_FN_DOUBLE_WRAPPER(log1p) DEVICE_FN_FLOAT_WRAPPER(sinh) @@ -90,6 +92,8 @@ CUSTOM_UNARY_OP(float, erfc_f32, erfcg(x)) CUSTOM_UNARY_OP(double, erfc_f64, erfcg(x)) CUSTOM_UNARY_OP(float, erf_inv_f32, erfinvg(x)) CUSTOM_UNARY_OP(double, erf_inv_f64, erfinvg(x)) +CUSTOM_UNARY_OP(float, expm1_f32, expm1g(x)) +CUSTOM_UNARY_OP(double, expm1_f64, expm1g(x)) CUSTOM_UNARY_OP(float, ln_1p_f32, log1pg(x)) CUSTOM_UNARY_OP(double, ln_1p_f64, log1pg(x)) CUSTOM_UNARY_OP(float, sigmoid_f32, 1.0 / (1.0 + expg(-x))) diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 72cdca7b1c..21b6c2a9ad 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -82,6 +82,7 @@ rustler::init! { tensors::erfc, tensors::erf_inv, tensors::exp, + tensors::expm1, tensors::floor, tensors::is_infinity, tensors::is_nan, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 389326edd2..ebec9c40e5 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -412,6 +412,7 @@ custom_unary_op!(Cbrt, "cbrt", |v| v.cbrt(), (BF16, F16, F32, F64)); custom_unary_op!(Cosh, "cosh", |v| v.cosh(), (BF16, F16, F32, F64)); custom_unary_op!(Erfc, "erfc", |v| erfc(v), (BF16, F16, F32, F64)); custom_unary_op!(ErfInv, "erf_inv", |v| erf_inv(v), (BF16, F16, F32, F64)); +custom_unary_op!(Expm1, "expm1", |v| v.exp_m1(), (BF16, F16, F32, F64)); custom_unary_op!(Log1p, "ln_1p", |v| v.ln_1p(), (BF16, F16, F32, F64)); custom_unary_op!(Sigmoid, "sigmoid", |v| 1. / (1. + (-v).exp()), (F32, F64)); custom_unary_op!(Sinh, "sinh", |v| v.sinh(), (BF16, F16, F32, F64)); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 86d338290c..eeffe5dce9 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -2,7 +2,8 @@ use crate::atoms; use crate::error::CandlexError; use crate::ops::{ Acos, Acosh, Asin, Asinh, Atan, Atanh, BitAnd, BitNot, BitOr, BitXor, Cbrt, Cosh, ErfInv, Erfc, - IsInf, IsNan, Log1p, LogicalAnd, LogicalOr, LogicalXor, Pow, Shl, Shr, Sigmoid, Sinh, Tan, + Expm1, IsInf, IsNan, Log1p, LogicalAnd, LogicalOr, LogicalXor, Pow, Shl, Shr, Sigmoid, Sinh, + Tan, }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -381,6 +382,7 @@ custom_unary_nif!(cbrt, Cbrt); custom_unary_nif!(cosh, Cosh); custom_unary_nif!(erfc, Erfc); custom_unary_nif!(erf_inv, ErfInv); +custom_unary_nif!(expm1, Expm1); custom_unary_nif!(is_infinity, IsInf); custom_unary_nif!(is_nan, IsNan); custom_unary_nif!(log1p, Log1p); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index c418783341..d8b1550ba7 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -580,6 +580,15 @@ defmodule CandlexTest do |> assert_equal(t([2.7182817459106445, 7.389056205749512, 20.08553695678711])) end + test "expm1" do + Nx.expm1(1.0) + |> assert_close(t(1.718281865119934)) + + t([1.0, 2, 3]) + |> Nx.expm1() + |> assert_close(t([1.718281865119934, 6.389056205749512, 19.08553695678711])) + end + test "cos" do Nx.cos(1.0) |> assert_close(t(0.5403022766113281)) From 9d009ac8c6cde4fef920df1d4844dc09974f6380 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 6 Oct 2023 10:24:33 -0300 Subject: [PATCH 171/185] candlex remainder --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 1 + .../candlex/src/kernels/custom_binary.cu | 4 ++ candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 6 +++ candlex/native/candlex/src/tensors.rs | 5 ++- candlex/test/candlex_test.exs | 42 +++++++++++++++++++ 7 files changed, 58 insertions(+), 3 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 131c03ac98..c982c70002 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -238,7 +238,7 @@ defmodule Candlex.Backend do end end - for op <- [:pow] do + for op <- [:pow, :remainder] do @impl true def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do {left, right} = maybe_upcast(left, right) diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 5f6a93002f..13bc3c9bab 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -83,6 +83,7 @@ defmodule Candlex.Native do :multiply, :not_equal, :pow, + :remainder, :right_shift, :subtract ] do diff --git a/candlex/native/candlex/src/kernels/custom_binary.cu b/candlex/native/candlex/src/kernels/custom_binary.cu index 06ddaa099e..1e9d950822 100644 --- a/candlex/native/candlex/src/kernels/custom_binary.cu +++ b/candlex/native/candlex/src/kernels/custom_binary.cu @@ -10,6 +10,8 @@ DEVICE_FN_FLOAT_WRAPPER(pow) DEVICE_FN_DOUBLE_WRAPPER(pow) +DEVICE_FN_FLOAT_WRAPPER(remainder) +DEVICE_FN_DOUBLE_WRAPPER(remainder) #define CUSTOM_BINARY_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, FUNC) \ extern "C" __global__ void FN_NAME( \ @@ -86,6 +88,8 @@ CUSTOM_BINARY_OP(uint32_t, bit_xor_u32, x ^ y) CUSTOM_BINARY_OP(int64_t, bit_xor_i64, x ^ y) CUSTOM_BINARY_OP(float, pow_f32, powg(x, y)) CUSTOM_BINARY_OP(double, pow_f64, powg(x, y)) +CUSTOM_BINARY_OP(float, remainder_f32, remainderg(x, y)) +CUSTOM_BINARY_OP(double, remainder_f64, remainderg(x, y)) CUSTOM_BINARY_OP(uint32_t, shl_u32, x << y) CUSTOM_BINARY_OP(int64_t, shl_i64, x << y) CUSTOM_BINARY_OP(uint32_t, shr_u32, x >> y) diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 21b6c2a9ad..88dd5f48a0 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -29,6 +29,7 @@ rustler::init! { tensors::subtract, tensors::multiply, tensors::divide, + tensors::remainder, tensors::pow, tensors::max, tensors::min, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index ebec9c40e5..c47be5edfe 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -424,6 +424,12 @@ custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2, (U32, I64)); custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2, (U32, I64)); custom_binary_op!(BitXor, "bit_xor", |v1, v2| v1 ^ v2, (U32, I64)); custom_binary_op!(Pow, "pow", |v1, v2| v1.powf(v2), (F32, F64)); +custom_binary_op!( + Remainder, + "remainder", + |v1, v2| v1 % v2, + (U8, I64, F32, F64) +); custom_binary_op!(Shl, "shl", |v1, v2| v1 << v2, (U32, I64)); custom_binary_op!(Shr, "shr", |v1, v2| v1 >> v2, (U32, I64)); custom_binary_bool_op!( diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index eeffe5dce9..69aa0ff80b 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -2,8 +2,8 @@ use crate::atoms; use crate::error::CandlexError; use crate::ops::{ Acos, Acosh, Asin, Asinh, Atan, Atanh, BitAnd, BitNot, BitOr, BitXor, Cbrt, Cosh, ErfInv, Erfc, - Expm1, IsInf, IsNan, Log1p, LogicalAnd, LogicalOr, LogicalXor, Pow, Shl, Shr, Sigmoid, Sinh, - Tan, + Expm1, IsInf, IsNan, Log1p, LogicalAnd, LogicalOr, LogicalXor, Pow, Remainder, Shl, Shr, + Sigmoid, Sinh, Tan, }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -412,6 +412,7 @@ custom_binary_nif!(logical_or, LogicalOr); custom_binary_nif!(logical_xor, LogicalXor); custom_binary_nif!(pow, Pow); custom_binary_nif!(right_shift, Shr); +custom_binary_nif!(remainder, Remainder); fn tuple_to_vec(term: Term) -> Result, rustler::Error> { Ok(rustler::types::tuple::get_tuple(term)? diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index d8b1550ba7..521df3df5a 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -214,6 +214,48 @@ defmodule CandlexTest do )) end + test "remainder" do + Nx.remainder(1, 2) + |> assert_equal(t(1)) + + t([1, 2, 3]) + |> Nx.remainder(2) + |> assert_equal(t([1, 0, 1])) + + 2 + |> Nx.remainder(t([1.0, 2.0, 3.0])) + |> assert_equal(t([0.0, 0.0, 2.0])) + + t([[10], [20]], names: [:x, :y]) + |> Nx.remainder(t([[3, 4]], names: [nil, :y])) + |> assert_equal(t( + [ + [1, 2], + [2, 0] + ] + )) + + left = t(-11) + right = t(10, type: :u8) + + Nx.remainder(left, right) + |> assert_equal(t(-1)) + + left + |> Nx.add(t(20)) + |> Nx.remainder(right) + |> assert_equal(t(9)) + + positive_left = t(9, type: :u8) + Nx.remainder(positive_left, right) + |> assert_equal(t(9)) + + positive_left + |> Nx.add(Nx.tensor(20, type: :u8)) + |> Nx.remainder(right) + |> assert_equal(t(9)) + end + test "broadcast" do Nx.broadcast(1, {1, 2, 3}) |> assert_equal(t([[[1, 1, 1], [1, 1, 1]]])) From ba264f1dc9333be1e824748cdb5c9be8b547c840 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 6 Oct 2023 10:32:54 -0300 Subject: [PATCH 172/185] candlex atan2 --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 1 + .../candlex/src/kernels/custom_binary.cu | 4 ++++ candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 1 + candlex/native/candlex/src/tensors.rs | 7 +++--- candlex/test/candlex_test.exs | 22 +++++++++++++++++++ 7 files changed, 34 insertions(+), 4 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index c982c70002..bd03e84ca9 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -238,7 +238,7 @@ defmodule Candlex.Backend do end end - for op <- [:pow, :remainder] do + for op <- [:atan2, :pow, :remainder] do @impl true def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do {left, right} = maybe_upcast(left, right) diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 13bc3c9bab..9933676f96 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -64,6 +64,7 @@ defmodule Candlex.Native do for op <- [ :add, + :atan2, :bitwise_and, :bitwise_or, :bitwise_xor, diff --git a/candlex/native/candlex/src/kernels/custom_binary.cu b/candlex/native/candlex/src/kernels/custom_binary.cu index 1e9d950822..9de163cc43 100644 --- a/candlex/native/candlex/src/kernels/custom_binary.cu +++ b/candlex/native/candlex/src/kernels/custom_binary.cu @@ -8,6 +8,8 @@ #define DEVICE_FN_DOUBLE_WRAPPER(FN_NAME) \ __device__ __forceinline__ double FN_NAME##g(double a, double b) { return FN_NAME(a, b); } +DEVICE_FN_FLOAT_WRAPPER(atan2) +DEVICE_FN_DOUBLE_WRAPPER(atan2) DEVICE_FN_FLOAT_WRAPPER(pow) DEVICE_FN_DOUBLE_WRAPPER(pow) DEVICE_FN_FLOAT_WRAPPER(remainder) @@ -80,6 +82,8 @@ extern "C" __global__ void FN_NAME( \ #define CUSTOM_BINARY_OP(TYPENAME, FN_NAME, FUNC) \ CUSTOM_BINARY_OP_OUT(TYPENAME, TYPENAME, FN_NAME, FUNC) +CUSTOM_BINARY_OP(float, atan2_f32, atan2g(x, y)) +CUSTOM_BINARY_OP(double, atan2_f64, atan2g(x, y)) CUSTOM_BINARY_OP(uint32_t, bit_and_u32, x & y) CUSTOM_BINARY_OP(int64_t, bit_and_i64, x & y) CUSTOM_BINARY_OP(uint32_t, bit_or_u32, x | y) diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 88dd5f48a0..8506789512 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -26,6 +26,7 @@ rustler::init! { tensors::from_binary, tensors::to_binary, tensors::add, + tensors::atan2, tensors::subtract, tensors::multiply, tensors::divide, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index c47be5edfe..81dff468c4 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -423,6 +423,7 @@ custom_unary_bool_op!(IsNan, "is_nan", is_nan, (F32, F64)); custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2, (U32, I64)); custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2, (U32, I64)); custom_binary_op!(BitXor, "bit_xor", |v1, v2| v1 ^ v2, (U32, I64)); +custom_binary_op!(Atan2, "atan2", |v1, v2| v1.atan2(v2), (F32, F64)); custom_binary_op!(Pow, "pow", |v1, v2| v1.powf(v2), (F32, F64)); custom_binary_op!( Remainder, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 69aa0ff80b..676f73b2f8 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -1,9 +1,9 @@ use crate::atoms; use crate::error::CandlexError; use crate::ops::{ - Acos, Acosh, Asin, Asinh, Atan, Atanh, BitAnd, BitNot, BitOr, BitXor, Cbrt, Cosh, ErfInv, Erfc, - Expm1, IsInf, IsNan, Log1p, LogicalAnd, LogicalOr, LogicalXor, Pow, Remainder, Shl, Shr, - Sigmoid, Sinh, Tan, + Acos, Acosh, Asin, Asinh, Atan, Atan2, Atanh, BitAnd, BitNot, BitOr, BitXor, Cbrt, Cosh, + ErfInv, Erfc, Expm1, IsInf, IsNan, Log1p, LogicalAnd, LogicalOr, LogicalXor, Pow, Remainder, + Shl, Shr, Sigmoid, Sinh, Tan, }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -403,6 +403,7 @@ binary_nif!(less, lt); binary_nif!(less_equal, le); binary_nif!(matmul, broadcast_matmul); +custom_binary_nif!(atan2, Atan2); custom_binary_nif!(bitwise_and, BitAnd); custom_binary_nif!(bitwise_or, BitOr); custom_binary_nif!(bitwise_xor, BitXor); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 521df3df5a..abb939a6b0 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -256,6 +256,28 @@ defmodule CandlexTest do |> assert_equal(t(9)) end + test "atan2" do + Nx.atan2(1.0, 2.0) + |> assert_close(t(0.46364760398864746)) + + t([1.0, 2, 3]) + |> Nx.atan2(1) + |> assert_close(t([0.7853981852531433, 1.1071487665176392, 1.249045729637146])) + + 1.0 + |> Nx.atan2(t([1.0, 2.0, 3.0])) + |> assert_close(t([0.7853981852531433, 0.46364760398864746, 0.32175055146217346])) + + t([[-0.0], [0.0]], type: :f64) + |> Nx.atan2(t([-0.0, 0.0], type: :f64)) + |> assert_close(t( + [ + [-3.141592653589793, -0.0], + [3.141592653589793, 0.0] + ] + )) + end + test "broadcast" do Nx.broadcast(1, {1, 2, 3}) |> assert_equal(t([[[1, 1, 1], [1, 1, 1]]])) From 52e980648ad3acf41b4fc975532e8e8b5402c43d Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 6 Oct 2023 11:21:48 -0300 Subject: [PATCH 173/185] candlex gather for tensor of rank=1 --- candlex/lib/candlex/backend.ex | 13 +++++++++++++ candlex/test/candlex_test.exs | 23 +++++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index bd03e84ca9..0fb79c1e9a 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -330,6 +330,19 @@ defmodule Candlex.Backend do # Indexed + @impl true + def gather(%T{} = out, %T{shape: {_}} = tensor, %T{} = indices) do + tensor + |> from_nx() + |> Native.gather(from_nx(Nx.flatten(indices)), 0) + |> unwrap!() + |> to_nx(out) + end + + def gather(%T{} = _out, %T{} = _tensor, %T{} = _indices) do + raise("unsupported gather for tensor of rank greater than 1") + end + @impl true def put_slice(%T{} = out, %T{} = t, [start] = _start_indices, slice) do t diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index abb939a6b0..a13f04c900 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1523,6 +1523,29 @@ defmodule CandlexTest do )) end + test "gather" do + t([1, 2, 3, 4]) + |> Nx.gather(t([[3], [1], [2]])) + |> assert_equal(t([4, 2, 3])) + + # t([[1, 2], [3, 4]]) + # |> Nx.gather(t([[1, 1], [0, 1], [1, 0]])) + # |> assert_equal(t([4, 2, 3])) + + # t([[1, 2], [3, 4]]) + # |> Nx.gather(t([[[1, 1], [0, 0]], [[1, 0], [0, 1]]])) + # |> assert_equal(t( + # [ + # [4, 1], + # [3, 2] + # ] + # )) + + # t([[[1, 2], [11, 12]], [[101, 102], [111, 112]]]) + # |> Nx.gather(t([[0, 0, 0], [0, 1, 1], [1, 1, 1]])) + # |> assert_equal(t([1, 12, 112])) + end + test "transpose" do t(1) |> Nx.transpose() From c0b8df6010ab9d9f0a31ea7848eabc415aea8496 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 6 Oct 2023 11:45:53 -0300 Subject: [PATCH 174/185] candlex indexed_add for tensor of rank=1 --- candlex/lib/candlex/backend.ex | 15 ++++++++++++++ candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 5 +++++ candlex/test/candlex_test.exs | 28 +++++++++++++++++++++++++++ 5 files changed, 50 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 0fb79c1e9a..60ccb44696 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -343,6 +343,21 @@ defmodule Candlex.Backend do raise("unsupported gather for tensor of rank greater than 1") end + @impl true + def indexed_add(%T{} = out, %T{shape: {_}} = tensor, %T{} = indices, %T{} = updates) do + {tensor, updates} = maybe_upcast(tensor, updates) + + tensor + |> from_nx() + |> Native.index_add(from_nx(Nx.flatten(indices)), from_nx(updates), 0) + |> unwrap!() + |> to_nx(out) + end + + def indexed_add(%T{} = _out, %T{} = _tensor, %T{} = _indices, %T{} = _updates) do + raise("unsupported indexed_add for tensor of rank greater than 1") + end + @impl true def put_slice(%T{} = out, %T{} = t, [start] = _start_indices, slice) do t diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 9933676f96..7b16f5cf56 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -11,6 +11,7 @@ defmodule Candlex.Native do def narrow(_tensor, _dim, _start, _length), do: error() def gather(_tensor, _indexes, _dim), do: error() def index_select(_tensor, _indexes, _dim), do: error() + def index_add(_tensor, _indexes, _source, _dim), do: error() def chunk(_tensor, _num_chunks), do: error() def squeeze(_tensor, _dim), do: error() def arange(_start, _end, _dtype, _shape, _device), do: error() diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 8506789512..1d5ee6cedf 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -52,6 +52,7 @@ rustler::init! { tensors::narrow, tensors::gather, tensors::index_select, + tensors::index_add, tensors::chunk, tensors::squeeze, tensors::clamp, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 676f73b2f8..ae8f3930e3 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -97,6 +97,11 @@ pub fn index_select(t: ExTensor, indexes: ExTensor, dim: usize) -> Result Result { + Ok(ExTensor::new(t.index_add(indexes.deref(), source.deref(), dim)?)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn chunk(t: ExTensor, num_chunks: usize) -> Result, CandlexError> { Ok(t.chunk(num_chunks, 0)? diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index a13f04c900..cdb441936a 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1546,6 +1546,34 @@ defmodule CandlexTest do # |> assert_equal(t([1, 12, 112])) end + test "indexed_add" do + t([1.0]) + |> Nx.indexed_add(t([[0], [0]]), t([1, 1])) + |> assert_equal(t([3.0])) + + t([1]) + |> Nx.indexed_add(t([[0], [0]]), t([1.0, 1.0])) + |> assert_equal(t([3.0])) + + t([1], type: :u8) + |> Nx.indexed_add(t([[0], [0]]), t([1, 1], type: :s64)) + |> assert_equal(t([3])) + + # Nx.iota({1, 2, 3}) + # |> Nx.indexed_add( + # t([[0, 0, 0], [0, 1, 1], [0, 0, 0], [0, 0, 2], [0, 1, 2]]), + # t([1, 3, 1, -2, 5]) + # ) + # |> assert_equal(t( + # [ + # [ + # [2, 1, 0], + # [3, 7, 10] + # ] + # ] + # )) + end + test "transpose" do t(1) |> Nx.transpose() From 12c011fe69277c9849f0945edb109ebaef3bc43c Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 6 Oct 2023 12:00:15 -0300 Subject: [PATCH 175/185] candlex reduce_min --- candlex/lib/candlex/backend.ex | 24 +++++++++++++++++++ candlex/lib/candlex/native.ex | 2 +- candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 11 +++++++++ candlex/test/candlex_test.exs | 34 +++++++++++++++++++++++++++ 5 files changed, 71 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 60ccb44696..c6119bb6b0 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -185,6 +185,30 @@ defmodule Candlex.Backend do |> to_nx(out) end + @impl true + def reduce_min(%T{} = out, %T{shape: {}} = tensor, _opts) do + out + |> from_binary(to_binary(tensor), []) + end + + def reduce_min(%T{} = out, %T{} = tensor, opts) do + axis = + case opts[:axes] do + nil -> 0 + [] -> 0 + [axis] -> axis + axes -> raise "doesn't support axes option with more than 1 axis, '#{inspect(axes)}'" + end + + keep_axis = opts[:keep_axes] || false + + tensor + |> from_nx() + |> Native.reduce_min(axis, keep_axis) + |> unwrap!() + |> to_nx(out) + end + # Element-wise @impl true diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 7b16f5cf56..45ba716bea 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -95,7 +95,7 @@ defmodule Candlex.Native do def sum(_tensor, _dims, _keep_dims), do: error() def permute(_tensor, _dims), do: error() - for op <- [:argmax, :argmin, :reduce_max] do + for op <- [:argmax, :argmin, :reduce_max, :reduce_min] do def unquote(op)(_tensor, _dim, _keep_dim), do: error() end diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 1d5ee6cedf..6fcef75cd9 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -47,6 +47,7 @@ rustler::init! { tensors::argmax, tensors::argmin, tensors::reduce_max, + tensors::reduce_min, tensors::negate, tensors::where_cond, tensors::narrow, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index ae8f3930e3..c64335df78 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -200,6 +200,17 @@ pub fn reduce_max( Ok(ExTensor::new(t)) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn reduce_min(ex_tensor: ExTensor, dim: usize, keep_dim: bool) -> Result { + let t = if keep_dim { + ex_tensor.min_keepdim(dim)? + } else { + ex_tensor.min(dim)? + }; + + Ok(ExTensor::new(t)) +} + #[rustler::nif(schedule = "DirtyCpu")] pub fn sum( ex_tensor: ExTensor, diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index cdb441936a..84448b1393 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1485,6 +1485,40 @@ defmodule CandlexTest do # )) end + test "reduce_min" do + Nx.reduce_min(t(42)) + |> assert_equal(t(42)) + + Nx.reduce_min(t(42.0)) + |> assert_equal(t(42.0)) + + Nx.reduce_min(t([1, 2, 3])) + |> assert_equal(t(1)) + + t([[3, 1, 4], [2, 1, 1]], names: [:x, :y]) + |> Nx.reduce_min(axes: [:x]) + |> assert_equal(t([2, 1, 1])) + + t([[3, 1, 4], [2, 1, 1]], names: [:x, :y]) + |> Nx.reduce_min(axes: [:y]) + |> assert_equal(t([1, 1])) + + # t([[[1, 2], [4, 5]], [[2, 4], [3, 8]]], names: [:x, :y, :z]) + # |> Nx.reduce_min(axes: [:x, :z]) + # |> assert_equal(t([1, 3])) + + # t([[[1, 2], [4, 5]], [[2, 4], [3, 8]]], names: [:x, :y, :z]) + # |> Nx.reduce_min(axes: [:x, :z], keep_axes: true) + # |> assert_equal(t( + # [ + # [ + # [1], + # [3] + # ] + # ] + # )) + end + test "take_along_axis" do t([[1, 2, 3], [4, 5, 6]]) |> Nx.take_along_axis( From ec17e7dfce1c926b984679cc02b80272e81dcee5 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 6 Oct 2023 13:09:37 -0300 Subject: [PATCH 176/185] cargo fmt --- candlex/native/candlex/src/tensors.rs | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index c64335df78..b84d5edc68 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -98,8 +98,17 @@ pub fn index_select(t: ExTensor, indexes: ExTensor, dim: usize) -> Result Result { - Ok(ExTensor::new(t.index_add(indexes.deref(), source.deref(), dim)?)) +pub fn index_add( + t: ExTensor, + indexes: ExTensor, + source: ExTensor, + dim: usize, +) -> Result { + Ok(ExTensor::new(t.index_add( + indexes.deref(), + source.deref(), + dim, + )?)) } #[rustler::nif(schedule = "DirtyCpu")] @@ -201,7 +210,11 @@ pub fn reduce_max( } #[rustler::nif(schedule = "DirtyCpu")] -pub fn reduce_min(ex_tensor: ExTensor, dim: usize, keep_dim: bool) -> Result { +pub fn reduce_min( + ex_tensor: ExTensor, + dim: usize, + keep_dim: bool, +) -> Result { let t = if keep_dim { ex_tensor.min_keepdim(dim)? } else { From 15888bab4e8188685d749ef68cbfcd726fe32b36 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 6 Oct 2023 13:09:56 -0300 Subject: [PATCH 177/185] candlex quotient --- candlex/lib/candlex/backend.ex | 2 +- candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/tensors.rs | 1 + candlex/test/candlex_test.exs | 40 +++++++++++++++++++++++++++ 5 files changed, 44 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index c6119bb6b0..672986f027 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -262,7 +262,7 @@ defmodule Candlex.Backend do end end - for op <- [:atan2, :pow, :remainder] do + for op <- [:atan2, :pow, :quotient, :remainder] do @impl true def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do {left, right} = maybe_upcast(left, right) diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 45ba716bea..d8b2668f99 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -85,6 +85,7 @@ defmodule Candlex.Native do :multiply, :not_equal, :pow, + :quotient, :remainder, :right_shift, :subtract diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 6fcef75cd9..13103fb9e8 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -30,6 +30,7 @@ rustler::init! { tensors::subtract, tensors::multiply, tensors::divide, + tensors::quotient, tensors::remainder, tensors::pow, tensors::max, diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index b84d5edc68..9eba062e8d 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -422,6 +422,7 @@ custom_unary_nif!(tan, Tan); binary_nif!(add, broadcast_add); binary_nif!(subtract, broadcast_sub); binary_nif!(multiply, broadcast_mul); +binary_nif!(quotient, broadcast_div); binary_nif!(max, broadcast_maximum); binary_nif!(min, broadcast_minimum); binary_nif!(equal, eq); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 84448b1393..cf1a6ae697 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -256,6 +256,46 @@ defmodule CandlexTest do |> assert_equal(t(9)) end + test "quotient" do + Nx.quotient(11, 2) + |> assert_equal(t(5)) + + t([2, 4, 5]) + |> Nx.quotient(2) + |> assert_equal(t([1, 2, 2])) + + 10 + |> Nx.quotient(t([1, 2, 3])) + |> assert_equal(t([10, 5, 3])) + + t([[10, 20]], names: [nil, :y]) + |> Nx.quotient(t([[1], [2]], names: [:x, nil])) + |> assert_equal(t( + [ + [10, 20], + [5, 10] + ] + )) + + t([[10, 20]]) + |> Nx.quotient(t([[1], [2]])) + |> assert_equal(t( + [ + [10, 20], + [5, 10] + ] + )) + + t([[10, 20]], type: :u8) + |> Nx.quotient(t([[1], [2]], type: :u32)) + |> assert_equal(t( + [ + [10, 20], + [5, 10] + ] + )) + end + test "atan2" do Nx.atan2(1.0, 2.0) |> assert_close(t(0.46364760398864746)) From b687f175f9c41ce80e85ae832c5ec32d4356715e Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 6 Oct 2023 13:35:33 -0300 Subject: [PATCH 178/185] candlex sign --- candlex/lib/candlex/backend.ex | 1 + candlex/lib/candlex/native.ex | 1 + candlex/native/candlex/src/kernels/custom_unary.cu | 2 ++ candlex/native/candlex/src/lib.rs | 1 + candlex/native/candlex/src/ops.rs | 1 + candlex/native/candlex/src/tensors.rs | 3 ++- candlex/test/candlex_test.exs | 6 ++++++ 7 files changed, 14 insertions(+), 1 deletion(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 672986f027..46d8b22ab4 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -336,6 +336,7 @@ defmodule Candlex.Backend do :round, :rsqrt, :sigmoid, + :sign, :sin, :sinh, :sqrt, diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index d8b2668f99..4ccc8d3e65 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -54,6 +54,7 @@ defmodule Candlex.Native do :round, :rsqrt, :sigmoid, + :sign, :sin, :sinh, :sqrt, diff --git a/candlex/native/candlex/src/kernels/custom_unary.cu b/candlex/native/candlex/src/kernels/custom_unary.cu index bc5b737598..1b27dd43dd 100644 --- a/candlex/native/candlex/src/kernels/custom_unary.cu +++ b/candlex/native/candlex/src/kernels/custom_unary.cu @@ -98,6 +98,8 @@ CUSTOM_UNARY_OP(float, ln_1p_f32, log1pg(x)) CUSTOM_UNARY_OP(double, ln_1p_f64, log1pg(x)) CUSTOM_UNARY_OP(float, sigmoid_f32, 1.0 / (1.0 + expg(-x))) CUSTOM_UNARY_OP(double, sigmoid_f64, 1.0 / (1.0 + expg(-x))) +CUSTOM_UNARY_OP(float, sign_f32, signbit(x)) +CUSTOM_UNARY_OP(double, sign_f64, signbit(x)) CUSTOM_UNARY_OP(float, sinh_f32, sinhg(x)) CUSTOM_UNARY_OP(double, sinh_f64, sinhg(x)) CUSTOM_UNARY_OP(float, tan_f32, tang(x)) diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 13103fb9e8..8d2ae71514 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -81,6 +81,7 @@ rustler::init! { tensors::cos, tensors::cosh, tensors::sigmoid, + tensors::sign, tensors::sin, tensors::sinh, tensors::erf, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 81dff468c4..36450c556a 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -415,6 +415,7 @@ custom_unary_op!(ErfInv, "erf_inv", |v| erf_inv(v), (BF16, F16, F32, F64)); custom_unary_op!(Expm1, "expm1", |v| v.exp_m1(), (BF16, F16, F32, F64)); custom_unary_op!(Log1p, "ln_1p", |v| v.ln_1p(), (BF16, F16, F32, F64)); custom_unary_op!(Sigmoid, "sigmoid", |v| 1. / (1. + (-v).exp()), (F32, F64)); +custom_unary_op!(Sign, "sign", |v| v.signum(), (I64, BF16, F16, F32, F64)); custom_unary_op!(Sinh, "sinh", |v| v.sinh(), (BF16, F16, F32, F64)); custom_unary_op!(Tan, "tan", |v| v.tan(), (BF16, F16, F32, F64)); custom_unary_bool_op!(IsInf, "is_inf", is_infinite, (F32, F64)); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 9eba062e8d..3ea396a3c5 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -3,7 +3,7 @@ use crate::error::CandlexError; use crate::ops::{ Acos, Acosh, Asin, Asinh, Atan, Atan2, Atanh, BitAnd, BitNot, BitOr, BitXor, Cbrt, Cosh, ErfInv, Erfc, Expm1, IsInf, IsNan, Log1p, LogicalAnd, LogicalOr, LogicalXor, Pow, Remainder, - Shl, Shr, Sigmoid, Sinh, Tan, + Shl, Shr, Sigmoid, Sign, Sinh, Tan, }; use candle_core::{DType, Device, Tensor}; use half::{bf16, f16}; @@ -416,6 +416,7 @@ custom_unary_nif!(is_infinity, IsInf); custom_unary_nif!(is_nan, IsNan); custom_unary_nif!(log1p, Log1p); custom_unary_nif!(sigmoid, Sigmoid); +custom_unary_nif!(sign, Sign); custom_unary_nif!(sinh, Sinh); custom_unary_nif!(tan, Tan); diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index cf1a6ae697..deb033257c 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -296,6 +296,12 @@ defmodule CandlexTest do )) end + test "sign" do + t([-2, -1, 0, 1, 2]) + |> Nx.sign() + |> assert_equal(t([-1, -1, 0, 1, 1])) + end + test "atan2" do Nx.atan2(1.0, 2.0) |> assert_close(t(0.46364760398864746)) From 1d47f1e6c4439a98b77709332352d787be552c05 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 6 Oct 2023 14:37:21 -0300 Subject: [PATCH 179/185] fix remainder in cuda --- candlex/native/candlex/src/kernels/custom_binary.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/candlex/native/candlex/src/kernels/custom_binary.cu b/candlex/native/candlex/src/kernels/custom_binary.cu index 9de163cc43..42f0646275 100644 --- a/candlex/native/candlex/src/kernels/custom_binary.cu +++ b/candlex/native/candlex/src/kernels/custom_binary.cu @@ -10,10 +10,10 @@ DEVICE_FN_FLOAT_WRAPPER(atan2) DEVICE_FN_DOUBLE_WRAPPER(atan2) +DEVICE_FN_FLOAT_WRAPPER(fmod) +DEVICE_FN_DOUBLE_WRAPPER(fmod) DEVICE_FN_FLOAT_WRAPPER(pow) DEVICE_FN_DOUBLE_WRAPPER(pow) -DEVICE_FN_FLOAT_WRAPPER(remainder) -DEVICE_FN_DOUBLE_WRAPPER(remainder) #define CUSTOM_BINARY_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, FUNC) \ extern "C" __global__ void FN_NAME( \ @@ -92,8 +92,8 @@ CUSTOM_BINARY_OP(uint32_t, bit_xor_u32, x ^ y) CUSTOM_BINARY_OP(int64_t, bit_xor_i64, x ^ y) CUSTOM_BINARY_OP(float, pow_f32, powg(x, y)) CUSTOM_BINARY_OP(double, pow_f64, powg(x, y)) -CUSTOM_BINARY_OP(float, remainder_f32, remainderg(x, y)) -CUSTOM_BINARY_OP(double, remainder_f64, remainderg(x, y)) +CUSTOM_BINARY_OP(float, remainder_f32, fmodg(x, y)) +CUSTOM_BINARY_OP(double, remainder_f64, fmodg(x, y)) CUSTOM_BINARY_OP(uint32_t, shl_u32, x << y) CUSTOM_BINARY_OP(int64_t, shl_i64, x << y) CUSTOM_BINARY_OP(uint32_t, shr_u32, x >> y) From b4f8fe9e4fe5297ded2e069b00dfdce8faf77dea Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 6 Oct 2023 14:38:28 -0300 Subject: [PATCH 180/185] cuda sign i64 --- candlex/native/candlex/src/kernels/custom_unary.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/candlex/native/candlex/src/kernels/custom_unary.cu b/candlex/native/candlex/src/kernels/custom_unary.cu index 1b27dd43dd..40bc3aea52 100644 --- a/candlex/native/candlex/src/kernels/custom_unary.cu +++ b/candlex/native/candlex/src/kernels/custom_unary.cu @@ -98,6 +98,7 @@ CUSTOM_UNARY_OP(float, ln_1p_f32, log1pg(x)) CUSTOM_UNARY_OP(double, ln_1p_f64, log1pg(x)) CUSTOM_UNARY_OP(float, sigmoid_f32, 1.0 / (1.0 + expg(-x))) CUSTOM_UNARY_OP(double, sigmoid_f64, 1.0 / (1.0 + expg(-x))) +CUSTOM_UNARY_OP(int64_t, sign_i64, x > 0 ? 1 : (x == 0 ? 0 : -1)) CUSTOM_UNARY_OP(float, sign_f32, signbit(x)) CUSTOM_UNARY_OP(double, sign_f64, signbit(x)) CUSTOM_UNARY_OP(float, sinh_f32, sinhg(x)) From 1222d86b88114514bf35b37a6184ce285090e37b Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 6 Oct 2023 14:43:13 -0300 Subject: [PATCH 181/185] fix cuda remainder integer --- candlex/native/candlex/src/kernels/custom_binary.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/candlex/native/candlex/src/kernels/custom_binary.cu b/candlex/native/candlex/src/kernels/custom_binary.cu index 42f0646275..d9d9090d73 100644 --- a/candlex/native/candlex/src/kernels/custom_binary.cu +++ b/candlex/native/candlex/src/kernels/custom_binary.cu @@ -92,6 +92,8 @@ CUSTOM_BINARY_OP(uint32_t, bit_xor_u32, x ^ y) CUSTOM_BINARY_OP(int64_t, bit_xor_i64, x ^ y) CUSTOM_BINARY_OP(float, pow_f32, powg(x, y)) CUSTOM_BINARY_OP(double, pow_f64, powg(x, y)) +CUSTOM_BINARY_OP(uint8_t, remainder_u8, x % y) +CUSTOM_BINARY_OP(int64_t, remainder_i64, x % y) CUSTOM_BINARY_OP(float, remainder_f32, fmodg(x, y)) CUSTOM_BINARY_OP(double, remainder_f64, fmodg(x, y)) CUSTOM_BINARY_OP(uint32_t, shl_u32, x << y) From 45f8cf7be51809a2281bc67be79eca52554c6882 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 6 Oct 2023 15:25:21 -0300 Subject: [PATCH 182/185] candlex put_slice, support longer start_indices --- candlex/lib/candlex/backend.ex | 20 +++++++++----------- candlex/test/candlex_test.exs | 24 ++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 46d8b22ab4..705bdd3500 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -384,23 +384,21 @@ defmodule Candlex.Backend do end @impl true - def put_slice(%T{} = out, %T{} = t, [start] = _start_indices, slice) do - t - |> from_nx() - |> Native.slice_scatter(from_nx(slice), 0, Nx.to_number(start)) - |> unwrap!() - |> to_nx(out) - end + def put_slice(%T{} = out, %T{} = t, [_ | _] = start_indices, slice) do + [last_start_index | leading_start_indices] = Enum.reverse(start_indices) - def put_slice(%T{} = out, %T{} = t, [start1, start2] = _start_indices, slice) do - if Nx.equal(start1, 0) do + if Enum.all?(leading_start_indices, fn i -> Nx.equal(i, 0) end) do t |> from_nx() - |> Native.slice_scatter(from_nx(slice), 1, Nx.to_number(start2)) + |> Native.slice_scatter( + from_nx(slice), + length(start_indices) - 1, + Nx.to_number(last_start_index) + ) |> unwrap!() |> to_nx(out) else - raise "unsupported" + raise "put_slice only supports last start index not to be 0 for now" end end diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index deb033257c..b96c005e1c 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1800,6 +1800,30 @@ defmodule CandlexTest do # [4, 9, 10] # ] # )) + + t([ + [ + [1, 2], + [3, 4] + ], + [ + [4, 5], + [6, 7] + ] + ]) + |> Nx.put_slice([0, 0, 1], t([[[8], [9]], [[10], [11]]])) + |> assert_equal( + t([ + [ + [1, 8], + [3, 9] + ], + [ + [4, 10], + [6, 11] + ] + ]) + ) end test "pad" do From 94ac2c258abccebf46e2c0563ca72b26944b8608 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 6 Oct 2023 16:46:24 -0300 Subject: [PATCH 183/185] add error impl for missing backend callbacks --- candlex/lib/candlex/backend.ex | 63 ++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 705bdd3500..9a04f9fc4f 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -812,6 +812,69 @@ defmodule Candlex.Backend do |> Stream.map(&to_nx(&1, out)) end + for op <- [ + :cholesky, + :conjugate, + :count_leading_zeros, + :imag, + :population_count, + :real + ] do + @impl true + def unquote(op)(_out, _tensor) do + raise "unsupported Candlex.Backend.#{unquote(op)} function" + end + end + + for op <- [ + :any, + :argsort, + :eigh, + :fft, + :ifft, + :lu, + :product, + :qr, + :reverse, + :sort, + ] do + @impl true + def unquote(op)(_out, _tensor, _) do + raise "unsupported Candlex.Backend.#{unquote(op)} function" + end + end + + for op <- [ + :indexed_put, + :map, + :triangular_solve, + :window_max, + :window_min, + :window_product, + :window_sum + ] do + @impl true + def unquote(op)(_out, _tensor, _, _) do + raise "unsupported Candlex.Backend.#{unquote(op)} function" + end + end + + @impl true + def reduce(_out, _tensor, _, _, _) do + raise "unsupported Candlex.Backend.reduce function" + end + + for op <- [ + :window_reduce, + :window_scatter_max, + :window_scatter_min + ] do + @impl true + def unquote(op)(_out, _tensor, _, _, _, _) do + raise "unsupported Candlex.Backend.#{unquote(op)} function" + end + end + defp permute(native_tensor, permutation) do native_tensor |> Native.permute(permutation) From e00b9f3ca56324ac9bec1cb5eb0c574ecb667c39 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 6 Oct 2023 16:47:46 -0300 Subject: [PATCH 184/185] fix warning --- candlex/test/candlex_test.exs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index b96c005e1c..99db38c52e 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -2176,7 +2176,7 @@ defmodule CandlexTest do Nx.tensor(values, opts) end - defp check(value, opts \\ []) do + defp check(value, opts) do tensor = t(value, opts) tensor From 405a08182c793ac78f65bc00f0f2559222b04d86 Mon Sep 17 00:00:00 2001 From: Gonzalo <456459+grzuy@users.noreply.github.com> Date: Fri, 3 Nov 2023 11:20:59 -0300 Subject: [PATCH 185/185] update with upstream --- candlex/CHANGELOG.md | 14 + candlex/LICENSE | 177 +++++++ candlex/README.md | 72 ++- candlex/lib/candlex/backend.ex | 108 ++-- candlex/lib/candlex/native.ex | 24 +- candlex/mix.exs | 52 +- candlex/mix.lock | 11 +- candlex/native/candlex/.cargo/config.toml | 4 + candlex/native/candlex/Cargo.lock | 378 +++++++++----- candlex/native/candlex/Cargo.toml | 6 +- candlex/native/candlex/src/kernels.rs | 4 +- candlex/native/candlex/src/lib.rs | 4 + candlex/native/candlex/src/ops.rs | 4 +- candlex/native/candlex/src/tensors.rs | 92 +++- candlex/test/candlex_test.exs | 605 +++++++++++++--------- candlex/test/support/nx_case.ex | 45 ++ 16 files changed, 1145 insertions(+), 455 deletions(-) create mode 100644 candlex/CHANGELOG.md create mode 100644 candlex/LICENSE create mode 100644 candlex/test/support/nx_case.ex diff --git a/candlex/CHANGELOG.md b/candlex/CHANGELOG.md new file mode 100644 index 0000000000..25203fcdbf --- /dev/null +++ b/candlex/CHANGELOG.md @@ -0,0 +1,14 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.1.2] - 2023-10-30 + +### Added + +- Precompiled binaries for a few CPU targets. diff --git a/candlex/LICENSE b/candlex/LICENSE new file mode 100644 index 0000000000..f433b1a53f --- /dev/null +++ b/candlex/LICENSE @@ -0,0 +1,177 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS diff --git a/candlex/README.md b/candlex/README.md index 83509bb3cb..4805cf5baf 100644 --- a/candlex/README.md +++ b/candlex/README.md @@ -1,6 +1,10 @@ # Candlex -**TODO: Add description** +[![ci](https://github.com/mimiquate/candlex/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/mimiquate/candlex/actions?query=branch%3Amain) +[![Hex.pm](https://img.shields.io/hexpm/v/candlex.svg)](https://hex.pm/packages/candlex) +[![Docs](https://img.shields.io/badge/docs-gray.svg)](https://hexdocs.pm/candlex) + +An `Nx` [backend](https://hexdocs.pm/nx/Nx.html#module-backends) for [candle](https://huggingface.github.io/candle) machine learning minimalist framework ## Installation @@ -10,7 +14,7 @@ by adding `candlex` to your list of dependencies in `mix.exs`: ```elixir def deps do [ - {:candlex, "~> 0.1.0"} + {:candlex, "~> 0.1.2"} ] end ``` @@ -19,3 +23,67 @@ Documentation can be generated with [ExDoc](https://github.com/elixir-lang/ex_do and published on [HexDocs](https://hexdocs.pm). Once published, the docs can be found at . +## Usage + +Just configure Nx to default to Candlex backend in your configuration: + +```elixir +# Possibly config/runtime.exs + +config :nx, default_backend: Candlex.Backend +``` + +or in your scripts, precede all your Nx operations with: + +```elixir +Nx.default_backend(Candlex.Backend) +``` + +More details in [Nx backends](https://hexdocs.pm/nx/Nx.html#module-backends) + +#### `CANDLEX_NIF_BUILD` + +Defaults to `false`. If `true` the native binary is built locally, which may be useful +if no precompiled binary is available for your target environment. Once set, you +must run `mix deps.clean candlex --build` explicitly to force to recompile. +Building has a number of dependencies, see *Building from source* below. + +## Building from source + +To build the native binary locally you need to set `CANDLEX_NIF_BUILD=true`. +Keep in mind that the compilation usually takes time. + +You will need the following installed in your system for the compilation: + + * [Git](https://git-scm.com) for fetching candle-core source + * [Rust](https://www.rust-lang.org) with cargo to compile rustler NIFs + +## Releasing + +To publish a new version of this package: + +1. Update `@version` in `mix.exs` and `project-version` in `.github/workflows/binaries.yml`. +1. `git tag -s ` to create new signed tag. +1. `git push origin ` to push the tag. +1. Wait for the `binaries.yml` GitHub workflow to build all the NIF binaries. +1. `mix rustler_precompiled.download Candlex.Native --all --print` to generate binaries checksums locally. +1. `rm -r native/candlex/target` to leave out rust crate build artifacts from published elixir package. +1. `mix hex.build --unpack` to check the package includes the correct files. +1. Publish the release from draft in GitHub. +1. `mix hex.publish` to publish package to Hex.pm. + +## License + +Copyright 2023 Mimiquate + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex index 9a04f9fc4f..4e68eba143 100644 --- a/candlex/lib/candlex/backend.ex +++ b/candlex/lib/candlex/backend.ex @@ -118,9 +118,31 @@ defmodule Candlex.Backend do # Aggregates @impl true - def all(%T{} = out, %T{} = tensor, _opts) do - from_nx(tensor) - |> Native.all() + def all(%T{} = out, %T{} = tensor, opts) do + case opts[:axes] do + nil -> + from_nx(tensor) + |> Native.all() + + axes -> + from_nx(tensor) + |> Native.all_within_dims(axes, opts[:keep_axes]) + end + |> unwrap!() + |> to_nx(out) + end + + @impl true + def any(%T{} = out, %T{} = tensor, opts) do + case opts[:axes] do + nil -> + from_nx(tensor) + |> Native.any() + + axes -> + from_nx(tensor) + |> Native.any_within_dims(axes, opts[:keep_axes]) + end |> unwrap!() |> to_nx(out) end @@ -506,6 +528,25 @@ defmodule Candlex.Backend do end @impl true + def dot( + %T{type: _out_type} = out, + %T{shape: left_shape, type: _left_type} = left, + [left_axis] = _left_axes, + [] = _left_batched_axes, + %T{shape: right_shape, type: _right_type} = right, + [0] = _right_axes, + [] = _right_batched_axes + ) + when tuple_size(left_shape) >= 1 and tuple_size(right_shape) == 1 and + left_axis == tuple_size(left_shape) - 1 do + {left, right} = maybe_upcast(left, right) + + from_nx(left) + |> Native.dot(from_nx(right)) + |> unwrap!() + |> to_nx(out) + end + def dot( %T{type: _out_type} = out, %T{shape: left_shape, type: _left_type} = left, @@ -516,6 +557,8 @@ defmodule Candlex.Backend do [] = _right_batched_axes ) when tuple_size(left_shape) == 2 and tuple_size(right_shape) == 2 do + {left, right} = maybe_upcast(left, right) + Native.matmul( from_nx(left), from_nx(right) @@ -813,13 +856,13 @@ defmodule Candlex.Backend do end for op <- [ - :cholesky, - :conjugate, - :count_leading_zeros, - :imag, - :population_count, - :real - ] do + :cholesky, + :conjugate, + :count_leading_zeros, + :imag, + :population_count, + :real + ] do @impl true def unquote(op)(_out, _tensor) do raise "unsupported Candlex.Backend.#{unquote(op)} function" @@ -827,17 +870,16 @@ defmodule Candlex.Backend do end for op <- [ - :any, - :argsort, - :eigh, - :fft, - :ifft, - :lu, - :product, - :qr, - :reverse, - :sort, - ] do + :argsort, + :eigh, + :fft, + :ifft, + :lu, + :product, + :qr, + :reverse, + :sort + ] do @impl true def unquote(op)(_out, _tensor, _) do raise "unsupported Candlex.Backend.#{unquote(op)} function" @@ -845,14 +887,14 @@ defmodule Candlex.Backend do end for op <- [ - :indexed_put, - :map, - :triangular_solve, - :window_max, - :window_min, - :window_product, - :window_sum - ] do + :indexed_put, + :map, + :triangular_solve, + :window_max, + :window_min, + :window_product, + :window_sum + ] do @impl true def unquote(op)(_out, _tensor, _, _) do raise "unsupported Candlex.Backend.#{unquote(op)} function" @@ -865,10 +907,10 @@ defmodule Candlex.Backend do end for op <- [ - :window_reduce, - :window_scatter_max, - :window_scatter_min - ] do + :window_reduce, + :window_scatter_max, + :window_scatter_min + ] do @impl true def unquote(op)(_out, _tensor, _, _, _, _) do raise "unsupported Candlex.Backend.#{unquote(op)} function" diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex index 4ccc8d3e65..902bc89b99 100644 --- a/candlex/lib/candlex/native.ex +++ b/candlex/lib/candlex/native.ex @@ -1,12 +1,33 @@ defmodule Candlex.Native do @moduledoc false - use Rustler, otp_app: :candlex, features: Application.compile_env(:candlex, :crate_features, []) + mix_config = Mix.Project.config() + version = mix_config[:version] + source_url = mix_config[:package][:links]["GitHub"] + mode = if Mix.env() in [:dev, :test], do: :debug, else: :release + + use RustlerPrecompiled, + otp_app: :candlex, + features: Application.compile_env(:candlex, :crate_features, []), + base_url: "#{source_url}/releases/download/v#{version}", + force_build: System.get_env("CANDLEX_NIF_BUILD") in ["1", "true"], + mode: mode, + version: version, + nif_versions: ["2.16"], + targets: [ + "aarch64-apple-darwin", + "aarch64-unknown-linux-gnu", + "x86_64-apple-darwin", + "x86_64-unknown-linux-gnu" + ] # Rustler will override all the below stub functions with real NIFs def from_binary(_binary, _dtype, _shape, _device), do: error() def to_binary(_tensor), do: error() def all(_tensor), do: error() + def all_within_dims(_tensor, _dims, _keep_dims), do: error() + def any(_tensor), do: error() + def any_within_dims(_tensor, _dims, _keep_dims), do: error() def where_cond(_tensor, _on_true, _on_false), do: error() def narrow(_tensor, _dim, _start, _length), do: error() def gather(_tensor, _indexes, _dim), do: error() @@ -71,6 +92,7 @@ defmodule Candlex.Native do :bitwise_or, :bitwise_xor, :divide, + :dot, :equal, :greater, :greater_equal, diff --git a/candlex/mix.exs b/candlex/mix.exs index ffd08acef8..89774a2cd6 100644 --- a/candlex/mix.exs +++ b/candlex/mix.exs @@ -1,14 +1,21 @@ defmodule Candlex.MixProject do use Mix.Project + @description "An Nx backend for candle machine learning minimalist framework" + @source_url "https://github.com/mimiquate/candlex" + @version "0.1.2" + def project do [ app: :candlex, - version: "0.1.0", - elixir: "~> 1.15", + description: @description, + version: @version, + elixir: "~> 1.14", elixirc_paths: elixirc_paths(Mix.env()), start_permanent: Mix.env() == :prod, - deps: deps() + deps: deps(), + docs: docs(), + package: package() ] end @@ -25,8 +32,43 @@ defmodule Candlex.MixProject do # Run "mix help deps" to learn about dependencies. defp deps do [ - {:nx, path: "../nx"}, - {:rustler, "~> 0.29.1"} + # {:nx, "~> 0.6.2"}, + {:nx, git: "https://github.com/elixir-nx/nx", sparse: "nx"}, + {:rustler_precompiled, "~> 0.7.0"}, + + # Optional + {:rustler, "~> 0.29", optional: true}, + + # Dev + {:ex_doc, "~> 0.30.9", only: :dev, runtime: false} + ] + end + + defp docs do + [ + main: "Candlex", + source_url: @source_url, + source_ref: "v#{@version}" + ] + end + + defp package do + [ + files: [ + "lib", + "native", + "priv", + ".formatter.exs", + "mix.exs", + "CHANGELOG.md", + "README.md", + "LICENSE", + "checksum-*.exs" + ], + licenses: ["Apache-2.0"], + links: %{ + "GitHub" => @source_url + } ] end end diff --git a/candlex/mix.lock b/candlex/mix.lock index dd8959ea38..37478055ac 100644 --- a/candlex/mix.lock +++ b/candlex/mix.lock @@ -1,7 +1,16 @@ %{ + "castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"}, "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, + "earmark_parser": {:hex, :earmark_parser, "1.4.37", "2ad73550e27c8946648b06905a57e4d454e4d7229c2dafa72a0348c99d8be5f7", [:mix], [], "hexpm", "6b19783f2802f039806f375610faa22da130b8edc21209d0bff47918bb48360e"}, + "ex_doc": {:hex, :ex_doc, "0.30.9", "d691453495c47434c0f2052b08dd91cc32bc4e1a218f86884563448ee2502dd2", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "d7aaaf21e95dc5cddabf89063327e96867d00013963eadf2c6ad135506a8bc10"}, "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, - "rustler": {:hex, :rustler, "0.29.1", "880f20ae3027bd7945def6cea767f5257bc926f33ff50c0d5d5a5315883c084d", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "109497d701861bfcd26eb8f5801fe327a8eef304f56a5b63ef61151ff44ac9b6"}, + "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, + "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, + "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, + "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, + "nx": {:git, "https://github.com/elixir-nx/nx", "7706e8601e40916c02f8773df7802b3bfab43054", [sparse: "nx"]}, + "rustler": {:hex, :rustler, "0.30.0", "cefc49922132b072853fa9b0ca4dc2ffcb452f68fb73b779042b02d545e097fb", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "9ef1abb6a7dda35c47cfc649e6a5a61663af6cf842a55814a554a84607dee389"}, + "rustler_precompiled": {:hex, :rustler_precompiled, "0.7.0", "5d0834fc06dbc76dd1034482f17b1797df0dba9b491cef8bb045fcaca94bcade", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "fdf43a6835f4e4de5bfbc4c019bfb8c46d124bd4635fefa3e20d9a2bbbec1512"}, "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, "toml": {:hex, :toml, "0.7.0", "fbcd773caa937d0c7a02c301a1feea25612720ac3fa1ccb8bfd9d30d822911de", [:mix], [], "hexpm", "0690246a2478c1defd100b0c9b89b4ea280a22be9a7b313a8a058a2408a2fa70"}, } diff --git a/candlex/native/candlex/.cargo/config.toml b/candlex/native/candlex/.cargo/config.toml index 20f03f3d80..1602afb085 100644 --- a/candlex/native/candlex/.cargo/config.toml +++ b/candlex/native/candlex/.cargo/config.toml @@ -3,3 +3,7 @@ rustflags = [ "-C", "link-arg=-undefined", "-C", "link-arg=dynamic_lookup", ] + +# Provides a small build size, but takes more time to build. +[profile.release] +lto = true diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock index 78a18543a6..bc1f3aa8a7 100644 --- a/candlex/native/candlex/Cargo.lock +++ b/candlex/native/candlex/Cargo.lock @@ -2,6 +2,21 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + [[package]] name = "aho-corasick" version = "1.1.1" @@ -16,6 +31,9 @@ name = "anyhow" version = "1.0.75" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" +dependencies = [ + "backtrace", +] [[package]] name = "approx" @@ -32,6 +50,21 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "backtrace" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -43,6 +76,20 @@ name = "bytemuck" version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", +] [[package]] name = "byteorder" @@ -53,12 +100,12 @@ checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" [[package]] name = "candle-core" version = "0.3.0" -source = "git+https://github.com/huggingface/candle#8f7973958c55324a24f0c514e7ac6ded6681980f" +source = "git+https://github.com/huggingface/candle#4c967b9184834cd1e166dfdd6d88450d16bad8f2" dependencies = [ "byteorder", - "candle-gemm", "candle-kernels", "cudarc", + "gemm", "half", "memmap2", "num-traits", @@ -72,136 +119,12 @@ dependencies = [ "zip", ] -[[package]] -name = "candle-gemm" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef9b07a4b0ba1a304b44432006580980ddff9748c201261c279437e7b11bba68" -dependencies = [ - "candle-gemm-c32", - "candle-gemm-c64", - "candle-gemm-common", - "candle-gemm-f16", - "candle-gemm-f32", - "candle-gemm-f64", - "dyn-stack", - "lazy_static", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "rayon", - "seq-macro", -] - -[[package]] -name = "candle-gemm-c32" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f595241dad99811de285e029889f57c29dd98e33de7a8a6b881867b1488d7d4a" -dependencies = [ - "candle-gemm-common", - "dyn-stack", - "lazy_static", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "rayon", - "seq-macro", -] - -[[package]] -name = "candle-gemm-c64" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "648f22fd8f5a4f330e29d791845b514966421308a6a2b5fedb949ee07e54c77f" -dependencies = [ - "candle-gemm-common", - "dyn-stack", - "lazy_static", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "rayon", - "seq-macro", -] - -[[package]] -name = "candle-gemm-common" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e03c01b4ca3b9d71e4eb89e42946a08f8b0d2f1b861f7fa2ea0966233f1e0b08" -dependencies = [ - "dyn-stack", - "lazy_static", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "rayon", - "seq-macro", -] - -[[package]] -name = "candle-gemm-f16" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97f8af2a482131713d28a337abff6debf26c529afa1837caf2ba190909b2107c" -dependencies = [ - "candle-gemm-common", - "candle-gemm-f32", - "dyn-stack", - "half", - "lazy_static", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "rayon", - "seq-macro", -] - -[[package]] -name = "candle-gemm-f32" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "938927961e2f0c0a6064fcf3524ea3f7f455fe5708419532a6fea9aea1ab45ae" -dependencies = [ - "candle-gemm-common", - "dyn-stack", - "lazy_static", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "rayon", - "seq-macro", -] - -[[package]] -name = "candle-gemm-f64" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d192d7126e59b81ef4cf13cd9f194e6dbdc09171f65d0074d059dc009ac06775" -dependencies = [ - "candle-gemm-common", - "dyn-stack", - "lazy_static", - "num-complex", - "num-traits", - "paste", - "raw-cpuid", - "rayon", - "seq-macro", -] - [[package]] name = "candle-kernels" version = "0.3.0" -source = "git+https://github.com/huggingface/candle#8f7973958c55324a24f0c514e7ac6ded6681980f" +source = "git+https://github.com/huggingface/candle#4c967b9184834cd1e166dfdd6d88450d16bad8f2" dependencies = [ + "anyhow", "glob", "rayon", ] @@ -219,6 +142,15 @@ dependencies = [ "thiserror", ] +[[package]] +name = "cc" +version = "1.0.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "libc", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -284,9 +216,9 @@ dependencies = [ [[package]] name = "dyn-stack" -version = "0.9.1" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fe7f8d7bcc523381d3c437b82cf74805de3931de0da69309ae0fe1bdf7a256e" +checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b" dependencies = [ "bytemuck", "reborrow", @@ -298,6 +230,123 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +[[package]] +name = "gemm" +version = "0.16.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c90e866771c2adff6c4603108bb531ea36979b62a23b5e193b9311199182ccc9" +dependencies = [ + "dyn-stack", + "gemm-c32", + "gemm-c64", + "gemm-common", + "gemm-f16", + "gemm-f32", + "gemm-f64", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c32" +version = "0.16.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d5ffc8e9a442172ec223d4e536f93b59f1d2b34164b37e462e6886d76699e71" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c64" +version = "0.16.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45cd8172c8ee8beed35cdb7254bec535ddc83cde78d2eeefeae43cfd57fef3d8" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-common" +version = "0.16.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c022dd4893afa781d1ba9912796af63a9a91ac8d8fdadd449483bbabaf7c47c6" +dependencies = [ + "bytemuck", + "dyn-stack", + "half", + "num-complex", + "num-traits", + "once_cell", + "paste", + "pulp", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f16" +version = "0.16.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efb4c26de6ddd1125d966dffa48eb609448566a2bd2dd589ea2798f17f689b5" +dependencies = [ + "dyn-stack", + "gemm-common", + "gemm-f32", + "half", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f32" +version = "0.16.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a78fca5a0ef41202015e099e45e0bdda83e7f49318de3322d58c215a3fec334a" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-f64" +version = "0.16.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f18516a29ab14fa4b1714c93ffa36a89a65418f0a2fba702206d4b421061b48" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + [[package]] name = "getrandom" version = "0.2.10" @@ -309,6 +358,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "gimli" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" + [[package]] name = "glob" version = "0.3.1" @@ -321,6 +376,7 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" dependencies = [ + "bytemuck", "cfg-if", "crunchy", "num-traits", @@ -399,6 +455,15 @@ dependencies = [ "autocfg", ] +[[package]] +name = "miniz_oxide" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +dependencies = [ + "adler", +] + [[package]] name = "nalgebra" version = "0.29.0" @@ -434,6 +499,7 @@ version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" dependencies = [ + "bytemuck", "num-traits", ] @@ -460,9 +526,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.16" +version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f30b0abd723be7e2ffca1272140fac1a2f084c77ec3e123c192b66af1ee9e6c2" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", "libm", @@ -478,6 +544,21 @@ dependencies = [ "libc", ] +[[package]] +name = "object" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +dependencies = [ + "memchr", +] + +[[package]] +name = "once_cell" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" + [[package]] name = "paste" version = "1.0.14" @@ -499,6 +580,17 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "pulp" +version = "0.16.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcb0a2934e77dd9aabb33fedcd3bf6ea1d6d3cacd4c5bc07b9a916f7b101c6b7" +dependencies = [ + "bytemuck", + "libm", + "num-complex", +] + [[package]] name = "quote" version = "1.0.33" @@ -618,11 +710,17 @@ version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" +[[package]] +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + [[package]] name = "rustler" -version = "0.29.1" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0884cb623b9f43d3e2c51f9071c5e96a5acf3e6e6007866812884ff0cb983f1e" +checksum = "c4b4fea69e23de68c42c06769d6624d2d018da550c17244dd4b691f90ced4a7e" dependencies = [ "lazy_static", "rustler_codegen", @@ -631,9 +729,9 @@ dependencies = [ [[package]] name = "rustler_codegen" -version = "0.29.1" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50e277af754f2560cf4c4ebedb68c1a735292fb354505c6133e47ec406e699cf" +checksum = "406061bd07aaf052c344257afed4988c5ec8efe4d2352b4c2cf27ea7c8575b12" dependencies = [ "heck", "proc-macro2", @@ -787,18 +885,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.49" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4" +checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.49" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc" +checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", diff --git a/candlex/native/candlex/Cargo.toml b/candlex/native/candlex/Cargo.toml index 79f7cc2f06..306760b668 100644 --- a/candlex/native/candlex/Cargo.toml +++ b/candlex/native/candlex/Cargo.toml @@ -12,10 +12,10 @@ crate-type = ["cdylib"] [dependencies] candle-core = { git = "https://github.com/huggingface/candle" } half = "2.3.1" -num-traits = "0.2.16" -rustler = "0.29.1" +num-traits = "0.2.17" +rustler = { version = "0.30.0", default-features = false, features = ["derive", "nif_version_2_16"] } statrs = "0.16.0" -thiserror = "1.0.47" +thiserror = "1.0.50" [build-dependencies] anyhow = "1.0.75" diff --git a/candlex/native/candlex/src/kernels.rs b/candlex/native/candlex/src/kernels.rs index c6627efde7..13317b35c7 100644 --- a/candlex/native/candlex/src/kernels.rs +++ b/candlex/native/candlex/src/kernels.rs @@ -1,4 +1,4 @@ #[rustfmt::skip] -pub const CUSTOM_UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_unary.ptx")); -#[rustfmt::skip] pub const CUSTOM_BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_binary.ptx")); +#[rustfmt::skip] +pub const CUSTOM_UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_unary.ptx")); diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs index 8d2ae71514..ef72f434e8 100644 --- a/candlex/native/candlex/src/lib.rs +++ b/candlex/native/candlex/src/lib.rs @@ -42,6 +42,9 @@ rustler::init! { tensors::less, tensors::less_equal, tensors::all, + tensors::all_within_dims, + tensors::any, + tensors::any_within_dims, tensors::sum, tensors::dtype, tensors::t_shape, @@ -68,6 +71,7 @@ rustler::init! { tensors::permute, tensors::slice_scatter, tensors::pad_with_zeros, + tensors::dot, tensors::matmul, tensors::abs, tensors::acos, diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs index 36450c556a..98bc300f46 100644 --- a/candlex/native/candlex/src/ops.rs +++ b/candlex/native/candlex/src/ops.rs @@ -410,8 +410,8 @@ custom_unary_op!(Atanh, "atanh", |v| v.atanh(), (BF16, F16, F32, F64)); custom_unary_op!(BitNot, "bit_not", |v| !v, (U8, U32, I64)); custom_unary_op!(Cbrt, "cbrt", |v| v.cbrt(), (BF16, F16, F32, F64)); custom_unary_op!(Cosh, "cosh", |v| v.cosh(), (BF16, F16, F32, F64)); -custom_unary_op!(Erfc, "erfc", |v| erfc(v), (BF16, F16, F32, F64)); -custom_unary_op!(ErfInv, "erf_inv", |v| erf_inv(v), (BF16, F16, F32, F64)); +custom_unary_op!(Erfc, "erfc", erfc, (BF16, F16, F32, F64)); +custom_unary_op!(ErfInv, "erf_inv", erf_inv, (BF16, F16, F32, F64)); custom_unary_op!(Expm1, "expm1", |v| v.exp_m1(), (BF16, F16, F32, F64)); custom_unary_op!(Log1p, "ln_1p", |v| v.ln_1p(), (BF16, F16, F32, F64)); custom_unary_op!(Sigmoid, "sigmoid", |v| 1. / (1. + (-v).exp()), (F32, F64)); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs index 3ea396a3c5..78c94b413f 100644 --- a/candlex/native/candlex/src/tensors.rs +++ b/candlex/native/candlex/src/tensors.rs @@ -115,7 +115,7 @@ pub fn index_add( pub fn chunk(t: ExTensor, num_chunks: usize) -> Result, CandlexError> { Ok(t.chunk(num_chunks, 0)? .into_iter() - .map(|t| ExTensor::new(t)) + .map(ExTensor::new) .collect()) } @@ -154,22 +154,38 @@ pub fn arange( #[rustler::nif(schedule = "DirtyCpu")] pub fn all(ex_tensor: ExTensor) -> Result { - let device = ex_tensor.device(); - let t = ex_tensor.flatten_all()?; - let dims = t.shape().dims(); - let on_true = Tensor::ones(dims, DType::U8, device)?; - let on_false = Tensor::zeros(dims, DType::U8, device)?; - - let bool_scalar = match t - .where_cond(&on_true, &on_false)? - .min(0)? - .to_scalar::()? - { - 0 => 0u8, - _ => 1u8, - }; + Ok(ExTensor::new(_all( + &ex_tensor.flatten_all()?, + vec![0], + false, + )?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn all_within_dims( + ex_tensor: ExTensor, + dims: Vec, + keep_dims: bool, +) -> Result { + Ok(ExTensor::new(_all(ex_tensor.deref(), dims, keep_dims)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn any(ex_tensor: ExTensor) -> Result { + Ok(ExTensor::new(_any( + &ex_tensor.flatten_all()?, + vec![0], + false, + )?)) +} - Ok(ExTensor::new(Tensor::new(bool_scalar, device)?)) +#[rustler::nif(schedule = "DirtyCpu")] +pub fn any_within_dims( + ex_tensor: ExTensor, + dims: Vec, + keep_dims: bool, +) -> Result { + Ok(ExTensor::new(_any(ex_tensor.deref(), dims, keep_dims)?)) } #[rustler::nif(schedule = "DirtyCpu")] @@ -346,6 +362,14 @@ pub fn divide(left: ExTensor, right: ExTensor) -> Result )) } +#[rustler::nif(schedule = "DirtyCpu")] +pub fn dot(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new( + left.mul(&right.broadcast_as(left.shape())?)? + .sum(left.rank() - 1)?, + )) +} + macro_rules! unary_nif { ($nif_name:ident, $native_fn_name:ident) => { #[rustler::nif(schedule = "DirtyCpu")] @@ -446,11 +470,43 @@ custom_binary_nif!(pow, Pow); custom_binary_nif!(right_shift, Shr); custom_binary_nif!(remainder, Remainder); +fn _any(tensor: &Tensor, dims: Vec, keep_dims: bool) -> Result { + let comparison = tensor.ne(&tensor.zeros_like()?)?; + + let result = if keep_dims { + dims.iter() + .rev() + .fold(comparison, |t, dim| t.max_keepdim(*dim).unwrap()) + } else { + dims.iter() + .rev() + .fold(comparison, |t, dim| t.max(*dim).unwrap()) + }; + + Ok(result) +} + +fn _all(tensor: &Tensor, dims: Vec, keep_dims: bool) -> Result { + let comparison = tensor.ne(&tensor.zeros_like()?)?; + + let result = if keep_dims { + dims.iter() + .rev() + .fold(comparison, |t, dim| t.min_keepdim(*dim).unwrap()) + } else { + dims.iter() + .rev() + .fold(comparison, |t, dim| t.min(*dim).unwrap()) + }; + + Ok(result) +} + fn tuple_to_vec(term: Term) -> Result, rustler::Error> { - Ok(rustler::types::tuple::get_tuple(term)? + rustler::types::tuple::get_tuple(term)? .iter() .map(|elem| elem.decode()) - .collect::>()?) + .collect::>() } fn vec_to_tuple(env: Env, vec: Vec) -> Result { diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs index 99db38c52e..58991433cd 100644 --- a/candlex/test/candlex_test.exs +++ b/candlex/test/candlex_test.exs @@ -1,5 +1,5 @@ defmodule CandlexTest do - use Candlex.Case, async: true + use Nx.Case, async: true doctest Candlex describe "creation" do @@ -83,26 +83,26 @@ defmodule CandlexTest do |> assert_equal(t([[0, 1, 2], [3, 4, 5]])) Nx.iota({3, 3}, axis: 1) - |> assert_equal(t( - [ + |> assert_equal( + t([ [0, 1, 2], [0, 1, 2], [0, 1, 2] - ] - )) + ]) + ) Nx.iota({3, 3}, axis: -1) - |> assert_equal(t( - [ + |> assert_equal( + t([ [0, 1, 2], [0, 1, 2], [0, 1, 2] - ] - )) + ]) + ) Nx.iota({3, 4, 3}, axis: 0, type: :f64) - |> assert_equal(t( - [ + |> assert_equal( + t([ [ [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], @@ -121,19 +121,19 @@ defmodule CandlexTest do [2.0, 2.0, 2.0], [2.0, 2.0, 2.0] ] - ] - )) + ]) + ) Nx.iota({1, 3, 2}, axis: 2) - |> assert_equal(t( - [ + |> assert_equal( + t([ [ [0, 1], [0, 1], [0, 1] ] - ] - )) + ]) + ) end test "max" do @@ -189,12 +189,12 @@ defmodule CandlexTest do t([[1.0], [2]]) |> Nx.divide(t([[10, 20]])) - |> assert_equal(t( - [ + |> assert_equal( + t([ [0.10000000149011612, 0.05000000074505806], [0.20000000298023224, 0.10000000149011612] - ] - )) + ]) + ) 1 |> Nx.divide(2) @@ -206,12 +206,12 @@ defmodule CandlexTest do t([[1], [2]]) |> Nx.divide(t([[10, 20]])) - |> assert_equal(t( - [ + |> assert_equal( + t([ [0.10000000149011612, 0.05000000074505806], [0.20000000298023224, 0.10000000149011612] - ] - )) + ]) + ) end test "remainder" do @@ -228,12 +228,12 @@ defmodule CandlexTest do t([[10], [20]], names: [:x, :y]) |> Nx.remainder(t([[3, 4]], names: [nil, :y])) - |> assert_equal(t( - [ + |> assert_equal( + t([ [1, 2], [2, 0] - ] - )) + ]) + ) left = t(-11) right = t(10, type: :u8) @@ -247,6 +247,7 @@ defmodule CandlexTest do |> assert_equal(t(9)) positive_left = t(9, type: :u8) + Nx.remainder(positive_left, right) |> assert_equal(t(9)) @@ -270,30 +271,30 @@ defmodule CandlexTest do t([[10, 20]], names: [nil, :y]) |> Nx.quotient(t([[1], [2]], names: [:x, nil])) - |> assert_equal(t( - [ + |> assert_equal( + t([ [10, 20], [5, 10] - ] - )) + ]) + ) t([[10, 20]]) |> Nx.quotient(t([[1], [2]])) - |> assert_equal(t( - [ + |> assert_equal( + t([ [10, 20], [5, 10] - ] - )) + ]) + ) t([[10, 20]], type: :u8) |> Nx.quotient(t([[1], [2]], type: :u32)) - |> assert_equal(t( - [ + |> assert_equal( + t([ [10, 20], [5, 10] - ] - )) + ]) + ) end test "sign" do @@ -316,12 +317,12 @@ defmodule CandlexTest do t([[-0.0], [0.0]], type: :f64) |> Nx.atan2(t([-0.0, 0.0], type: :f64)) - |> assert_close(t( - [ + |> assert_close( + t([ [-3.141592653589793, -0.0], [3.141592653589793, 0.0] - ] - )) + ]) + ) end test "broadcast" do @@ -390,12 +391,12 @@ defmodule CandlexTest do t([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) |> Nx.greater(t([1, 2, 3])) - |> assert_equal(t( - [ + |> assert_equal( + t([ [0, 0, 0], [1, 1, 1] - ] - )) + ]) + ) end test "less" do @@ -510,14 +511,13 @@ defmodule CandlexTest do # Dot product of vectors - # TODO: - # t([1, 2, 3]) - # |> Nx.dot(t([4, 5, 6])) - # |> assert_equal(t(32)) + t([1, 2, 3]) + |> Nx.dot(t([4, 5, 6])) + |> assert_equal(t(32)) - # t([1.0, 2.0, 3.0]) - # |> Nx.dot(t([1, 2, 3])) - # |> assert_equal(t(14.0)) + t([1.0, 2, 3]) + |> Nx.dot(t([1, 2, 3])) + |> assert_equal(t(14.0)) # Dot product of matrices (2-D tensors) @@ -532,24 +532,28 @@ defmodule CandlexTest do # )) t([[1.0, 2, 3], [4, 5, 6]]) - |> Nx.dot(t([[7.0, 8], [9, 10], [11, 12]])) - |> assert_equal(t( - [ + |> Nx.dot(t([[7, 8], [9, 10], [11, 12]])) + |> assert_equal( + t([ [58.0, 64], [139, 154] - ] - )) + ]) + ) # Dot product of vector and n-D tensor - # t([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]], names: [:i, :j, :k]) - # |> Nx.dot(t([5.0, 10], names: [:x])) - # |> assert_equal(t( - # [ - # [25, 55], - # [85, 115] - # ] - # )) + t([[0.0]]) + |> Nx.dot(t([55.0])) + |> assert_equal(t([0.0])) + + t([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]]) + |> Nx.dot(t([5, 10])) + |> assert_equal( + t([ + [25.0, 55], + [85, 115] + ]) + ) # t([5.0, 10], names: [:x]) # |> Nx.dot(t([[1.0, 2, 3], [4, 5, 6]], names: [:i, :j])) @@ -610,31 +614,31 @@ defmodule CandlexTest do t1 |> Nx.dot([0], [], t2, [0], []) - |> assert_equal(t( - [ + |> assert_equal( + t([ [100, 140], [140, 200] - ] - )) + ]) + ) # TODO: t1 |> Nx.dot([0], [], t2, [1], []) - |> assert_equal(t( - [ + |> assert_equal( + t([ [70, 150], [100, 220] - ] - )) + ]) + ) t1 |> Nx.dot([1], [], t2, [0], []) - |> assert_equal(t( - [ + |> assert_equal( + t([ [70, 100], [150, 220] - ] - )) + ]) + ) # t1 # |> Nx.dot([1], [], t2, [1], []) @@ -739,6 +743,10 @@ defmodule CandlexTest do t([-2.0, -1, 0, 1, 2]) |> Nx.abs() |> assert_equal(t([2, 1, 0, 1, 2])) + + t([-2, -1, 0, 1, 2]) + |> Nx.abs() + |> assert_equal(t([2, 1, 0, 1, 2])) end test "sqrt" do @@ -774,34 +782,34 @@ defmodule CandlexTest do t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) |> Nx.argmax(axis: 0) - |> assert_equal(t( - [ + |> assert_equal( + t([ [1, 0, 0], [1, 1, 0] - ] - )) + ]) + ) t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) |> Nx.argmax(axis: :z) - |> assert_equal(t( - [ + |> assert_equal( + t([ [0, 2], [0, 1] - ] - )) + ]) + ) t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) |> Nx.argmax(axis: :y, keep_axis: true) - |> assert_equal(t( - [ + |> assert_equal( + t([ [ [0, 0, 0] ], [ [0, 1, 0] ] - ] - )) + ]) + ) end test "argmin" do @@ -819,30 +827,30 @@ defmodule CandlexTest do t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) |> Nx.argmin(axis: 0) - |> assert_equal(t( - [ + |> assert_equal( + t([ [0, 0, 0], [0, 0, 0] - ] - )) + ]) + ) t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) |> Nx.argmin(axis: 1) - |> assert_equal(t( - [ + |> assert_equal( + t([ [1, 1, 0], [1, 0, 0] - ] - )) + ]) + ) t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) |> Nx.argmin(axis: :z) - |> assert_equal(t( - [ + |> assert_equal( + t([ [1, 1], [1, 2] - ] - )) + ]) + ) end test "acos" do @@ -1114,23 +1122,23 @@ defmodule CandlexTest do t([-1, 0, 1]) |> Nx.logical_and(t([[-1], [0], [1]])) - |> assert_equal(t( - [ + |> assert_equal( + t([ [1, 0, 1], [0, 0, 0], [1, 0, 1] - ] - )) + ]) + ) t([-1.0, 0.0, 1.0]) - |> Nx.logical_and(t([[-1], [0], [1]]))\ - |> assert_equal(t( - [ + |> Nx.logical_and(t([[-1], [0], [1]])) + |> assert_equal( + t([ [1, 0, 1], [0, 0, 0], [1, 0, 1] - ] - )) + ]) + ) end test "logical_or" do @@ -1139,23 +1147,23 @@ defmodule CandlexTest do t([-1, 0, 1]) |> Nx.logical_or(t([[-1], [0], [1]])) - |> assert_equal(t( - [ + |> assert_equal( + t([ [1, 1, 1], [1, 0, 1], [1, 1, 1] - ] - )) + ]) + ) t([-1.0, 0.0, 1.0]) |> Nx.logical_or(t([[-1], [0], [1]])) - |> assert_equal(t( - [ + |> assert_equal( + t([ [1, 1, 1], [1, 0, 1], [1, 1, 1] - ] - )) + ]) + ) end test "logical_xor" do @@ -1165,23 +1173,23 @@ defmodule CandlexTest do t([-1, 0, 1]) |> Nx.logical_xor(t([[-1], [0], [1]])) - |> assert_equal(t( - [ + |> assert_equal( + t([ [0, 1, 0], [1, 0, 1], [0, 1, 0] - ] - )) + ]) + ) t([-1.0, 0.0, 1.0]) |> Nx.logical_xor(t([[-1], [0], [1]])) - |> assert_equal(t( - [ + |> assert_equal( + t([ [0, 1, 0], [1, 0, 1], [0, 1, 0] - ] - )) + ]) + ) end test "erf" do @@ -1233,48 +1241,49 @@ defmodule CandlexTest do |> assert_equal(t(10.0)) t = Nx.iota({2, 2, 3}, names: [:x, :y, :z]) + Nx.sum(t, axes: [:x]) - |> assert_equal(t( - [ + |> assert_equal( + t([ [6, 8, 10], [12, 14, 16] - ] - )) + ]) + ) Nx.sum(t, axes: [:y]) - |> assert_equal(t( - [ + |> assert_equal( + t([ [3, 5, 7], [15, 17, 19] - ] - )) + ]) + ) Nx.sum(t, axes: [:z]) - |> assert_equal(t( - [ + |> assert_equal( + t([ [3, 12], [21, 30] - ] - )) + ]) + ) Nx.sum(t, axes: [:x, :z]) |> assert_equal(t([24, 42])) Nx.sum(t, axes: [-3]) - |> assert_equal(t( - [ + |> assert_equal( + t([ [6, 8, 10], [12, 14, 16] - ] - )) + ]) + ) t([[1, 2], [3, 4]], names: [:x, :y]) |> Nx.sum(axes: [:x], keep_axes: true) - |> assert_equal(t( - [ + |> assert_equal( + t([ [4, 6] - ] - )) + ]) + ) end test "to_batched/2" do @@ -1284,24 +1293,24 @@ defmodule CandlexTest do |> Enum.to_list() first - |> assert_equal(t( - [ + |> assert_equal( + t([ [ [0, 1], [2, 3] ] - ] - )) + ]) + ) second - |> assert_equal(t( - [ + |> assert_equal( + t([ [ [4, 5], [6, 7] ] - ] - )) + ]) + ) [first, second] = Nx.iota({10}) @@ -1401,8 +1410,8 @@ defmodule CandlexTest do |> Nx.reshape({4, 1, 1, 1}), strides: [1, 1] ) - |> assert_equal(t( - [ + |> assert_equal( + t([ [ [ [0.0, 0.0, 0.0], @@ -1425,8 +1434,8 @@ defmodule CandlexTest do [18.0, 21.0, 24.0] ] ] - ] - )) + ]) + ) # input/output permutation @@ -1441,8 +1450,8 @@ defmodule CandlexTest do assert result.shape == {1, 3, 3, 2} result - |> assert_close(t( - [ + |> assert_close( + t([ [ [15.0, 15.0], [51.0, 51.0], @@ -1458,8 +1467,8 @@ defmodule CandlexTest do [267.0, 267.0], [303.0, 303.0] ] - ] - )) + ]) + ) # Nx.iota({9}) # |> Nx.reshape({1, 1, 3, 3}) @@ -1568,39 +1577,35 @@ defmodule CandlexTest do test "take_along_axis" do t([[1, 2, 3], [4, 5, 6]]) |> Nx.take_along_axis( - t( - [ - [0, 0, 2, 2, 1, 1], - [2, 2, 1, 1, 0, 0] - ] - ), + t([ + [0, 0, 2, 2, 1, 1], + [2, 2, 1, 1, 0, 0] + ]), axis: 1 ) - |> assert_equal(t( - [ + |> assert_equal( + t([ [1, 1, 3, 3, 2, 2], [6, 6, 5, 5, 4, 4] - ] - )) + ]) + ) t([[1, 2, 3], [4, 5, 6]]) |> Nx.take_along_axis( - t( - [ - [0, 1, 1], - [1, 0, 0], - [0, 1, 0] - ] - ), + t([ + [0, 1, 1], + [1, 0, 0], + [0, 1, 0] + ]), axis: 0 ) - |> assert_equal(t( - [ + |> assert_equal( + t([ [1, 5, 6], [4, 2, 3], [1, 5, 3] - ] - )) + ]) + ) end test "gather" do @@ -1661,8 +1666,8 @@ defmodule CandlexTest do Nx.iota({2, 3, 4}, names: [:x, :y, :z]) |> Nx.transpose() - |> assert_equal(t( - [ + |> assert_equal( + t([ [ [0, 12], [4, 16], @@ -1683,8 +1688,8 @@ defmodule CandlexTest do [7, 19], [11, 23] ] - ] - )) + ]) + ) t(1) |> Nx.transpose(axes: []) @@ -1692,8 +1697,8 @@ defmodule CandlexTest do Nx.iota({2, 3, 4}, names: [:batch, :x, :y]) |> Nx.transpose(axes: [2, 1, :batch]) - |> assert_equal(t( - [ + |> assert_equal( + t([ [ [0, 12], [4, 16], @@ -1714,13 +1719,13 @@ defmodule CandlexTest do [7, 19], [11, 23] ] - ] - )) + ]) + ) Nx.iota({2, 3, 4}, names: [:batch, :x, :y]) |> Nx.transpose(axes: [:y, :batch, :x]) - |> assert_equal(t( - [ + |> assert_equal( + t([ [ [0, 4, 8], [12, 16, 20] @@ -1737,13 +1742,13 @@ defmodule CandlexTest do [3, 7, 11], [15, 19, 23] ] - ] - )) + ]) + ) Nx.iota({2, 3, 4}, names: [:batch, :x, :y]) |> Nx.transpose(axes: [:batch, :y, :x]) - |> assert_equal(t( - [ + |> assert_equal( + t([ [ [0, 4, 8], [1, 5, 9], @@ -1756,8 +1761,8 @@ defmodule CandlexTest do [14, 18, 22], [15, 19, 23] ] - ] - )) + ]) + ) end test "put_slice" do @@ -1767,21 +1772,21 @@ defmodule CandlexTest do t([[1, 2, 3], [4, 5, 6]]) |> Nx.put_slice([0, 0], t([[7, 8, 9], [10, 11, 12]])) - |> assert_equal(t( - [ + |> assert_equal( + t([ [7, 8, 9], [10, 11, 12] - ] - )) + ]) + ) t([[1, 2, 3], [4, 5, 6]]) |> Nx.put_slice([0, 1], t([[7, 8], [9, 10]])) - |> assert_equal(t( - [ + |> assert_equal( + t([ [1, 7, 8], [4, 9, 10] - ] - )) + ]) + ) # t([[1, 2, 3], [4, 5, 6]]) # |> Nx.put_slice([t(0), t(1)], t([[10.0, 11.0]])) @@ -1979,27 +1984,27 @@ defmodule CandlexTest do test "take" do t([[1, 2], [3, 4]]) |> Nx.take(t([1, 0, 1])) - |> assert_equal(t( - [ + |> assert_equal( + t([ [3, 4], [1, 2], [3, 4] - ] - )) + ]) + ) t([[1, 2], [3, 4]]) |> Nx.take(t([1, 0, 1]), axis: 1) - |> assert_equal(t( - [ + |> assert_equal( + t([ [2, 1, 2], [4, 3, 4] - ] - )) + ]) + ) t([[[1, 2], [11, 12]], [[101, 102], [111, 112]]]) |> Nx.take(t([1, 0, 1]), axis: 1) - |> assert_equal(t( - [ + |> assert_equal( + t([ [ [11, 12], [1, 2], @@ -2010,8 +2015,8 @@ defmodule CandlexTest do [101, 102], [111, 112] ] - ] - )) + ]) + ) # t([[1, 2], [11, 12]]) # |> Nx.take(t([[0, 0], [1, 1], [0, 0]]), axis: 1) @@ -2075,48 +2080,48 @@ defmodule CandlexTest do test "clip" do t([[1, 2, 3], [4, 5, 6]]) |> Nx.clip(2, 4) - |> assert_equal(t( - [ + |> assert_equal( + t([ [2, 2, 3], [4, 4, 4] - ] - )) + ]) + ) t([[1, 2, 3], [4, 5, 6]]) |> Nx.clip(2.0, 3) - |> assert_equal(t( - [ + |> assert_equal( + t([ [2.0, 2.0, 3.0], [3.0, 3.0, 3.0] - ] - )) + ]) + ) t([[1, 2, 3], [4, 5, 6]]) |> Nx.clip(t(2.0), Nx.max(1.0, 3.0)) - |> assert_equal(t( - [ + |> assert_equal( + t([ [2.0, 2.0, 3.0], [3.0, 3.0, 3.0] - ] - )) + ]) + ) t([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) |> Nx.clip(2, 6.0) - |> assert_equal(t( - [ + |> assert_equal( + t([ [2.0, 2.0, 3.0], [4.0, 5.0, 6.0] - ] - )) + ]) + ) t([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], type: :f32) |> Nx.clip(1, 4) - |> assert_equal(t( - [ + |> assert_equal( + t([ [1.0, 2.0, 3.0], [4.0, 4.0, 4.0] - ] - )) + ]) + ) end test "not_equal" do @@ -2133,15 +2138,118 @@ defmodule CandlexTest do t([[1, 4, 2], [4, 5, 6]]) |> Nx.not_equal(t([[1, 3, 2], [4, 2, 1]])) - |> assert_equal(t( - [ + |> assert_equal( + t([ [0, 1, 0], [0, 1, 1] - ] - )) + ]) + ) + end + + test "all" do + t(0) + |> Nx.all() + |> assert_equal(t(0)) + + t(10) + |> Nx.all() + |> assert_equal(t(1)) + + t([0, 1, 2]) + |> Nx.all() + |> assert_equal(t(0)) + + t([[-1, 0, 1], [2, 3, 4]], names: [:x, :y]) + |> Nx.all(axes: [:x]) + |> assert_equal(t([1, 0, 1])) + + t([[-1, 0, 1], [2, 3, 4]], names: [:x, :y]) + |> Nx.all(axes: [:y]) + |> assert_equal(t([0, 1])) + + t([[-1, 0, 1], [2, 3, 4]], names: [:x, :y]) + |> Nx.all(axes: [:y], keep_axes: true) + |> assert_equal( + t([ + [0], + [1] + ]) + ) + + tensor = Nx.tensor([[[1, 2], [0, 4]], [[5, 6], [7, 8]]], names: [:x, :y, :z]) + + tensor + |> Nx.all(axes: [:x, :y]) + |> assert_equal(t([0, 1])) + + tensor + |> Nx.all(axes: [:y, :z]) + |> assert_equal(t([0, 1])) + + tensor + |> Nx.all(axes: [:x, :z]) + |> assert_equal(t([1, 0])) + + tensor + |> Nx.all(axes: [:x, :y], keep_axes: true) + |> assert_equal( + t([ + [ + [0, 1] + ] + ]) + ) + + tensor + |> Nx.all(axes: [:y, :z], keep_axes: true) + |> assert_equal( + t([ + [ + [0] + ], + [ + [1] + ] + ]) + ) + + tensor + |> Nx.all(axes: [:x, :z], keep_axes: true) + |> assert_equal( + t([ + [ + [1], + [0] + ] + ]) + ) end - if Candlex.Backend.cuda_available? do + test "any" do + t([0, 1, 2]) + |> Nx.any() + |> assert_equal(t(1)) + + t([[0, 1, 0], [0, 1, 2]], names: [:x, :y]) + |> Nx.any(axes: [:x]) + |> assert_equal(t([0, 1, 1])) + + t([[0, 1, 0], [0, 1, 2]], names: [:x, :y]) + |> Nx.any(axes: [:y]) + |> assert_equal(t([1, 1])) + + tensor = t([[0, 1, 0], [0, 1, 2]], names: [:x, :y]) + + tensor + |> Nx.any(axes: [:x], keep_axes: true) + |> assert_equal(t([[0, 1, 1]])) + + tensor + |> Nx.any(axes: [:y], keep_axes: true) + |> assert_equal(t([[1], [1]])) + end + + if Candlex.Backend.cuda_available?() do test "different devices" do t([1, 2, 3], backend: {Candlex.Backend, device: :cpu}) |> Nx.add(t([10, 20, 30], backend: {Candlex.Backend, device: :cuda})) @@ -2182,6 +2290,7 @@ defmodule CandlexTest do tensor # |> IO.inspect() |> Nx.to_binary() + # |> IO.inspect() opts = diff --git a/candlex/test/support/nx_case.ex b/candlex/test/support/nx_case.ex new file mode 100644 index 0000000000..949a54915b --- /dev/null +++ b/candlex/test/support/nx_case.ex @@ -0,0 +1,45 @@ +defmodule Nx.Case do + @moduledoc """ + Test case for tensor assertions + """ + + use ExUnit.CaseTemplate + + using do + quote do + import Nx.Case + end + end + + def assert_equal(left, right) do + equals = + left + |> Nx.equal(right) + # |> Nx.logical_or(Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right))) + |> Nx.all() + |> Nx.to_number() + + if equals != 1 || Nx.shape(left) != Nx.shape(right) do + flunk(""" + Tensor assertion failed. + left: #{inspect(left)} + right: #{inspect(right)} + """) + end + end + + def assert_close(left, right) do + equals = + left + |> Nx.all_close(right, atol: 1.0e-4, rtol: 1.0e-4) + |> Nx.backend_transfer(Nx.BinaryBackend) + + if equals != Nx.tensor(1, type: {:u, 8}, backend: Nx.BinaryBackend) do + flunk(""" + Tensor assertion failed. + left: #{inspect(left)} + right: #{inspect(right)} + """) + end + end +end