@@ -136,7 +136,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
136
136
return res;
137
137
#else
138
138
139
- #if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
139
+ #if !defined(GGML_USE_HIPBLAS)
140
140
cudaError_t err;
141
141
if (getenv (" GGML_CUDA_ENABLE_UNIFIED_MEMORY" ) != nullptr )
142
142
{
@@ -149,7 +149,7 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device)
149
149
return err;
150
150
#else
151
151
return cudaMalloc (ptr, size);
152
- #endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
152
+ #endif // !defined(GGML_USE_HIPBLAS)
153
153
154
154
#endif
155
155
}
@@ -2830,6 +2830,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2830
2830
if (op->op == GGML_OP_MUL_MAT && a->ne [3 ] != b->ne [3 ]) {
2831
2831
return false ;
2832
2832
}
2833
+ #ifdef GGML_USE_MUSA
2834
+ if (b->type == GGML_TYPE_F16 && b->ne [2 ]*b->ne [3 ] > 1 &&
2835
+ !ggml_is_transposed (a) && !ggml_is_transposed (b)) {
2836
+ return false ;
2837
+ }
2838
+ #endif // GGML_USE_MUSA
2833
2839
switch (a->type ) {
2834
2840
case GGML_TYPE_F32:
2835
2841
case GGML_TYPE_F16:
@@ -2853,6 +2859,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2853
2859
case GGML_TYPE_IQ3_XXS:
2854
2860
case GGML_TYPE_IQ4_NL:
2855
2861
case GGML_TYPE_IQ4_XS:
2862
+ #ifdef GGML_USE_MUSA
2863
+ if (a->type == GGML_TYPE_Q3_K) {
2864
+ return false ;
2865
+ }
2866
+ #endif // GGML_USE_MUSA
2856
2867
return true ;
2857
2868
default :
2858
2869
return false ;
@@ -2978,6 +2989,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2978
2989
case GGML_OP_RWKV_WKV:
2979
2990
return true ;
2980
2991
case GGML_OP_FLASH_ATTN_EXT: {
2992
+ #ifndef FLASH_ATTN_AVAILABLE
2993
+ return false ;
2994
+ #endif
2981
2995
if (op->src [0 ]->ne [0 ] == 64 && op->src [1 ]->type == GGML_TYPE_F16) {
2982
2996
return true ;
2983
2997
}
0 commit comments