Skip to content

Commit 105065e

Browse files
Merge pull request #289 from InfiniTensor/issue/288_improve_torch_implementation_compatibility
issue/288: Improve the Compatibility of the Torch Implementations
2 parents a0abcb2 + c132b4c commit 105065e

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

test/infiniop/gemm.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,14 @@ class GemmDescriptor(Structure):
5757

5858
# PyTorch implementation for matrix multiplication
5959
def gemm(d, _c, beta, _a, _b, alpha):
60-
if _c.ndim == 2:
61-
torch.addmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
62-
elif _c.ndim == 3:
63-
torch.baddbmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
64-
else:
60+
try:
61+
if _c.ndim == 2:
62+
torch.addmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
63+
elif _c.ndim == 3:
64+
torch.baddbmm(_c, _a, _b, beta=beta, alpha=alpha, out=d)
65+
else:
66+
raise
67+
except Exception:
6568
torch.matmul(_a, _b, out=d)
6669
d.mul_(alpha).add_(_c, alpha=beta)
6770

test/infiniop/random_sample.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,13 @@ def random_sample(data, random_val, topp, topk, voc, temperature):
6767

6868
k_index = min(topk, voc) - 1
6969
threshold = min(cum_probs[k_index], topp) * random_val
70-
71-
idx = torch.searchsorted(cum_probs, threshold)
70+
71+
try:
72+
idx = torch.searchsorted(cum_probs, threshold)
73+
except Exception:
74+
# Fallback for manual search if torch.searchsorted is not supported
75+
indices = (cum_probs >= threshold).nonzero(as_tuple=True)[0]
76+
idx = indices[0] if indices.numel() > 0 else torch.tensor(len(cum_probs)-1, device=cum_probs.device)
7277
return sorted_indices[idx]
7378

7479
return torch.argmax(data)

test/infiniop/rearrange.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,11 @@ def column_major_strides(shape):
116116
NUM_ITERATIONS = 1000
117117

118118

119-
class RerrangeDescriptor(Structure):
119+
class RearrangeDescriptor(Structure):
120120
_fields_ = [("device", c_int32)]
121121

122122

123-
infiniopRearrangeDescriptor_t = POINTER(RerrangeDescriptor)
123+
infiniopRearrangeDescriptor_t = POINTER(RearrangeDescriptor)
124124

125125

126126
def rearrange_torch(x, x_shape, y_stride):

0 commit comments

Comments
 (0)