@@ -218,6 +218,7 @@ void rotate_backward_out_cuda(const scalar_t *entity, const scalar_t *relation,
218
218
const int64_t *h_index, const int64_t *t_index, const int64_t *r_index,
219
219
const scalar_t *score_grad, scalar_t *entity_grad, scalar_t *relation_grad,
220
220
int64_t num_entity, int64_t num_relation, int64_t embedding_dim, int64_t num_sample) {
221
+ const float kEpsilon = 1e-15 ; // 1e-15 from GraphVite
221
222
const int thread_id = blockIdx .x * blockDim .x + threadIdx .x ;
222
223
const int lane_id = thread_id % warpSize ;
223
224
const int num_thread = gridDim .x * blockDim .x ;
@@ -313,7 +314,7 @@ void simple_backward_out_cuda(const scalar_t *entity, const scalar_t *relation,
313
314
#define DECLARE_FORWARD_IMPL (NAME ) \
314
315
Tensor NAME##_forward_cuda(const Tensor &entity_, const Tensor &relation_, const Tensor &h_index_, \
315
316
const Tensor &t_index_, const Tensor &r_index_) { \
316
- constexpr const char *fn_name = #NAME" _forward_cuda" ; \
317
+ constexpr const char *fn_name = #NAME" _forward_cuda" ; \
317
318
TensorArg entity_arg (entity_, " entity" , 1 ), relation_arg (relation_, " relation" , 2 ), \
318
319
h_index_arg (h_index_, " h_index" , 3 ), r_index_arg (r_index_, " r_index" , 4 ), \
319
320
t_index_arg (t_index_, " t_index" , 5 ); \
@@ -353,7 +354,7 @@ void simple_backward_out_cuda(const scalar_t *entity, const scalar_t *relation,
353
354
std::tuple<Tensor, Tensor> NAME##_backward_cuda( \
354
355
const Tensor &entity_, const Tensor &relation_, const Tensor &h_index_, \
355
356
const Tensor &t_index_, const Tensor &r_index_, const Tensor &score_grad_) { \
356
- constexpr const char *fn_name = #NAME" _backward_cuda" ; \
357
+ constexpr const char *fn_name = #NAME" _backward_cuda" ; \
357
358
TensorArg entity_arg (entity_, " entity" , 1 ), relation_arg (relation_, " relation" , 2 ), \
358
359
h_index_arg (h_index_, " h_index" , 3 ), r_index_arg (r_index_, " r_index" , 4 ), \
359
360
t_index_arg (t_index_, " t_index" , 5 ), score_grad_arg (score_grad_, " score_grad" , 6 ); \
@@ -384,7 +385,7 @@ void simple_backward_out_cuda(const scalar_t *entity, const scalar_t *relation,
384
385
NAME##_backward_out_cuda<scalar_t ><<<4096 , 512 , 0 , stream>>> ( \
385
386
entity.data_ptr <scalar_t >(), relation .data_ptr <scalar_t >(), \
386
387
h_index.data_ptr <int64_t >(), t_index.data_ptr <int64_t >(), r_index.data_ptr <int64_t >(), \
387
- score_grad.data_ptr <scalar_t >(), \
388
+ score_grad.data_ptr <scalar_t >(), \
388
389
entity_grad.data_ptr <scalar_t >(), relation_grad.data_ptr <scalar_t >(), \
389
390
num_entity, num_relation, embedding_dim, num_sample \
390
391
); \
0 commit comments