Skip to content

Commit a6ccd72

Browse files
committed
Converted texture fetch operations to abstract IR
(Though only CUDA is supported in practice)
1 parent 50632a3 commit a6ccd72

File tree

6 files changed

+135
-133
lines changed

6 files changed

+135
-133
lines changed

src/cuda_eval.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* $Q uint64_t `00000000000004d2` Hex. number, 0-filled (64 bit)
1818
* --------------------------------------------------------------------------
1919
* $s const char * `foo` Zero-terminated string
20+
* $c char `f` A single ASCII character
2021
* --------------------------------------------------------------------------
2122
* $t Variable `f32` Variable type
2223
* --------------------------------------------------------------------------
@@ -720,6 +721,26 @@ static void jitc_cuda_render_var(uint32_t index, const Variable *v) {
720721
jitc_cuda_render_printf(index, v, a0);
721722
break;
722723

724+
case VarKind::TexLookup:
725+
fmt(" .reg.v4.f32 $v_v4;\n", v);
726+
if (a3)
727+
fmt(" tex.3d.v4.f32.f32 $v_v4, [$v, {$v, $v, $v, $v}];\n", v, a0, a1, a2, a3, a3);
728+
else if (a2)
729+
fmt(" tex.2d.v4.f32.f32 $v_v4, [$v, {$v, $v}];\n", v, a0, a1, a2);
730+
else
731+
fmt(" tex.1d.v4.f32.f32 $v_v4, [$v, {$v}];\n", v, a0, a1);
732+
break;
733+
734+
case VarKind::TexFetchBilerp:
735+
fmt(" .reg.v4.f32 $v_v4;\n"
736+
" tld4.$c.2d.v4.f32.f32 $v_v4, [$v, {$v, $v}];\n",
737+
v, "rgba"[v->literal], v, a0, a1, a2);
738+
break;
739+
740+
case VarKind::TexExtract:
741+
fmt(" mov.$t $v, $v_v4.$c;\n", v, v, a0, "xyzw"[v->literal]);
742+
break;
743+
723744
default:
724745
jitc_fail("jitc_cuda_render_var(): unhandled variable kind \"%s\"!",
725746
var_kind_name[(uint32_t) v->kind]);

src/cuda_tex.cpp

Lines changed: 83 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "log.h"
44
#include "var.h"
55
#include "op.h"
6+
#include "eval.h"
67
#include <string.h>
78
#include <memory>
89
#include <atomic>
@@ -463,143 +464,106 @@ void jitc_cuda_tex_memcpy_t2d(size_t ndim, const size_t *shape,
463464
src_texture);
464465
}
465466

