Skip to content

Commit 9d30c3a

Browse files
killeentsoumith
authored andcommitted
Advanced Indexing Part 1 -- Purely Integer Array Indexing
1 parent e391789 commit 9d30c3a

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

lib/TH/generic/THTensor.c

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,42 @@ void THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *sizes) {
315315
THFree(expandedStrides);
316316
}
317317

318+
319+
void THTensor_(expandNd)(THTensor **rets, THTensor **ops, int count) {
320+
for (int i = 0; i < count; ++i) {
321+
THArgCheck(THTensor_(nDimension)(ops[i]) > 0, i, "can't expand empty tensor %d", i);
322+
}
323+
324+
long *op_sizes[count];
325+
long op_dims[count];
326+
327+
for (int i = 0; i < count; ++i) {
328+
op_sizes[i] = ops[i]->size;
329+
op_dims[i] = ops[i]->nDimension;
330+
}
331+
332+
THLongStorage *sizes = THLongStorage_new();
333+
char error_buffer[1024];
334+
int ret = THLongStorage_inferSizeN(sizes,
335+
count,
336+
op_sizes,
337+
op_dims,
338+
error_buffer,
339+
1024);
340+
341+
if(ret != 0) {
342+
THLongStorage_free(sizes);
343+
THError(error_buffer);
344+
return;
345+
}
346+
347+
for (int i = 0; i < count; ++i) {
348+
THTensor_(expand)(rets[i], ops[i], sizes);
349+
}
350+
351+
THLongStorage_free(sizes);
352+
}
353+
318354
void THTensor_(set)(THTensor *self, THTensor *src)
319355
{
320356
if(self != src)

lib/TH/generic/THTensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ TH_API THTensor *THTensor_(newView)(THTensor *tensor, THLongStorage *size);
7272
TH_API THTensor *THTensor_(newExpand)(THTensor *tensor, THLongStorage *size);
7373

7474
TH_API void THTensor_(expand)(THTensor *r, THTensor *tensor, THLongStorage *size);
75+
TH_API void THTensor_(expandNd)(THTensor **rets, THTensor **ops, int count);
7576

7677
TH_API void THTensor_(resize)(THTensor *tensor, THLongStorage *size, THLongStorage *stride);
7778
TH_API void THTensor_(resizeAs)(THTensor *tensor, THTensor *src);

0 commit comments

Comments
 (0)