diff --git a/candlex/.formatter.exs b/candlex/.formatter.exs new file mode 100644 index 0000000000..d9f22bc671 --- /dev/null +++ b/candlex/.formatter.exs @@ -0,0 +1,5 @@ +# Used by "mix format" +[ + import_deps: [:nx], + inputs: ["{mix,.formatter}.exs", "{config,lib,test}/**/*.{ex,exs}"] +] diff --git a/candlex/.gitignore b/candlex/.gitignore new file mode 100644 index 0000000000..1fb4d171fd --- /dev/null +++ b/candlex/.gitignore @@ -0,0 +1,29 @@ +# The directory Mix will write compiled artifacts to. +/_build/ + +# If you run "mix test --cover", coverage assets end up here. +/cover/ + +# The directory Mix downloads your dependencies sources to. +/deps/ + +# Where third-party dependencies like ExDoc output generated docs. +/doc/ + +# Ignore .fetch files in case you like to edit your project deps locally. +/.fetch + +# If the VM crashes, it generates a dump, let's ignore it too. +erl_crash.dump + +# Also ignore archive artifacts (built via "mix archive.build"). +*.ez + +# Ignore package tarball (built via "mix hex.build"). +candlex-*.tar + +# Temporary files, for example, from tests. +/tmp/ + +# Shared objects build by Rust. +*.so diff --git a/candlex/CHANGELOG.md b/candlex/CHANGELOG.md new file mode 100644 index 0000000000..25203fcdbf --- /dev/null +++ b/candlex/CHANGELOG.md @@ -0,0 +1,14 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.1.2] - 2023-10-30 + +### Added + +- Precompiled binaries for a few CPU targets. diff --git a/candlex/LICENSE b/candlex/LICENSE new file mode 100644 index 0000000000..f433b1a53f --- /dev/null +++ b/candlex/LICENSE @@ -0,0 +1,177 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS diff --git a/candlex/README.md b/candlex/README.md new file mode 100644 index 0000000000..4805cf5baf --- /dev/null +++ b/candlex/README.md @@ -0,0 +1,89 @@ +# Candlex + +[![ci](https://github.com/mimiquate/candlex/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/mimiquate/candlex/actions?query=branch%3Amain) +[![Hex.pm](https://img.shields.io/hexpm/v/candlex.svg)](https://hex.pm/packages/candlex) +[![Docs](https://img.shields.io/badge/docs-gray.svg)](https://hexdocs.pm/candlex) + +An `Nx` [backend](https://hexdocs.pm/nx/Nx.html#module-backends) for [candle](https://huggingface.github.io/candle) machine learning minimalist framework + +## Installation + +If [available in Hex](https://hex.pm/docs/publish), the package can be installed +by adding `candlex` to your list of dependencies in `mix.exs`: + +```elixir +def deps do + [ + {:candlex, "~> 0.1.2"} + ] +end +``` + +Documentation can be generated with [ExDoc](https://github.com/elixir-lang/ex_doc) +and published on [HexDocs](https://hexdocs.pm). Once published, the docs can +be found at . + +## Usage + +Just configure Nx to default to Candlex backend in your configuration: + +```elixir +# Possibly config/runtime.exs + +config :nx, default_backend: Candlex.Backend +``` + +or in your scripts, precede all your Nx operations with: + +```elixir +Nx.default_backend(Candlex.Backend) +``` + +More details in [Nx backends](https://hexdocs.pm/nx/Nx.html#module-backends) + +#### `CANDLEX_NIF_BUILD` + +Defaults to `false`. If `true` the native binary is built locally, which may be useful +if no precompiled binary is available for your target environment. Once set, you +must run `mix deps.clean candlex --build` explicitly to force to recompile. +Building has a number of dependencies, see *Building from source* below. + +## Building from source + +To build the native binary locally you need to set `CANDLEX_NIF_BUILD=true`. +Keep in mind that the compilation usually takes time. + +You will need the following installed in your system for the compilation: + + * [Git](https://git-scm.com) for fetching candle-core source + * [Rust](https://www.rust-lang.org) with cargo to compile rustler NIFs + +## Releasing + +To publish a new version of this package: + +1. Update `@version` in `mix.exs` and `project-version` in `.github/workflows/binaries.yml`. +1. `git tag -s ` to create new signed tag. +1. `git push origin ` to push the tag. +1. Wait for the `binaries.yml` GitHub workflow to build all the NIF binaries. +1. `mix rustler_precompiled.download Candlex.Native --all --print` to generate binaries checksums locally. +1. `rm -r native/candlex/target` to leave out rust crate build artifacts from published elixir package. +1. `mix hex.build --unpack` to check the package includes the correct files. +1. Publish the release from draft in GitHub. +1. `mix hex.publish` to publish package to Hex.pm. + +## License + +Copyright 2023 Mimiquate + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/candlex/config/config.exs b/candlex/config/config.exs new file mode 100644 index 0000000000..aecd8eacf6 --- /dev/null +++ b/candlex/config/config.exs @@ -0,0 +1,17 @@ +import Config + +enable_cuda = + case System.get_env("CUDA") do + nil -> System.find_executable("nvcc") && System.find_executable("nvidia-smi") + "false" -> false + _ -> true + end + +crate_features = + if enable_cuda do + [:cuda] + else + [] + end + +config :candlex, crate_features: crate_features diff --git a/candlex/examples/linear_regression.exs b/candlex/examples/linear_regression.exs new file mode 100644 index 0000000000..e80949a8fe --- /dev/null +++ b/candlex/examples/linear_regression.exs @@ -0,0 +1,77 @@ +defmodule LinearRegression do + import Nx.Defn + + # y = mx + b + defn init_random_params do + key = Nx.Random.key(42) + {m, new_key} = Nx.Random.normal(key, 0.0, 0.1, shape: {1, 1}) + {b, _new_key} = Nx.Random.normal(new_key, 0.0, 0.1, shape: {1}) + {m, b} + end + + defn predict({m, b}, inp) do + Nx.dot(inp, m) + b + end + + # MSE Loss + defn loss({m, b}, inp, tar) do + preds = predict({m, b}, inp) + Nx.mean(Nx.pow(tar - preds, 2)) + end + + defn update({m, b} = params, inp, tar, step) do + {grad_m, grad_b} = grad(params, &loss(&1, inp, tar)) + {m - grad_m * step, b - grad_b * step} + end + + def train(params, epochs, lin_fn) do + data = + Stream.repeatedly(fn -> for _ <- 1..32, do: :rand.uniform() * 10 end) + |> Stream.map(fn x -> Enum.zip(x, Enum.map(x, lin_fn)) end) + + for _ <- 1..epochs, reduce: params do + acc -> + data + |> Enum.take(200) + |> Enum.reduce( + acc, + fn batch, cur_params -> + {inp, tar} = Enum.unzip(batch) + x = Nx.reshape(Nx.tensor(inp), {32, 1}) + y = Nx.reshape(Nx.tensor(tar), {32, 1}) + update(cur_params, x, y, 0.001) + end + ) + end + end +end + +Nx.default_backend(Candlex.Backend) +# Nx.default_backend(Nx.BinaryBackend) + +params = LinearRegression.init_random_params() +m = :rand.normal(0.0, 10.0) +b = :rand.normal(0.0, 5.0) +IO.puts("Target m: #{m} Target b: #{b}\n") + +lin_fn = fn x -> m * x + b end +epochs = 100 + +# These will be very close to the above coefficients +{time, {trained_m, trained_b}} = :timer.tc(LinearRegression, :train, [params, epochs, lin_fn]) + +trained_m = + trained_m + |> Nx.squeeze() + |> Nx.backend_transfer() + |> Nx.to_number() + +trained_b = + trained_b + |> Nx.squeeze() + |> Nx.backend_transfer() + |> Nx.to_number() + +IO.puts("Trained in #{time / 1_000_000} sec.") +IO.puts("Trained m: #{trained_m} Trained b: #{trained_b}\n") +IO.puts("Accuracy m: #{m - trained_m} Accuracy b: #{b - trained_b}") diff --git a/candlex/examples/mnist.exs b/candlex/examples/mnist.exs new file mode 100644 index 0000000000..9e0254a509 --- /dev/null +++ b/candlex/examples/mnist.exs @@ -0,0 +1,170 @@ +defmodule MNIST do + import Nx.Defn + + defn init_random_params do + key = Nx.Random.key(42) + {w1, new_key} = Nx.Random.normal(key, 0.0, 0.1, shape: {784, 128}, names: [:input, :layer]) + {b1, new_key} = Nx.Random.normal(new_key, 0.0, 0.1, shape: {128}, names: [:layer]) + + {w2, new_key} = + Nx.Random.normal(new_key, 0.0, 0.1, shape: {128, 10}, names: [:layer, :output]) + + {b2, _new_key} = Nx.Random.normal(new_key, 0.0, 0.1, shape: {10}, names: [:output]) + {w1, b1, w2, b2} + end + + defn softmax(logits) do + Nx.exp(logits) / Nx.sum(Nx.exp(logits), axes: [:output], keep_axes: true) + end + + defn predict({w1, b1, w2, b2}, batch) do + batch + |> Nx.dot(w1) + |> Nx.add(b1) + |> Nx.sigmoid() + |> Nx.dot(w2) + |> Nx.add(b2) + |> softmax() + end + + defn accuracy({w1, b1, w2, b2}, batch_images, batch_labels) do + Nx.mean( + Nx.equal( + Nx.argmax(batch_labels, axis: :output), + Nx.argmax(predict({w1, b1, w2, b2}, batch_images), axis: :output) + ) + ) + end + + defn loss({w1, b1, w2, b2}, batch_images, batch_labels) do + preds = predict({w1, b1, w2, b2}, batch_images) + -Nx.sum(Nx.mean(Nx.log(preds) * batch_labels, axes: [:output])) + end + + defn update({w1, b1, w2, b2} = params, batch_images, batch_labels, step) do + {grad_w1, grad_b1, grad_w2, grad_b2} = grad(params, &loss(&1, batch_images, batch_labels)) + + { + w1 - grad_w1 * step, + b1 - grad_b1 * step, + w2 - grad_w2 * step, + b2 - grad_b2 * step + } + end + + defn update_with_averages({_, _, _, _} = cur_params, imgs, tar, avg_loss, avg_accuracy, total) do + batch_loss = loss(cur_params, imgs, tar) + batch_accuracy = accuracy(cur_params, imgs, tar) + avg_loss = avg_loss + batch_loss / total + avg_accuracy = avg_accuracy + batch_accuracy / total + {update(cur_params, imgs, tar, 0.01), avg_loss, avg_accuracy} + end + + defp unzip_cache_or_download(zip) do + base_url = ~c"https://storage.googleapis.com/cvdf-datasets/mnist/" + path = Path.join("tmp", zip) + + data = + if File.exists?(path) do + IO.puts("Using #{zip} from tmp/\n") + File.read!(path) + else + IO.puts("Fetching #{zip} from https://storage.googleapis.com/cvdf-datasets/mnist/\n") + :inets.start() + :ssl.start() + + {:ok, {_status, _response, data}} = :httpc.request(base_url ++ zip) + File.mkdir_p!("tmp") + File.write!(path, data) + + data + end + + :zlib.gunzip(data) + end + + def download(images, labels) do + <<_::32, n_images::32, n_rows::32, n_cols::32, images::binary>> = + unzip_cache_or_download(images) + + train_images = + images + |> Nx.from_binary({:u, 8}) + |> Nx.reshape({n_images, n_rows * n_cols}, names: [:batch, :input]) + |> Nx.divide(255.0) + |> Nx.to_batched(30) + + IO.puts("#{n_images} #{n_rows}x#{n_cols} images\n") + + <<_::32, n_labels::32, labels::binary>> = unzip_cache_or_download(labels) + + train_labels = + labels + |> Nx.from_binary({:u, 8}) + |> Nx.reshape({n_labels, 1}, names: [:batch, :output]) + |> Nx.equal(Nx.tensor(Enum.to_list(0..9))) + |> Nx.to_batched(30) + + IO.puts("#{n_labels} labels\n") + + {train_images, train_labels} + end + + def train_epoch(cur_params, imgs, labels) do + total_batches = Enum.count(imgs) + + imgs + |> Stream.zip(labels) + |> Enum.reduce({cur_params, Nx.tensor(0.0), Nx.tensor(0.0)}, fn + {imgs, tar}, {cur_params, avg_loss, avg_accuracy} -> + update_with_averages(cur_params, imgs, tar, avg_loss, avg_accuracy, total_batches) + end) + end + + def train(imgs, labels, params, opts \\ []) do + epochs = opts[:epochs] || 5 + + for epoch <- 1..epochs, reduce: params do + cur_params -> + {time, {new_params, epoch_avg_loss, epoch_avg_acc}} = + :timer.tc(__MODULE__, :train_epoch, [cur_params, imgs, labels]) + + epoch_avg_loss = + epoch_avg_loss + |> Nx.backend_transfer() + |> Nx.to_number() + + epoch_avg_acc = + epoch_avg_acc + |> Nx.backend_transfer() + |> Nx.to_number() + + IO.puts("Epoch #{epoch} Time: #{time / 1_000_000}s") + IO.puts("Epoch #{epoch} average loss: #{inspect(epoch_avg_loss)}") + IO.puts("Epoch #{epoch} average accuracy: #{inspect(epoch_avg_acc)}") + IO.puts("\n") + new_params + end + end +end + +Nx.global_default_backend(Candlex.Backend) + +{train_images, train_labels} = + MNIST.download(~c"train-images-idx3-ubyte.gz", ~c"train-labels-idx1-ubyte.gz") + +IO.puts("Initializing parameters...\n") +params = MNIST.init_random_params() + +IO.puts("Training MNIST for 10 epochs...\n\n") +final_params = MNIST.train(train_images, train_labels, params, epochs: 10) + +IO.puts("The result of the first batch") + +IO.inspect( + MNIST.predict(final_params, hd(Enum.to_list(train_images))) + |> Nx.argmax(axis: :output) +) + +IO.puts("Labels for the first batch") +IO.inspect(hd(Enum.to_list(train_labels)) |> Nx.argmax(axis: :output)) diff --git a/candlex/lib/candlex.ex b/candlex/lib/candlex.ex new file mode 100644 index 0000000000..ee5cfce363 --- /dev/null +++ b/candlex/lib/candlex.ex @@ -0,0 +1,5 @@ +defmodule Candlex do + @moduledoc """ + Documentation for `Candlex`. + """ +end diff --git a/candlex/lib/candlex/backend.ex b/candlex/lib/candlex/backend.ex new file mode 100644 index 0000000000..4e68eba143 --- /dev/null +++ b/candlex/lib/candlex/backend.ex @@ -0,0 +1,1021 @@ +defmodule Candlex.Backend do + @moduledoc """ + An opaque Nx backend with bindings to candle. + """ + + defstruct [:device, :resource] + + @behaviour Nx.Backend + + alias Nx.Tensor, as: T + alias Candlex.Native + + @device_cuda :cuda + @device_cpu :cpu + + @impl true + def init(opts) do + Keyword.validate!(opts, [:device]) + end + + # Creation + + @impl true + def constant(%T{} = tensor, scalar, backend_options) do + tensor + |> Nx.BinaryBackend.constant(scalar, []) + |> Nx.BinaryBackend.backend_transfer(__MODULE__, backend_options) + end + + @impl true + def from_binary(%T{shape: shape, type: type} = tensor, binary, backend_options) do + binary + |> Native.from_binary(to_candle_dtype(type), shape, device_option(backend_options)) + |> unwrap!() + |> to_nx(tensor) + end + + @impl true + def iota(%T{shape: {}} = out, nil, backend_options) do + constant(out, 0, backend_options) + end + + def iota(%T{shape: shape, type: type} = out, nil, backend_options) do + Native.arange(0, Nx.size(shape), to_candle_dtype(type), shape, device_option(backend_options)) + |> unwrap!() + |> to_nx(out) + end + + def iota(%T{shape: shape, type: type} = out, axis, backend_options) do + # Build in one dimension, then broadcast + axis_size = elem(shape, axis) + + Native.arange( + 0, + axis_size, + to_candle_dtype(type), + Tuple.duplicate(1, Nx.rank(shape)) |> put_elem(axis, axis_size), + device_option(backend_options) + ) + |> unwrap!() + |> Native.broadcast_to(shape) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def eye(%T{shape: shape, type: type} = _out, backend_options) do + iota = Nx.iota(shape, backend: {__MODULE__, backend_options}) + + Nx.equal(Nx.tril(iota), Nx.triu(iota)) + |> Nx.as_type(type) + end + + # Backend + + @impl true + def backend_transfer(tensor, backend, backend_options) do + if backend == __MODULE__ && same_device?(tensor, device_option(backend_options)) do + tensor + else + try do + backend_copy(tensor, backend, backend_options) + after + backend_deallocate(tensor) + end + end + end + + @impl true + def backend_copy(%T{} = tensor, Candlex.Backend, backend_options) do + tensor + |> from_nx() + |> Native.to_device(device_option(backend_options)) + |> unwrap!() + |> to_nx(tensor) + end + + def backend_copy(%T{} = tensor, backend, backend_options) do + backend.from_binary(tensor, to_binary(tensor), backend_options) + end + + @impl true + def backend_deallocate(%T{} = _tensor) do + true + end + + # Conversion + + @impl true + def to_binary(tensor, _limit \\ nil) do + # TODO: don't ignore limit + + from_nx(tensor) + |> Native.to_binary() + |> unwrap!() + end + + # Aggregates + + @impl true + def all(%T{} = out, %T{} = tensor, opts) do + case opts[:axes] do + nil -> + from_nx(tensor) + |> Native.all() + + axes -> + from_nx(tensor) + |> Native.all_within_dims(axes, opts[:keep_axes]) + end + |> unwrap!() + |> to_nx(out) + end + + @impl true + def any(%T{} = out, %T{} = tensor, opts) do + case opts[:axes] do + nil -> + from_nx(tensor) + |> Native.any() + + axes -> + from_nx(tensor) + |> Native.any_within_dims(axes, opts[:keep_axes]) + end + |> unwrap!() + |> to_nx(out) + end + + @impl true + def sum(%T{type: out_type} = out, %T{} = t, opts) do + axes = opts[:axes] || Nx.axes(t) + keep_axes = opts[:keep_axes] || false + + t + |> from_nx() + |> Native.sum(axes, keep_axes) + |> unwrap!() + |> Native.to_type(to_candle_dtype(out_type)) + |> unwrap!() + |> to_nx(out) + end + + for op <- [:argmax, :argmin] do + @impl true + def unquote(op)(%T{} = out, %T{shape: {}} = _tensor, _opts) do + out + |> constant(0, []) + end + + def unquote(op)(%T{type: type} = out, %T{} = tensor, opts) do + axis = opts[:axis] || -1 + keep_axis = opts[:keep_axis] || false + + tensor + |> from_nx() + |> Native.unquote(op)(axis, keep_axis) + |> unwrap!() + # candle argmax/argmin changes to u32 + |> Native.to_type(to_candle_dtype(type)) + |> unwrap!() + |> to_nx(out) + end + end + + @impl true + def reduce_max(%T{} = out, %T{shape: {}} = tensor, _opts) do + out + |> from_binary(to_binary(tensor), []) + end + + def reduce_max(%T{} = out, %T{} = tensor, opts) do + axis = + case opts[:axes] do + nil -> 0 + [] -> 0 + [axis] -> axis + axes -> raise "doesn't support axes option with more than 1 axis, '#{inspect(axes)}'" + end + + keep_axis = opts[:keep_axes] || false + + tensor + |> from_nx() + |> Native.reduce_max(axis, keep_axis) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def reduce_min(%T{} = out, %T{shape: {}} = tensor, _opts) do + out + |> from_binary(to_binary(tensor), []) + end + + def reduce_min(%T{} = out, %T{} = tensor, opts) do + axis = + case opts[:axes] do + nil -> 0 + [] -> 0 + [axis] -> axis + axes -> raise "doesn't support axes option with more than 1 axis, '#{inspect(axes)}'" + end + + keep_axis = opts[:keep_axes] || false + + tensor + |> from_nx() + |> Native.reduce_min(axis, keep_axis) + |> unwrap!() + |> to_nx(out) + end + + # Element-wise + + @impl true + def clip(%T{} = out, %T{} = t, %T{} = min, %T{} = max) do + [t, min, max] = maybe_upcast([t, min, max]) + + t + |> from_nx() + |> Native.clamp(from_nx(min), from_nx(max)) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def select(%T{shape: shape, type: type} = out, pred, on_true, on_false) do + on_true = + on_true + |> from_nx() + |> Native.to_type(to_candle_dtype(type)) + |> unwrap!() + |> Native.broadcast_to(shape) + |> unwrap!() + + on_false = + on_false + |> from_nx() + |> Native.to_type(to_candle_dtype(type)) + |> unwrap!() + |> Native.broadcast_to(shape) + |> unwrap!() + + pred + |> from_nx() + |> Native.where_cond(on_true, on_false) + |> unwrap!() + |> to_nx(out) + end + + # Binary ops + + for op <- [:add, :divide, :max, :min, :multiply, :subtract] do + @impl true + def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do + {left, right} = maybe_transfer_device(left, right) + {left, right} = maybe_upcast(left, right) + + from_nx(left) + |> Native.unquote(op)(from_nx(right)) + |> unwrap!() + |> to_nx(out) + end + end + + for op <- [:atan2, :pow, :quotient, :remainder] do + @impl true + def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do + {left, right} = maybe_upcast(left, right) + {left, right} = maybe_broadcast_bin_args(out.shape, left, right) + + left + |> Native.unquote(op)(right) + |> unwrap!() + |> to_nx(out) + end + end + + for op <- [ + :bitwise_and, + :bitwise_or, + :bitwise_xor, + :equal, + :greater, + :greater_equal, + :left_shift, + :less, + :less_equal, + :logical_and, + :logical_or, + :logical_xor, + :not_equal, + :right_shift + ] do + @impl true + def unquote(op)(%T{} = out, %T{} = left, %T{} = right) do + {left, right} = maybe_transfer_device(left, right) + {left, right} = maybe_upcast(left, right) + {left, right} = maybe_broadcast_bin_args(out.shape, left, right) + + left + |> Native.unquote(op)(right) + |> unwrap!() + # TODO: Do this conditionally or as part of native op + |> Native.to_type(to_candle_dtype(out.type)) + |> unwrap!() + |> to_nx(out) + end + end + + # Unary ops + + for op <- [ + :abs, + :acos, + :acosh, + :asin, + :asinh, + :atan, + :atanh, + :bitwise_not, + :cbrt, + :ceil, + :cos, + :cosh, + :erf, + :erfc, + :erf_inv, + :exp, + :expm1, + :floor, + :is_infinity, + :is_nan, + :log, + :log1p, + :negate, + :round, + :rsqrt, + :sigmoid, + :sign, + :sin, + :sinh, + :sqrt, + :tan, + :tanh + ] do + @impl true + def unquote(op)(%T{} = out, %T{} = tensor) do + tensor + |> from_nx() + |> Native.unquote(op)() + |> unwrap!() + |> to_nx(out) + end + end + + # Indexed + + @impl true + def gather(%T{} = out, %T{shape: {_}} = tensor, %T{} = indices) do + tensor + |> from_nx() + |> Native.gather(from_nx(Nx.flatten(indices)), 0) + |> unwrap!() + |> to_nx(out) + end + + def gather(%T{} = _out, %T{} = _tensor, %T{} = _indices) do + raise("unsupported gather for tensor of rank greater than 1") + end + + @impl true + def indexed_add(%T{} = out, %T{shape: {_}} = tensor, %T{} = indices, %T{} = updates) do + {tensor, updates} = maybe_upcast(tensor, updates) + + tensor + |> from_nx() + |> Native.index_add(from_nx(Nx.flatten(indices)), from_nx(updates), 0) + |> unwrap!() + |> to_nx(out) + end + + def indexed_add(%T{} = _out, %T{} = _tensor, %T{} = _indices, %T{} = _updates) do + raise("unsupported indexed_add for tensor of rank greater than 1") + end + + @impl true + def put_slice(%T{} = out, %T{} = t, [_ | _] = start_indices, slice) do + [last_start_index | leading_start_indices] = Enum.reverse(start_indices) + + if Enum.all?(leading_start_indices, fn i -> Nx.equal(i, 0) end) do + t + |> from_nx() + |> Native.slice_scatter( + from_nx(slice), + length(start_indices) - 1, + Nx.to_number(last_start_index) + ) + |> unwrap!() + |> to_nx(out) + else + raise "put_slice only supports last start index not to be 0 for now" + end + end + + @impl true + def slice( + %T{shape: _output_shape} = out, + %T{shape: input_shape} = t, + starts, + lengths, + _strides + ) do + t + |> from_nx() + |> narrow(starts, lengths, 0, input_shape) + # TODO: Support strides + # |> stride(output_shape, lengths, strides) + |> to_nx(out) + end + + @impl true + def take(%T{} = out, %T{} = tensor, %T{} = indexes, axis) do + if Nx.rank(indexes) > 1 do + raise "only indexes of rank=1 supported for now" + end + + tensor + |> from_nx() + |> Native.index_select(from_nx(indexes), axis) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def take_along_axis(%T{} = out, %T{} = tensor, %T{} = indexes, axis) do + tensor + |> from_nx() + |> Native.gather(from_nx(indexes), axis) + |> unwrap!() + |> to_nx(out) + end + + # N-dim + + @impl true + def concatenate(%T{} = out, tensors, axis) do + tensors + |> maybe_upcast() + |> Enum.map(&from_nx/1) + |> Native.concatenate(axis) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def conv(%T{type: out_type} = out, %T{shape: shape} = tensor, %T{} = kernel, opts) do + # TODO: Support more opts + unsupported_option!(opts, :batch_group_size, 1) + unsupported_option!(opts, :feature_group_size, 1) + + # For now we assume: + # strides = opts[:strides] # [1, 1] + # padding = opts[:padding] # [{0, 0}, {0, 0}] + # input_dilation = opts[:input_dilation] # [1, 1] + # kernel_dilation = opts[:kernel_dilation] # [1, 1] + + input_permutation = opts[:input_permutation] + kernel_permutation = opts[:kernel_permutation] + + output_permutation = + case opts[:output_permutation] do + nil -> + nil + + l -> + # The permutation that Nx.Shape expects is actually the reverse permutation + # for the given input + l |> Enum.with_index() |> Enum.sort() |> Enum.map(&elem(&1, 1)) + end + + native_tensor = + tensor + |> from_nx() + |> permute(input_permutation) + |> Native.to_type(to_candle_dtype(out_type)) + |> unwrap!() + + native_kernel = + kernel + |> from_nx() + |> permute(kernel_permutation) + |> Native.to_type(to_candle_dtype(out_type)) + |> unwrap!() + + native_result = + case Nx.rank(shape) do + 3 -> Native.conv1d(native_tensor, native_kernel) + 4 -> Native.conv2d(native_tensor, native_kernel) + rank -> raise("unsupported conv for tensor of rank #{rank}, only 3 or 4 supported") + end + + native_result + |> unwrap!() + |> permute(output_permutation) + |> to_nx(out) + end + + @impl true + def dot( + %T{type: _out_type} = out, + %T{shape: left_shape, type: _left_type} = left, + [left_axis] = _left_axes, + [] = _left_batched_axes, + %T{shape: right_shape, type: _right_type} = right, + [0] = _right_axes, + [] = _right_batched_axes + ) + when tuple_size(left_shape) >= 1 and tuple_size(right_shape) == 1 and + left_axis == tuple_size(left_shape) - 1 do + {left, right} = maybe_upcast(left, right) + + from_nx(left) + |> Native.dot(from_nx(right)) + |> unwrap!() + |> to_nx(out) + end + + def dot( + %T{type: _out_type} = out, + %T{shape: left_shape, type: _left_type} = left, + [1] = _left_axes, + [] = _left_batched_axes, + %T{shape: right_shape, type: _right_type} = right, + [0] = _right_axes, + [] = _right_batched_axes + ) + when tuple_size(left_shape) == 2 and tuple_size(right_shape) == 2 do + {left, right} = maybe_upcast(left, right) + + Native.matmul( + from_nx(left), + from_nx(right) + ) + |> unwrap!() + |> to_nx(out) + end + + def dot( + out, + %T{shape: left_shape} = left, + [0], + left_batched_axes, + right, + right_axes, + right_batched_axes + ) + when tuple_size(left_shape) == 2 do + dot( + out, + left |> Nx.transpose(axes: [1, 0]), + [1], + left_batched_axes, + right, + right_axes, + right_batched_axes + ) + end + + def dot( + out, + left, + left_axes, + left_batched_axes, + %T{shape: right_shape} = right, + [1], + right_batched_axes + ) + when tuple_size(right_shape) == 2 do + dot( + out, + left, + left_axes, + left_batched_axes, + right |> Nx.transpose(axes: [1, 0]), + [0], + right_batched_axes + ) + end + + # Shape + + @impl true + def broadcast(out, %T{} = t, shape, axes) do + t + |> maybe_reshape(shape, axes) + |> from_nx() + |> Native.broadcast_to(shape) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def pad(%T{} = out, %T{} = _t, _pad_value, []) do + out + end + + def pad(%T{} = out, %T{} = t, %T{shape: {}} = pad_value, [{low, high, 0 = _inner}]) do + if !Nx.equal(pad_value, 0) do + raise "only pad_value=0 supported for now" + end + + t + |> from_nx() + |> Native.pad_with_zeros(low, high) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def reshape(%T{shape: shape} = out, %T{} = t) do + from_nx(t) + |> Native.reshape(shape) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def squeeze(%T{} = out, %T{} = t, axes) do + # sort the axes desc so we don't have to decrease the axis numbers after each squeeze + for axis <- Enum.sort(axes, :desc), reduce: from_nx(t) do + ref -> + ref + |> Native.squeeze(axis) + |> unwrap!() + end + |> to_nx(out) + end + + @impl true + def transpose(out, %T{} = t, axes) do + from_nx(t) + |> Native.permute(axes) + |> unwrap!() + |> to_nx(out) + end + + # Type + + @impl true + def as_type(%T{type: type} = out, %T{} = t) do + from_nx(t) + |> Native.to_type(to_candle_dtype(type)) + |> unwrap!() + |> to_nx(out) + end + + @impl true + def bitcast(out, tensor) do + out + |> from_binary(to_binary(tensor), []) + end + + # Inspect + + @impl true + def inspect(%T{} = tensor, inspect_opts) do + limit = if inspect_opts.limit == :infinity, do: :infinity, else: inspect_opts.limit + 1 + + tensor + |> to_binary(min(limit, Nx.size(tensor))) + |> then(&Nx.Backend.inspect(tensor, &1, inspect_opts)) + |> maybe_add_signature(tensor) + end + + defp maybe_add_signature(result, %T{data: %__MODULE__{device: device, resource: ref}}) + when is_reference(ref) do + Inspect.Algebra.concat([ + "Candlex.Backend(#{device})", + Inspect.Algebra.line(), + result + ]) + end + + defp narrow(t, [start | starts], [length | lengths], axis, shape) do + dim = elem(shape, axis) + start = min(start, dim - length) + + if start == 0 and length == dim do + # Nothing to narrow at this step + t + else + t + |> Native.narrow(axis, start, length) + |> unwrap!() + end + |> narrow(starts, lengths, axis + 1, shape) + end + + defp narrow(t, [], [], _axis, _shape), do: t + + defp maybe_reshape(%T{shape: {}} = t, target_shape, _axes) do + shape = + 1 + |> List.duplicate(tuple_size(target_shape)) + |> List.to_tuple() + + t + |> Nx.reshape(shape) + end + + defp maybe_reshape(%T{shape: shape} = t, target_shape, axes) do + base_broadcast_shape = 1 |> List.duplicate(tuple_size(target_shape)) |> List.to_tuple() + + new_shape = + shape + |> Tuple.to_list() + |> Enum.zip(axes) + |> Enum.reduce(base_broadcast_shape, fn {dim_size, target_axis}, shape_acc -> + shape_acc + |> Tuple.delete_at(target_axis) + |> Tuple.insert_at(target_axis, dim_size) + end) + + t + |> Nx.reshape(new_shape) + end + + defp maybe_upcast(%T{type: t} = left, %T{type: t} = right) do + {left, right} + end + + defp maybe_upcast(left, right) do + type = Nx.Type.merge(left.type, right.type) + + {Nx.as_type(left, type), Nx.as_type(right, type)} + end + + defp maybe_upcast([first | _] = tensors) do + type = + tensors + |> Enum.reduce( + first.type, + fn tensor, type -> + Nx.Type.merge(type, tensor.type) + end + ) + + tensors + |> Enum.map(fn tensor -> + Nx.as_type(tensor, type) + end) + end + + defp maybe_broadcast_bin_args(out_shape, l, r) do + { + case l.shape do + ^out_shape -> + from_nx(l) + + _ -> + l |> from_nx() |> Native.broadcast_to(out_shape) |> unwrap!() + end, + case r.shape do + ^out_shape -> from_nx(r) + _ -> r |> from_nx() |> Native.broadcast_to(out_shape) |> unwrap!() + end + } + end + + defp maybe_transfer_device( + %T{data: %__MODULE__{device: device}} = l, + %T{data: %__MODULE__{device: device}} = r + ) do + {l, r} + end + + defp maybe_transfer_device( + %T{data: %__MODULE__{device: device}} = l, + %T{data: %__MODULE__{device: _other_device}} = r + ) do + { + l, + r |> Nx.backend_transfer({__MODULE__, device: device}) + } + end + + defp maybe_transfer_device(%T{} = l, %T{data: %__MODULE__{device: device}} = r) do + { + l |> Nx.backend_transfer({__MODULE__, device: device}), + r + } + end + + defp maybe_transfer_device(%T{data: %__MODULE__{device: device}} = l, %T{} = r) do + { + l, + r |> Nx.backend_transfer({__MODULE__, device: device}) + } + end + + ## Conversions + + @impl true + def to_batched(%T{shape: out_shape} = out, %T{shape: shape} = t, opts) do + leftover = opts[:leftover] + first_dimension = 0 + batch_size = elem(out_shape, first_dimension) + axis_total = elem(shape, first_dimension) + remainder = rem(axis_total, batch_size) + num_batches = div(axis_total, batch_size) + native_tensor = from_nx(t) + + cond do + remainder == 0 -> + native_tensor + |> Native.chunk(num_batches) + |> unwrap!() + + remainder > 0 && leftover == :repeat -> + [ + native_tensor, + Native.narrow(native_tensor, first_dimension, 0, batch_size - remainder) + |> unwrap!() + ] + |> Native.concatenate(first_dimension) + |> unwrap!() + |> Native.chunk(num_batches + 1) + |> unwrap!() + + true -> + raise "not implemented" + end + |> Stream.map(&to_nx(&1, out)) + end + + for op <- [ + :cholesky, + :conjugate, + :count_leading_zeros, + :imag, + :population_count, + :real + ] do + @impl true + def unquote(op)(_out, _tensor) do + raise "unsupported Candlex.Backend.#{unquote(op)} function" + end + end + + for op <- [ + :argsort, + :eigh, + :fft, + :ifft, + :lu, + :product, + :qr, + :reverse, + :sort + ] do + @impl true + def unquote(op)(_out, _tensor, _) do + raise "unsupported Candlex.Backend.#{unquote(op)} function" + end + end + + for op <- [ + :indexed_put, + :map, + :triangular_solve, + :window_max, + :window_min, + :window_product, + :window_sum + ] do + @impl true + def unquote(op)(_out, _tensor, _, _) do + raise "unsupported Candlex.Backend.#{unquote(op)} function" + end + end + + @impl true + def reduce(_out, _tensor, _, _, _) do + raise "unsupported Candlex.Backend.reduce function" + end + + for op <- [ + :window_reduce, + :window_scatter_max, + :window_scatter_min + ] do + @impl true + def unquote(op)(_out, _tensor, _, _, _, _) do + raise "unsupported Candlex.Backend.#{unquote(op)} function" + end + end + + defp permute(native_tensor, permutation) do + native_tensor + |> Native.permute(permutation) + |> unwrap!() + end + + @doc false + defp from_nx(%T{data: %__MODULE__{} = data}), do: data + + defp from_nx(%T{} = tensor) do + tensor + |> Nx.backend_transfer(__MODULE__) + |> from_nx() + end + + defp to_nx(%__MODULE__{resource: ref} = backend_tensor, %T{type: nx_type, shape: nx_shape} = t) + when is_reference(ref) do + {:ok, candle_dtype} = Native.dtype(backend_tensor) + {:ok, candle_shape} = Native.t_shape(backend_tensor) + + case {nx_type, from_candle_dtype(candle_dtype)} do + {{:u, 64}, {:s, 64}} -> + :ok + + {type, type} -> + :ok + + {type, other_type} -> + raise "tensor type mismatch, Nx (#{inspect(type)}) and Candle (#{inspect(other_type)})" + end + + if nx_shape != candle_shape do + raise "tensor shape mismatch, Nx (#{inspect(nx_shape)}) and Candle (#{inspect(candle_shape)})" + end + + %{t | data: backend_tensor} + end + + defp to_candle_dtype({:s, 8} = t), do: unsupported_dtype(t) + defp to_candle_dtype({:s, 16} = t), do: unsupported_dtype(t) + defp to_candle_dtype({:s, 32} = t), do: unsupported_dtype(t) + defp to_candle_dtype({:s, 64}), do: "i64" + defp to_candle_dtype({:u, 8}), do: "u8" + defp to_candle_dtype({:u, 16} = t), do: unsupported_dtype(t) + defp to_candle_dtype({:u, 32}), do: "u32" + defp to_candle_dtype({:u, 64}), do: "i64" + defp to_candle_dtype({:f, 16}), do: "f16" + defp to_candle_dtype({:f, 32}), do: "f32" + defp to_candle_dtype({:f, 64}), do: "f64" + defp to_candle_dtype({:bf, 16}), do: "bf16" + defp to_candle_dtype({:c, 64} = t), do: unsupported_dtype(t) + defp to_candle_dtype({:c, 128} = t), do: unsupported_dtype(t) + + defp from_candle_dtype("i64"), do: {:s, 64} + defp from_candle_dtype("u8"), do: {:u, 8} + defp from_candle_dtype("u32"), do: {:u, 32} + defp from_candle_dtype("f16"), do: {:f, 16} + defp from_candle_dtype("bf16"), do: {:bf, 16} + defp from_candle_dtype("f32"), do: {:f, 32} + defp from_candle_dtype("f64"), do: {:f, 64} + + defp device_option(nil) do + default_device() + end + + defp device_option(backend_options) do + backend_options[:device] || default_device() + end + + defp default_device do + if cuda_available?() do + @device_cuda + else + @device_cpu + end + end + + defp same_device?(%T{data: %__MODULE__{device: device}}, device) do + true + end + + defp same_device?(_t, _d) do + false + end + + def cuda_available? do + Native.is_cuda_available() + end + + defp unsupported_dtype(t) do + raise("Unsupported candle dtype for #{inspect(t)}") + end + + defp unsupported_option!(opts, key, acceptable_default) do + if opts[key] != nil and opts[key] != acceptable_default do + raise "#{inspect(key)} option with #{inspect(opts[key])} is not supported" + end + end + + defp unwrap!({:ok, result}), do: result + defp unwrap!({:error, error}), do: raise("Candlex: #{error}") +end diff --git a/candlex/lib/candlex/native.ex b/candlex/lib/candlex/native.ex new file mode 100644 index 0000000000..902bc89b99 --- /dev/null +++ b/candlex/lib/candlex/native.ex @@ -0,0 +1,130 @@ +defmodule Candlex.Native do + @moduledoc false + + mix_config = Mix.Project.config() + version = mix_config[:version] + source_url = mix_config[:package][:links]["GitHub"] + mode = if Mix.env() in [:dev, :test], do: :debug, else: :release + + use RustlerPrecompiled, + otp_app: :candlex, + features: Application.compile_env(:candlex, :crate_features, []), + base_url: "#{source_url}/releases/download/v#{version}", + force_build: System.get_env("CANDLEX_NIF_BUILD") in ["1", "true"], + mode: mode, + version: version, + nif_versions: ["2.16"], + targets: [ + "aarch64-apple-darwin", + "aarch64-unknown-linux-gnu", + "x86_64-apple-darwin", + "x86_64-unknown-linux-gnu" + ] + + # Rustler will override all the below stub functions with real NIFs + def from_binary(_binary, _dtype, _shape, _device), do: error() + def to_binary(_tensor), do: error() + def all(_tensor), do: error() + def all_within_dims(_tensor, _dims, _keep_dims), do: error() + def any(_tensor), do: error() + def any_within_dims(_tensor, _dims, _keep_dims), do: error() + def where_cond(_tensor, _on_true, _on_false), do: error() + def narrow(_tensor, _dim, _start, _length), do: error() + def gather(_tensor, _indexes, _dim), do: error() + def index_select(_tensor, _indexes, _dim), do: error() + def index_add(_tensor, _indexes, _source, _dim), do: error() + def chunk(_tensor, _num_chunks), do: error() + def squeeze(_tensor, _dim), do: error() + def arange(_start, _end, _dtype, _shape, _device), do: error() + def broadcast_to(_tensor, _shape), do: error() + def reshape(_tensor, _shape), do: error() + def to_type(_tensor, _dtype), do: error() + def dtype(_tensor), do: error() + def t_shape(_tensor), do: error() + def concatenate(_tensors, _axis), do: error() + def conv1d(_tensor, _kernel), do: error() + def conv2d(_tensor, _kernel), do: error() + def slice_scatter(_tensor, _src, _dim, _start), do: error() + def pad_with_zeros(_tensor, _left, _right), do: error() + def clamp(_tensor, _min, _max), do: error() + + for op <- [ + :abs, + :acos, + :acosh, + :asin, + :asinh, + :atan, + :atanh, + :bitwise_not, + :cbrt, + :ceil, + :cos, + :cosh, + :erf, + :erfc, + :erf_inv, + :exp, + :expm1, + :floor, + :is_infinity, + :is_nan, + :log, + :log1p, + :negate, + :round, + :rsqrt, + :sigmoid, + :sign, + :sin, + :sinh, + :sqrt, + :tan, + :tanh + ] do + def unquote(op)(_tensor), do: error() + end + + for op <- [ + :add, + :atan2, + :bitwise_and, + :bitwise_or, + :bitwise_xor, + :divide, + :dot, + :equal, + :greater, + :greater_equal, + :left_shift, + :less, + :less_equal, + :logical_and, + :logical_or, + :logical_xor, + :matmul, + :max, + :min, + :multiply, + :not_equal, + :pow, + :quotient, + :remainder, + :right_shift, + :subtract + ] do + def unquote(op)(_left, _right), do: error() + end + + def sum(_tensor, _dims, _keep_dims), do: error() + def permute(_tensor, _dims), do: error() + + for op <- [:argmax, :argmin, :reduce_max, :reduce_min] do + def unquote(op)(_tensor, _dim, _keep_dim), do: error() + end + + def is_cuda_available(), do: error() + def to_device(_tensor, _device), do: error() + + defp error(), do: :erlang.nif_error(:nif_not_loaded) +end diff --git a/candlex/mix.exs b/candlex/mix.exs new file mode 100644 index 0000000000..89774a2cd6 --- /dev/null +++ b/candlex/mix.exs @@ -0,0 +1,74 @@ +defmodule Candlex.MixProject do + use Mix.Project + + @description "An Nx backend for candle machine learning minimalist framework" + @source_url "https://github.com/mimiquate/candlex" + @version "0.1.2" + + def project do + [ + app: :candlex, + description: @description, + version: @version, + elixir: "~> 1.14", + elixirc_paths: elixirc_paths(Mix.env()), + start_permanent: Mix.env() == :prod, + deps: deps(), + docs: docs(), + package: package() + ] + end + + # Run "mix help compile.app" to learn about applications. + def application do + [ + extra_applications: [:logger] + ] + end + + defp elixirc_paths(:test), do: ["lib", "test/support"] + defp elixirc_paths(_), do: ["lib"] + + # Run "mix help deps" to learn about dependencies. + defp deps do + [ + # {:nx, "~> 0.6.2"}, + {:nx, git: "https://github.com/elixir-nx/nx", sparse: "nx"}, + {:rustler_precompiled, "~> 0.7.0"}, + + # Optional + {:rustler, "~> 0.29", optional: true}, + + # Dev + {:ex_doc, "~> 0.30.9", only: :dev, runtime: false} + ] + end + + defp docs do + [ + main: "Candlex", + source_url: @source_url, + source_ref: "v#{@version}" + ] + end + + defp package do + [ + files: [ + "lib", + "native", + "priv", + ".formatter.exs", + "mix.exs", + "CHANGELOG.md", + "README.md", + "LICENSE", + "checksum-*.exs" + ], + licenses: ["Apache-2.0"], + links: %{ + "GitHub" => @source_url + } + ] + end +end diff --git a/candlex/mix.lock b/candlex/mix.lock new file mode 100644 index 0000000000..37478055ac --- /dev/null +++ b/candlex/mix.lock @@ -0,0 +1,16 @@ +%{ + "castore": {:hex, :castore, "1.0.4", "ff4d0fb2e6411c0479b1d965a814ea6d00e51eb2f58697446e9c41a97d940b28", [:mix], [], "hexpm", "9418c1b8144e11656f0be99943db4caf04612e3eaecefb5dae9a2a87565584f8"}, + "complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"}, + "earmark_parser": {:hex, :earmark_parser, "1.4.37", "2ad73550e27c8946648b06905a57e4d454e4d7229c2dafa72a0348c99d8be5f7", [:mix], [], "hexpm", "6b19783f2802f039806f375610faa22da130b8edc21209d0bff47918bb48360e"}, + "ex_doc": {:hex, :ex_doc, "0.30.9", "d691453495c47434c0f2052b08dd91cc32bc4e1a218f86884563448ee2502dd2", [:mix], [{:earmark_parser, "~> 1.4.31", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "d7aaaf21e95dc5cddabf89063327e96867d00013963eadf2c6ad135506a8bc10"}, + "jason": {:hex, :jason, "1.4.1", "af1504e35f629ddcdd6addb3513c3853991f694921b1b9368b0bd32beb9f1b63", [:mix], [{:decimal, "~> 1.0 or ~> 2.0", [hex: :decimal, repo: "hexpm", optional: true]}], "hexpm", "fbb01ecdfd565b56261302f7e1fcc27c4fb8f32d56eab74db621fc154604a7a1"}, + "makeup": {:hex, :makeup, "1.1.0", "6b67c8bc2882a6b6a445859952a602afc1a41c2e08379ca057c0f525366fc3ca", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "0a45ed501f4a8897f580eabf99a2e5234ea3e75a4373c8a52824f6e873be57a6"}, + "makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"}, + "makeup_erlang": {:hex, :makeup_erlang, "0.1.2", "ad87296a092a46e03b7e9b0be7631ddcf64c790fa68a9ef5323b6cbb36affc72", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f3f5a1ca93ce6e092d92b6d9c049bcda58a3b617a8d888f8e7231c85630e8108"}, + "nimble_parsec": {:hex, :nimble_parsec, "1.3.1", "2c54013ecf170e249e9291ed0a62e5832f70a476c61da16f6aac6dca0189f2af", [:mix], [], "hexpm", "2682e3c0b2eb58d90c6375fc0cc30bc7be06f365bf72608804fb9cffa5e1b167"}, + "nx": {:git, "https://github.com/elixir-nx/nx", "7706e8601e40916c02f8773df7802b3bfab43054", [sparse: "nx"]}, + "rustler": {:hex, :rustler, "0.30.0", "cefc49922132b072853fa9b0ca4dc2ffcb452f68fb73b779042b02d545e097fb", [:mix], [{:jason, "~> 1.0", [hex: :jason, repo: "hexpm", optional: false]}, {:toml, "~> 0.6", [hex: :toml, repo: "hexpm", optional: false]}], "hexpm", "9ef1abb6a7dda35c47cfc649e6a5a61663af6cf842a55814a554a84607dee389"}, + "rustler_precompiled": {:hex, :rustler_precompiled, "0.7.0", "5d0834fc06dbc76dd1034482f17b1797df0dba9b491cef8bb045fcaca94bcade", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: false]}, {:rustler, "~> 0.23", [hex: :rustler, repo: "hexpm", optional: true]}], "hexpm", "fdf43a6835f4e4de5bfbc4c019bfb8c46d124bd4635fefa3e20d9a2bbbec1512"}, + "telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"}, + "toml": {:hex, :toml, "0.7.0", "fbcd773caa937d0c7a02c301a1feea25612720ac3fa1ccb8bfd9d30d822911de", [:mix], [], "hexpm", "0690246a2478c1defd100b0c9b89b4ea280a22be9a7b313a8a058a2408a2fa70"}, +} diff --git a/candlex/native/candlex/.cargo/config.toml b/candlex/native/candlex/.cargo/config.toml new file mode 100644 index 0000000000..1602afb085 --- /dev/null +++ b/candlex/native/candlex/.cargo/config.toml @@ -0,0 +1,9 @@ +[target.'cfg(target_os = "macos")'] +rustflags = [ + "-C", "link-arg=-undefined", + "-C", "link-arg=dynamic_lookup", +] + +# Provides a small build size, but takes more time to build. +[profile.release] +lto = true diff --git a/candlex/native/candlex/.gitignore b/candlex/native/candlex/.gitignore new file mode 100644 index 0000000000..ea8c4bf7f3 --- /dev/null +++ b/candlex/native/candlex/.gitignore @@ -0,0 +1 @@ +/target diff --git a/candlex/native/candlex/Cargo.lock b/candlex/native/candlex/Cargo.lock new file mode 100644 index 0000000000..bc1f3aa8a7 --- /dev/null +++ b/candlex/native/candlex/Cargo.lock @@ -0,0 +1,1009 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "addr2line" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a30b2e23b9e17a9f90641c7ab1549cd9b44f296d3ccbf309d2863cfe398a0cb" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "aho-corasick" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea5d730647d4fadd988536d06fecce94b7b4f2a7efdae548f1cf4b63205518ab" +dependencies = [ + "memchr", +] + +[[package]] +name = "anyhow" +version = "1.0.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" +dependencies = [ + "backtrace", +] + +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "backtrace" +version = "0.3.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2089b7e3f35b9dd2d0ed921ead4f6d318c27680d4a5bd167b3ee120edb105837" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bytemuck" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", +] + +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + +[[package]] +name = "candle-core" +version = "0.3.0" +source = "git+https://github.com/huggingface/candle#4c967b9184834cd1e166dfdd6d88450d16bad8f2" +dependencies = [ + "byteorder", + "candle-kernels", + "cudarc", + "gemm", + "half", + "memmap2", + "num-traits", + "num_cpus", + "rand", + "rand_distr", + "rayon", + "safetensors", + "thiserror", + "yoke", + "zip", +] + +[[package]] +name = "candle-kernels" +version = "0.3.0" +source = "git+https://github.com/huggingface/candle#4c967b9184834cd1e166dfdd6d88450d16bad8f2" +dependencies = [ + "anyhow", + "glob", + "rayon", +] + +[[package]] +name = "candlex" +version = "0.1.0" +dependencies = [ + "anyhow", + "candle-core", + "half", + "num-traits", + "rustler", + "statrs", + "thiserror", +] + +[[package]] +name = "cc" +version = "1.0.83" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "libc", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "crc32fast" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "cudarc" +version = "0.9.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc4cab390f4a32340211f015292a4551742a63e528e9ade9e0bde0d1a989d2a1" +dependencies = [ + "half", +] + +[[package]] +name = "dyn-stack" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b" +dependencies = [ + "bytemuck", + "reborrow", +] + +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + +[[package]] +name = "gemm" +version = "0.16.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c90e866771c2adff6c4603108bb531ea36979b62a23b5e193b9311199182ccc9" +dependencies = [ + "dyn-stack", + "gemm-c32", + "gemm-c64", + "gemm-common", + "gemm-f16", + "gemm-f32", + "gemm-f64", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c32" +version = "0.16.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d5ffc8e9a442172ec223d4e536f93b59f1d2b34164b37e462e6886d76699e71" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c64" +version = "0.16.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45cd8172c8ee8beed35cdb7254bec535ddc83cde78d2eeefeae43cfd57fef3d8" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-common" +version = "0.16.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c022dd4893afa781d1ba9912796af63a9a91ac8d8fdadd449483bbabaf7c47c6" +dependencies = [ + "bytemuck", + "dyn-stack", + "half", + "num-complex", + "num-traits", + "once_cell", + "paste", + "pulp", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f16" +version = "0.16.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1efb4c26de6ddd1125d966dffa48eb609448566a2bd2dd589ea2798f17f689b5" +dependencies = [ + "dyn-stack", + "gemm-common", + "gemm-f32", + "half", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f32" +version = "0.16.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a78fca5a0ef41202015e099e45e0bdda83e7f49318de3322d58c215a3fec334a" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-f64" +version = "0.16.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f18516a29ab14fa4b1714c93ffa36a89a65418f0a2fba702206d4b421061b48" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "getrandom" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "gimli" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fb8d784f27acf97159b40fc4db5ecd8aa23b9ad5ef69cdd136d3bc80665f0c0" + +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + +[[package]] +name = "half" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +dependencies = [ + "bytemuck", + "cfg-if", + "crunchy", + "num-traits", + "rand", + "rand_distr", +] + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "hermit-abi" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" + +[[package]] +name = "itoa" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "libc" +version = "0.2.148" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cdc71e17332e86d2e1d38c1f99edcb6288ee11b815fb1a4b049eaa2114d369b" + +[[package]] +name = "libm" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7012b1bbb0719e1097c47611d3898568c546d597c2e74d66f6087edd5233ff4" + +[[package]] +name = "matrixmultiply" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7574c1cf36da4798ab73da5b215bbf444f50718207754cb522201d78d1cd0ff2" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memchr" +version = "2.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" + +[[package]] +name = "memmap2" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" +dependencies = [ + "libc", + "stable_deref_trait", +] + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "miniz_oxide" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +dependencies = [ + "adler", +] + +[[package]] +name = "nalgebra" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d506eb7e08d6329505faa8a3a00a5dcc6de9f76e0c77e4b75763ae3c770831ff" +dependencies = [ + "approx", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "rand", + "rand_distr", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01fcc0b8149b4632adc89ac3b7b31a12fb6099a0317a4eb2ebff574ef7de7218" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "num-complex" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +dependencies = [ + "bytemuck", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "object" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cf5f9dd3933bd50a9e1f149ec995f39ae2c496d31fd772c1fd45ebc27e902b0" +dependencies = [ + "memchr", +] + +[[package]] +name = "once_cell" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" + +[[package]] +name = "paste" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro2" +version = "1.0.67" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d433d9f1a3e8c1263d9456598b16fec66f4acc9a74dacffd35c7bb09b3a1328" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pulp" +version = "0.16.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcb0a2934e77dd9aabb33fedcd3bf6ea1d6d3cacd4c5bc07b9a916f7b101c6b7" +dependencies = [ + "bytemuck", + "libm", + "num-complex", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "raw-cpuid" +version = "10.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "reborrow" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" + +[[package]] +name = "regex" +version = "1.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebee201405406dbf528b8b672104ae6d6d63e6d118cb10e4d51abbc7b58044ff" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59b23e92ee4318893fa3fe3e6fb365258efbfe6ac6ab30f090cdcbb7aa37efa9" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbb5fb1acd8a1a18b3dd5be62d25485eb770e05afb408a9627d14d451bae12da" + +[[package]] +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + +[[package]] +name = "rustler" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4b4fea69e23de68c42c06769d6624d2d018da550c17244dd4b691f90ced4a7e" +dependencies = [ + "lazy_static", + "rustler_codegen", + "rustler_sys", +] + +[[package]] +name = "rustler_codegen" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "406061bd07aaf052c344257afed4988c5ec8efe4d2352b4c2cf27ea7c8575b12" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.37", +] + +[[package]] +name = "rustler_sys" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a7c0740e5322b64e2b952d8f0edce5f90fcf6f6fe74cca3f6e78eb3de5ea858" +dependencies = [ + "regex", + "unreachable", +] + +[[package]] +name = "ryu" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" + +[[package]] +name = "safe_arch" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f398075ce1e6a179b46f51bd88d0598b92b00d3551f1a2d4ac49e771b56ac354" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "safetensors" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" + +[[package]] +name = "serde" +version = "1.0.188" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.188" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", +] + +[[package]] +name = "serde_json" +version = "1.0.107" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "simba" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0b7840f121a46d63066ee7a99fc81dcabbc6105e437cae43528cea199b5a05f" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", + "wide", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + +[[package]] +name = "statrs" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d08e5e1748192713cc281da8b16924fb46be7b0c2431854eadc785823e5696e" +dependencies = [ + "approx", + "lazy_static", + "nalgebra", + "num-traits", + "rand", +] + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7303ef2c05cd654186cb250d29049a24840ca25d2747c25c0381c8d9e2f582e8" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "synstructure" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "285ba80e733fac80aa4270fbcdf83772a79b80aa35c97075320abfee4a915b06" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", + "unicode-xid", +] + +[[package]] +name = "thiserror" +version = "1.0.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.50" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "unicode-xid" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" + +[[package]] +name = "unreachable" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "382810877fe448991dfc7f0dd6e3ae5d58088fd0ea5e35189655f84e6814fa56" +dependencies = [ + "void", +] + +[[package]] +name = "void" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wide" +version = "0.7.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebecebefc38ff1860b4bc47550bbfa63af5746061cf0d29fcd7fa63171602598" +dependencies = [ + "bytemuck", + "safe_arch", +] + +[[package]] +name = "yoke" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61e38c508604d6bbbd292dadb3c02559aa7fff6b654a078a36217cad871636e4" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5e19fb6ed40002bab5403ffa37e53e0e56f914a4450c8765f533018db1db35f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", + "synstructure", +] + +[[package]] +name = "zerofrom" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "655b0814c5c0b19ade497851070c640773304939a6c0fd5f5fb43da0696d05b7" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6a647510471d372f2e6c2e6b7219e44d8c574d24fdc11c610a61455782f18c3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.37", + "synstructure", +] + +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "byteorder", + "crc32fast", + "crossbeam-utils", +] diff --git a/candlex/native/candlex/Cargo.toml b/candlex/native/candlex/Cargo.toml new file mode 100644 index 0000000000..306760b668 --- /dev/null +++ b/candlex/native/candlex/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "candlex" +version = "0.1.0" +authors = [] +edition = "2021" + +[lib] +name = "candlex" +path = "src/lib.rs" +crate-type = ["cdylib"] + +[dependencies] +candle-core = { git = "https://github.com/huggingface/candle" } +half = "2.3.1" +num-traits = "0.2.17" +rustler = { version = "0.30.0", default-features = false, features = ["derive", "nif_version_2_16"] } +statrs = "0.16.0" +thiserror = "1.0.50" + +[build-dependencies] +anyhow = "1.0.75" + +[features] +cuda = ["candle-core/cuda"] diff --git a/candlex/native/candlex/build.rs b/candlex/native/candlex/build.rs new file mode 100644 index 0000000000..86a9c6183c --- /dev/null +++ b/candlex/native/candlex/build.rs @@ -0,0 +1,250 @@ +#![allow(unused)] +use anyhow::{Context, Result}; +use std::io::Write; +use std::path::PathBuf; + +struct KernelDirectories { + kernel_dir: &'static str, + rust_target: &'static str, + include_dirs: &'static [&'static str], +} + +const DIRS: [KernelDirectories; 1] = [KernelDirectories { + kernel_dir: "src/kernels/", + rust_target: "src/kernels.rs", + include_dirs: &[], +}]; + +impl KernelDirectories { + fn maybe_build_ptx( + &self, + cu_file: &std::path::Path, + ptx_file: &std::path::Path, + compute_cap: usize, + ) -> Result<()> { + let should_compile = if ptx_file.exists() { + cu_file + .metadata()? + .modified()? + .duration_since(ptx_file.metadata()?.modified()?) + .is_ok() + } else { + true + }; + + if should_compile { + #[cfg(feature = "cuda")] + { + let mut command = std::process::Command::new("nvcc"); + let out_dir = ptx_file.parent().context("no parent for ptx file")?; + let include_dirs: Vec = + self.include_dirs.iter().map(|c| format!("-I{c}")).collect(); + + command + .arg(format!("--gpu-architecture=sm_{compute_cap}")) + .arg("--ptx") + .args(["--default-stream", "per-thread"]) + .args(["--output-directory", out_dir.to_str().unwrap()]) + .arg(format!("-I/{}", self.kernel_dir)) + .args(include_dirs) + .arg(cu_file); + + let output = command + .spawn() + .context("failed spawning nvcc")? + .wait_with_output()?; + + if !output.status.success() { + anyhow::bail!( + "nvcc error while compiling {cu_file:?}:\n\n# stdout\n{:#}\n\n# stderr\n{:#}", + String::from_utf8_lossy(&output.stdout), + String::from_utf8_lossy(&output.stderr) + ) + } + } + + #[cfg(not(feature = "cuda"))] + std::fs::OpenOptions::new() + .create(true) + .write(true) + .open(ptx_file)?; + } + + Ok(()) + } + + fn process(&self, out_dir: &std::path::Path, compute_cap: usize) -> Result<()> { + println!("cargo:rerun-if-changed={}", self.kernel_dir); + let kernel_dir = PathBuf::from(self.kernel_dir); + let out_dir = out_dir.join(self.kernel_dir); + if !out_dir.exists() { + std::fs::create_dir_all(&out_dir)?; + } + let mut cu_files = vec![]; + let mut cuh_files = vec![]; + for file in std::fs::read_dir(kernel_dir)?.flatten() { + let file = file.path(); + match file.extension().and_then(|v| v.to_str()) { + Some("cu") => cu_files.push(file), + Some("cuh") => cuh_files.push(file), + _ => {} + } + } + + let mut ptx_paths = vec![]; + for cu_file in cu_files.iter() { + let file_stem = cu_file + .file_stem() + .with_context(|| format!("no stem {cu_file:?}"))?; + let file_stem = file_stem.to_string_lossy().into_owned(); + let ptx_file = out_dir.join(&format!("{file_stem}.ptx")); + self.maybe_build_ptx(cu_file, &ptx_file, compute_cap)?; + ptx_paths.push(ptx_file); + } + + let regenerate_rs_file = true; + if regenerate_rs_file { + let mut file = std::fs::File::create(self.rust_target)?; + for ptx_path in ptx_paths { + let name = ptx_path + .file_stem() + .context("empty stem")? + .to_string_lossy(); + file.write_all(b"#[rustfmt::skip]\n")?; + let const_definition = format!( + r#"pub const {}: &str = include_str!(concat!(env!("OUT_DIR"), "/{}/{name}.ptx"));"#, + name.to_uppercase().replace('.', "_"), + self.kernel_dir, + ); + file.write_all(const_definition.as_bytes())?; + file.write_all(b"\n")?; + } + } + + Ok(()) + } +} + +fn main() -> Result<()> { + println!("cargo:rerun-if-changed=build.rs"); + + let out_dir = std::env::var("OUT_DIR").context("OUT_DIR not set")?; + let out_dir = PathBuf::from(out_dir); + #[cfg(feature = "cuda")] + set_cuda_include_dir()?; + #[cfg(feature = "cuda")] + let compute_cap = compute_cap()?; + #[cfg(not(feature = "cuda"))] + let compute_cap = 0; + + for dir in DIRS { + dir.process(&out_dir, compute_cap)? + } + + Ok(()) +} + +fn set_cuda_include_dir() -> Result<()> { + // NOTE: copied from cudarc build.rs. + let env_vars = [ + "CUDA_PATH", + "CUDA_ROOT", + "CUDA_TOOLKIT_ROOT_DIR", + "CUDNN_LIB", + ]; + let env_vars = env_vars + .into_iter() + .map(std::env::var) + .filter_map(Result::ok) + .map(Into::::into); + + let roots = [ + "/usr", + "/usr/local/cuda", + "/opt/cuda", + "/usr/lib/cuda", + "C:/Program Files/NVIDIA GPU Computing Toolkit", + "C:/CUDA", + ]; + let roots = roots.into_iter().map(Into::::into); + let root = env_vars + .chain(roots) + .find(|path| path.join("include").join("cuda.h").is_file()) + .context("cannot find include/cuda.h")?; + println!( + "cargo:rustc-env=CUDA_INCLUDE_DIR={}", + root.join("include").display() + ); + Ok(()) +} + +#[allow(unused)] +fn compute_cap() -> Result { + // Grab compute code from nvidia-smi + let mut compute_cap = { + let out = std::process::Command::new("nvidia-smi") + .arg("--query-gpu=compute_cap") + .arg("--format=csv") + .output() + .context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?; + let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?; + let mut lines = out.lines(); + assert_eq!( + lines.next().context("missing line in stdout")?, + "compute_cap" + ); + let cap = lines + .next() + .context("missing line in stdout")? + .replace('.', ""); + cap.parse::() + .with_context(|| format!("cannot parse as int {cap}"))? + }; + + // Grab available GPU codes from nvcc and select the highest one + let max_nvcc_code = { + let out = std::process::Command::new("nvcc") + .arg("--list-gpu-code") + .output() + .expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH."); + let out = std::str::from_utf8(&out.stdout).unwrap(); + + let out = out.lines().collect::>(); + let mut codes = Vec::with_capacity(out.len()); + for code in out { + let code = code.split('_').collect::>(); + if !code.is_empty() && code.contains(&"sm") { + if let Ok(num) = code[1].parse::() { + codes.push(num); + } + } + } + codes.sort(); + if !codes.contains(&compute_cap) { + anyhow::bail!( + "nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}." + ); + } + *codes.last().unwrap() + }; + + // If nvidia-smi compute_cap is higher than the highest gpu code from nvcc, + // then choose the highest gpu code in nvcc + if compute_cap > max_nvcc_code { + println!( + "cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}." + ); + compute_cap = max_nvcc_code; + } + + println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP"); + + if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") { + compute_cap = compute_cap_str + .parse::() + .with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?; + println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP"); + } + println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}"); + Ok(compute_cap) +} diff --git a/candlex/native/candlex/src/devices.rs b/candlex/native/candlex/src/devices.rs new file mode 100644 index 0000000000..43a615a328 --- /dev/null +++ b/candlex/native/candlex/src/devices.rs @@ -0,0 +1,4 @@ +#[rustler::nif(schedule = "DirtyCpu")] +pub fn is_cuda_available() -> bool { + candle_core::utils::cuda_is_available() +} diff --git a/candlex/native/candlex/src/error.rs b/candlex/native/candlex/src/error.rs new file mode 100644 index 0000000000..6cdccceb38 --- /dev/null +++ b/candlex/native/candlex/src/error.rs @@ -0,0 +1,21 @@ +use rustler::{Encoder, Env, Term}; +use thiserror::Error; + +// Defines the atoms for each value of CandlexError. +rustler::atoms! { + candle, +} + +#[derive(Error, Debug)] +pub enum CandlexError { + #[error("Candle Error: {0}")] + Candle(#[from] candle_core::Error), + #[error("Generic Error: {0}")] + Other(String), +} + +impl Encoder for CandlexError { + fn encode<'b>(&self, env: Env<'b>) -> Term<'b> { + format!("{self}").encode(env) + } +} diff --git a/candlex/native/candlex/src/kernels.rs b/candlex/native/candlex/src/kernels.rs new file mode 100644 index 0000000000..13317b35c7 --- /dev/null +++ b/candlex/native/candlex/src/kernels.rs @@ -0,0 +1,4 @@ +#[rustfmt::skip] +pub const CUSTOM_BINARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_binary.ptx")); +#[rustfmt::skip] +pub const CUSTOM_UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/src/kernels//custom_unary.ptx")); diff --git a/candlex/native/candlex/src/kernels/custom_binary.cu b/candlex/native/candlex/src/kernels/custom_binary.cu new file mode 100644 index 0000000000..d9d9090d73 --- /dev/null +++ b/candlex/native/candlex/src/kernels/custom_binary.cu @@ -0,0 +1,111 @@ +#include +#include +#include "strides.cuh" + +#define DEVICE_FN_FLOAT_WRAPPER(FN_NAME) \ + __device__ __forceinline__ float FN_NAME##g(float a, float b) { return FN_NAME##f(a, b); } + +#define DEVICE_FN_DOUBLE_WRAPPER(FN_NAME) \ + __device__ __forceinline__ double FN_NAME##g(double a, double b) { return FN_NAME(a, b); } + +DEVICE_FN_FLOAT_WRAPPER(atan2) +DEVICE_FN_DOUBLE_WRAPPER(atan2) +DEVICE_FN_FLOAT_WRAPPER(fmod) +DEVICE_FN_DOUBLE_WRAPPER(fmod) +DEVICE_FN_FLOAT_WRAPPER(pow) +DEVICE_FN_DOUBLE_WRAPPER(pow) + +#define CUSTOM_BINARY_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, FUNC) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *dims_and_strides, \ + const TYPENAME *lhs, \ + const TYPENAME *rhs, \ + OUT_TYPENAME *out \ +) { \ + const size_t *dims = dims_and_strides; \ + const size_t *lhs_strides = dims_and_strides + 1 * num_dims; \ + const size_t *rhs_strides = dims_and_strides + 2 * num_dims; \ + bool lhs_cont = is_contiguous(num_dims, dims, lhs_strides); \ + bool rhs_cont = is_contiguous(num_dims, dims, rhs_strides); \ + if (lhs_cont && rhs_cont) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + TYPENAME x = lhs[i]; \ + TYPENAME y = rhs[i]; \ + out[i] = FUNC; \ + } \ + } else if (lhs_cont) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int rhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + rhs_i += i_dim * rhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs[i]; \ + TYPENAME y = rhs[rhs_i]; \ + out[i] = FUNC; \ + } \ + } else if (rhs_cont) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int lhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + lhs_i += i_dim * lhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs[lhs_i]; \ + TYPENAME y = rhs[i]; \ + out[i] = FUNC; \ + } \ + } else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned int tmp_i = i; \ + unsigned int lhs_i = 0; \ + unsigned int rhs_i = 0; \ + for (int d = num_dims - 1; d >= 0; d--) { \ + unsigned int i_dim = tmp_i % dims[d]; \ + lhs_i += i_dim * lhs_strides[d]; \ + rhs_i += i_dim * rhs_strides[d]; \ + tmp_i /= dims[d]; \ + } \ + TYPENAME x = lhs[lhs_i]; \ + TYPENAME y = rhs[rhs_i]; \ + out[i] = FUNC; \ + } \ + } \ +} \ + +#define CUSTOM_BINARY_OP(TYPENAME, FN_NAME, FUNC) \ + CUSTOM_BINARY_OP_OUT(TYPENAME, TYPENAME, FN_NAME, FUNC) + +CUSTOM_BINARY_OP(float, atan2_f32, atan2g(x, y)) +CUSTOM_BINARY_OP(double, atan2_f64, atan2g(x, y)) +CUSTOM_BINARY_OP(uint32_t, bit_and_u32, x & y) +CUSTOM_BINARY_OP(int64_t, bit_and_i64, x & y) +CUSTOM_BINARY_OP(uint32_t, bit_or_u32, x | y) +CUSTOM_BINARY_OP(int64_t, bit_or_i64, x | y) +CUSTOM_BINARY_OP(uint32_t, bit_xor_u32, x ^ y) +CUSTOM_BINARY_OP(int64_t, bit_xor_i64, x ^ y) +CUSTOM_BINARY_OP(float, pow_f32, powg(x, y)) +CUSTOM_BINARY_OP(double, pow_f64, powg(x, y)) +CUSTOM_BINARY_OP(uint8_t, remainder_u8, x % y) +CUSTOM_BINARY_OP(int64_t, remainder_i64, x % y) +CUSTOM_BINARY_OP(float, remainder_f32, fmodg(x, y)) +CUSTOM_BINARY_OP(double, remainder_f64, fmodg(x, y)) +CUSTOM_BINARY_OP(uint32_t, shl_u32, x << y) +CUSTOM_BINARY_OP(int64_t, shl_i64, x << y) +CUSTOM_BINARY_OP(uint32_t, shr_u32, x >> y) +CUSTOM_BINARY_OP(int64_t, shr_i64, x >> y) + +CUSTOM_BINARY_OP_OUT(uint8_t, uint8_t, logical_and_u8, x && y) +CUSTOM_BINARY_OP_OUT(int64_t, uint8_t, logical_and_i64, x && y) +CUSTOM_BINARY_OP_OUT(float, uint8_t, logical_and_f32, x && y) +CUSTOM_BINARY_OP_OUT(uint8_t, uint8_t, logical_or_u8, x || y) +CUSTOM_BINARY_OP_OUT(int64_t, uint8_t, logical_or_i64, x || y) +CUSTOM_BINARY_OP_OUT(float, uint8_t, logical_or_f32, x || y) +CUSTOM_BINARY_OP_OUT(int64_t, uint8_t, logical_xor_i64, !x != !y) +CUSTOM_BINARY_OP_OUT(float, uint8_t, logical_xor_f32, !x != !y) diff --git a/candlex/native/candlex/src/kernels/custom_unary.cu b/candlex/native/candlex/src/kernels/custom_unary.cu new file mode 100644 index 0000000000..40bc3aea52 --- /dev/null +++ b/candlex/native/candlex/src/kernels/custom_unary.cu @@ -0,0 +1,112 @@ +#define _USE_MATH_DEFINES +#include +#include +#include +#include "strides.cuh" + +#define DEVICE_FN_FLOAT_WRAPPER(FN_NAME) \ + __device__ __forceinline__ float FN_NAME##g(float a) { return FN_NAME##f(a); } + +#define DEVICE_FN_DOUBLE_WRAPPER(FN_NAME) \ + __device__ __forceinline__ double FN_NAME##g(double a) { return FN_NAME(a); } + +DEVICE_FN_FLOAT_WRAPPER(acos) +DEVICE_FN_DOUBLE_WRAPPER(acos) +DEVICE_FN_FLOAT_WRAPPER(acosh) +DEVICE_FN_DOUBLE_WRAPPER(acosh) +DEVICE_FN_FLOAT_WRAPPER(asin) +DEVICE_FN_DOUBLE_WRAPPER(asin) +DEVICE_FN_FLOAT_WRAPPER(asinh) +DEVICE_FN_DOUBLE_WRAPPER(asinh) +DEVICE_FN_FLOAT_WRAPPER(atan) +DEVICE_FN_DOUBLE_WRAPPER(atan) +DEVICE_FN_FLOAT_WRAPPER(atanh) +DEVICE_FN_DOUBLE_WRAPPER(atanh) +DEVICE_FN_FLOAT_WRAPPER(cbrt) +DEVICE_FN_DOUBLE_WRAPPER(cbrt) +DEVICE_FN_FLOAT_WRAPPER(cosh) +DEVICE_FN_DOUBLE_WRAPPER(cosh) +DEVICE_FN_FLOAT_WRAPPER(erfc) +DEVICE_FN_DOUBLE_WRAPPER(erfc) +DEVICE_FN_FLOAT_WRAPPER(erfinv) +DEVICE_FN_DOUBLE_WRAPPER(erfinv) +DEVICE_FN_FLOAT_WRAPPER(exp) +DEVICE_FN_DOUBLE_WRAPPER(exp) +DEVICE_FN_FLOAT_WRAPPER(expm1) +DEVICE_FN_DOUBLE_WRAPPER(expm1) +DEVICE_FN_FLOAT_WRAPPER(log1p) +DEVICE_FN_DOUBLE_WRAPPER(log1p) +DEVICE_FN_FLOAT_WRAPPER(sinh) +DEVICE_FN_DOUBLE_WRAPPER(sinh) +DEVICE_FN_FLOAT_WRAPPER(tan) +DEVICE_FN_DOUBLE_WRAPPER(tan) + +#define CUSTOM_UNARY_OP_OUT(TYPENAME, OUT_TYPENAME, FN_NAME, FUNC) \ +extern "C" __global__ void FN_NAME( \ + const size_t numel, \ + const size_t num_dims, \ + const size_t *info, \ + const TYPENAME *inp, \ + OUT_TYPENAME *out \ +) { \ + const size_t *dims = info; \ + const size_t *strides = info + num_dims; \ + if (is_contiguous(num_dims, dims, strides)) { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + TYPENAME x = inp ? inp[i] : out[i]; \ + out[i] = FUNC; \ + } \ + } \ + else { \ + for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \ + unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \ + TYPENAME x = inp ? inp[strided_i] : out[i]; \ + out[i] = FUNC; \ + } \ + } \ +} \ + +#define CUSTOM_UNARY_OP(TYPENAME, FN_NAME, FUNC) \ + CUSTOM_UNARY_OP_OUT(TYPENAME, TYPENAME, FN_NAME, FUNC) + +CUSTOM_UNARY_OP(float, acos_f32, acosg(x)) +CUSTOM_UNARY_OP(double, acos_f64, acosg(x)) +CUSTOM_UNARY_OP(float, acosh_f32, acoshg(x)) +CUSTOM_UNARY_OP(double, acosh_f64, acoshg(x)) +CUSTOM_UNARY_OP(float, asin_f32, asing(x)) +CUSTOM_UNARY_OP(double, asin_f64, asing(x)) +CUSTOM_UNARY_OP(float, asinh_f32, asinhg(x)) +CUSTOM_UNARY_OP(double, asinh_f64, asinhg(x)) +CUSTOM_UNARY_OP(float, atan_f32, atang(x)) +CUSTOM_UNARY_OP(double, atan_f64, atang(x)) +CUSTOM_UNARY_OP(float, atanh_f32, atanhg(x)) +CUSTOM_UNARY_OP(double, atanh_f64, atanhg(x)) +CUSTOM_UNARY_OP(uint8_t, bit_not_u8, ~x) +CUSTOM_UNARY_OP(uint32_t, bit_not_u32, ~x) +CUSTOM_UNARY_OP(int64_t, bit_not_i64, ~x) +CUSTOM_UNARY_OP(float, cbrt_f32, cbrtg(x)) +CUSTOM_UNARY_OP(double, cbrt_f64, cbrtg(x)) +CUSTOM_UNARY_OP(float, cosh_f32, coshg(x)) +CUSTOM_UNARY_OP(double, cosh_f64, coshg(x)) +CUSTOM_UNARY_OP(float, erfc_f32, erfcg(x)) +CUSTOM_UNARY_OP(double, erfc_f64, erfcg(x)) +CUSTOM_UNARY_OP(float, erf_inv_f32, erfinvg(x)) +CUSTOM_UNARY_OP(double, erf_inv_f64, erfinvg(x)) +CUSTOM_UNARY_OP(float, expm1_f32, expm1g(x)) +CUSTOM_UNARY_OP(double, expm1_f64, expm1g(x)) +CUSTOM_UNARY_OP(float, ln_1p_f32, log1pg(x)) +CUSTOM_UNARY_OP(double, ln_1p_f64, log1pg(x)) +CUSTOM_UNARY_OP(float, sigmoid_f32, 1.0 / (1.0 + expg(-x))) +CUSTOM_UNARY_OP(double, sigmoid_f64, 1.0 / (1.0 + expg(-x))) +CUSTOM_UNARY_OP(int64_t, sign_i64, x > 0 ? 1 : (x == 0 ? 0 : -1)) +CUSTOM_UNARY_OP(float, sign_f32, signbit(x)) +CUSTOM_UNARY_OP(double, sign_f64, signbit(x)) +CUSTOM_UNARY_OP(float, sinh_f32, sinhg(x)) +CUSTOM_UNARY_OP(double, sinh_f64, sinhg(x)) +CUSTOM_UNARY_OP(float, tan_f32, tang(x)) +CUSTOM_UNARY_OP(double, tan_f64, tang(x)) + +CUSTOM_UNARY_OP_OUT(float, uint8_t, is_inf_f32, isinf(x) ? 1 : 0) +CUSTOM_UNARY_OP_OUT(double, uint8_t, is_inf_f64, isinf(x) ? 1 : 0) +CUSTOM_UNARY_OP_OUT(float, uint8_t, is_nan_f32, isnan(x) ? 1 : 0) +CUSTOM_UNARY_OP_OUT(double, uint8_t, is_nan_f64, isnan(x) ? 1 : 0) diff --git a/candlex/native/candlex/src/kernels/strides.cuh b/candlex/native/candlex/src/kernels/strides.cuh new file mode 100644 index 0000000000..c95123de51 --- /dev/null +++ b/candlex/native/candlex/src/kernels/strides.cuh @@ -0,0 +1,34 @@ +// TODO: This is often used to check that the data is contiguous so that +// kernels can be easily mapped. However this only returns true for row +// major, if all the inputs are column major, we could apply the fast path +// too (but we wouldn't if some of them are row major and some column major). +__device__ bool is_contiguous( + const size_t num_dims, + const size_t *dims, + const size_t *strides +) { + size_t acc = 1; + for (unsigned int d = 0; d < num_dims; d++) { + unsigned int dim_idx = num_dims - 1 - d; + if (acc != strides[dim_idx]) { + return false; + } + acc *= dims[dim_idx]; + } + return true; +} + +__device__ unsigned int get_strided_index( + unsigned int idx, + const size_t num_dims, + const size_t *dims, + const size_t *strides +) { + unsigned int strided_i = 0; + for (unsigned int d = 0; d < num_dims; d++) { + unsigned int dim_idx = num_dims - 1 - d; + strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; + idx /= dims[dim_idx]; + } + return strided_i; +} diff --git a/candlex/native/candlex/src/lib.rs b/candlex/native/candlex/src/lib.rs new file mode 100644 index 0000000000..ef72f434e8 --- /dev/null +++ b/candlex/native/candlex/src/lib.rs @@ -0,0 +1,119 @@ +mod atoms { + rustler::atoms! { + cpu, + cuda + } +} + +mod devices; +mod error; +#[cfg(feature = "cuda")] +mod kernels; +mod ops; +mod tensors; + +use rustler::{Env, Term}; +use tensors::TensorRef; + +fn load(env: Env, _info: Term) -> bool { + rustler::resource!(TensorRef, env); + true +} + +rustler::init! { + "Elixir.Candlex.Native", + [ + tensors::from_binary, + tensors::to_binary, + tensors::add, + tensors::atan2, + tensors::subtract, + tensors::multiply, + tensors::divide, + tensors::quotient, + tensors::remainder, + tensors::pow, + tensors::max, + tensors::min, + tensors::equal, + tensors::not_equal, + tensors::greater, + tensors::greater_equal, + tensors::less, + tensors::less_equal, + tensors::all, + tensors::all_within_dims, + tensors::any, + tensors::any_within_dims, + tensors::sum, + tensors::dtype, + tensors::t_shape, + tensors::argmax, + tensors::argmin, + tensors::reduce_max, + tensors::reduce_min, + tensors::negate, + tensors::where_cond, + tensors::narrow, + tensors::gather, + tensors::index_select, + tensors::index_add, + tensors::chunk, + tensors::squeeze, + tensors::clamp, + tensors::arange, + tensors::to_type, + tensors::broadcast_to, + tensors::reshape, + tensors::concatenate, + tensors::conv1d, + tensors::conv2d, + tensors::permute, + tensors::slice_scatter, + tensors::pad_with_zeros, + tensors::dot, + tensors::matmul, + tensors::abs, + tensors::acos, + tensors::acosh, + tensors::asin, + tensors::asinh, + tensors::atan, + tensors::atanh, + tensors::cbrt, + tensors::ceil, + tensors::cos, + tensors::cosh, + tensors::sigmoid, + tensors::sign, + tensors::sin, + tensors::sinh, + tensors::erf, + tensors::erfc, + tensors::erf_inv, + tensors::exp, + tensors::expm1, + tensors::floor, + tensors::is_infinity, + tensors::is_nan, + tensors::round, + tensors::log, + tensors::log1p, + tensors::rsqrt, + tensors::sqrt, + tensors::tan, + tensors::tanh, + tensors::bitwise_not, + tensors::bitwise_and, + tensors::bitwise_or, + tensors::bitwise_xor, + tensors::logical_and, + tensors::logical_or, + tensors::logical_xor, + tensors::left_shift, + tensors::right_shift, + tensors::to_device, + devices::is_cuda_available + ], + load = load +} diff --git a/candlex/native/candlex/src/ops.rs b/candlex/native/candlex/src/ops.rs new file mode 100644 index 0000000000..98bc300f46 --- /dev/null +++ b/candlex/native/candlex/src/ops.rs @@ -0,0 +1,458 @@ +#[cfg(feature = "cuda")] +use candle_core::CudaStorage; +use candle_core::{CpuStorage, CustomOp1, CustomOp2, Error, Layout, Shape}; +use num_traits::cast::FromPrimitive; +use num_traits::Float; + +fn erfc(v: T) -> T { + FromPrimitive::from_f64(statrs::function::erf::erfc(v.to_f64().unwrap())).unwrap() +} + +fn erf_inv(v: T) -> T { + FromPrimitive::from_f64(statrs::function::erf::erf_inv(v.to_f64().unwrap())).unwrap() +} + +macro_rules! custom_unary_op { + ($struct_name:ident, $name:expr, $cpu_closure:expr, ($($dtypes:ident),+)) => { + pub(crate) struct $struct_name; + + impl CustomOp1 for $struct_name { + // Box does not support const yet, so use a function to get the name. + fn name(&self) -> &'static str { + $name + } + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + storage: &CpuStorage, + layout: &Layout, + ) -> Result<(CpuStorage, Shape), candle_core::Error> { + use candle_core::backend::BackendStorage; + + match storage { + $( + CpuStorage::$dtypes(vec) => { + Ok( + ( + CpuStorage::$dtypes(candle_core::cpu_backend::unary_map(vec, layout, $cpu_closure)), + layout.shape().clone() + ) + ) + } + )* + s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())? + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + storage: &CudaStorage, + layout: &Layout, + ) -> Result<(CudaStorage, Shape), candle_core::Error> { + use crate::kernels; + use candle_core::cuda_backend::cudarc::driver::{CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig}; + use candle_core::cuda_backend::{kernel_name, Map1, WrapErr}; + use candle_core::{CudaDevice, WithDType}; + + impl Map1 for $struct_name { + fn f( + &self, + src: &CudaSlice, + device: &CudaDevice, + layout: &Layout, + ) -> Result, candle_core::Error> { + let src = src.slice(layout.start_offset()..); + let func = device.get_or_load_func(&kernel_name::($name), kernels::CUSTOM_UNARY)?; + let dims = layout.shape().dims(); + let elem_count = layout.shape().elem_count(); + let launch_config = LaunchConfig::for_num_elems(elem_count as u32); + let dims_and_strides = device.htod_copy([dims, layout.stride()].concat()).w()?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { device.alloc::(elem_count) }.w()?; + let params = (elem_count, dims.len(), &dims_and_strides, &src, &dst); + // SAFETY: ffi. + unsafe { func.launch(launch_config, params) }.w()?; + + Ok(dst) + } + } + + use candle_core::backend::BackendStorage; + let device = storage.device(); + let slice = $struct_name.map(&storage.slice, device, layout)?; + + Ok( + ( + CudaStorage { + slice, + device: device.clone(), + }, + layout.shape().clone() + ) + ) + } + } + }; +} + +macro_rules! custom_unary_bool_op { + ($struct_name:ident, $name:expr, $fn_name:ident, ($($dtypes:ident),+)) => { + pub(crate) struct $struct_name; + + impl CustomOp1 for $struct_name { + // Box does not support const yet, so use a function to get the name. + fn name(&self) -> &'static str { + $name + } + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + storage: &CpuStorage, + layout: &Layout, + ) -> Result<(CpuStorage, Shape), candle_core::Error> { + use candle_core::backend::BackendStorage; + + match storage { + $( + CpuStorage::$dtypes(vec) => { + Ok( + ( + CpuStorage::U8( + candle_core::cpu_backend::unary_map(vec, layout, |v| u8::from(v.$fn_name())) + ), + layout.shape().clone() + ) + ) + } + )* + s => Err(Error::UnsupportedDTypeForOp(s.dtype(), $name).bt())? + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + storage: &CudaStorage, + layout: &Layout, + ) -> Result<(CudaStorage, Shape), candle_core::Error> { + use crate::kernels; + use candle_core::cuda_backend::cudarc::driver::{CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits}; + use candle_core::cuda_backend::{kernel_name, CudaStorageSlice, Map1Any, WrapErr}; + use candle_core::{CudaDevice, WithDType}; + + impl Map1Any for $struct_name { + fn f) -> CudaStorageSlice>( + &self, + src: &CudaSlice, + device: &CudaDevice, + layout: &Layout, + _wrap: W, + ) -> Result { + let src = src.slice(layout.start_offset()..); + let func = device.get_or_load_func(&kernel_name::($name), kernels::CUSTOM_UNARY)?; + let dims = layout.shape().dims(); + let elem_count = layout.shape().elem_count(); + let launch_config = LaunchConfig::for_num_elems(elem_count as u32); + let dims_and_strides = device.htod_copy([dims, layout.stride()].concat()).w()?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { device.alloc::(elem_count) }.w()?; + let params = (elem_count, dims.len(), &dims_and_strides, &src, &dst); + // SAFETY: ffi. + unsafe { func.launch(launch_config, params) }.w()?; + + Ok(CudaStorageSlice::U8(dst)) + } + } + + use candle_core::backend::BackendStorage; + let device = storage.device(); + let slice = $struct_name.map(&storage.slice, device, layout)?; + + Ok( + ( + CudaStorage { + slice, + device: device.clone(), + }, + layout.shape().clone() + ) + ) + } + } + }; +} + +macro_rules! custom_binary_op { + ($struct_name:ident, $name:literal, $cpu_closure:expr, ($($dtypes:ident),+)) => { + pub(crate) struct $struct_name; + + impl CustomOp2 for $struct_name { + fn name(&self) -> &'static str { + $name + } + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + ) -> Result<(CpuStorage, Shape), candle_core::Error> { + use candle_core::backend::BackendStorage; + + match (s1, s2) { + $( + (CpuStorage::$dtypes(lhs), CpuStorage::$dtypes(rhs)) => { + Ok( + ( + CpuStorage::$dtypes( + candle_core::cpu_backend::binary_map(l1, l2, lhs, rhs, $cpu_closure) + ), + l1.shape().clone() + ) + ) + } + )* + _ => { + Err(Error::DTypeMismatchBinaryOp { + lhs: s1.dtype(), + rhs: s2.dtype(), + op: self.name(), + } + .bt()) + } + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &CudaStorage, + l1: &Layout, + s2: &CudaStorage, + l2: &Layout, + ) -> Result<(CudaStorage, Shape), candle_core::Error> { + use crate::kernels; + use candle_core::cuda_backend::cudarc::driver::{CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits}; + use candle_core::cuda_backend::{kernel_name, Map2, WrapErr}; + use candle_core::{CudaDevice, WithDType}; + + impl Map2 for $struct_name { + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + device: &CudaDevice, + ) -> Result, candle_core::Error> { + let shape1 = layout1.shape(); + let dims1 = shape1.dims(); + let elem_count1 = shape1.elem_count(); + let launch_config = LaunchConfig::for_num_elems(elem_count1 as u32); + let dims_and_strides = device + .htod_copy([dims1, layout1.stride(), layout2.stride()].concat()) + .w()?; + let src1 = src1.slice(layout1.start_offset()..); + let src2 = src2.slice(layout2.start_offset()..); + let func = device.get_or_load_func(&kernel_name::($name), kernels::CUSTOM_BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { device.alloc::(elem_count1) }.w()?; + let params = (elem_count1, dims1.len(), &dims_and_strides, &src1, &src2, &out); + // SAFETY: ffi + unsafe { func.launch(launch_config, params) }.w()?; + + Ok(out) + } + } + + use candle_core::backend::BackendStorage; + let device = s1.device(); + let slice = $struct_name.map(&s1.slice, l1, &s2.slice, l2, device)?; + + Ok( + ( + CudaStorage { + slice, + device: device.clone(), + }, + l1.shape().clone() + ) + ) + } + } + } +} + +macro_rules! custom_binary_bool_op { + ($struct_name:ident, $name:literal, $cpu_closure:expr, ($($dtypes:ident),+)) => { + pub(crate) struct $struct_name; + + impl CustomOp2 for $struct_name { + fn name(&self) -> &'static str { + $name + } + + /// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides, + /// offsets etc so the associated layout should be used to access it. + fn cpu_fwd( + &self, + s1: &CpuStorage, + l1: &Layout, + s2: &CpuStorage, + l2: &Layout, + ) -> Result<(CpuStorage, Shape), candle_core::Error> { + use candle_core::backend::BackendStorage; + + match (s1, s2) { + $( + (CpuStorage::$dtypes(lhs), CpuStorage::$dtypes(rhs)) => { + Ok( + ( + CpuStorage::U8( + candle_core::cpu_backend::binary_map( + l1, + l2, + lhs, + rhs, + |v1, v2| u8::from($cpu_closure(v1, v2)) + ) + ), + l1.shape().clone() + ) + ) + } + )* + _ => { + Err(Error::DTypeMismatchBinaryOp { + lhs: s1.dtype(), + rhs: s2.dtype(), + op: self.name(), + } + .bt()) + } + } + } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + s1: &CudaStorage, + l1: &Layout, + s2: &CudaStorage, + l2: &Layout, + ) -> Result<(CudaStorage, Shape), candle_core::Error> { + use crate::kernels; + use candle_core::cuda_backend::cudarc::driver::{CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, ValidAsZeroBits}; + use candle_core::cuda_backend::{kernel_name, CudaStorageSlice, Map2Any, WrapErr}; + use candle_core::{CudaDevice, WithDType}; + + impl Map2Any for $struct_name { + fn f( + &self, + src1: &CudaSlice, + layout1: &Layout, + src2: &CudaSlice, + layout2: &Layout, + device: &CudaDevice, + ) -> Result { + let shape1 = layout1.shape(); + let dims1 = shape1.dims(); + let elem_count1 = shape1.elem_count(); + let launch_config = LaunchConfig::for_num_elems(elem_count1 as u32); + let dims_and_strides = device + .htod_copy([dims1, layout1.stride(), layout2.stride()].concat()) + .w()?; + let src1 = src1.slice(layout1.start_offset()..); + let src2 = src2.slice(layout2.start_offset()..); + let func = device.get_or_load_func(&kernel_name::($name), kernels::CUSTOM_BINARY)?; + // SAFETY: Set later by running the kernel. + let out = unsafe { device.alloc::(elem_count1) }.w()?; + let params = (elem_count1, dims1.len(), &dims_and_strides, &src1, &src2, &out); + // SAFETY: ffi + unsafe { func.launch(launch_config, params) }.w()?; + + Ok(CudaStorageSlice::U8(out)) + } + } + + use candle_core::backend::BackendStorage; + let device = s1.device(); + let slice = $struct_name.map(&s1.slice, l1, &s2.slice, l2, device)?; + + Ok( + ( + CudaStorage { + slice, + device: device.clone(), + }, + l1.shape().clone() + ) + ) + } + } + } +} + +custom_unary_op!(Acos, "acos", |v| v.acos(), (BF16, F16, F32, F64)); +custom_unary_op!(Acosh, "acosh", |v| v.acosh(), (BF16, F16, F32, F64)); +custom_unary_op!(Asin, "asin", |v| v.asin(), (BF16, F16, F32, F64)); +custom_unary_op!(Asinh, "asinh", |v| v.asinh(), (BF16, F16, F32, F64)); +custom_unary_op!(Atan, "atan", |v| v.atan(), (BF16, F16, F32, F64)); +custom_unary_op!(Atanh, "atanh", |v| v.atanh(), (BF16, F16, F32, F64)); +custom_unary_op!(BitNot, "bit_not", |v| !v, (U8, U32, I64)); +custom_unary_op!(Cbrt, "cbrt", |v| v.cbrt(), (BF16, F16, F32, F64)); +custom_unary_op!(Cosh, "cosh", |v| v.cosh(), (BF16, F16, F32, F64)); +custom_unary_op!(Erfc, "erfc", erfc, (BF16, F16, F32, F64)); +custom_unary_op!(ErfInv, "erf_inv", erf_inv, (BF16, F16, F32, F64)); +custom_unary_op!(Expm1, "expm1", |v| v.exp_m1(), (BF16, F16, F32, F64)); +custom_unary_op!(Log1p, "ln_1p", |v| v.ln_1p(), (BF16, F16, F32, F64)); +custom_unary_op!(Sigmoid, "sigmoid", |v| 1. / (1. + (-v).exp()), (F32, F64)); +custom_unary_op!(Sign, "sign", |v| v.signum(), (I64, BF16, F16, F32, F64)); +custom_unary_op!(Sinh, "sinh", |v| v.sinh(), (BF16, F16, F32, F64)); +custom_unary_op!(Tan, "tan", |v| v.tan(), (BF16, F16, F32, F64)); +custom_unary_bool_op!(IsInf, "is_inf", is_infinite, (F32, F64)); +custom_unary_bool_op!(IsNan, "is_nan", is_nan, (F32, F64)); + +custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2, (U32, I64)); +custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2, (U32, I64)); +custom_binary_op!(BitXor, "bit_xor", |v1, v2| v1 ^ v2, (U32, I64)); +custom_binary_op!(Atan2, "atan2", |v1, v2| v1.atan2(v2), (F32, F64)); +custom_binary_op!(Pow, "pow", |v1, v2| v1.powf(v2), (F32, F64)); +custom_binary_op!( + Remainder, + "remainder", + |v1, v2| v1 % v2, + (U8, I64, F32, F64) +); +custom_binary_op!(Shl, "shl", |v1, v2| v1 << v2, (U32, I64)); +custom_binary_op!(Shr, "shr", |v1, v2| v1 >> v2, (U32, I64)); +custom_binary_bool_op!( + LogicalAnd, + "logical_and", + |v1, v2| if v1 as i8 != 0 && v2 as i8 != 0 { 1 } else { 0 }, + (U8, U32, I64, F32, F64) +); +custom_binary_bool_op!( + LogicalOr, + "logical_or", + |v1, v2| if v1 as i8 == 0 && v2 as i8 == 0 { 0 } else { 1 }, + (U8, U32, I64, F32, F64) +); +custom_binary_bool_op!( + LogicalXor, + "logical_xor", + |v1, v2| if (v1 as i8 != 0) == (v2 as i8 != 0) { + 0 + } else { + 1 + }, + (U8, U32, I64, F32, F64) +); diff --git a/candlex/native/candlex/src/tensors.rs b/candlex/native/candlex/src/tensors.rs new file mode 100644 index 0000000000..78c94b413f --- /dev/null +++ b/candlex/native/candlex/src/tensors.rs @@ -0,0 +1,583 @@ +use crate::atoms; +use crate::error::CandlexError; +use crate::ops::{ + Acos, Acosh, Asin, Asinh, Atan, Atan2, Atanh, BitAnd, BitNot, BitOr, BitXor, Cbrt, Cosh, + ErfInv, Erfc, Expm1, IsInf, IsNan, Log1p, LogicalAnd, LogicalOr, LogicalXor, Pow, Remainder, + Shl, Shr, Sigmoid, Sign, Sinh, Tan, +}; +use candle_core::{DType, Device, Tensor}; +use half::{bf16, f16}; +use rustler::{Atom, Binary, Encoder, Env, NewBinary, NifStruct, ResourceArc, Term}; +use std::ops::Deref; +use std::result::Result; +use std::str::FromStr; + +pub(crate) struct TensorRef(Tensor); + +#[derive(NifStruct)] +#[module = "Candlex.Backend"] +pub struct ExTensor { + device: Atom, + resource: ResourceArc, +} + +impl ExTensor { + pub fn new(tensor: Tensor) -> Self { + let dev_string = match tensor.device() { + Device::Cpu => atoms::cpu(), + Device::Cuda(_) => atoms::cuda(), + }; + + Self { + device: dev_string, + resource: ResourceArc::new(TensorRef(tensor)), + } + } +} + +// Implement Deref so we can call `Tensor` functions directly from an `ExTensor` struct. +impl Deref for ExTensor { + type Target = Tensor; + + fn deref(&self) -> &Self::Target { + &self.resource.0 + } +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn from_binary( + binary: Binary, + dtype_str: &str, + shape: Term, + device: Atom, +) -> Result { + Ok(ExTensor::new(Tensor::from_raw_buffer( + binary.as_slice(), + // TODO: Handle DTypeParseError + DType::from_str(dtype_str).unwrap(), + // TODO: Handle rustler::Error + &tuple_to_vec(shape).unwrap(), + &device_from_atom(device)?, + )?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn to_device(ex_tensor: ExTensor, device: Atom) -> Result { + Ok(ExTensor::new( + ex_tensor.to_device(&device_from_atom(device)?)?, + )) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn to_binary(env: Env, ex_tensor: ExTensor) -> Result { + let bytes = tensor_bytes(ex_tensor.flatten_all()?)?; + let mut binary = NewBinary::new(env, bytes.len()); + binary.as_mut_slice().copy_from_slice(bytes.as_slice()); + + Ok(binary.into()) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn narrow( + t: ExTensor, + dim: usize, + start: usize, + length: usize, +) -> Result { + Ok(ExTensor::new(t.narrow(dim, start, length)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn gather(t: ExTensor, indexes: ExTensor, dim: usize) -> Result { + Ok(ExTensor::new(t.gather(indexes.deref(), dim)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn index_select(t: ExTensor, indexes: ExTensor, dim: usize) -> Result { + Ok(ExTensor::new(t.index_select(indexes.deref(), dim)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn index_add( + t: ExTensor, + indexes: ExTensor, + source: ExTensor, + dim: usize, +) -> Result { + Ok(ExTensor::new(t.index_add( + indexes.deref(), + source.deref(), + dim, + )?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn chunk(t: ExTensor, num_chunks: usize) -> Result, CandlexError> { + Ok(t.chunk(num_chunks, 0)? + .into_iter() + .map(ExTensor::new) + .collect()) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn squeeze(t: ExTensor, dim: usize) -> Result { + Ok(ExTensor::new(t.squeeze(dim)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn clamp(t: ExTensor, min_val: ExTensor, max_val: ExTensor) -> Result { + Ok(ExTensor::new(t.clamp( + &min_val.broadcast_as(t.shape())?, + &max_val.broadcast_as(t.shape())?, + )?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn rsqrt(t: ExTensor) -> Result { + Ok(ExTensor::new(t.sqrt()?.recip()?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn arange( + start: i64, + end: i64, + dtype_str: &str, + shape: Term, + device: Atom, +) -> Result { + Ok(ExTensor::new( + Tensor::arange(start, end, &device_from_atom(device)?)? + .to_dtype(DType::from_str(dtype_str).unwrap())? + .reshape(tuple_to_vec(shape).unwrap())?, + )) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn all(ex_tensor: ExTensor) -> Result { + Ok(ExTensor::new(_all( + &ex_tensor.flatten_all()?, + vec![0], + false, + )?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn all_within_dims( + ex_tensor: ExTensor, + dims: Vec, + keep_dims: bool, +) -> Result { + Ok(ExTensor::new(_all(ex_tensor.deref(), dims, keep_dims)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn any(ex_tensor: ExTensor) -> Result { + Ok(ExTensor::new(_any( + &ex_tensor.flatten_all()?, + vec![0], + false, + )?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn any_within_dims( + ex_tensor: ExTensor, + dims: Vec, + keep_dims: bool, +) -> Result { + Ok(ExTensor::new(_any(ex_tensor.deref(), dims, keep_dims)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn argmax(ex_tensor: ExTensor, dim: usize, keep_dim: bool) -> Result { + let t = if keep_dim { + ex_tensor.argmax_keepdim(dim)? + } else { + ex_tensor.argmax(dim)? + }; + + Ok(ExTensor::new(t)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn argmin(ex_tensor: ExTensor, dim: usize, keep_dim: bool) -> Result { + let t = if keep_dim { + ex_tensor.argmin_keepdim(dim)? + } else { + ex_tensor.argmin(dim)? + }; + + Ok(ExTensor::new(t)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn reduce_max( + ex_tensor: ExTensor, + dim: usize, + keep_dim: bool, +) -> Result { + let t = if keep_dim { + ex_tensor.max_keepdim(dim)? + } else { + ex_tensor.max(dim)? + }; + + Ok(ExTensor::new(t)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn reduce_min( + ex_tensor: ExTensor, + dim: usize, + keep_dim: bool, +) -> Result { + let t = if keep_dim { + ex_tensor.min_keepdim(dim)? + } else { + ex_tensor.min(dim)? + }; + + Ok(ExTensor::new(t)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn sum( + ex_tensor: ExTensor, + dims: Vec, + keep_dims: bool, +) -> Result { + let t = if keep_dims { + ex_tensor.sum_keepdim(dims)? + } else { + ex_tensor.sum(dims)? + }; + + Ok(ExTensor::new(t)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn permute(ex_tensor: ExTensor, dims: Vec) -> Result { + Ok(ExTensor::new(ex_tensor.permute(dims)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn broadcast_to(t: ExTensor, shape: Term) -> Result { + Ok(ExTensor::new(t.broadcast_as(tuple_to_vec(shape).unwrap())?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn reshape(t: ExTensor, shape: Term) -> Result { + Ok(ExTensor::new(t.reshape(tuple_to_vec(shape).unwrap())?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn slice_scatter( + t: ExTensor, + src: ExTensor, + dim: usize, + start: usize, +) -> Result { + Ok(ExTensor::new(t.slice_scatter(src.deref(), dim, start)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn pad_with_zeros(t: ExTensor, left: usize, right: usize) -> Result { + Ok(ExTensor::new(t.pad_with_zeros(0, left, right)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn where_cond( + t: ExTensor, + on_true: ExTensor, + on_false: ExTensor, +) -> Result { + Ok(ExTensor::new(t.where_cond(&on_true, &on_false)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn to_type(t: ExTensor, dtype_str: &str) -> Result { + Ok(ExTensor::new( + t.to_dtype(DType::from_str(dtype_str).unwrap())?, + )) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn dtype(t: ExTensor) -> Result<&'static str, CandlexError> { + Ok(t.dtype().as_str()) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn t_shape(env: Env, t: ExTensor) -> Result { + Ok(vec_to_tuple(env, t.shape().clone().into_dims()).unwrap()) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn concatenate(ex_tensors: Vec, dim: usize) -> Result { + let tensors = ex_tensors + .iter() + .map(|t| t.deref()) + .collect::>(); + Ok(ExTensor::new(Tensor::cat(&tensors[..], dim)?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn conv1d(tensor: ExTensor, kernel: ExTensor) -> Result { + let padding = 0; + let stride = 1; + let dilation = 1; + let groups = 1; + + Ok(ExTensor::new(tensor.conv1d( + kernel.deref(), + padding, + stride, + dilation, + groups, + )?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn conv2d(tensor: ExTensor, kernel: ExTensor) -> Result { + let padding = 0; + let stride = 1; + let dilation = 1; + let groups = 1; + + Ok(ExTensor::new(tensor.conv2d( + kernel.deref(), + padding, + stride, + dilation, + groups, + )?)) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn divide(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new( + // Need to force float in case we receive integers, given + // candle rounds down integer division. + left.to_dtype(DType::F32)? + .broadcast_div(&right.to_dtype(DType::F32)?)?, + )) +} + +#[rustler::nif(schedule = "DirtyCpu")] +pub fn dot(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new( + left.mul(&right.broadcast_as(left.shape())?)? + .sum(left.rank() - 1)?, + )) +} + +macro_rules! unary_nif { + ($nif_name:ident, $native_fn_name:ident) => { + #[rustler::nif(schedule = "DirtyCpu")] + pub fn $nif_name(ex_tensor: ExTensor) -> Result { + Ok(ExTensor::new(ex_tensor.$native_fn_name()?)) + } + }; + ($nif_name:ident) => { + unary_nif!($nif_name, $nif_name); + }; +} + +macro_rules! binary_nif { + ($nif_name:ident, $native_fn_name:ident) => { + #[rustler::nif(schedule = "DirtyCpu")] + pub fn $nif_name(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new(left.$native_fn_name(right.deref())?)) + } + }; +} + +macro_rules! custom_unary_nif { + ($nif_name:ident, $custom_op_name:ident) => { + #[rustler::nif(schedule = "DirtyCpu")] + pub fn $nif_name(ex_tensor: ExTensor) -> Result { + Ok(ExTensor::new(ex_tensor.apply_op1_no_bwd(&$custom_op_name)?)) + } + }; +} + +macro_rules! custom_binary_nif { + ($nif_name:ident, $custom_op_name:ident) => { + #[rustler::nif(schedule = "DirtyCpu")] + pub fn $nif_name(left: ExTensor, right: ExTensor) -> Result { + Ok(ExTensor::new( + left.apply_op2_no_bwd(right.deref(), &$custom_op_name)?, + )) + } + }; +} + +unary_nif!(negate, neg); +unary_nif!(abs); +unary_nif!(ceil); +unary_nif!(cos); +unary_nif!(erf); +unary_nif!(exp); +unary_nif!(floor); +unary_nif!(round); +unary_nif!(sin); +unary_nif!(log); +unary_nif!(sqrt); +unary_nif!(tanh); + +custom_unary_nif!(acos, Acos); +custom_unary_nif!(acosh, Acosh); +custom_unary_nif!(asin, Asin); +custom_unary_nif!(asinh, Asinh); +custom_unary_nif!(atan, Atan); +custom_unary_nif!(atanh, Atanh); +custom_unary_nif!(bitwise_not, BitNot); +custom_unary_nif!(cbrt, Cbrt); +custom_unary_nif!(cosh, Cosh); +custom_unary_nif!(erfc, Erfc); +custom_unary_nif!(erf_inv, ErfInv); +custom_unary_nif!(expm1, Expm1); +custom_unary_nif!(is_infinity, IsInf); +custom_unary_nif!(is_nan, IsNan); +custom_unary_nif!(log1p, Log1p); +custom_unary_nif!(sigmoid, Sigmoid); +custom_unary_nif!(sign, Sign); +custom_unary_nif!(sinh, Sinh); +custom_unary_nif!(tan, Tan); + +binary_nif!(add, broadcast_add); +binary_nif!(subtract, broadcast_sub); +binary_nif!(multiply, broadcast_mul); +binary_nif!(quotient, broadcast_div); +binary_nif!(max, broadcast_maximum); +binary_nif!(min, broadcast_minimum); +binary_nif!(equal, eq); +binary_nif!(not_equal, ne); +binary_nif!(greater, gt); +binary_nif!(greater_equal, ge); +binary_nif!(less, lt); +binary_nif!(less_equal, le); +binary_nif!(matmul, broadcast_matmul); + +custom_binary_nif!(atan2, Atan2); +custom_binary_nif!(bitwise_and, BitAnd); +custom_binary_nif!(bitwise_or, BitOr); +custom_binary_nif!(bitwise_xor, BitXor); +custom_binary_nif!(left_shift, Shl); +custom_binary_nif!(logical_and, LogicalAnd); +custom_binary_nif!(logical_or, LogicalOr); +custom_binary_nif!(logical_xor, LogicalXor); +custom_binary_nif!(pow, Pow); +custom_binary_nif!(right_shift, Shr); +custom_binary_nif!(remainder, Remainder); + +fn _any(tensor: &Tensor, dims: Vec, keep_dims: bool) -> Result { + let comparison = tensor.ne(&tensor.zeros_like()?)?; + + let result = if keep_dims { + dims.iter() + .rev() + .fold(comparison, |t, dim| t.max_keepdim(*dim).unwrap()) + } else { + dims.iter() + .rev() + .fold(comparison, |t, dim| t.max(*dim).unwrap()) + }; + + Ok(result) +} + +fn _all(tensor: &Tensor, dims: Vec, keep_dims: bool) -> Result { + let comparison = tensor.ne(&tensor.zeros_like()?)?; + + let result = if keep_dims { + dims.iter() + .rev() + .fold(comparison, |t, dim| t.min_keepdim(*dim).unwrap()) + } else { + dims.iter() + .rev() + .fold(comparison, |t, dim| t.min(*dim).unwrap()) + }; + + Ok(result) +} + +fn tuple_to_vec(term: Term) -> Result, rustler::Error> { + rustler::types::tuple::get_tuple(term)? + .iter() + .map(|elem| elem.decode()) + .collect::>() +} + +fn vec_to_tuple(env: Env, vec: Vec) -> Result { + Ok(rustler::types::tuple::make_tuple( + env, + &vec.into_iter() + .map(|elem| elem.encode(env)) + .collect::>(), + )) +} + +static CUDA_DEVICE: std::sync::Mutex> = std::sync::Mutex::new(None); + +fn device_from_atom(atom: Atom) -> Result { + if atom == atoms::cpu() { + Ok(Device::Cpu) + } else if atom == atoms::cuda() { + let mut cuda_device = CUDA_DEVICE.lock().unwrap(); + + if let Some(device) = cuda_device.as_ref() { + Ok(device.clone()) + } else { + let new_cuda_device = Device::new_cuda(0)?; + *cuda_device = Some(new_cuda_device.clone()); + + Ok(new_cuda_device) + } + } else { + Err(CandlexError::Other(format!( + "unsupported device {:?}", + atom + ))) + } +} + +fn tensor_bytes(tensor: Tensor) -> Result, CandlexError> { + Ok(match tensor.dtype() { + DType::I64 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::U8 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::U32 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::F16 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::F32 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::F64 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + DType::BF16 => tensor + .to_vec1::()? + .iter() + .flat_map(|val| val.to_ne_bytes()) + .collect(), + }) +} diff --git a/candlex/test/candlex_test.exs b/candlex/test/candlex_test.exs new file mode 100644 index 0000000000..58991433cd --- /dev/null +++ b/candlex/test/candlex_test.exs @@ -0,0 +1,2303 @@ +defmodule CandlexTest do + use Nx.Case, async: true + doctest Candlex + + describe "creation" do + test "tensor" do + check(255, type: :u8) + check(100_002, type: :u32) + check(100_102, type: :u64) + check(-101, type: :s64) + check(1.16, type: :f16) + check(1.32, type: :f32) + check([1, 2, 3], type: :f32) + check(-0.002, type: :f64) + check([1, 2], type: :u32) + check([[1, 2], [3, 4]], type: :u32) + check([[1, 2, 3, 4], [5, 6, 7, 8]], type: :u32) + check([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], type: :u32) + check([0, 255], type: :u8) + check([-0.5, 0.88], type: :f32) + check([-0.5, 0.88], type: :f64) + check(2.16, type: :bf16) + end + + test "named dimensions" do + check([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) + + t([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) + |> assert_equal(t([[1, 2, 3], [4, 5, 6]])) + end + + test "tensor tensor" do + t(t([1, 2, 3])) + |> assert_equal(t([1, 2, 3])) + end + + test "tril" do + t([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + |> Nx.tril() + |> assert_equal(t([[1, 0, 0], [4, 5, 0], [7, 8, 9]])) + end + + test "triu" do + t([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + |> Nx.triu() + |> assert_equal(t([[1, 2, 3], [0, 5, 6], [0, 0, 9]])) + end + + test "addition" do + t([1, 2, 3]) + |> Nx.add(t([10, 20, 30])) + |> assert_equal(t([11, 22, 33])) + + t([1, 2, 3], type: :u64) + |> Nx.add(t([10, 20, 30], type: :u64)) + |> assert_equal(t([11, 22, 33])) + + Nx.add(1, 2.2) + |> assert_equal(t(3.2)) + + t([1, 2, 3]) + |> Nx.add(1.0) + |> assert_equal(t([2.0, 3.0, 4.0])) + end + + test "iota" do + Nx.iota({}) + |> assert_equal(t(0)) + + Nx.iota({}, type: :f32) + |> assert_equal(t(0.0)) + + Nx.iota({5}) + |> assert_equal(t([0, 1, 2, 3, 4])) + + Nx.iota({5}, type: :u64) + |> assert_equal(t([0, 1, 2, 3, 4])) + + Nx.iota({5}, type: :f32) + |> assert_equal(t([0.0, 1.0, 2.0, 3.0, 4.0])) + + Nx.iota({2, 3}) + |> assert_equal(t([[0, 1, 2], [3, 4, 5]])) + + Nx.iota({3, 3}, axis: 1) + |> assert_equal( + t([ + [0, 1, 2], + [0, 1, 2], + [0, 1, 2] + ]) + ) + + Nx.iota({3, 3}, axis: -1) + |> assert_equal( + t([ + [0, 1, 2], + [0, 1, 2], + [0, 1, 2] + ]) + ) + + Nx.iota({3, 4, 3}, axis: 0, type: :f64) + |> assert_equal( + t([ + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0] + ], + [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0] + ], + [ + [2.0, 2.0, 2.0], + [2.0, 2.0, 2.0], + [2.0, 2.0, 2.0], + [2.0, 2.0, 2.0] + ] + ]) + ) + + Nx.iota({1, 3, 2}, axis: 2) + |> assert_equal( + t([ + [ + [0, 1], + [0, 1], + [0, 1] + ] + ]) + ) + end + + test "max" do + Nx.max(1, 2) + |> assert_equal(t(2)) + + Nx.max(1, t([1.0, 2.0, 3.0], names: [:data])) + |> assert_equal(t([1.0, 2.0, 3.0])) + + t([[1], [2]], type: :f32, names: [:x, nil]) + |> Nx.max(t([[10, 20]], type: :f32, names: [nil, :y])) + |> assert_equal(t([[10.0, 20.0], [10.0, 20.0]])) + end + + test "min" do + Nx.min(1, 2) + |> assert_equal(t(1)) + + Nx.min(1, t([1.0, 2.0, 3.0], names: [:data])) + |> assert_equal(t([1.0, 1.0, 1.0])) + + t([[1], [2]], type: :f32, names: [:x, nil]) + |> Nx.min(t([[10, 20]], type: :f32, names: [nil, :y])) + |> assert_equal(t([[1.0, 1.0], [2.0, 2.0]])) + end + + test "multiply" do + t([1, 2]) + |> Nx.multiply(t([3, 4])) + |> assert_equal(t([3, 8])) + + t([1, 2], type: :u64) + |> Nx.multiply(t([3, 4], type: :u64)) + |> assert_equal(t([3, 8])) + + t([[1], [2]]) + |> Nx.multiply(t([3, 4])) + |> assert_equal(t([[3, 4], [6, 8]])) + + t([1, 2]) + |> Nx.multiply(t([[3], [4]])) + |> assert_equal(t([[3, 6], [4, 8]])) + end + + test "divide/2" do + 1.0 + |> Nx.divide(2) + |> assert_equal(t(0.5)) + + t([1.0, 2, 3]) + |> Nx.divide(1) + |> assert_equal(t([1.0, 2.0, 3.0])) + + t([[1.0], [2]]) + |> Nx.divide(t([[10, 20]])) + |> assert_equal( + t([ + [0.10000000149011612, 0.05000000074505806], + [0.20000000298023224, 0.10000000149011612] + ]) + ) + + 1 + |> Nx.divide(2) + |> assert_equal(t(0.5)) + + t([1, 2, 3]) + |> Nx.divide(2) + |> assert_equal(t([0.5, 1.0, 1.5])) + + t([[1], [2]]) + |> Nx.divide(t([[10, 20]])) + |> assert_equal( + t([ + [0.10000000149011612, 0.05000000074505806], + [0.20000000298023224, 0.10000000149011612] + ]) + ) + end + + test "remainder" do + Nx.remainder(1, 2) + |> assert_equal(t(1)) + + t([1, 2, 3]) + |> Nx.remainder(2) + |> assert_equal(t([1, 0, 1])) + + 2 + |> Nx.remainder(t([1.0, 2.0, 3.0])) + |> assert_equal(t([0.0, 0.0, 2.0])) + + t([[10], [20]], names: [:x, :y]) + |> Nx.remainder(t([[3, 4]], names: [nil, :y])) + |> assert_equal( + t([ + [1, 2], + [2, 0] + ]) + ) + + left = t(-11) + right = t(10, type: :u8) + + Nx.remainder(left, right) + |> assert_equal(t(-1)) + + left + |> Nx.add(t(20)) + |> Nx.remainder(right) + |> assert_equal(t(9)) + + positive_left = t(9, type: :u8) + + Nx.remainder(positive_left, right) + |> assert_equal(t(9)) + + positive_left + |> Nx.add(Nx.tensor(20, type: :u8)) + |> Nx.remainder(right) + |> assert_equal(t(9)) + end + + test "quotient" do + Nx.quotient(11, 2) + |> assert_equal(t(5)) + + t([2, 4, 5]) + |> Nx.quotient(2) + |> assert_equal(t([1, 2, 2])) + + 10 + |> Nx.quotient(t([1, 2, 3])) + |> assert_equal(t([10, 5, 3])) + + t([[10, 20]], names: [nil, :y]) + |> Nx.quotient(t([[1], [2]], names: [:x, nil])) + |> assert_equal( + t([ + [10, 20], + [5, 10] + ]) + ) + + t([[10, 20]]) + |> Nx.quotient(t([[1], [2]])) + |> assert_equal( + t([ + [10, 20], + [5, 10] + ]) + ) + + t([[10, 20]], type: :u8) + |> Nx.quotient(t([[1], [2]], type: :u32)) + |> assert_equal( + t([ + [10, 20], + [5, 10] + ]) + ) + end + + test "sign" do + t([-2, -1, 0, 1, 2]) + |> Nx.sign() + |> assert_equal(t([-1, -1, 0, 1, 1])) + end + + test "atan2" do + Nx.atan2(1.0, 2.0) + |> assert_close(t(0.46364760398864746)) + + t([1.0, 2, 3]) + |> Nx.atan2(1) + |> assert_close(t([0.7853981852531433, 1.1071487665176392, 1.249045729637146])) + + 1.0 + |> Nx.atan2(t([1.0, 2.0, 3.0])) + |> assert_close(t([0.7853981852531433, 0.46364760398864746, 0.32175055146217346])) + + t([[-0.0], [0.0]], type: :f64) + |> Nx.atan2(t([-0.0, 0.0], type: :f64)) + |> assert_close( + t([ + [-3.141592653589793, -0.0], + [3.141592653589793, 0.0] + ]) + ) + end + + test "broadcast" do + Nx.broadcast(1, {1, 2, 3}) + |> assert_equal(t([[[1, 1, 1], [1, 1, 1]]])) + + t([1, 2, 3]) + |> Nx.broadcast({3, 2}, axes: [0]) + |> assert_equal(t([[1, 1], [2, 2], [3, 3]])) + end + + test "access" do + tensor = t([[1, 2], [3, 4]]) + + assert_equal(tensor[0], t([1, 2])) + assert_equal(tensor[1], t([3, 4])) + end + + test "concatenate" do + [t([1, 2, 3])] + |> Nx.concatenate() + |> assert_equal(t([1, 2, 3])) + + [t([1, 2, 3]), t([4, 5, 6])] + |> Nx.concatenate() + |> assert_equal(t([1, 2, 3, 4, 5, 6])) + + t1 = Nx.iota({2, 2, 2}, names: [:x, :y, :z], type: :f32) + t2 = Nx.iota({1, 2, 2}, names: [:x, :y, :z], type: :u8) + t3 = Nx.iota({1, 2, 2}, names: [:x, :y, :z], type: :s64) + + [t1, t2, t3] + |> Nx.concatenate(axis: :x) + |> assert_equal( + t([ + [ + [0.0, 1.0], + [2.0, 3.0] + ], + [ + [4.0, 5.0], + [6.0, 7.0] + ], + [ + [0.0, 1.0], + [2.0, 3.0] + ], + [ + [0.0, 1.0], + [2.0, 3.0] + ] + ]) + ) + end + + test "greater" do + Nx.greater(1, 2) + |> assert_equal(t(0)) + + Nx.greater(1, t([1, 2, 3])) + |> assert_equal(t([0, 0, 0])) + + t([1, 2, 3]) + |> Nx.greater(t([1, 2, 2])) + |> assert_equal(t([0, 0, 1])) + + t([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + |> Nx.greater(t([1, 2, 3])) + |> assert_equal( + t([ + [0, 0, 0], + [1, 1, 1] + ]) + ) + end + + test "less" do + Nx.less(1, 2) + |> assert_equal(t(1)) + + Nx.less(1, t([1, 2, 3])) + |> assert_equal(t([0, 1, 1])) + + t([[1.0, 2.0, 3.0], [4.0, 2.0, 1.0]]) + |> Nx.less(t([1, 2, 3])) + |> assert_equal(t([[0, 0, 0], [0, 0, 1]])) + end + + test "less_equal" do + Nx.less_equal(1, 2) + |> assert_equal(t(1)) + + Nx.less_equal(1, t([1, 2, 3])) + |> assert_equal(t([1, 1, 1])) + + t([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + |> Nx.less_equal(t([1, 2, 3])) + |> assert_equal(t([[1, 1, 1], [0, 0, 0]])) + end + + test "bitcast" do + t([0, 0, 0], type: :s64) + |> Nx.bitcast(:f64) + |> assert_equal(t([0.0, 0.0, 0.0])) + + t([0, 0, 0], type: :u32) + |> Nx.bitcast(:f32) + |> assert_equal(t([0.0, 0.0, 0.0])) + + t([0, 0, 0], type: :u32) + |> Nx.bitcast(:u32) + |> assert_equal(t([0, 0, 0])) + end + + test "eye" do + Nx.eye(2) + |> assert_equal(t([[1, 0], [0, 1]])) + + Nx.eye(3, type: :f32) + |> assert_equal( + t([ + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0] + ]) + ) + + Nx.eye({1, 2}) + |> assert_equal(t([[1, 0]])) + + Nx.eye({2, 4, 3}) + |> assert_equal( + t([ + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, 0, 0] + ], + [ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [0, 0, 0] + ] + ]) + ) + + # assert_equal doesn't yet work with vectorized axes + # Nx.eye({3}, vectorized_axes: [x: 1, y: 2]) + # |> assert_equal(t( + # [ + # [ + # [1, 0, 0], + # [1, 0, 0] + # ] + # ] + # )) + + # Nx.eye({2, 3}, vectorized_axes: [x: 2]) + # |> assert_equal(t( + # [ + # [ + # [1, 0, 0], + # [0, 1, 0] + # ], + # [ + # [1, 0, 0], + # [0, 1, 0] + # ] + # ] + # )) + end + + test "dot/2" do + # Dot product of scalars + + Nx.dot(5, 5) + |> assert_equal(t(25)) + + Nx.dot(-2.0, 5.0) + |> assert_equal(t(-10.0)) + + Nx.dot(2, 2.0) + |> assert_equal(t(4.0)) + + # Dot product of vectors + + t([1, 2, 3]) + |> Nx.dot(t([4, 5, 6])) + |> assert_equal(t(32)) + + t([1.0, 2, 3]) + |> Nx.dot(t([1, 2, 3])) + |> assert_equal(t(14.0)) + + # Dot product of matrices (2-D tensors) + + # TODO: Candle matmul doesn't support integers yet + # t([[1, 2, 3], [4, 5, 6]]) + # |> Nx.dot(t([[7, 8], [9, 10], [11, 12]])) + # |> assert_equal(t( + # [ + # [58, 64], + # [139, 154] + # ] + # )) + + t([[1.0, 2, 3], [4, 5, 6]]) + |> Nx.dot(t([[7, 8], [9, 10], [11, 12]])) + |> assert_equal( + t([ + [58.0, 64], + [139, 154] + ]) + ) + + # Dot product of vector and n-D tensor + + t([[0.0]]) + |> Nx.dot(t([55.0])) + |> assert_equal(t([0.0])) + + t([[[1.0, 2], [3, 4]], [[5, 6], [7, 8]]]) + |> Nx.dot(t([5, 10])) + |> assert_equal( + t([ + [25.0, 55], + [85, 115] + ]) + ) + + # t([5.0, 10], names: [:x]) + # |> Nx.dot(t([[1.0, 2, 3], [4, 5, 6]], names: [:i, :j])) + # |> assert_equal(t( + # [45, 60, 75] + # )) + + # t([[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]], names: [:shard, :batch, :x, :y, :z]) + # |> Nx.dot(t([2.0, 2.0], names: [:data])) + # |> assert_equal(t( + # [ + # [ + # [ + # [6.0, 14.0], + # [22.0, 30.0] + # ] + # ] + # ] + # )) + + # Dot product of n-D and m-D tensors + + # t([[[1.0, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]], names: [:x, :y, :z]) + # |> Nx.dot(t([[[1.0, 2, 3], [3, 4, 5], [5, 6, 7]]], names: [:i, :j, :k])) + # |> assert_equal(t( + # [ + # [ + # [ + # [22, 28, 34] + # ], + # [ + # [49, 64, 79] + # ], + # [ + # [76, 100, 124] + # ] + # ], + # [ + # [ + # [22, 28, 34] + # ], + # [ + # [49, 64, 79] + # ], + # [ + # [76, 100, 124] + # ] + # ] + # ] + # )) + end + + test "dot/6" do + # Contracting along axes + + t1 = t([[1.0, 2], [3, 4]], names: [:x, :y]) + t2 = t([[10.0, 20], [30, 40]], names: [:height, :width]) + + t1 + |> Nx.dot([0], [], t2, [0], []) + |> assert_equal( + t([ + [100, 140], + [140, 200] + ]) + ) + + # TODO: + t1 + |> Nx.dot([0], [], t2, [1], []) + |> assert_equal( + t([ + [70, 150], + [100, 220] + ]) + ) + + t1 + |> Nx.dot([1], [], t2, [0], []) + |> assert_equal( + t([ + [70, 100], + [150, 220] + ]) + ) + + # t1 + # |> Nx.dot([1], [], t2, [1], []) + # |> assert_equal(t( + # [ + # [50, 110], + # [110, 250] + # ] + # )) + + # t1 + # |> Nx.dot([0, 1], [], t2, [0, 1], []) + # |> assert_equal(t(300)) + end + + test "negate" do + # TODO: candle doesn't support unary functions for integers yet + # Nx.negate(1) + # |> assert_equal(t(-1)) + + Nx.negate(1.0) + |> assert_equal(t(-1.0)) + + t([1.0, 2.0, -3.0], type: :f32) + |> Nx.negate() + |> assert_equal(t([-1.0, -2.0, 3.0])) + end + + test "sin" do + Nx.sin(1.0) + |> assert_close(t(0.8414709568023682)) + + t([1.0, 2.0, 3.0]) + |> Nx.sin() + |> assert_close(t([0.8414709568023682, 0.9092974066734314, 0.14112000167369843])) + end + + test "sinh" do + Nx.sinh(1.0) + |> assert_close(t(1.175201177597046)) + + t([1.0, 2, 3]) + |> Nx.sinh() + |> assert_close(t([1.175201177597046, 3.6268603801727295, 10.017874717712402])) + end + + test "exp" do + Nx.exp(1.0) + |> assert_equal(t(2.7182817459106445)) + + t([1.0, 2, 3]) + |> Nx.exp() + |> assert_equal(t([2.7182817459106445, 7.389056205749512, 20.08553695678711])) + end + + test "expm1" do + Nx.expm1(1.0) + |> assert_close(t(1.718281865119934)) + + t([1.0, 2, 3]) + |> Nx.expm1() + |> assert_close(t([1.718281865119934, 6.389056205749512, 19.08553695678711])) + end + + test "cos" do + Nx.cos(1.0) + |> assert_close(t(0.5403022766113281)) + + t([1.0, 2, 3]) + |> Nx.cos() + |> assert_close(t([0.5403022766113281, -0.416146844625473, -0.9899924993515015])) + end + + test "cosh" do + Nx.cosh(1.0) + |> assert_close(t(1.5430806875228882)) + + t([1.0, 2, 3]) + |> Nx.cosh() + |> assert_close(t([1.5430806875228882, 3.762195587158203, 10.067662239074707])) + end + + test "log" do + Nx.log(1.0) + |> assert_equal(t(0.0)) + + t([1.0, 2, 3]) + |> Nx.log() + |> assert_equal(t([0.0, 0.6931471824645996, 1.0986123085021973])) + end + + test "tanh" do + Nx.tanh(1.0) + |> assert_equal(t(0.7615941762924194)) + + t([1.0, 2, 3]) + |> Nx.tanh() + |> assert_equal(t([0.7615941762924194, 0.9640275835990906, 0.9950547814369202])) + end + + test "abs" do + t([-2.0, -1, 0, 1, 2]) + |> Nx.abs() + |> assert_equal(t([2, 1, 0, 1, 2])) + + t([-2, -1, 0, 1, 2]) + |> Nx.abs() + |> assert_equal(t([2, 1, 0, 1, 2])) + end + + test "sqrt" do + Nx.sqrt(1.0) + |> assert_equal(t(1.0)) + + t([1.0, 2, 3]) + |> Nx.sqrt() + |> assert_equal(t([1.0, 1.4142135381698608, 1.7320507764816284])) + end + + test "rsqrt" do + Nx.rsqrt(1.0) + |> assert_equal(t(1.0)) + + t([1.0, 2, 3]) + |> Nx.rsqrt() + |> assert_equal(t([1.0, 0.7071067690849304, 0.5773502588272095])) + end + + test "argmax" do + Nx.argmax(4) + |> assert_equal(t(0)) + + # TODO: Support argmax without specific axis + # t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) + # |> Nx.argmax() + # |> assert_equal(t(10)) + + # t([2.0, 4.0]) + # |> Nx.argmax() + # |> assert_equal(t(1)) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) + |> Nx.argmax(axis: 0) + |> assert_equal( + t([ + [1, 0, 0], + [1, 1, 0] + ]) + ) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) + |> Nx.argmax(axis: :z) + |> assert_equal( + t([ + [0, 2], + [0, 1] + ]) + ) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) + |> Nx.argmax(axis: :y, keep_axis: true) + |> assert_equal( + t([ + [ + [0, 0, 0] + ], + [ + [0, 1, 0] + ] + ]) + ) + end + + test "argmin" do + Nx.argmin(4) + |> assert_equal(t(0)) + + # TODO: Support argmin without specific axis + # t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) + # |> Nx.argmin() + # |> assert_equal(t(4)) + + # t([2.0, 4.0]) + # |> Nx.argmin() + # |> assert_equal(t(0)) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]]) + |> Nx.argmin(axis: 0) + |> assert_equal( + t([ + [0, 0, 0], + [0, 0, 0] + ]) + ) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) + |> Nx.argmin(axis: 1) + |> assert_equal( + t([ + [1, 1, 0], + [1, 0, 0] + ]) + ) + + t([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]], names: [:x, :y, :z]) + |> Nx.argmin(axis: :z) + |> assert_equal( + t([ + [1, 1], + [1, 2] + ]) + ) + end + + test "acos" do + Nx.acos(0.10000000149011612) + |> assert_equal(t(1.4706288576126099)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.acos() + |> assert_equal(t([1.4706288576126099, 1.0471975803375244, 0.4510268568992615])) + end + + test "acosh" do + Nx.acosh(1.0) + |> assert_equal(t(0.0)) + + t([1.0, 2, 3]) + |> Nx.acosh() + |> assert_close(t([0.0, 1.316957950592041, 1.7627471685409546])) + end + + test "asin" do + Nx.asin(0.10000000149011612) + |> assert_equal(t(0.1001674234867096)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.asin() + |> assert_equal(t([0.1001674234867096, 0.5235987901687622, 1.1197694540023804])) + end + + test "asinh" do + Nx.asinh(1.0) + |> assert_close(t(0.8813735842704773)) + + t([1.0, 2, 3]) + |> Nx.asinh() + |> assert_close(t([0.8813735842704773, 1.4436354637145996, 1.8184465169906616])) + end + + test "tan" do + Nx.tan(1.0) + |> assert_close(t(1.5574077367782593)) + + t([1.0, 2, 3]) + |> Nx.tan() + |> assert_close(t([1.5574077367782593, -2.185039758682251, -0.14254654943943024])) + end + + test "atan" do + Nx.atan(0.10000000149011612) + |> assert_close(t(0.09966865181922913)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.atan() + |> assert_close(t([0.09966865181922913, 0.46364760398864746, 0.7328150868415833])) + end + + test "atanh" do + Nx.atanh(0.10000000149011612) + |> assert_close(t(0.10033535212278366)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.atanh() + |> assert_close(t([0.10033535212278366, 0.5493061542510986, 1.4722193479537964])) + end + + test "ceil" do + t([-1, 0, 1]) + |> Nx.ceil() + |> assert_equal(t([-1, 0, 1])) + + t([-1.5, -0.5, 0.5, 1.5]) + |> Nx.ceil() + |> assert_equal(t([-1.0, 0.0, 1.0, 2.0])) + end + + test "floor" do + t([-1, 0, 1]) + |> Nx.floor() + |> assert_equal(t([-1, 0, 1])) + + t([-1.5, -0.5, 0.5, 1.5]) + |> Nx.floor() + |> assert_equal(t([-2.0, -1.0, 0.0, 1.0])) + end + + test "round" do + t([-1, 0, 1]) + |> Nx.round() + |> assert_equal(t([-1, 0, 1])) + + t([-1.5, -0.5, 0.5, 1.5]) + |> Nx.round() + |> assert_equal(t([-2.0, -1.0, 1.0, 2.0])) + end + + test "cbrt" do + Nx.cbrt(1.0) + |> assert_equal(t(1.0)) + + t([1.0, 2, 3]) + |> Nx.cbrt() + |> assert_equal(t([1.0, 1.2599210739135742, 1.4422495365142822])) + end + + test "log1p" do + Nx.log1p(1.0) + |> assert_equal(t(0.6931471824645996)) + + t([1.0, 2, 3]) + |> Nx.log1p() + |> assert_equal(t([0.6931471824645996, 1.0986123085021973, 1.3862943649291992])) + end + + test "bitwise_and" do + Nx.bitwise_and(1, 0) + |> assert_equal(t(0)) + + t([0, 1, 2]) + |> Nx.bitwise_and(1) + |> assert_equal(t([0, 1, 0])) + + t([0, -1, -2]) + |> Nx.bitwise_and(-1) + |> assert_equal(t([0, -1, -2])) + + t([0, 0, 1, 1]) + |> Nx.bitwise_and(t([0, 1, 0, 1])) + |> assert_equal(t([0, 0, 0, 1])) + end + + test "bitwise_or" do + Nx.bitwise_or(1, 0) + |> assert_equal(t(1)) + + t([0, 1, 2]) + |> Nx.bitwise_or(1) + |> assert_equal(t([1, 1, 3])) + + t([0, -1, -2]) + |> Nx.bitwise_or(-1) + |> assert_equal(t([-1, -1, -1])) + + t([0, 0, 1, 1]) + |> Nx.bitwise_or(t([0, 1, 0, 1])) + |> assert_equal(t([0, 1, 1, 1])) + end + + test "bitwise_xor" do + Nx.bitwise_xor(1, 0) + |> assert_equal(t(1)) + + t([1, 2, 3]) + |> Nx.bitwise_xor(2) + |> assert_equal(t([3, 0, 1])) + + t([1, 2, 3], type: :u32) + |> Nx.bitwise_xor(2) + |> assert_equal(t([3, 0, 1])) + + t([-1, -2, -3]) + |> Nx.bitwise_xor(2) + |> assert_equal(t([-3, -4, -1])) + + t([0, 0, 1, 1]) + |> Nx.bitwise_xor(t([0, 1, 0, 1])) + |> assert_equal(t([0, 1, 1, 0])) + end + + test "bitwise_not" do + Nx.bitwise_not(1) + |> assert_equal(t(-2)) + + t([-1, 0, 1]) + |> Nx.bitwise_not() + |> assert_equal(t([0, -1, -2])) + + t([0, 1, 254, 255], type: :u8) + |> Nx.bitwise_not() + |> assert_equal(t([255, 254, 1, 0])) + end + + test "left_shift" do + Nx.left_shift(1, 0) + |> assert_equal(t(1)) + + t([1, 2, 3]) + |> Nx.left_shift(2) + |> assert_equal(t([4, 8, 12])) + + t([1, 1, -1, -1]) + |> Nx.left_shift(t([1, 2, 3, 4])) + |> assert_equal(t([2, 4, -8, -16])) + + t([1, 2, 3], type: :u32) + |> Nx.left_shift(2) + |> assert_equal(t([4, 8, 12])) + + t([1, 2, 3], type: :u32) + |> Nx.left_shift(t(2, type: :u8)) + |> assert_equal(t([4, 8, 12])) + + t([1, 1, 0, 0], type: :u32) + |> Nx.left_shift(t([1, 2, 3, 4])) + |> assert_equal(t([2, 4, 0, 0])) + + t([1, 1, 0, 0], type: :u32) + |> Nx.left_shift(t([1, 2, 3, 4], type: :u8)) + |> assert_equal(t([2, 4, 0, 0])) + end + + test "right_shift" do + Nx.right_shift(1, 0) + |> assert_equal(t(1)) + + t([2, 4, 8]) + |> Nx.right_shift(2) + |> assert_equal(t([0, 1, 2])) + + t([16, 32, -64, -128]) + |> Nx.right_shift(t([1, 2, 3, 4])) + |> assert_equal(t([8, 8, -8, -8])) + + t([2, 4, 8], type: :u32) + |> Nx.right_shift(2) + |> assert_equal(t([0, 1, 2])) + + t([16, 32, -64, -128], type: :u32) + |> Nx.right_shift(t([1, 2, 3, 4])) + |> assert_equal(t([8, 8, 536_870_904, 268_435_448])) + end + + test "is_infinity" do + t([:infinity, :nan, :neg_infinity, 1, 0]) + |> Nx.is_infinity() + |> assert_equal(t([1, 0, 1, 0, 0])) + + t([:infinity, 1, :neg_infinity]) + |> Nx.is_infinity() + |> assert_equal(t([1, 0, 1])) + + # TODO: Not supported for :s64 + # t([1, 0]) + # |> Nx.is_infinity() + # |> assert_equal(t([0, 0])) + end + + test "is_nan" do + t([:nan, 1.0, 0.0]) + |> Nx.is_nan() + |> assert_equal(t([1, 0, 0])) + + t([:nan, :infinity]) + |> Nx.is_nan() + |> assert_equal(t([1, 0])) + + # Complex not yet supported + # t(Complex.new(0, :nan)) + # |> Nx.is_nan() + # |> assert_equal(t(1)) + + t([1.0, 0.0]) + |> Nx.is_nan() + |> assert_equal(t([0, 0])) + end + + test "logical_and" do + Nx.logical_and(1, t([-1, 0, 1])) + |> assert_equal(t([1, 0, 1])) + + t([-1, 0, 1]) + |> Nx.logical_and(t([[-1], [0], [1]])) + |> assert_equal( + t([ + [1, 0, 1], + [0, 0, 0], + [1, 0, 1] + ]) + ) + + t([-1.0, 0.0, 1.0]) + |> Nx.logical_and(t([[-1], [0], [1]])) + |> assert_equal( + t([ + [1, 0, 1], + [0, 0, 0], + [1, 0, 1] + ]) + ) + end + + test "logical_or" do + Nx.logical_or(0, t([-1, 0, 1])) + |> assert_equal(t([1, 0, 1])) + + t([-1, 0, 1]) + |> Nx.logical_or(t([[-1], [0], [1]])) + |> assert_equal( + t([ + [1, 1, 1], + [1, 0, 1], + [1, 1, 1] + ]) + ) + + t([-1.0, 0.0, 1.0]) + |> Nx.logical_or(t([[-1], [0], [1]])) + |> assert_equal( + t([ + [1, 1, 1], + [1, 0, 1], + [1, 1, 1] + ]) + ) + end + + test "logical_xor" do + 0 + |> Nx.logical_xor(t([-1, 0, 1])) + |> assert_equal(t([1, 0, 1])) + + t([-1, 0, 1]) + |> Nx.logical_xor(t([[-1], [0], [1]])) + |> assert_equal( + t([ + [0, 1, 0], + [1, 0, 1], + [0, 1, 0] + ]) + ) + + t([-1.0, 0.0, 1.0]) + |> Nx.logical_xor(t([[-1], [0], [1]])) + |> assert_equal( + t([ + [0, 1, 0], + [1, 0, 1], + [0, 1, 0] + ]) + ) + end + + test "erf" do + Nx.erf(1.0) + |> assert_close(t(0.8427007794380188)) + + Nx.erf(t([1.0, 2, 3])) + |> assert_close(t([0.8427007794380188, 0.9953222870826721, 0.9999778866767883])) + end + + test "erfc" do + Nx.erfc(1.0) + |> assert_close(t(0.15729920566082)) + + Nx.erfc(t([1.0, 2, 3])) + |> assert_close(t([0.15729920566082, 0.004677734803408384, 2.2090496713644825e-5])) + end + + test "erf_inv" do + Nx.erf_inv(0.10000000149011612) + |> assert_close(t(0.08885598927736282)) + + t([0.10000000149011612, 0.5, 0.8999999761581421]) + |> Nx.erf_inv() + |> assert_close(t([0.08885598927736282, 0.4769362807273865, 1.163087010383606])) + + t(0.10000000149011612, type: :f64) + |> Nx.erf_inv() + |> assert_close(t(0.08885598927736282, type: :f64)) + + t([0.10000000149011612, 0.5, 0.8999999761581421], type: :f64) + |> Nx.erf_inv() + |> assert_close( + t([0.0888559891358877, 0.47693629334671295, 1.1630870196442271], type: :f64) + ) + end + + test "sum/2" do + t(42) + |> Nx.sum() + |> assert_equal(t(42)) + + t([1, 2, 3]) + |> Nx.sum() + |> assert_equal(t(6)) + + t([[1.0, 2.0], [3.0, 4.0]]) + |> Nx.sum() + |> assert_equal(t(10.0)) + + t = Nx.iota({2, 2, 3}, names: [:x, :y, :z]) + + Nx.sum(t, axes: [:x]) + |> assert_equal( + t([ + [6, 8, 10], + [12, 14, 16] + ]) + ) + + Nx.sum(t, axes: [:y]) + |> assert_equal( + t([ + [3, 5, 7], + [15, 17, 19] + ]) + ) + + Nx.sum(t, axes: [:z]) + |> assert_equal( + t([ + [3, 12], + [21, 30] + ]) + ) + + Nx.sum(t, axes: [:x, :z]) + |> assert_equal(t([24, 42])) + + Nx.sum(t, axes: [-3]) + |> assert_equal( + t([ + [6, 8, 10], + [12, 14, 16] + ]) + ) + + t([[1, 2], [3, 4]], names: [:x, :y]) + |> Nx.sum(axes: [:x], keep_axes: true) + |> assert_equal( + t([ + [4, 6] + ]) + ) + end + + test "to_batched/2" do + [first, second] = + Nx.iota({2, 2, 2}) + |> Nx.to_batched(1) + |> Enum.to_list() + + first + |> assert_equal( + t([ + [ + [0, 1], + [2, 3] + ] + ]) + ) + + second + |> assert_equal( + t([ + [ + [4, 5], + [6, 7] + ] + ]) + ) + + [first, second] = + Nx.iota({10}) + |> Nx.to_batched(5) + |> Enum.to_list() + + first + |> assert_equal(Nx.tensor([0, 1, 2, 3, 4])) + + second + |> assert_equal(Nx.tensor([5, 6, 7, 8, 9])) + + [first, second, third, fourth] = + Nx.iota({10}) + |> Nx.to_batched(3) + |> Enum.to_list() + + first + |> assert_equal(Nx.tensor([0, 1, 2])) + + second + |> assert_equal(Nx.tensor([3, 4, 5])) + + third + |> assert_equal(Nx.tensor([6, 7, 8])) + + fourth + |> assert_equal(Nx.tensor([9, 0, 1])) + + # TODO: Implement with discard + # [first, second] = + # Nx.iota({10}) + # |> Nx.to_batched(4, leftover: :discard) + # |> Enum.to_list() + + # first + # |> assert_equal(Nx.tensor([0, 1, 2, 3])) + + # second + # |> assert_equal(Nx.tensor([4, 5, 6, 7])) + end + + test "sigmoid/1" do + Nx.sigmoid(1.0) + |> assert_close(t(0.7310585975646973)) + + t([1.0, 2, 3]) + |> Nx.sigmoid() + |> assert_close(t([0.7310585975646973, 0.8807970881462097, 0.9525741338729858])) + end + + test "mean/1" do + t(42) + |> Nx.mean() + |> assert_equal(t(42.0)) + + t([1, 2, 3]) + |> Nx.mean() + |> assert_equal(t(2.0)) + + t([0.1, 0.2, 0.3]) + |> Nx.mean() + |> assert_equal(t(0.2)) + end + + test "pow" do + # Nx.pow(2, 4) + # |> assert_equal(t(16)) + + # t([1, 2, 3], type: :u32) + # |> Nx.pow(t(2, type: :u32)) + # |> assert_equal(t([1, 4, 9])) + + t([1.0, 2.0, 3.0]) + |> Nx.pow(2) + |> assert_equal(t([1.0, 4.0, 9.0])) + + 2 + |> Nx.pow(t([1.0, 2.0, 3.0])) + |> assert_equal(t([2.0, 4.0, 8.0])) + + # t([[2], [3]]) + # |> Nx.pow(t([[4, 5]])) + # |> assert_equal(t( + # [ + # [16, 32], + # [81, 243] + # ] + # )) + end + + test "conv" do + Nx.iota({9}) + |> Nx.reshape({1, 1, 3, 3}) + |> Nx.conv( + Nx.iota({4}) + |> Nx.reshape({4, 1, 1, 1}), + strides: [1, 1] + ) + |> assert_equal( + t([ + [ + [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0] + ], + [ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0] + ], + [ + [0.0, 2.0, 4.0], + [6.0, 8.0, 10.0], + [12.0, 14.0, 16.0] + ], + [ + [0.0, 3.0, 6.0], + [9.0, 12.0, 15.0], + [18.0, 21.0, 24.0] + ] + ] + ]) + ) + + # input/output permutation + + result = + Nx.iota({1, 3, 3, 6}) + |> Nx.conv( + 1 |> Nx.broadcast({2, 6, 1, 1}), + input_permutation: [0, 3, 1, 2], + output_permutation: [0, 3, 1, 2] + ) + + assert result.shape == {1, 3, 3, 2} + + result + |> assert_close( + t([ + [ + [15.0, 15.0], + [51.0, 51.0], + [87.0, 87.0] + ], + [ + [123.0, 123.0], + [159.0, 159.0], + [195.0, 195.0] + ], + [ + [231.0, 231.0], + [267.0, 267.0], + [303.0, 303.0] + ] + ]) + ) + + # Nx.iota({9}) + # |> Nx.reshape({1, 1, 3, 3}) + # |> Nx.conv( + # Nx.iota({8}) + # |> Nx.reshape({4, 1, 2, 1}), + # strides: 2, + # padding: :same, + # kernel_dilation: [2, 1] + # ) + # |> assert_equal(t( + # [ + # [ + # [ + # [3.0, 5.0], + # [0.0, 0.0] + # ], + # [ + # [9.0, 15.0], + # [6.0, 10.0] + # ], + # [ + # [15.0, 25.0], + # [12.0, 20.0] + # ], + # [ + # [21.0, 35.0], + # [18.0, 30.0] + # ] + # ] + # ] + # )) + end + + test "reduce_max" do + t(42) + |> Nx.reduce_max() + |> assert_equal(t(42)) + + t(42.0) + |> Nx.reduce_max() + |> assert_equal(t(42.0)) + + t([1, 2, 3]) + |> Nx.reduce_max() + |> assert_equal(t(3)) + + t([[3, 1, 4], [2, 1, 1]], names: [:x, :y]) + |> Nx.reduce_max(axes: [:x]) + |> assert_equal(t([3, 1, 4])) + + t([[3, 1, 4], [2, 1, 1]], names: [:x, :y]) + |> Nx.reduce_max(axes: [:y]) + |> assert_equal(t([4, 2])) + + # t([[[1, 2], [4, 5]], [[2, 4], [3, 8]]], names: [:x, :y, :z]) + # |> Nx.reduce_max(axes: [:x, :z]) + # |> assert_equal(t([4, 8])) + + # t([[[1, 2], [4, 5]], [[2, 4], [3, 8]]], names: [:x, :y, :z]) + # |> Nx.reduce_max(axes: [:x, :z], keep_axes: true) + # |> assert_equal(t( + # [ + # [ + # [4], + # [8] + # ] + # ] + # )) + end + + test "reduce_min" do + Nx.reduce_min(t(42)) + |> assert_equal(t(42)) + + Nx.reduce_min(t(42.0)) + |> assert_equal(t(42.0)) + + Nx.reduce_min(t([1, 2, 3])) + |> assert_equal(t(1)) + + t([[3, 1, 4], [2, 1, 1]], names: [:x, :y]) + |> Nx.reduce_min(axes: [:x]) + |> assert_equal(t([2, 1, 1])) + + t([[3, 1, 4], [2, 1, 1]], names: [:x, :y]) + |> Nx.reduce_min(axes: [:y]) + |> assert_equal(t([1, 1])) + + # t([[[1, 2], [4, 5]], [[2, 4], [3, 8]]], names: [:x, :y, :z]) + # |> Nx.reduce_min(axes: [:x, :z]) + # |> assert_equal(t([1, 3])) + + # t([[[1, 2], [4, 5]], [[2, 4], [3, 8]]], names: [:x, :y, :z]) + # |> Nx.reduce_min(axes: [:x, :z], keep_axes: true) + # |> assert_equal(t( + # [ + # [ + # [1], + # [3] + # ] + # ] + # )) + end + + test "take_along_axis" do + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.take_along_axis( + t([ + [0, 0, 2, 2, 1, 1], + [2, 2, 1, 1, 0, 0] + ]), + axis: 1 + ) + |> assert_equal( + t([ + [1, 1, 3, 3, 2, 2], + [6, 6, 5, 5, 4, 4] + ]) + ) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.take_along_axis( + t([ + [0, 1, 1], + [1, 0, 0], + [0, 1, 0] + ]), + axis: 0 + ) + |> assert_equal( + t([ + [1, 5, 6], + [4, 2, 3], + [1, 5, 3] + ]) + ) + end + + test "gather" do + t([1, 2, 3, 4]) + |> Nx.gather(t([[3], [1], [2]])) + |> assert_equal(t([4, 2, 3])) + + # t([[1, 2], [3, 4]]) + # |> Nx.gather(t([[1, 1], [0, 1], [1, 0]])) + # |> assert_equal(t([4, 2, 3])) + + # t([[1, 2], [3, 4]]) + # |> Nx.gather(t([[[1, 1], [0, 0]], [[1, 0], [0, 1]]])) + # |> assert_equal(t( + # [ + # [4, 1], + # [3, 2] + # ] + # )) + + # t([[[1, 2], [11, 12]], [[101, 102], [111, 112]]]) + # |> Nx.gather(t([[0, 0, 0], [0, 1, 1], [1, 1, 1]])) + # |> assert_equal(t([1, 12, 112])) + end + + test "indexed_add" do + t([1.0]) + |> Nx.indexed_add(t([[0], [0]]), t([1, 1])) + |> assert_equal(t([3.0])) + + t([1]) + |> Nx.indexed_add(t([[0], [0]]), t([1.0, 1.0])) + |> assert_equal(t([3.0])) + + t([1], type: :u8) + |> Nx.indexed_add(t([[0], [0]]), t([1, 1], type: :s64)) + |> assert_equal(t([3])) + + # Nx.iota({1, 2, 3}) + # |> Nx.indexed_add( + # t([[0, 0, 0], [0, 1, 1], [0, 0, 0], [0, 0, 2], [0, 1, 2]]), + # t([1, 3, 1, -2, 5]) + # ) + # |> assert_equal(t( + # [ + # [ + # [2, 1, 0], + # [3, 7, 10] + # ] + # ] + # )) + end + + test "transpose" do + t(1) + |> Nx.transpose() + |> assert_equal(t(1)) + + Nx.iota({2, 3, 4}, names: [:x, :y, :z]) + |> Nx.transpose() + |> assert_equal( + t([ + [ + [0, 12], + [4, 16], + [8, 20] + ], + [ + [1, 13], + [5, 17], + [9, 21] + ], + [ + [2, 14], + [6, 18], + [10, 22] + ], + [ + [3, 15], + [7, 19], + [11, 23] + ] + ]) + ) + + t(1) + |> Nx.transpose(axes: []) + |> assert_equal(t(1)) + + Nx.iota({2, 3, 4}, names: [:batch, :x, :y]) + |> Nx.transpose(axes: [2, 1, :batch]) + |> assert_equal( + t([ + [ + [0, 12], + [4, 16], + [8, 20] + ], + [ + [1, 13], + [5, 17], + [9, 21] + ], + [ + [2, 14], + [6, 18], + [10, 22] + ], + [ + [3, 15], + [7, 19], + [11, 23] + ] + ]) + ) + + Nx.iota({2, 3, 4}, names: [:batch, :x, :y]) + |> Nx.transpose(axes: [:y, :batch, :x]) + |> assert_equal( + t([ + [ + [0, 4, 8], + [12, 16, 20] + ], + [ + [1, 5, 9], + [13, 17, 21] + ], + [ + [2, 6, 10], + [14, 18, 22] + ], + [ + [3, 7, 11], + [15, 19, 23] + ] + ]) + ) + + Nx.iota({2, 3, 4}, names: [:batch, :x, :y]) + |> Nx.transpose(axes: [:batch, :y, :x]) + |> assert_equal( + t([ + [ + [0, 4, 8], + [1, 5, 9], + [2, 6, 10], + [3, 7, 11] + ], + [ + [12, 16, 20], + [13, 17, 21], + [14, 18, 22], + [15, 19, 23] + ] + ]) + ) + end + + test "put_slice" do + t([0, 1, 2, 3, 4]) + |> Nx.put_slice([2], Nx.tensor([5, 6])) + |> assert_equal(t([0, 1, 5, 6, 4])) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.put_slice([0, 0], t([[7, 8, 9], [10, 11, 12]])) + |> assert_equal( + t([ + [7, 8, 9], + [10, 11, 12] + ]) + ) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.put_slice([0, 1], t([[7, 8], [9, 10]])) + |> assert_equal( + t([ + [1, 7, 8], + [4, 9, 10] + ]) + ) + + # t([[1, 2, 3], [4, 5, 6]]) + # |> Nx.put_slice([t(0), t(1)], t([[10.0, 11.0]])) + # |> assert_equal(t( + # [ + # [1.0, 10.0, 11.0], + # [4.0, 5.0, 6.0] + # ] + # )) + + # t([[1, 2, 3], [4, 5, 6]]) + # |> Nx.put_slice([1, 1], t([[7, 8], [9, 10]])) + # |> assert_equal(t( + # [ + # [1, 7, 8], + # [4, 9, 10] + # ] + # )) + + t([ + [ + [1, 2], + [3, 4] + ], + [ + [4, 5], + [6, 7] + ] + ]) + |> Nx.put_slice([0, 0, 1], t([[[8], [9]], [[10], [11]]])) + |> assert_equal( + t([ + [ + [1, 8], + [3, 9] + ], + [ + [4, 10], + [6, 11] + ] + ]) + ) + end + + test "pad" do + t(1) + |> Nx.pad(0, []) + |> assert_equal(t(1)) + + t([1, 2, 3], names: [:data]) + |> Nx.pad(0, [{1, 1, 0}]) + |> assert_equal(t([0, 1, 2, 3, 0])) + + # t([[1, 2, 3], [4, 5, 6]]) + # |> Nx.pad(0, [{0, 0, 1}, {0, 0, 1}]) + # |> assert_equal(t( + # [ + # [1, 0, 2, 0, 3], + # [0, 0, 0, 0, 0], + # [4, 0, 5, 0, 6] + # ] + # )) + + # Nx.pad(Nx.tensor([[1, 2, 3], [4, 5, 6]]), 0, [{1, 1, 0}, {1, 1, 0}]) + # [ + # [0, 0, 0, 0, 0], + # [0, 1, 2, 3, 0], + # [0, 4, 5, 6, 0], + # [0, 0, 0, 0, 0] + # ] + # > + + # tensor = Nx.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + # Nx.pad(tensor, 0, [{0, 2, 0}, {1, 1, 0}, {1, 0, 0}]) + # [ + # [ + # [0, 0, 0], + # [0, 1, 2], + # [0, 3, 4], + # [0, 0, 0] + # ], + # [ + # [0, 0, 0], + # [0, 5, 6], + # [0, 7, 8], + # [0, 0, 0] + # ], + # [ + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0] + # ], + # [ + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0] + # ] + # ] + + # tensor = Nx.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + # Nx.pad(tensor, 0, [{1, 0, 0}, {1, 1, 0}, {0, 1, 0}]) + # [ + # [ + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0], + # [0, 0, 0] + # ], + # [ + # [0, 0, 0], + # [1, 2, 0], + # [3, 4, 0], + # [0, 0, 0] + # ], + # [ + # [0, 0, 0], + # [5, 6, 0], + # [7, 8, 0], + # [0, 0, 0] + # ] + # ] + + # tensor = Nx.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]) + # Nx.pad(tensor, 0.0, [{1, 2, 0}, {1, 0, 0}, {0, 1, 0}]) + # [ + # [ + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0] + # ], + # [ + # [0.0, 0.0, 0.0], + # [1.0, 2.0, 0.0], + # [3.0, 4.0, 0.0] + # ], + # [ + # [0.0, 0.0, 0.0], + # [5.0, 6.0, 0.0], + # [7.0, 8.0, 0.0] + # ], + # [ + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0] + # ], + # [ + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0] + # ] + # ] + + # Nx.pad(Nx.tensor([0, 1, 2, 3, 0]), 0, [{-1, -1, 0}]) + # [1, 2, 3] + + # tensor = Nx.tensor([ + # [[0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], + # [[0, 0, 0], [1, 2, 0], [3, 4, 0], [0, 0, 0]], + # [[0, 0, 0], [5, 6, 0], [7, 8, 0], [0, 0, 0]] + # ]) + # Nx.pad(tensor, 0, [{-1, 0, 0}, {-1, -1, 0}, {0, -1, 0}]) + # [ + # [ + # [1, 2], + # [3, 4] + # ], + # [ + # [5, 6], + # [7, 8] + # ] + # ] + + # t([[0, 1, 2, 3], [0, 4, 5, 6]]) + # |> Nx.pad(0, [{0, 0, 0}, {-1, 1, 0}]) + # |> assert_equal(t( + # [ + # [1, 2, 3, 0], + # [4, 5, 6, 0] + # ] + # )) + + # t([[0, 1, 2], [3, 4, 5]], type: :f32) + # |> Nx.pad(0, [{-1, 2, 0}, {1, -1, 0}]) + # |> assert_equal(t( + # [ + # [0.0, 3.0, 4.0], + # [0.0, 0.0, 0.0], + # [0.0, 0.0, 0.0] + # ] + # ) + end + + test "take" do + t([[1, 2], [3, 4]]) + |> Nx.take(t([1, 0, 1])) + |> assert_equal( + t([ + [3, 4], + [1, 2], + [3, 4] + ]) + ) + + t([[1, 2], [3, 4]]) + |> Nx.take(t([1, 0, 1]), axis: 1) + |> assert_equal( + t([ + [2, 1, 2], + [4, 3, 4] + ]) + ) + + t([[[1, 2], [11, 12]], [[101, 102], [111, 112]]]) + |> Nx.take(t([1, 0, 1]), axis: 1) + |> assert_equal( + t([ + [ + [11, 12], + [1, 2], + [11, 12] + ], + [ + [111, 112], + [101, 102], + [111, 112] + ] + ]) + ) + + # t([[1, 2], [11, 12]]) + # |> Nx.take(t([[0, 0], [1, 1], [0, 0]]), axis: 1) + # |> assert_equal(t( + # [ + # [ + # [1, 1], + # [2, 2], + # [1, 1] + # ], + # [ + # [11, 11], + # [12, 12], + # [11, 11] + # ] + # ] + # )) + + # t([[[1, 2], [11, 12]], [[101, 102], [111, 112]]]) + # |> Nx.take(t([[0, 0, 0], [1, 1, 1], [0, 0, 0]]), axis: 1) + # |> assert_equal(t( + # [ + # [ + # [ + # [1, 2], + # [1, 2], + # [1, 2] + # ], + # [ + # [11, 12], + # [11, 12], + # [11, 12] + # ], + # [ + # [1, 2], + # [1, 2], + # [1, 2] + # ] + # ], + # [ + # [ + # [101, 102], + # [101, 102], + # [101, 102] + # ], + # [ + # [111, 112], + # [111, 112], + # [111, 112] + # ], + # [ + # [101, 102], + # [101, 102], + # [101, 102] + # ] + # ] + # ] + # )) + end + + test "clip" do + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.clip(2, 4) + |> assert_equal( + t([ + [2, 2, 3], + [4, 4, 4] + ]) + ) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.clip(2.0, 3) + |> assert_equal( + t([ + [2.0, 2.0, 3.0], + [3.0, 3.0, 3.0] + ]) + ) + + t([[1, 2, 3], [4, 5, 6]]) + |> Nx.clip(t(2.0), Nx.max(1.0, 3.0)) + |> assert_equal( + t([ + [2.0, 2.0, 3.0], + [3.0, 3.0, 3.0] + ]) + ) + + t([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + |> Nx.clip(2, 6.0) + |> assert_equal( + t([ + [2.0, 2.0, 3.0], + [4.0, 5.0, 6.0] + ]) + ) + + t([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], type: :f32) + |> Nx.clip(1, 4) + |> assert_equal( + t([ + [1.0, 2.0, 3.0], + [4.0, 4.0, 4.0] + ]) + ) + end + + test "not_equal" do + Nx.not_equal(1, 2) + |> assert_equal(t(1)) + + t([1, 2, 3]) + |> Nx.not_equal(t(1)) + |> assert_equal(t([0, 1, 1])) + + t([1, 1, 2]) + |> Nx.not_equal(t([1, 2, 3])) + |> assert_equal(t([0, 1, 1])) + + t([[1, 4, 2], [4, 5, 6]]) + |> Nx.not_equal(t([[1, 3, 2], [4, 2, 1]])) + |> assert_equal( + t([ + [0, 1, 0], + [0, 1, 1] + ]) + ) + end + + test "all" do + t(0) + |> Nx.all() + |> assert_equal(t(0)) + + t(10) + |> Nx.all() + |> assert_equal(t(1)) + + t([0, 1, 2]) + |> Nx.all() + |> assert_equal(t(0)) + + t([[-1, 0, 1], [2, 3, 4]], names: [:x, :y]) + |> Nx.all(axes: [:x]) + |> assert_equal(t([1, 0, 1])) + + t([[-1, 0, 1], [2, 3, 4]], names: [:x, :y]) + |> Nx.all(axes: [:y]) + |> assert_equal(t([0, 1])) + + t([[-1, 0, 1], [2, 3, 4]], names: [:x, :y]) + |> Nx.all(axes: [:y], keep_axes: true) + |> assert_equal( + t([ + [0], + [1] + ]) + ) + + tensor = Nx.tensor([[[1, 2], [0, 4]], [[5, 6], [7, 8]]], names: [:x, :y, :z]) + + tensor + |> Nx.all(axes: [:x, :y]) + |> assert_equal(t([0, 1])) + + tensor + |> Nx.all(axes: [:y, :z]) + |> assert_equal(t([0, 1])) + + tensor + |> Nx.all(axes: [:x, :z]) + |> assert_equal(t([1, 0])) + + tensor + |> Nx.all(axes: [:x, :y], keep_axes: true) + |> assert_equal( + t([ + [ + [0, 1] + ] + ]) + ) + + tensor + |> Nx.all(axes: [:y, :z], keep_axes: true) + |> assert_equal( + t([ + [ + [0] + ], + [ + [1] + ] + ]) + ) + + tensor + |> Nx.all(axes: [:x, :z], keep_axes: true) + |> assert_equal( + t([ + [ + [1], + [0] + ] + ]) + ) + end + + test "any" do + t([0, 1, 2]) + |> Nx.any() + |> assert_equal(t(1)) + + t([[0, 1, 0], [0, 1, 2]], names: [:x, :y]) + |> Nx.any(axes: [:x]) + |> assert_equal(t([0, 1, 1])) + + t([[0, 1, 0], [0, 1, 2]], names: [:x, :y]) + |> Nx.any(axes: [:y]) + |> assert_equal(t([1, 1])) + + tensor = t([[0, 1, 0], [0, 1, 2]], names: [:x, :y]) + + tensor + |> Nx.any(axes: [:x], keep_axes: true) + |> assert_equal(t([[0, 1, 1]])) + + tensor + |> Nx.any(axes: [:y], keep_axes: true) + |> assert_equal(t([[1], [1]])) + end + + if Candlex.Backend.cuda_available?() do + test "different devices" do + t([1, 2, 3], backend: {Candlex.Backend, device: :cpu}) + |> Nx.add(t([10, 20, 30], backend: {Candlex.Backend, device: :cuda})) + |> assert_equal(t([11, 22, 33])) + + t([1, 2, 3], backend: {Candlex.Backend, device: :cuda}) + |> Nx.add(t([10, 20, 30], backend: {Candlex.Backend, device: :cpu})) + |> assert_equal(t([11, 22, 33])) + end + end + + test "backend_transfer" do + t([1, 2, 3], backend: Nx.BinaryBackend) + |> Nx.backend_transfer({Candlex.Backend, device: :cpu}) + |> assert_equal(t([1, 2, 3])) + + t([1, 2, 3], backend: {Candlex.Backend, device: :cpu}) + |> Nx.backend_transfer(Nx.BinaryBackend) + |> assert_equal(t([1, 2, 3])) + + t([1, 2, 3], backend: {Candlex.Backend, device: :cpu}) + |> Nx.backend_transfer({Candlex.Backend, device: :cpu}) + |> assert_equal(t([1, 2, 3])) + end + end + + defp t(values, opts \\ []) do + opts = + [backend: Candlex.Backend] + |> Keyword.merge(opts) + + Nx.tensor(values, opts) + end + + defp check(value, opts) do + tensor = t(value, opts) + + tensor + # |> IO.inspect() + |> Nx.to_binary() + + # |> IO.inspect() + + opts = + [backend: Nx.BinaryBackend] + |> Keyword.merge(opts) + + assert Nx.backend_copy(tensor) == t(value, opts) + assert Nx.backend_transfer(tensor) == t(value, opts) + end +end diff --git a/candlex/test/support/candlex_case.ex b/candlex/test/support/candlex_case.ex new file mode 100644 index 0000000000..3f11eae98d --- /dev/null +++ b/candlex/test/support/candlex_case.ex @@ -0,0 +1,45 @@ +defmodule Candlex.Case do + @moduledoc """ + Test case for tensor assertions + """ + + use ExUnit.CaseTemplate + + using do + quote do + import Candlex.Case + end + end + + def assert_equal(left, right) do + equals = + left + |> Nx.equal(right) + # |> Nx.logical_or(Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right))) + |> Nx.all() + |> Nx.to_number() + + if equals != 1 || Nx.shape(left) != Nx.shape(right) do + flunk(""" + Tensor assertion failed. + left: #{inspect(left)} + right: #{inspect(right)} + """) + end + end + + def assert_close(left, right) do + equals = + left + |> Nx.all_close(right, atol: 1.0e-4, rtol: 1.0e-4) + |> Nx.backend_transfer(Nx.BinaryBackend) + + if equals != Nx.tensor(1, type: {:u, 8}, backend: Nx.BinaryBackend) do + flunk(""" + Tensor assertion failed. + left: #{inspect(left)} + right: #{inspect(right)} + """) + end + end +end diff --git a/candlex/test/support/nx_case.ex b/candlex/test/support/nx_case.ex new file mode 100644 index 0000000000..949a54915b --- /dev/null +++ b/candlex/test/support/nx_case.ex @@ -0,0 +1,45 @@ +defmodule Nx.Case do + @moduledoc """ + Test case for tensor assertions + """ + + use ExUnit.CaseTemplate + + using do + quote do + import Nx.Case + end + end + + def assert_equal(left, right) do + equals = + left + |> Nx.equal(right) + # |> Nx.logical_or(Nx.is_nan(left) |> Nx.logical_and(Nx.is_nan(right))) + |> Nx.all() + |> Nx.to_number() + + if equals != 1 || Nx.shape(left) != Nx.shape(right) do + flunk(""" + Tensor assertion failed. + left: #{inspect(left)} + right: #{inspect(right)} + """) + end + end + + def assert_close(left, right) do + equals = + left + |> Nx.all_close(right, atol: 1.0e-4, rtol: 1.0e-4) + |> Nx.backend_transfer(Nx.BinaryBackend) + + if equals != Nx.tensor(1, type: {:u, 8}, backend: Nx.BinaryBackend) do + flunk(""" + Tensor assertion failed. + left: #{inspect(left)} + right: #{inspect(right)} + """) + end + end +end diff --git a/candlex/test/test_helper.exs b/candlex/test/test_helper.exs new file mode 100644 index 0000000000..13695bf434 --- /dev/null +++ b/candlex/test/test_helper.exs @@ -0,0 +1,2 @@ +Application.put_env(:nx, :default_backend, Candlex.Backend) +ExUnit.start()