466-
void jitc_cuda_tex_lookup(size_t ndim, const void *texture_handle,
467-
const uint32_t *pos, uint32_t *out) {
468-
if (ndim < 1 || ndim > 3)
469-
jitc_raise("jit_cuda_tex_lookup(): invalid texture dimension!");
470-
467+
Variable jitc_cuda_tex_check(size_t ndim, const uint32_t *pos) {
471468
// Validate input types, determine size of the operation
472469
uint32_t size = 0;
470+
bool dirty = false, placeholder = false;
471+
JitBackend backend = JitBackend::Invalid;
472+
473+
if (ndim < 1 || ndim > 3)
474+
jitc_raise("jit_cuda_tex_check(): invalid texture dimension!");
475+
473476
for (size_t i = 0; i < ndim; ++i) {
474477
const Variable *v = jitc_var(pos[i]);
475478
if ((VarType) v->type != VarType::Float32)
476-
jitc_raise("jit_cuda_tex_lookup(): type mismatch for arg. %zu (got "
479+
jitc_raise("jit_cuda_tex_check(): type mismatch for arg. %zu (got "
477480
"%s, expected %s)", i, type_name[v->type],
478481
type_name[(int) VarType::Float32]);
479482
size = std::max(size, v->size);
483+
dirty |= v->is_dirty();
484+
placeholder |= (bool) v->placeholder;
485+
backend = (JitBackend) v->backend;
480486
}
481487

482-
DrJitCudaTexture &texture = *((DrJitCudaTexture *) texture_handle);
483-
484-
for (size_t tex = 0; tex < texture.n_textures; ++tex) {
485-
uint32_t dep[2] = {
486-
texture.indices[tex],
487-
pos[0]
488-
};
489-
490-
if (ndim >= 2) {
491-
const char *stmt_1[2] = {
492-
".reg.v2.f32 $r0$n"
493-
"mov.v2.f32 $r0, { $r1, $r2 }",
494-
".reg.v4.f32 $r0$n"
495-
"mov.v4.f32 $r0, { $r1, $r2, $r3, $r3 }"
496-
};
497-
dep[1] = jitc_var_stmt(JitBackend::CUDA, VarType::Void,
498-
stmt_1[ndim - 2], 1, (unsigned int) ndim,
499-
pos);
500-
} else {
501-
jitc_var_inc_ref(dep[1]);
488+
if (dirty) {
489+
jitc_eval(thread_state(backend));
490+
for (size_t i = 0; i < ndim; ++i) {
491+
if (jitc_var(pos[i])->is_dirty())
492+
jitc_fail("jit_cuda_tex_check(): operand r%u remains dirty "
493+
"following evaluation!", pos[i]);
502494
}
495+
}
503496

504-
const char *stmt_2[3] = {
505-
".reg.v4.f32 $r0$n"
506-
"tex.1d.v4.f32.f32 $r0, [$r1, {$r2}]",
507-
508-
".reg.v4.f32 $r0$n"
509-
"tex.2d.v4.f32.f32 $r0, [$r1, $r2]",
510-
511-
".reg.v4.f32 $r0$n"
512-
"tex.3d.v4.f32.f32 $r0, [$r1, $r2]"
513-
};
514-
515-
uint32_t lookup = jitc_var_stmt(JitBackend::CUDA, VarType::Void,
516-
stmt_2[ndim - 1], 1, 2, dep);
517-
jitc_var_dec_ref(dep[1]);
518-
519-
const char *stmt_3[4] = {
520-
"mov.f32 $r0, $r1.r",
521-
"mov.f32 $r0, $r1.g",
522-
"mov.f32 $r0, $r1.b",
523-
"mov.f32 $r0, $r1.a"
524-
};
497+
Variable v;
498+
v.size = size;
499+
v.backend = (uint32_t) backend;
500+
v.placeholder = placeholder;
501+
v.type = (uint32_t) VarType::Float32;
502+
return v;
503+
}
525504

526-
for (size_t ch = 0; ch < texture.channels(tex); ++ch) {
527-
uint32_t lookup_result_index = jitc_var_stmt(
528-
JitBackend::CUDA, VarType::Float32, stmt_3[ch], 1, 1, &lookup);
529-
out[tex * 4 + ch] = lookup_result_index;
505+
void jitc_cuda_tex_lookup(size_t ndim, const void *texture_handle,
506+
const uint32_t *pos, uint32_t *out) {
507+
DrJitCudaTexture &tex = *((DrJitCudaTexture *) texture_handle);
508+
Variable v = jitc_cuda_tex_check(ndim, pos);
509+
510+
for (size_t ti = 0; ti < tex.n_textures; ++ti) {
511+
// Perform a fetch per texture ..
512+
v.kind = VarKind::TexLookup;
513+
v.literal = 0;
514+
memset(v.dep, 0, sizeof(v.dep));
515+
v.dep[0] = tex.indices[ti];
516+
jitc_var_inc_ref(tex.indices[ti]);
517+
for (size_t j = 0; j < ndim; ++j) {
518+
v.dep[j + 1] = pos[j];
519+
jitc_var_inc_ref(pos[j]);
520+
}
521+
Ref tex_load = steal(jitc_var_new(v));
522+
523+
// .. and then extract components
524+
v.kind = VarKind::TexExtract;
525+
memset(v.dep, 0, sizeof(v.dep));
526+
for (size_t ch = 0; ch < tex.channels(ti); ++ch) {
527+
v.literal = (uint64_t) ch;
528+
v.dep[0] = tex_load;
529+
jitc_var_inc_ref(tex_load);
530+
*out++ = jitc_var_new(v);
530531
}
531-
532-
jitc_var_dec_ref(lookup);
533532
}
534533
}
535534

536535
void jitc_cuda_tex_bilerp_fetch(size_t ndim, const void *texture_handle,
537536
const uint32_t *pos, uint32_t *out) {
538537
if (ndim != 2)
539-
jitc_raise("jitc_cuda_tex_bilerp_fetch(): invalid texture dimension, "
540-
"only 2D textures are supported!");
541-
542-
// Validate input types, determine size of the operation
543-
uint32_t size = 0;
544-
for (size_t i = 0; i < ndim; ++i) {
545-
const Variable *v = jitc_var(pos[i]);
546-
if ((VarType) v->type != VarType::Float32)
547-
jitc_raise("jitc_cuda_tex_bilerp_fetch(): type mismatch for arg. "
548-
"%zu (got %s, expected %s)",
549-
i, type_name[v->type],
550-
type_name[(int) VarType::Float32]);
551-
size = std::max(size, v->size);
552-
}
553-
554-
DrJitCudaTexture &texture = *((DrJitCudaTexture *) texture_handle);
555-
556-
for (size_t tex = 0; tex < texture.n_textures; ++tex) {
557-
uint32_t dep[2] = {
558-
texture.indices[tex],
559-
pos[0]
560-
};
561-
562-
const char *stmt_1 = ".reg.v2.f32 $r0$n"
563-
"mov.v2.f32 $r0, { $r1, $r2 }";
564-
dep[1] = jitc_var_stmt(JitBackend::CUDA, VarType::Void, stmt_1, 1,
565-
(unsigned int) ndim, pos);
566-
567-
const char *stmt_2[4] = {
568-
".reg.v4.f32 $r0$n"
569-
"tld4.r.2d.v4.f32.f32 $r0, [$r1, $r2]",
570-
571-
".reg.v4.f32 $r0$n"
572-
"tld4.g.2d.v4.f32.f32 $r0, [$r1, $r2]",
573-
574-
".reg.v4.f32 $r0$n"
575-
"tld4.b.2d.v4.f32.f32 $r0, [$r1, $r2]",
576-
577-
".reg.v4.f32 $r0$n"
578-
"tld4.a.2d.v4.f32.f32 $r0, [$r1, $r2]"
579-
};
580-
581-
const char *stmt_3[4] = {
582-
"mov.f32 $r0, $r1.x",
583-
"mov.f32 $r0, $r1.y",
584-
"mov.f32 $r0, $r1.z",
585-
"mov.f32 $r0, $r1.w"
586-
};
587-
588-
for (size_t ch = 0; ch < texture.channels(tex); ++ch) {
589-
uint32_t fetch_channel = jitc_var_stmt(
590-
JitBackend::CUDA, VarType::Void, stmt_2[ch], 1, 2, dep);
591-
592-
for (size_t i = 0; i < 4; ++i) {
593-
uint32_t result_index =
594-
jitc_var_stmt(JitBackend::CUDA, VarType::Float32,
595-
stmt_3[i], 1, 1, &fetch_channel);
596-
out[(i * texture.n_channels) + (tex * 4 + ch)] = result_index;
538+
jitc_raise("jitc_cuda_tex_bilerp_fetch(): only 2D textures are supported!");
539+
540+
DrJitCudaTexture &tex = *((DrJitCudaTexture *) texture_handle);
541+
Variable v = jitc_cuda_tex_check(ndim, pos);
542+
543+
for (size_t ti = 0; ti < tex.n_textures; ++ti) {
544+
for (size_t ch = 0; ch < tex.channels(ti); ++ch) {
545+
// Perform a fetch per texture and channel..
546+
v.kind = VarKind::TexFetchBilerp;
547+
v.literal = ch;
548+
memset(v.dep, 0, sizeof(v.dep));
549+
v.dep[0] = tex.indices[ti];
550+
jitc_var_inc_ref(tex.indices[ti]);
551+
for (size_t j = 0; j < ndim; ++j) {
552+
v.dep[j + 1] = pos[j];
553+
jitc_var_inc_ref(pos[j]);
554+
}
555+
Ref tex_load = steal(jitc_var_new(v));
556+
557+
memset(v.dep, 0, sizeof(v.dep));
558+
v.kind = VarKind::TexExtract;
559+
for (uint32_t j = 0; j < 4; ++j) {
560+
// .. and then extract components
561+
v.literal = (uint64_t) j;
562+
v.dep[0] = tex_load;
563+
jitc_var_inc_ref(tex_load);
564+
*out++ = jitc_var_new(v);
597565
}
598-
599-
jitc_var_dec_ref(fetch_channel);
600566
}
601-
602-
jitc_var_dec_ref(dep[1]);
603567
}
604568
}
605569

@@ -612,11 +576,10 @@ void jitc_cuda_tex_destroy(void *texture_handle) {
612576

613577
DrJitCudaTexture *texture = (DrJitCudaTexture *) texture_handle;
614578

615-
// The `texture` struct can potentially be deleted when decreasing the
616-
// reference count of the individual textures. We must hoist the number of
617-
// textures out of the loop condition.
579+
/* The `texture` struct can potentially be deleted when decreasing the
580+
reference count of the individual textures. We must hoist the number
581+
of textures out of the loop condition. */
618582
const size_t n_textures = texture->n_textures;
619-
for (size_t tex = 0; tex < n_textures; ++tex) {
583+
for (size_t tex = 0; tex < n_textures; ++tex)
620584
jitc_var_dec_ref(texture->indices[tex]);
621-
}
622585
}

src/cuda_tex.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
#include "cuda_api.h"
22

3-
#if defined(DRJIT_DYNAMIC_CUDA)
4-
5-
#endif
6-
73
extern void *jitc_cuda_tex_create(size_t ndim, const size_t *shape,
84
size_t n_channels, int filter_mode,
95
int wrap_mode);

src/internal.h

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,23 @@ enum VarKind : uint32_t {
9090
// Specialized nodes for vcalls
9191
VCallMask, VCallSelf,
9292

93-
/// Counter node to determine the current lane ID
93+
// Counter node to determine the current lane ID
9494
Counter,
9595

96-
/// Recorded 'printf' instruction for debugging purposes
96+
// Recorded 'printf' instruction for debugging purposes
9797
Printf,
9898

99-
Count /// Denotes the number of different node types
99+
// Perform a standard texture lookup (CUDA)
100+
TexLookup,
101+
102+
// Load all texels used for bilinear interpolation (CUDA)
103+
TexFetchBilerp,
104+
105+
// Extract a component from a preceding texture lookup (CUDA)
106+
TexExtract,
107+
108+
// Denotes the number of different node types
109+
Count
100110
};
101111

102112
#pragma pack(push, 1)

src/strbuf.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ void StringBuffer::fmt_cuda(size_t nargs, const char *fmt, ...) {
252252
case 'Q':
253253
case 'X': len += MAXSIZE_X64; (void) va_arg(args, uint64_t); arg++; break;
254254

255+
case 'c': len++; (void) va_arg(args, int); arg++; break;
256+
255257
case 's':
256258
len += strlen(va_arg(args, const char *));
257259
arg++;
@@ -316,6 +318,8 @@ void StringBuffer::fmt_cuda(size_t nargs, const char *fmt, ...) {
316318
case 'X': put_x64_unchecked(va_arg(args2, uint64_t)); break;
317319
case 'Q': put_q64_unchecked(va_arg(args2, uint64_t)); break;
318320

321+
case 'c': *m_cur++ = (char) va_arg(args2, int); break;
322+
319323
case 's': {
320324
const char *s = va_arg(args2, const char *);
321325
put(s, strlen(s));

src/var.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,20 @@ const char *var_kind_name[(int) VarKind::Count] {
144144
// Specialized nodes for vcalls
145145
"vcall_mask", "self",
146146

147-
/// Counter node to determine the current lane ID
147+
// Counter node to determine the current lane ID
148148
"counter",
149149

150-
/// Recorded 'printf' instruction for debugging purposes
151-
"printf"
150+
// Recorded 'printf' instruction for debugging purposes
151+
"printf",
152+
153+
// Perform a standard texture lookup (CUDA)
154+
"tex_lookup",
155+
156+
// Load all texels used for bilinear interpolation (CUDA)
157+
"tex_fetch_bilerp",
158+
159+
// Extract a component from a preceding texture lookup (CUDA)
160+
"tex_extract",
152161
};
153162

154163

@@ -533,8 +542,8 @@ uint32_t jitc_var_literal(JitBackend backend, VarType type, const void *value,
533542
jitc_check_size("jit_var_literal", size);
534543

535544
/* When initializing a value pointer array while recording a virtual
536-
function, we can leverage the already available `self` variable instead
537-
of creating a new one. */
545+
function, we can leverage the already available `self` variable
546+
instead of creating a new one. */
538547
if (is_class) {
539548
ThreadState *ts = thread_state(backend);
540549
if (ts->vcall_self_value &&
@@ -853,7 +862,6 @@ uint32_t jitc_var_new_node_4(JitBackend backend, VarKind kind, VarType vt,
853862
return jitc_var_new(v);
854863
}
855864

856-
857865
void jitc_var_set_callback(uint32_t index,
858866
void (*callback)(uint32_t, int, void *),
859867
void *callback_data) {

0 commit comments

Comments
 (0)