Skip to content

Commit f288665

Browse files
fantessileht
authored andcommitted
feat(torch): nbeats
feat(nbeats): make nbeats able to handle signals that are more than 1D feat(nbeats): expose nbeats net definition in api
1 parent a471b82 commit f288665

15 files changed

+1289
-86
lines changed

CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ if (USE_TORCH)
773773
if (NOT USE_CPU_ONLY AND CUDA_FOUND)
774774
list(APPEND TORCH_LIB_DEPS ${TORCH_LOCATION}/lib/libc10_cuda.so ${TORCH_LOCATION}/lib/libtorch_cuda.so)
775775
else()
776-
list(APPEND TORCH_LIB_DEPS iomp5)
776+
list(APPEND TORCH_LIB_DEPS ${TORCH_LOCATION}/lib/libtorch_cpu.so iomp5)
777777
endif()
778778

779779
set(TORCH_INC_DIR ${TORCH_LOCATION}/include/ ${TORCH_LOCATION}/include/torch/csrc/api/include/ ${CMAKE_BINARY_DIR}/pytorch/src/pytorch/torch/include/torch/csrc/api/include ${TORCH_LOCATION}/.. ${CMAKE_BINARY_DIR}/src)

examples/all/sinus/gen.py

-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
os.remove("predict/"+f)
3232
os.rmdir("predict")
3333

34-
3534
os.mkdir("train")
3635
os.mkdir("test")
3736
os.mkdir("predict")

src/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ if (USE_TORCH)
8181
backends/torch/torchinputconns.cc
8282
backends/torch/db.cpp
8383
backends/torch/db_lmdb.cpp
84+
backends/torch/native/templates/nbeats.cc
8485
basegraph.cc
8586
caffegraphinput.cc
8687
backends/torch/torchgraphbackend.cc

src/backends/torch/native/native.h

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#ifndef NATIVE_H
2+
#define NATIVE_H
3+
4+
#include "native_net.h"
5+
#include "native_factory.h"
6+
7+
#endif
+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#ifndef NATIVE_FACTORY_H
2+
#define NATIVE_FACTORY_H
3+
4+
#include "native_net.h"
5+
#include "./templates/nbeats.h"
6+
#include "../torchinputconns.h"
7+
#include "apidata.h"
8+
9+
namespace dd
10+
{
11+
class NativeFactory
12+
{
13+
public:
14+
template <class TInputConnectorStrategy>
15+
static NativeModule *from_template(const std::string tdef,
16+
const APIData template_params,
17+
const TInputConnectorStrategy &inputc)
18+
{
19+
(void)(tdef);
20+
(void)(template_params);
21+
(void)(inputc);
22+
return nullptr;
23+
}
24+
25+
static bool valid_template_def(std::string tdef)
26+
{
27+
if (tdef.find("nbeats") != std::string::npos)
28+
return true;
29+
return false;
30+
}
31+
32+
static bool is_timeserie(std::string tdef)
33+
{
34+
if (tdef.find("nbeats") != std::string::npos)
35+
return true;
36+
return false;
37+
}
38+
};
39+
40+
template <>
41+
NativeModule *NativeFactory::from_template<CSVTSTorchInputFileConn>(
42+
const std::string tdef, const APIData template_params,
43+
const CSVTSTorchInputFileConn &inputc)
44+
{
45+
if (tdef.find("nbeats") != std::string::npos)
46+
{
47+
std::vector<std::string> p = template_params.get("template_params")
48+
.get<std::vector<std::string>>();
49+
return new NBeats(inputc, p);
50+
}
51+
else
52+
return nullptr;
53+
}
54+
}
55+
#endif
+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#ifndef NATIVE_NET_H
2+
#define NATIVE_NET_H
3+
4+
#include "torch/torch.h"
5+
6+
namespace dd
7+
{
8+
9+
class NativeModule : public torch::nn::Module
10+
{
11+
public:
12+
virtual torch::Tensor forward(torch::Tensor x) = 0;
13+
virtual ~NativeModule()
14+
{
15+
}
16+
/**
17+
* \brief see torch::module::to
18+
* @param device cpu / gpu
19+
* @param non_blocking
20+
*/
21+
virtual void to(torch::Device device, bool non_blocking = false)
22+
{
23+
torch::nn::Module::to(device, non_blocking);
24+
_device = device;
25+
}
26+
27+
/**
28+
* \brief see torch::module::to
29+
* @param dtype : torch::kFloat32 or torch::kFloat64
30+
* @param non_blocking
31+
*/
32+
virtual void to(torch::Dtype dtype, bool non_blocking = false)
33+
{
34+
torch::nn::Module::to(dtype, non_blocking);
35+
_dtype = dtype;
36+
}
37+
38+
/**
39+
* \brief see torch::module::to
40+
* @param device cpu / gpu
41+
* @param dtype : torch::kFloat32 or torch::kFloat64
42+
* @param non_blocking
43+
*/
44+
virtual void to(torch::Device device, torch::Dtype dtype,
45+
bool non_blocking = false)
46+
{
47+
torch::nn::Module::to(device, dtype, non_blocking);
48+
_device = device;
49+
_dtype = dtype;
50+
}
51+
52+
virtual torch::Tensor cleanup_output(torch::Tensor output)
53+
{
54+
return output;
55+
}
56+
57+
virtual torch::Tensor loss(std::string loss, torch::Tensor input,
58+
torch::Tensor output, torch::Tensor target)
59+
= 0;
60+
61+
virtual void update_input_connector(TorchInputInterface &inputc)
62+
{
63+
(void)(inputc);
64+
}
65+
66+
protected:
67+
torch::Dtype _dtype
68+
= torch::kFloat32; /**< type of data stored in tensors */
69+
torch::Device _device
70+
= torch::DeviceType::CPU; /**< device to compute on */
71+
};
72+
}
73+
74+
#endif

0 commit comments

Comments
 (0)