Skip to content

Commit

Permalink
I tried so hard... And got so far...
Browse files Browse the repository at this point in the history
  • Loading branch information
mratsim committed Jan 10, 2021
1 parent ade3e1c commit 9e000fa
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 15 deletions.
38 changes: 31 additions & 7 deletions flambeau/raw_bindings/data_api.nim
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ type
[Batch] = object

func next*(it: var TorchDataIterator) {.importcpp: "(++#)".}
func get*[Batch](it: var TorchDataIterator[Batch]): lent Batch {.importcpp: "(*#)".}
func get*[Batch](it: var TorchDataIterator[Batch]): Batch {.importcpp: "(*#)".}
# TODO: this should be lent?
func `==`*(it1, it2: TorchDataIterator): bool {.importcpp: "# == #".}

# #######################################################################
Expand Down Expand Up @@ -247,13 +248,13 @@ type
StatelessDataLoader*
{.bycopy, pure,
importcpp: "torch::data::StatelessDataLoader".}
[Dataset, Sampler]
[D, S] # Dataset, Sampler
= object of DataLoaderBase

StatefulDataLoader*
{.bycopy, pure,
importcpp: "torch::data::StatefulDataLoader".}
[Dataset]
[D] # Dataset
= object of DataLoaderBase

DataLoaderOptions*
Expand All @@ -265,7 +266,7 @@ type
# and https://github.com/nim-lang/Nim/issues/16655
# BatchDataset and Dataset have no generics attached
# and so we can't infer their Iterator type :/
func start*(dl: DataLoaderBase
func start*(dl: StatelessDataLoader
): TorchDataIterator[Example[Tensor, Tensor]]
{.importcpp: "#.begin()".}
## Start an iterator
Expand All @@ -274,13 +275,36 @@ func start*(dl: DataLoaderBase
## and so the output is fixed to Example[Tensor, Tensor]
## which is the output of the Stack transform

func stop*(dl: DataLoaderBase
func stop*(dl: StatelessDataLoader
): TorchDataIterator[Example[Tensor, Tensor]]
{.importcpp: "#.end()".}
## Returns a sentinel value that denotes
## the end of an iterator

iterator items*(dl: DataLoaderBase): Example[Tensor, Tensor] =
func start*[D, S](
dl: CppUniquePtr[StatelessDataLoader[D, S]]
): TorchDataIterator[Example[Tensor, Tensor]]
{.importcpp: "#->begin()".}
## Start an iterator
## Note: due to compiler bugs with C++ interop
## we can't attach the DataLoaderBase generic type,
## and so the output is fixed to Example[Tensor, Tensor]
## which is the output of the Stack transform
##
## Overload as StatelessDataLoader has no default constructors
## So we don't want Nim to use temporaries

func stop*[D, S](
dl: CppUniquePtr[StatelessDataLoader[D, S]]
): TorchDataIterator[Example[Tensor, Tensor]]
{.importcpp: "#->end()".}
## Returns a sentinel value that denotes
## the end of an iterator
##
## Overload as StatelessDataLoader has no default constructors
## So we don't want Nim to use temporaries

iterator items*(dl: StatelessDataLoader or CppUniquePtr[StatelessDataLoader]): Example[Tensor, Tensor] =
# TODO: lent Example[Tensor, Tensor],
# borrow checker complains about 'cur' escaping it's frame
# but `cur.get()` already returns a borrowed view
Expand All @@ -290,7 +314,7 @@ iterator items*(dl: DataLoaderBase): Example[Tensor, Tensor] =
yield cur.get()
cur.next()

iterator pairs*(dl: DataLoaderBase): tuple[index: int, value: Example[Tensor, Tensor]] =
iterator pairs*(dl: StatelessDataLoader or CppUniquePtr[StatelessDataLoader]): tuple[index: int, value: Example[Tensor, Tensor]] =
# TODO: lent Example[Tensor, Tensor]
# borrow checker complains about 'cur' escaping it's frame
# but `cur.get()` already returns a borrowed view
Expand Down
2 changes: 1 addition & 1 deletion flambeau/raw_bindings/neural_nets.nim
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ func reset_parameters*(linear: Linear){.importcpp: "#.reset_parameters()".}

# pretty_print

func forward*(linear: Linear, input: Tensor): Tensor {.importcpp: "#.forward(#)".}
func forward*(linear: Linear, input: Tensor): Tensor {.importcpp: "#->forward(#)".}
## Transforms the ``input`` tensor
## by multiplying with the ``weight``
## and optionally adding the ``bias``,
Expand Down
1 change: 1 addition & 0 deletions flambeau/raw_bindings/optimizers.nim
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@ func init*(
{.constructor, importcpp:"torch::optim::SGD(@)".}

func step*(optim: var SGD){.importcpp: "#.step()".}
func zero_grad*(optim: var SGD){.importcpp: "#.zero_grad()".}
5 changes: 5 additions & 0 deletions flambeau/raw_bindings/tensors.nim
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ const torchHeader* = torchHeadersPath / "torch/torch.h"

{.push header: torchHeader.}

{.passC: "-Wfatal-errors".} # The default "-fmax-errors=3" is unreadable

# Assumptions
# -----------------------------------------------------------------------
#
Expand Down Expand Up @@ -322,6 +324,9 @@ func vulkan*(a: Tensor): Tensor {.importcpp: "#.vulkan()".}
# libtorch/include/ATen/TensorIndexing.h
# and https://pytorch.org/cppdocs/notes/tensor_indexing.html

func item*(a: Tensor, T: typedesc): T {.importcpp: "#.item<'0>()".}
## Extract the scalar from a 0-dimensional tensor

# Unsure what those corresponds to in Python
# func `[]`*(a: Tensor, index: Scalar): Tensor {.importcpp: "#[#]".}
# func `[]`*(a: Tensor, index: Tensor): Tensor {.importcpp: "#[#]".}
Expand Down
37 changes: 30 additions & 7 deletions proof_of_concepts/poc09_end_to_end.nim
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# This is a port of the C++ end-to-end example
# at https://pytorch.org/cppdocs/frontend.html

import ../flambeau
import
../flambeau,
std/[enumerate, strformat]

# Argh, need Linear{nullptr} in the codegen
# so we cheat by inlining C++
Expand All @@ -20,15 +22,16 @@ struct Net: public torch::nn::Module {
};
"""].}

type Net{.importcpp.} = object of Module
type Net{.pure, importcpp.} = object of Module
fc1: Linear
fc2: Linear
fc3: Linear

proc init(T: type Net): Net =
result.fc1 = result.register_module("fc1", Linear.init(784, 64))
result.fc2 = result.register_module("fc2", Linear.init(64, 32))
result.fc3 = result.register_module("fc3", Linear.init(32, 10))
proc init(net: var Net) =
# Note: PyTorch Model serialization requires shared_ptr
net.fc1 = net.register_module("fc1", Linear.init(784, 64))
net.fc2 = net.register_module("fc2", Linear.init(64, 32))
net.fc3 = net.register_module("fc3", Linear.init(32, 10))

func forward*(net: Net, x: Tensor): Tensor =
var x = x
Expand All @@ -39,7 +42,8 @@ func forward*(net: Net, x: Tensor): Tensor =
return x

proc main() =
let net = Net.init() # TODO: make_shared
let net = make_shared(Net)
net.init()

let data_loader = make_data_loader(
mnist("build/mnist").map(Stack[Example[Tensor, Tensor]].init()),
Expand All @@ -51,4 +55,23 @@ proc main() =
learning_rate = 0.01
)

for epoch in 1 .. 10:
# Iterate the data loader to yield batches from the dataset.
for batch_index, batch in data_loader.pairs():
# Reset gradients.
optimizer.zero_grad()
# Execute the model on the input data.
let prediction = net.forward(batch.data)
# Compute a loss value to judge the prediction of our model.
var loss = nll_loss(prediction, batch.target)
# Compute the gradients of the loss w.r.t. the parameters of our model.
loss.backward()
# Update the parameters based on the calculated gradients.
optimizer.step()
# output the loss and checkpoint every 100 batches.
if batch_index mod 100 == 0:
echo &"Epoch: {epoch} | Batch: {batch_index} | Loss: {loss.item(float32)}"
# Serialize your model periodically as a checkpoint.
net.save("net.pt")

main()

0 comments on commit 9e000fa

Please sign in to comment.