Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metax]Fix metax backend bugs #432

Merged
merged 1 commit into from
Jan 24, 2025
Merged

Conversation

mx-flaggems-user
Copy link
Collaborator

PR Category

Operator

Type of Change

Bug Fix

Description

  1. update _metax multi-backend code
  2. fix argmin op might test failed under int types with dim=None

1. update multi-backend code
2. fix argmin op might test failed under int types
@mx-flaggems-user mx-flaggems-user changed the title [Operator]Fix metax backend bugs [Metax]Fix metax backend bugs Jan 23, 2025
Copy link
Collaborator

@Galaxy1458 Galaxy1458 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Galaxy1458 Galaxy1458 merged commit 676963b into FlagOpen:master Jan 24, 2025
8 of 9 checks passed
DuanYaQi pushed a commit that referenced this pull request Feb 11, 2025
1. update multi-backend code
2. fix argmin op might test failed under int types

Co-authored-by: mx-flaggems-user <[email protected]>
machuanjiang pushed a commit that referenced this pull request Feb 25, 2025
1. update multi-backend code
2. fix argmin op might test failed under int types

Co-authored-by: mx-flaggems-user <[email protected]>
machuanjiang pushed a commit that referenced this pull request Feb 25, 2025
1. update multi-backend code
2. fix argmin op might test failed under int types

Co-authored-by: mx-flaggems-user <[email protected]>
Galaxy1458 added a commit that referenced this pull request Feb 27, 2025
* add triton_musa submodule

* Modify testcase from cuda to musa.

* Workaround for musa testcase.

* modify test_unary_pointwise_ops from cuda to musa

* modify test_reduction_ops from cuda to musa

* Fix bug of reduceOp and shared memory.

* fix dropout bug.

* fix softmax exceeds shared memory error

* Promote cpu reference accuracy to float32

* Modify benchmark performance test script to musa

* Add torch_musa unsupported op test case.

* Support bert model.

* Update submodule url.

* Comment v3/v4 test case.

* fix: vectornorm upcast to fp64

* fix: group_norm modify case because of out-of-shared-memory

* Promote golden accuracy from fp32 to fp64.

* Rebase on master commit 1e49d6.

* Update triton_musa submodule.

* align perf utils to profiling pack 0717

* rebase on master commit 9000685

Signed-off-by: Jian Li <[email protected]>

* Support op cumsum.
config: {BLOCK_M: 8, num_warps: 8} will cause the number of registers
within a single thread to be exceeded when the tensor shape is 4096 * 2304,
so reduce BLOCK_M to 4 to supprot cumsum.

* fix embedding tensor usage

Signed-off-by: Jian Li <[email protected]>

* uncomment supported op test

Signed-off-by: Jian Li <[email protected]>

* Support isclose() and allclose()

- Torch_musa does not support fp64 input type, so CPU is used as a reference

* Open up some tests that have already passed

- Does not support test_accuracy_groupnorm

- Some use cases have accuracy issues in test_embedding

* rebase on master commit 801377f

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit 2e55d66

Signed-off-by: Jian Li <[email protected]>

* fix distribution ops warps num

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit a156268

Signed-off-by: Jian Li <[email protected]>

* Support op argmax.

* adapt masked_fill and dropout resolution

Signed-off-by: Jian Li <[email protected]>

* fix _log2 in topk op

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit a1138bf

Signed-off-by: Jian Li <[email protected]>

* Open BF16

* rebase on master commit edb09f0

Signed-off-by: Jian Li <[email protected]>

* fix div_trunc and div_floor

Signed-off-by: Jian Li <[email protected]>

* fix div_trunc cont.

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit 2c9ae67

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit 204f3d4

Signed-off-by: Jian Li <[email protected]>

* Support op vstack

* rebase on master commit 4e5081

rebase on master commit 8f669d

* SW-46066: remove submodule(triton_musa) in flaggems

Signed-off-by: chuanjiang.ma <[email protected]>

* [Operator] Fix vstack (#237)

Modify the function parameter type declaration so that it can run in python 3.8

---------

Co-authored-by: zhengyang <[email protected]>

* [Operator] Add upsample_nearest2d op [MooreThreads] (#193)

* SW-46093: rebase on master commit 98924c

Signed-off-by: jiaqi.wang <[email protected]>

* succeed for norm (#233)

Add _weight_norm op, while the original _weight_norm op changed to _weight_norm_interface op.

* [Operator] slice&select scatter (#143)

* add Ops & UT & Bench

* add full zero ones Ops & UT & Bench

* split normal op

* [Operator] init slice&select scatter

* code format

* PR comment

* split test_special_ops

* add K-S test

* split special perf

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* resolve conflict

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* test for op

* test for op

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* add argparse

* fix desc

* fix num

* Update test_specific_ops.py

* split UT files

* fix

* fix

* [Operator] Optimize CrossEntropyLoss (#131)

reimplement cross_entropy_loss forward and backward
support; indices/probabilities/weight/reduction/ignore_index/label_smoothing; perform better than torch eager on large scale tensors

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* [Test] Test for op (#151)

* [chore] solve slice&select scatter's test cases

* [fix] fix slice&select scatter's test cases

* [chore] remove out-of-range indices in select_scatter's test cases

* [chore] simplify slice_scatter's test cases

* [fix] Added range that is deleted by mistake

* Merge branch 'master' into slice&select_scatter

* [chore] reformat

* [fix] typo

* [chore] Considering perf, pause the replacement of some aTen operators
* slice_scatter
* select_scatter
* index_select

* [fix] Add libentry in op.cumsum

* [fix] Del slice&select scatter's perf tests

* [Chore] Add pytest mark for slice&select scatter's test

* [Fix] Correct slice_scatter test

* [Fix] Replace CPU Tensor

---------

Co-authored-by: Bowen12992 <[email protected]>
Co-authored-by: Tongxin Bai <[email protected]>
Co-authored-by: Clement Chan <[email protected]>
Co-authored-by: Bowen <[email protected]>
Co-authored-by: StrongSpoon <[email protected]>

* [Operator] Add slice&select_scatter's benchmark (#262)

* benchmark fix (#229)

* benchmark fix

*  add seven new testing parameters

* move shapes info to yaml file

* Added the BenchmarkMetrics & BenchmarkResult  abstraction

* Specializing slice_scatter. (#270)

* specializing slice_scatter. WIP.

* polish and refine 2d_inner cases.

* fix slice_scatter error on 1d inputs.

* test slice_scatter fallback

* Enhance Benchmarking for repeat_interleave Operation (#274)

* Relocate select and slice benchmarks to test_select_and_slice_perf.py

* sort keys for summary result

* clean cuda cache after benchmark

* fix repeat_interleave

* modify format for summary info

* SW-46093: rebase on master commit 2bd92c

Signed-off-by: jiaqi.wang <[email protected]>

* add test entry & use skip to avoid error

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: update mt test script by waiting frontend

Signed-off-by: chuanjiang.ma <[email protected]>

* SW-46093: adjust the skipped ops

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: currently test accuracy ref to cpu

Signed-off-by: chuanjiang.ma <[email protected]>

* SW-47833: skip one case for half type op not supported by torch_musa

Signed-off-by: chuanjiang.ma <[email protected]>

* mthreads: update test suite setup

Signed-off-by: machuanjiang <[email protected]>

* mthreads: fix review comments

Signed-off-by: chuanjiang.ma <[email protected]>

* mthreads: skip weight_nor testing for driver complaining

1. one test in special_op test change the device type from cuda to musa

Signed-off-by: chuanjiang.ma <[email protected]>

* fix #304 (#308)

fix max: input tensor with big shape may cause "numel (2097152) exceeds triton maximum tensor numel"

* add program_id & num_programs that returns tensor in tl.int64  (#327)

* add program_id & num_programs that returns tensor in tl.int64 to avoid integer overflow

* fix bmm

* fix vector norm, treat empty dim list as full-reduce; use tle in fused ops, too

* block pointer requires offsets/block shape to be int32

* [bugfix] replace tle with tl for operators using philox (#337)

* [bugfix] replace tle with tl for operators using philox

* [bugfix] tle in rand

* [bugfix] uniform

* Remove INT_DTYPES from isfinite operation in benchmark and tests (#336)

* [Operator] Fix full & full_like op (#338)

Co-authored-by: zhengyang <[email protected]>
Co-authored-by: Bowen <[email protected]>

* [Operator] Add diag_embed (#288)

* diag_embed operator

* add diag_embed ops

* resolve conflict

* SW-46093: add test python script for 50 ops

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: support output compared_speedup for benchmark plot

Signed-off-by: root <chuanjiang.ma>

* SW-46093: add an optin to control not output compared speedup

Signed-off-by: root <chuanjiang.ma>

* SW-46093: cherry-pick updates from commit 54a471 to ca13b7

Signed-off-by: jiaqi.wang <[email protected]>

* SW-49945: Adapt Multiple backends code for MUSA

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: Adjust skipped ops

Signed-off-by: jiaqi.wang <[email protected]>

* SW-49945: Update _mthreads backend config

Signed-off-by: jiaqi.wang <[email protected]>

* add pytest alluredir options for generating report

* [Operator] Add sort op (#322)

* [Operator] Add sort op

---------

Co-authored-by: zhengyang <[email protected]>
Co-authored-by: MARD1NO <[email protected]>

* [Muti_backend]part_2_device (#344)

* new feature, muti_backend

* update auto_tune_module

* update auto_tune_module

* update auto_tune_module

* update __init__

* rebase

* fix bug

* modifiy auto_tune_config

* fix bug

* fix bug

* update

* update

* update scatter&gather

* fix auto_tune

* add gen_torch_device_fn

* fix codestyle

* fix codestyle

* Modify code based on comments

* Modify gen_impl with loops instead of recursion

* Update code structure

* Polish code

* update

* Polish code

* Modify code based on comments

* modify based on comment

* Modify code based on comments

* update

* final fix

* modify_device_unify

* fix bug

* update

* remove '-> list'

* modify

* modify

* modify

* fix bug

* muti_backend_part_3 enhancement of ability (#361)

* muti_backend_part_3 enhancement of ability

* fix bug

* modify

* modify

* fix bug

* modify

* SW-49945: adjust merged code

Signed-off-by: jiaqi.wang <[email protected]>

* Reimplement slice_scatter/select_scatter using pointwise_dynamic (#358)

* merged with origin.
* updated
* renames copy_ to copy, fixes benchmark code.
* fixes benchmark.
* handle self-overlapping input in select_scatter and slice_scatter

---------

Co-authored-by: Clement Chan <[email protected]>

* modify code (#371)

* modify code

* fix bug

* fix bug

* modify

* fix

* fix

* [Polish code] Adjust some directory structures and variable names (#373)

* modify code

* fix bug

* fix bug

* modify

* fix

* fix

* polish code

* polish code

* polish code

* [Fix bugs] Variable reference dependency error (#374)

* fix benchmark bugs

* fix

* [Muti backend] add heuristics_config for muti_bacnkend (#377)

* muti_backend

* [Fix bugs] Variable reference dependency error (#374)

* fix benchmark bugs

* fix

* [bugfix] disable more shapes from blas benchmark (#375)

* add heuristics_config

* modify

* modify

* modify

* modify

---------

Co-authored-by: StrongSpoon <[email protected]>

* SW-49945: update configuration for mthreads

Signed-off-by: jiaqi.wang <[email protected]>

* [benchmark] skip perf test of cummin when triton < 3.0 (#385)

* SW-46093: fix test configuration

Signed-off-by: jiaqi.wang <[email protected]>

* SW-50265: optimize the perf of gelu, tanh op and most other pointwise ops

* Gather bwd (#382)

* add gather_backward op

* add debug log in

* perf gather backward

* rebased with master

* scatter rewrite done.

* scatter handling internally overlapping input.

* Scatter reduce now uses atomics.

* remove fp16 from scatter reduce UT.

* sets threadblock size to 128 for scatter.

* Change atomic memory order to relaxed in scatter.

---------

Co-authored-by: awayzjj <[email protected]>
Co-authored-by: StrongSpoon <[email protected]>

* [Operator]Fix metax backend bugs (#432)

1. update multi-backend code
2. fix argmin op might test failed under int types

Co-authored-by: mx-flaggems-user <[email protected]>

* [TEST] fix error in argmin UT when dtype=int16 (#431)

Co-authored-by: junjian.zhan <[email protected]>

* SW-49945: Update configurations for multi-backends

Signed-off-by: jiaqi.wang <[email protected]>

* SW-47833: Fix case resolve_conj and fill.
1. resolve_conj: ref to this link: https://jira.mthreads.com/browse/MTAI-1530
2. fill: torch_musa does not support case torch.fill(dtype=cpu, dtype=musa).

* add backward of conv2d (#365)

* add backward of conv2d

* delete useless code

* format code of tests

* modify configs for tuning

* modify autotune config

* delete test flag

* delete useless type convert

---------

Co-authored-by: Jiang Bin <[email protected]>

* SW-52470: Cherry-pick updates from commit f1ba20c to 5f31f35.

* SW-49945: rebase for multi-backends

Signed-off-by: jiaqi.wang <[email protected]>

* SW-52470: adjust skipped ops for merge

Signed-off-by: jiaqi.wang <[email protected]>

* Mthreads: adjust ops for multi-backends merge

Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>

* add triton_musa submodule

* Modify testcase from cuda to musa.

* Workaround for musa testcase.

* modify test_unary_pointwise_ops from cuda to musa

* modify test_reduction_ops from cuda to musa

* Fix bug of reduceOp and shared memory.

* fix dropout bug.

* fix softmax exceeds shared memory error

* Promote cpu reference accuracy to float32

* Modify benchmark performance test script to musa

* Add torch_musa unsupported op test case.

* Support bert model.

* Update submodule url.

* Comment v3/v4 test case.

* fix: vectornorm upcast to fp64

* fix: group_norm modify case because of out-of-shared-memory

* Promote golden accuracy from fp32 to fp64.

* Rebase on master commit 1e49d6.

* Update triton_musa submodule.

* align perf utils to profiling pack 0717

* rebase on master commit 9000685

Signed-off-by: Jian Li <[email protected]>

* Support op cumsum.
config: {BLOCK_M: 8, num_warps: 8} will cause the number of registers
within a single thread to be exceeded when the tensor shape is 4096 * 2304,
so reduce BLOCK_M to 4 to supprot cumsum.

* fix embedding tensor usage

Signed-off-by: Jian Li <[email protected]>

* uncomment supported op test

Signed-off-by: Jian Li <[email protected]>

* Support isclose() and allclose()

- Torch_musa does not support fp64 input type, so CPU is used as a reference

* Open up some tests that have already passed

- Does not support test_accuracy_groupnorm

- Some use cases have accuracy issues in test_embedding

* rebase on master commit 801377f

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit 2e55d66

Signed-off-by: Jian Li <[email protected]>

* fix distribution ops warps num

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit a156268

Signed-off-by: Jian Li <[email protected]>

* Support op argmax.

* adapt masked_fill and dropout resolution

Signed-off-by: Jian Li <[email protected]>

* fix _log2 in topk op

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit a1138bf

Signed-off-by: Jian Li <[email protected]>

* Open BF16

* rebase on master commit edb09f0

Signed-off-by: Jian Li <[email protected]>

* fix div_trunc and div_floor

Signed-off-by: Jian Li <[email protected]>

* fix div_trunc cont.

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit 2c9ae67

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit 204f3d4

Signed-off-by: Jian Li <[email protected]>

* Support op vstack

* rebase on master commit 4e5081

rebase on master commit 8f669d

* SW-46066: remove submodule(triton_musa) in flaggems

Signed-off-by: chuanjiang.ma <[email protected]>

* [Operator] Fix vstack (#237)

Modify the function parameter type declaration so that it can run in python 3.8

---------

Co-authored-by: zhengyang <[email protected]>

* [Operator] Add upsample_nearest2d op [MooreThreads] (#193)

* SW-46093: rebase on master commit 98924c

Signed-off-by: jiaqi.wang <[email protected]>

* succeed for norm (#233)

Add _weight_norm op, while the original _weight_norm op changed to _weight_norm_interface op.

* [Operator] slice&select scatter (#143)

* add Ops & UT & Bench

* add full zero ones Ops & UT & Bench

* split normal op

* [Operator] init slice&select scatter

* code format

* PR comment

* split test_special_ops

* add K-S test

* split special perf

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* resolve conflict

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* test for op

* test for op

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* add argparse

* fix desc

* fix num

* Update test_specific_ops.py

* split UT files

* fix

* fix

* [Operator] Optimize CrossEntropyLoss (#131)

reimplement cross_entropy_loss forward and backward
support; indices/probabilities/weight/reduction/ignore_index/label_smoothing; perform better than torch eager on large scale tensors

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* [Test] Test for op (#151)

* [chore] solve slice&select scatter's test cases

* [fix] fix slice&select scatter's test cases

* [chore] remove out-of-range indices in select_scatter's test cases

* [chore] simplify slice_scatter's test cases

* [fix] Added range that is deleted by mistake

* Merge branch 'master' into slice&select_scatter

* [chore] reformat

* [fix] typo

* [chore] Considering perf, pause the replacement of some aTen operators
* slice_scatter
* select_scatter
* index_select

* [fix] Add libentry in op.cumsum

* [fix] Del slice&select scatter's perf tests

* [Chore] Add pytest mark for slice&select scatter's test

* [Fix] Correct slice_scatter test

* [Fix] Replace CPU Tensor

---------

Co-authored-by: Bowen12992 <[email protected]>
Co-authored-by: Tongxin Bai <[email protected]>
Co-authored-by: Clement Chan <[email protected]>
Co-authored-by: Bowen <[email protected]>
Co-authored-by: StrongSpoon <[email protected]>

* [Operator] Add slice&select_scatter's benchmark (#262)

* benchmark fix (#229)

* benchmark fix

*  add seven new testing parameters

* move shapes info to yaml file

* Added the BenchmarkMetrics & BenchmarkResult  abstraction

* Specializing slice_scatter. (#270)

* specializing slice_scatter. WIP.

* polish and refine 2d_inner cases.

* fix slice_scatter error on 1d inputs.

* test slice_scatter fallback

* Enhance Benchmarking for repeat_interleave Operation (#274)

* Relocate select and slice benchmarks to test_select_and_slice_perf.py

* sort keys for summary result

* clean cuda cache after benchmark

* fix repeat_interleave

* modify format for summary info

* SW-46093: rebase on master commit 2bd92c

Signed-off-by: jiaqi.wang <[email protected]>

* add test entry & use skip to avoid error

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: update mt test script by waiting frontend

Signed-off-by: chuanjiang.ma <[email protected]>

* SW-46093: adjust the skipped ops

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: currently test accuracy ref to cpu

Signed-off-by: chuanjiang.ma <[email protected]>

* SW-47833: skip one case for half type op not supported by torch_musa

Signed-off-by: chuanjiang.ma <[email protected]>

* mthreads: update test suite setup

Signed-off-by: machuanjiang <[email protected]>

* mthreads: fix review comments

Signed-off-by: chuanjiang.ma <[email protected]>

* mthreads: skip weight_nor testing for driver complaining

1. one test in special_op test change the device type from cuda to musa

Signed-off-by: chuanjiang.ma <[email protected]>

* fix #304 (#308)

fix max: input tensor with big shape may cause "numel (2097152) exceeds triton maximum tensor numel"

* add program_id & num_programs that returns tensor in tl.int64  (#327)

* add program_id & num_programs that returns tensor in tl.int64 to avoid integer overflow

* fix bmm

* fix vector norm, treat empty dim list as full-reduce; use tle in fused ops, too

* block pointer requires offsets/block shape to be int32

* [bugfix] replace tle with tl for operators using philox (#337)

* [bugfix] replace tle with tl for operators using philox

* [bugfix] tle in rand

* [bugfix] uniform

* Remove INT_DTYPES from isfinite operation in benchmark and tests (#336)

* [Operator] Fix full & full_like op (#338)

Co-authored-by: zhengyang <[email protected]>
Co-authored-by: Bowen <[email protected]>

* [Operator] Add diag_embed (#288)

* diag_embed operator

* add diag_embed ops

* resolve conflict

* SW-46093: add test python script for 50 ops

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: support output compared_speedup for benchmark plot

Signed-off-by: root <chuanjiang.ma>

* SW-46093: add an optin to control not output compared speedup

Signed-off-by: root <chuanjiang.ma>

* SW-46093: cherry-pick updates from commit 54a471 to ca13b7

Signed-off-by: jiaqi.wang <[email protected]>

* SW-49945: Adapt Multiple backends code for MUSA

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: Adjust skipped ops

Signed-off-by: jiaqi.wang <[email protected]>

* SW-49945: Update _mthreads backend config

Signed-off-by: jiaqi.wang <[email protected]>

* add pytest alluredir options for generating report

* [Operator] Add sort op (#322)

* [Operator] Add sort op

---------

Co-authored-by: zhengyang <[email protected]>
Co-authored-by: MARD1NO <[email protected]>

* [Muti_backend]part_2_device (#344)

* new feature, muti_backend

* update auto_tune_module

* update auto_tune_module

* update auto_tune_module

* update __init__

* rebase

* fix bug

* modifiy auto_tune_config

* fix bug

* fix bug

* update

* update

* update scatter&gather

* fix auto_tune

* add gen_torch_device_fn

* fix codestyle

* fix codestyle

* Modify code based on comments

* Modify gen_impl with loops instead of recursion

* Update code structure

* Polish code

* update

* Polish code

* Modify code based on comments

* modify based on comment

* Modify code based on comments

* update

* final fix

* modify_device_unify

* fix bug

* update

* remove '-> list'

* modify

* modify

* modify

* fix bug

* muti_backend_part_3 enhancement of ability (#361)

* muti_backend_part_3 enhancement of ability

* fix bug

* modify

* modify

* fix bug

* modify

* SW-49945: adjust merged code

Signed-off-by: jiaqi.wang <[email protected]>

* Reimplement slice_scatter/select_scatter using pointwise_dynamic (#358)

* merged with origin.
* updated
* renames copy_ to copy, fixes benchmark code.
* fixes benchmark.
* handle self-overlapping input in select_scatter and slice_scatter

---------

Co-authored-by: Clement Chan <[email protected]>

* modify code (#371)

* modify code

* fix bug

* fix bug

* modify

* fix

* fix

* [Polish code] Adjust some directory structures and variable names (#373)

* modify code

* fix bug

* fix bug

* modify

* fix

* fix

* polish code

* polish code

* polish code

* [Fix bugs] Variable reference dependency error (#374)

* fix benchmark bugs

* fix

* [Muti backend] add heuristics_config for muti_bacnkend (#377)

* muti_backend

* [Fix bugs] Variable reference dependency error (#374)

* fix benchmark bugs

* fix

* [bugfix] disable more shapes from blas benchmark (#375)

* add heuristics_config

* modify

* modify

* modify

* modify

---------

Co-authored-by: StrongSpoon <[email protected]>

* SW-49945: update configuration for mthreads

Signed-off-by: jiaqi.wang <[email protected]>

* [benchmark] skip perf test of cummin when triton < 3.0 (#385)

* SW-46093: fix test configuration

Signed-off-by: jiaqi.wang <[email protected]>

* SW-50265: optimize the perf of gelu, tanh op and most other pointwise ops

* Gather bwd (#382)

* add gather_backward op

* add debug log in

* perf gather backward

* rebased with master

* scatter rewrite done.

* scatter handling internally overlapping input.

* Scatter reduce now uses atomics.

* remove fp16 from scatter reduce UT.

* sets threadblock size to 128 for scatter.

* Change atomic memory order to relaxed in scatter.

---------

Co-authored-by: awayzjj <[email protected]>
Co-authored-by: StrongSpoon <[email protected]>

* [Operator]Fix metax backend bugs (#432)

1. update multi-backend code
2. fix argmin op might test failed under int types

Co-authored-by: mx-flaggems-user <[email protected]>

* [TEST] fix error in argmin UT when dtype=int16 (#431)

Co-authored-by: junjian.zhan <[email protected]>

* SW-49945: Update configurations for multi-backends

Signed-off-by: jiaqi.wang <[email protected]>

* SW-47833: Fix case resolve_conj and fill.
1. resolve_conj: ref to this link: https://jira.mthreads.com/browse/MTAI-1530
2. fill: torch_musa does not support case torch.fill(dtype=cpu, dtype=musa).

* add backward of conv2d (#365)

* add backward of conv2d

* delete useless code

* format code of tests

* modify configs for tuning

* modify autotune config

* delete test flag

* delete useless type convert

---------

Co-authored-by: Jiang Bin <[email protected]>

* SW-52470: Cherry-pick updates from commit f1ba20c to 5f31f35.

* SW-49945: rebase for multi-backends

Signed-off-by: jiaqi.wang <[email protected]>

* SW-52470: adjust skipped ops for merge

Signed-off-by: jiaqi.wang <[email protected]>

* Mthreads: adjust ops for multi-backends merge

Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>

* MTHREADS: Checking Code Format

Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>

* MTHREADS: fix bugs for multi-backends ops

Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>

* MTHREADS: fix according to review issue

Signed-off-by: chuanjiang.ma <[email protected]>

* MTHREADS: add annotation to explain the log-file-cross calc

Signed-off-by: chuanjiang.ma <[email protected]>

---------

Signed-off-by: Jian Li <[email protected]>
Signed-off-by: chuanjiang.ma <[email protected]>
Signed-off-by: jiaqi.wang <[email protected]>
Signed-off-by: machuanjiang <[email protected]>
Signed-off-by: root <chuanjiang.ma>
Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>
Co-authored-by: yuzhe-wu <[email protected]>
Co-authored-by: Buzz <[email protected]>
Co-authored-by: hang.zhang <[email protected]>
Co-authored-by: lingfeng.qiu <[email protected]>
Co-authored-by: jialuo.bai <[email protected]>
Co-authored-by: Jian Li <[email protected]>
Co-authored-by: Tianjie Ling <[email protected]>
Co-authored-by: jiaqi.wang <[email protected]>
Co-authored-by: zhzhcookie <[email protected]>
Co-authored-by: zhengyang <[email protected]>
Co-authored-by: zaccur <[email protected]>
Co-authored-by: TZWX-0 <[email protected]>
Co-authored-by: Hiujin Gwok <[email protected]>
Co-authored-by: Bowen12992 <[email protected]>
Co-authored-by: Tongxin Bai <[email protected]>
Co-authored-by: Clement Chan <[email protected]>
Co-authored-by: Bowen <[email protected]>
Co-authored-by: StrongSpoon <[email protected]>
Co-authored-by: kiddyjinjin <[email protected]>
Co-authored-by: haowen-han <[email protected]>
Co-authored-by: niu_he <[email protected]>
Co-authored-by: root <chuanjiang.ma>
Co-authored-by: yu.huang <[email protected]>
Co-authored-by: MARD1NO <[email protected]>
Co-authored-by: Galaxy1458 <[email protected]>
Co-authored-by: Sheng Wang <[email protected]>
Co-authored-by: awayzjj <[email protected]>
Co-authored-by: StrongSpoon <[email protected]>
Co-authored-by: mx-flaggems-user <[email protected]>
Co-authored-by: mx-flaggems-user <[email protected]>
Co-authored-by: junjian.zhan <[email protected]>
Co-authored-by: FatJhon <[email protected]>
Co-authored-by: Jiang Bin <[email protected]>
Co-authored-by: jiaqi.wang <jiaqi.wang @mthreads.com>
StrongSpoon added a commit that referenced this pull request Mar 17, 2025
* add triton_musa submodule

* Modify testcase from cuda to musa.

* Workaround for musa testcase.

* modify test_unary_pointwise_ops from cuda to musa

* modify test_reduction_ops from cuda to musa

* Fix bug of reduceOp and shared memory.

* fix dropout bug.

* fix softmax exceeds shared memory error

* Promote cpu reference accuracy to float32

* Modify benchmark performance test script to musa

* Add torch_musa unsupported op test case.

* Support bert model.

* Update submodule url.

* Comment v3/v4 test case.

* fix: vectornorm upcast to fp64

* fix: group_norm modify case because of out-of-shared-memory

* Promote golden accuracy from fp32 to fp64.

* Rebase on master commit 1e49d6.

* Update triton_musa submodule.

* align perf utils to profiling pack 0717

* rebase on master commit 9000685

Signed-off-by: Jian Li <[email protected]>

* Support op cumsum.
config: {BLOCK_M: 8, num_warps: 8} will cause the number of registers
within a single thread to be exceeded when the tensor shape is 4096 * 2304,
so reduce BLOCK_M to 4 to supprot cumsum.

* fix embedding tensor usage

Signed-off-by: Jian Li <[email protected]>

* uncomment supported op test

Signed-off-by: Jian Li <[email protected]>

* Support isclose() and allclose()

- Torch_musa does not support fp64 input type, so CPU is used as a reference

* Open up some tests that have already passed

- Does not support test_accuracy_groupnorm

- Some use cases have accuracy issues in test_embedding

* rebase on master commit 801377f

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit 2e55d66

Signed-off-by: Jian Li <[email protected]>

* fix distribution ops warps num

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit a156268

Signed-off-by: Jian Li <[email protected]>

* Support op argmax.

* adapt masked_fill and dropout resolution

Signed-off-by: Jian Li <[email protected]>

* fix _log2 in topk op

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit a1138bf

Signed-off-by: Jian Li <[email protected]>

* Open BF16

* rebase on master commit edb09f0

Signed-off-by: Jian Li <[email protected]>

* fix div_trunc and div_floor

Signed-off-by: Jian Li <[email protected]>

* fix div_trunc cont.

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit 2c9ae67

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit 204f3d4

Signed-off-by: Jian Li <[email protected]>

* Support op vstack

* rebase on master commit 4e5081

rebase on master commit 8f669d

* SW-46066: remove submodule(triton_musa) in flaggems

Signed-off-by: chuanjiang.ma <[email protected]>

* [Operator] Fix vstack (#237)

Modify the function parameter type declaration so that it can run in python 3.8

---------

Co-authored-by: zhengyang <[email protected]>

* [Operator] Add upsample_nearest2d op [MooreThreads] (#193)

* SW-46093: rebase on master commit 98924c

Signed-off-by: jiaqi.wang <[email protected]>

* succeed for norm (#233)

Add _weight_norm op, while the original _weight_norm op changed to _weight_norm_interface op.

* [Operator] slice&select scatter (#143)

* add Ops & UT & Bench

* add full zero ones Ops & UT & Bench

* split normal op

* [Operator] init slice&select scatter

* code format

* PR comment

* split test_special_ops

* add K-S test

* split special perf

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* resolve conflict

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* test for op

* test for op

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* add argparse

* fix desc

* fix num

* Update test_specific_ops.py

* split UT files

* fix

* fix

* [Operator] Optimize CrossEntropyLoss (#131)

reimplement cross_entropy_loss forward and backward
support; indices/probabilities/weight/reduction/ignore_index/label_smoothing; perform better than torch eager on large scale tensors

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* [Test] Test for op (#151)

* [chore] solve slice&select scatter's test cases

* [fix] fix slice&select scatter's test cases

* [chore] remove out-of-range indices in select_scatter's test cases

* [chore] simplify slice_scatter's test cases

* [fix] Added range that is deleted by mistake

* Merge branch 'master' into slice&select_scatter

* [chore] reformat

* [fix] typo

* [chore] Considering perf, pause the replacement of some aTen operators
* slice_scatter
* select_scatter
* index_select

* [fix] Add libentry in op.cumsum

* [fix] Del slice&select scatter's perf tests

* [Chore] Add pytest mark for slice&select scatter's test

* [Fix] Correct slice_scatter test

* [Fix] Replace CPU Tensor

---------

Co-authored-by: Bowen12992 <[email protected]>
Co-authored-by: Tongxin Bai <[email protected]>
Co-authored-by: Clement Chan <[email protected]>
Co-authored-by: Bowen <[email protected]>
Co-authored-by: StrongSpoon <[email protected]>

* [Operator] Add slice&select_scatter's benchmark (#262)

* benchmark fix (#229)

* benchmark fix

*  add seven new testing parameters

* move shapes info to yaml file

* Added the BenchmarkMetrics & BenchmarkResult  abstraction

* Specializing slice_scatter. (#270)

* specializing slice_scatter. WIP.

* polish and refine 2d_inner cases.

* fix slice_scatter error on 1d inputs.

* test slice_scatter fallback

* Enhance Benchmarking for repeat_interleave Operation (#274)

* Relocate select and slice benchmarks to test_select_and_slice_perf.py

* sort keys for summary result

* clean cuda cache after benchmark

* fix repeat_interleave

* modify format for summary info

* SW-46093: rebase on master commit 2bd92c

Signed-off-by: jiaqi.wang <[email protected]>

* add test entry & use skip to avoid error

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: update mt test script by waiting frontend

Signed-off-by: chuanjiang.ma <[email protected]>

* SW-46093: adjust the skipped ops

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: currently test accuracy ref to cpu

Signed-off-by: chuanjiang.ma <[email protected]>

* SW-47833: skip one case for half type op not supported by torch_musa

Signed-off-by: chuanjiang.ma <[email protected]>

* mthreads: update test suite setup

Signed-off-by: machuanjiang <[email protected]>

* mthreads: fix review comments

Signed-off-by: chuanjiang.ma <[email protected]>

* mthreads: skip weight_nor testing for driver complaining

1. one test in special_op test change the device type from cuda to musa

Signed-off-by: chuanjiang.ma <[email protected]>

* fix #304 (#308)

fix max: input tensor with big shape may cause "numel (2097152) exceeds triton maximum tensor numel"

* add program_id & num_programs that returns tensor in tl.int64  (#327)

* add program_id & num_programs that returns tensor in tl.int64 to avoid integer overflow

* fix bmm

* fix vector norm, treat empty dim list as full-reduce; use tle in fused ops, too

* block pointer requires offsets/block shape to be int32

* [bugfix] replace tle with tl for operators using philox (#337)

* [bugfix] replace tle with tl for operators using philox

* [bugfix] tle in rand

* [bugfix] uniform

* Remove INT_DTYPES from isfinite operation in benchmark and tests (#336)

* [Operator] Fix full & full_like op (#338)

Co-authored-by: zhengyang <[email protected]>
Co-authored-by: Bowen <[email protected]>

* [Operator] Add diag_embed (#288)

* diag_embed operator

* add diag_embed ops

* resolve conflict

* SW-46093: add test python script for 50 ops

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: support output compared_speedup for benchmark plot

Signed-off-by: root <chuanjiang.ma>

* SW-46093: add an optin to control not output compared speedup

Signed-off-by: root <chuanjiang.ma>

* SW-46093: cherry-pick updates from commit 54a471 to ca13b7

Signed-off-by: jiaqi.wang <[email protected]>

* SW-49945: Adapt Multiple backends code for MUSA

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: Adjust skipped ops

Signed-off-by: jiaqi.wang <[email protected]>

* SW-49945: Update _mthreads backend config

Signed-off-by: jiaqi.wang <[email protected]>

* add pytest alluredir options for generating report

* [Operator] Add sort op (#322)

* [Operator] Add sort op

---------

Co-authored-by: zhengyang <[email protected]>
Co-authored-by: MARD1NO <[email protected]>

* [Muti_backend]part_2_device (#344)

* new feature, muti_backend

* update auto_tune_module

* update auto_tune_module

* update auto_tune_module

* update __init__

* rebase

* fix bug

* modifiy auto_tune_config

* fix bug

* fix bug

* update

* update

* update scatter&gather

* fix auto_tune

* add gen_torch_device_fn

* fix codestyle

* fix codestyle

* Modify code based on comments

* Modify gen_impl with loops instead of recursion

* Update code structure

* Polish code

* update

* Polish code

* Modify code based on comments

* modify based on comment

* Modify code based on comments

* update

* final fix

* modify_device_unify

* fix bug

* update

* remove '-> list'

* modify

* modify

* modify

* fix bug

* muti_backend_part_3 enhancement of ability (#361)

* muti_backend_part_3 enhancement of ability

* fix bug

* modify

* modify

* fix bug

* modify

* SW-49945: adjust merged code

Signed-off-by: jiaqi.wang <[email protected]>

* Reimplement slice_scatter/select_scatter using pointwise_dynamic (#358)

* merged with origin.
* updated
* renames copy_ to copy, fixes benchmark code.
* fixes benchmark.
* handle self-overlapping input in select_scatter and slice_scatter

---------

Co-authored-by: Clement Chan <[email protected]>

* modify code (#371)

* modify code

* fix bug

* fix bug

* modify

* fix

* fix

* [Polish code] Adjust some directory structures and variable names (#373)

* modify code

* fix bug

* fix bug

* modify

* fix

* fix

* polish code

* polish code

* polish code

* [Fix bugs] Variable reference dependency error (#374)

* fix benchmark bugs

* fix

* [Muti backend] add heuristics_config for muti_bacnkend (#377)

* muti_backend

* [Fix bugs] Variable reference dependency error (#374)

* fix benchmark bugs

* fix

* [bugfix] disable more shapes from blas benchmark (#375)

* add heuristics_config

* modify

* modify

* modify

* modify

---------

Co-authored-by: StrongSpoon <[email protected]>

* SW-49945: update configuration for mthreads

Signed-off-by: jiaqi.wang <[email protected]>

* [benchmark] skip perf test of cummin when triton < 3.0 (#385)

* SW-46093: fix test configuration

Signed-off-by: jiaqi.wang <[email protected]>

* SW-50265: optimize the perf of gelu, tanh op and most other pointwise ops

* Gather bwd (#382)

* add gather_backward op

* add debug log in

* perf gather backward

* rebased with master

* scatter rewrite done.

* scatter handling internally overlapping input.

* Scatter reduce now uses atomics.

* remove fp16 from scatter reduce UT.

* sets threadblock size to 128 for scatter.

* Change atomic memory order to relaxed in scatter.

---------

Co-authored-by: awayzjj <[email protected]>
Co-authored-by: StrongSpoon <[email protected]>

* [Operator]Fix metax backend bugs (#432)

1. update multi-backend code
2. fix argmin op might test failed under int types

Co-authored-by: mx-flaggems-user <[email protected]>

* [TEST] fix error in argmin UT when dtype=int16 (#431)

Co-authored-by: junjian.zhan <[email protected]>

* SW-49945: Update configurations for multi-backends

Signed-off-by: jiaqi.wang <[email protected]>

* SW-47833: Fix case resolve_conj and fill.
1. resolve_conj: ref to this link: https://jira.mthreads.com/browse/MTAI-1530
2. fill: torch_musa does not support case torch.fill(dtype=cpu, dtype=musa).

* add backward of conv2d (#365)

* add backward of conv2d

* delete useless code

* format code of tests

* modify configs for tuning

* modify autotune config

* delete test flag

* delete useless type convert

---------

Co-authored-by: Jiang Bin <[email protected]>

* SW-52470: Cherry-pick updates from commit f1ba20c to 5f31f35.

* SW-49945: rebase for multi-backends

Signed-off-by: jiaqi.wang <[email protected]>

* SW-52470: adjust skipped ops for merge

Signed-off-by: jiaqi.wang <[email protected]>

* Mthreads: adjust ops for multi-backends merge

Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>

* add triton_musa submodule

* Modify testcase from cuda to musa.

* Workaround for musa testcase.

* modify test_unary_pointwise_ops from cuda to musa

* modify test_reduction_ops from cuda to musa

* Fix bug of reduceOp and shared memory.

* fix dropout bug.

* fix softmax exceeds shared memory error

* Promote cpu reference accuracy to float32

* Modify benchmark performance test script to musa

* Add torch_musa unsupported op test case.

* Support bert model.

* Update submodule url.

* Comment v3/v4 test case.

* fix: vectornorm upcast to fp64

* fix: group_norm modify case because of out-of-shared-memory

* Promote golden accuracy from fp32 to fp64.

* Rebase on master commit 1e49d6.

* Update triton_musa submodule.

* align perf utils to profiling pack 0717

* rebase on master commit 9000685

Signed-off-by: Jian Li <[email protected]>

* Support op cumsum.
config: {BLOCK_M: 8, num_warps: 8} will cause the number of registers
within a single thread to be exceeded when the tensor shape is 4096 * 2304,
so reduce BLOCK_M to 4 to supprot cumsum.

* fix embedding tensor usage

Signed-off-by: Jian Li <[email protected]>

* uncomment supported op test

Signed-off-by: Jian Li <[email protected]>

* Support isclose() and allclose()

- Torch_musa does not support fp64 input type, so CPU is used as a reference

* Open up some tests that have already passed

- Does not support test_accuracy_groupnorm

- Some use cases have accuracy issues in test_embedding

* rebase on master commit 801377f

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit 2e55d66

Signed-off-by: Jian Li <[email protected]>

* fix distribution ops warps num

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit a156268

Signed-off-by: Jian Li <[email protected]>

* Support op argmax.

* adapt masked_fill and dropout resolution

Signed-off-by: Jian Li <[email protected]>

* fix _log2 in topk op

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit a1138bf

Signed-off-by: Jian Li <[email protected]>

* Open BF16

* rebase on master commit edb09f0

Signed-off-by: Jian Li <[email protected]>

* fix div_trunc and div_floor

Signed-off-by: Jian Li <[email protected]>

* fix div_trunc cont.

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit 2c9ae67

Signed-off-by: Jian Li <[email protected]>

* rebase on master commit 204f3d4

Signed-off-by: Jian Li <[email protected]>

* Support op vstack

* rebase on master commit 4e5081

rebase on master commit 8f669d

* SW-46066: remove submodule(triton_musa) in flaggems

Signed-off-by: chuanjiang.ma <[email protected]>

* [Operator] Fix vstack (#237)

Modify the function parameter type declaration so that it can run in python 3.8

---------

Co-authored-by: zhengyang <[email protected]>

* [Operator] Add upsample_nearest2d op [MooreThreads] (#193)

* SW-46093: rebase on master commit 98924c

Signed-off-by: jiaqi.wang <[email protected]>

* succeed for norm (#233)

Add _weight_norm op, while the original _weight_norm op changed to _weight_norm_interface op.

* [Operator] slice&select scatter (#143)

* add Ops & UT & Bench

* add full zero ones Ops & UT & Bench

* split normal op

* [Operator] init slice&select scatter

* code format

* PR comment

* split test_special_ops

* add K-S test

* split special perf

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* resolve conflict

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* test for op

* test for op

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* add argparse

* fix desc

* fix num

* Update test_specific_ops.py

* split UT files

* fix

* fix

* [Operator] Optimize CrossEntropyLoss (#131)

reimplement cross_entropy_loss forward and backward
support; indices/probabilities/weight/reduction/ignore_index/label_smoothing; perform better than torch eager on large scale tensors

* Exponential added. (#138)

* exponential added.
* Added K-S tests to exponential_, fp64 corrected.
* aligned with aten prototype
* Exponential_ uses uint64 offsets in Triton kernel.
* Update pyproject config for new test dependencies.

* Use int64 indexing when needed & fix argmax (#146)

 1. fix amax, armax and triu, use int64 indexing when the largest tensor's size_in_bytes exceed int32's max;
2. change the tiling scheme for argmax to loop in the reduction dimension, instead of data-size-dependent-tile-size

* Making libentry thread safe (#136)

* libentry now is lock protected.

* Add multithreading tests for libentry.

* polish code.

* [Test] Test for op (#151)

* [chore] solve slice&select scatter's test cases

* [fix] fix slice&select scatter's test cases

* [chore] remove out-of-range indices in select_scatter's test cases

* [chore] simplify slice_scatter's test cases

* [fix] Added range that is deleted by mistake

* Merge branch 'master' into slice&select_scatter

* [chore] reformat

* [fix] typo

* [chore] Considering perf, pause the replacement of some aTen operators
* slice_scatter
* select_scatter
* index_select

* [fix] Add libentry in op.cumsum

* [fix] Del slice&select scatter's perf tests

* [Chore] Add pytest mark for slice&select scatter's test

* [Fix] Correct slice_scatter test

* [Fix] Replace CPU Tensor

---------

Co-authored-by: Bowen12992 <[email protected]>
Co-authored-by: Tongxin Bai <[email protected]>
Co-authored-by: Clement Chan <[email protected]>
Co-authored-by: Bowen <[email protected]>
Co-authored-by: StrongSpoon <[email protected]>

* [Operator] Add slice&select_scatter's benchmark (#262)

* benchmark fix (#229)

* benchmark fix

*  add seven new testing parameters

* move shapes info to yaml file

* Added the BenchmarkMetrics & BenchmarkResult  abstraction

* Specializing slice_scatter. (#270)

* specializing slice_scatter. WIP.

* polish and refine 2d_inner cases.

* fix slice_scatter error on 1d inputs.

* test slice_scatter fallback

* Enhance Benchmarking for repeat_interleave Operation (#274)

* Relocate select and slice benchmarks to test_select_and_slice_perf.py

* sort keys for summary result

* clean cuda cache after benchmark

* fix repeat_interleave

* modify format for summary info

* SW-46093: rebase on master commit 2bd92c

Signed-off-by: jiaqi.wang <[email protected]>

* add test entry & use skip to avoid error

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: update mt test script by waiting frontend

Signed-off-by: chuanjiang.ma <[email protected]>

* SW-46093: adjust the skipped ops

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: currently test accuracy ref to cpu

Signed-off-by: chuanjiang.ma <[email protected]>

* SW-47833: skip one case for half type op not supported by torch_musa

Signed-off-by: chuanjiang.ma <[email protected]>

* mthreads: update test suite setup

Signed-off-by: machuanjiang <[email protected]>

* mthreads: fix review comments

Signed-off-by: chuanjiang.ma <[email protected]>

* mthreads: skip weight_nor testing for driver complaining

1. one test in special_op test change the device type from cuda to musa

Signed-off-by: chuanjiang.ma <[email protected]>

* fix #304 (#308)

fix max: input tensor with big shape may cause "numel (2097152) exceeds triton maximum tensor numel"

* add program_id & num_programs that returns tensor in tl.int64  (#327)

* add program_id & num_programs that returns tensor in tl.int64 to avoid integer overflow

* fix bmm

* fix vector norm, treat empty dim list as full-reduce; use tle in fused ops, too

* block pointer requires offsets/block shape to be int32

* [bugfix] replace tle with tl for operators using philox (#337)

* [bugfix] replace tle with tl for operators using philox

* [bugfix] tle in rand

* [bugfix] uniform

* Remove INT_DTYPES from isfinite operation in benchmark and tests (#336)

* [Operator] Fix full & full_like op (#338)

Co-authored-by: zhengyang <[email protected]>
Co-authored-by: Bowen <[email protected]>

* [Operator] Add diag_embed (#288)

* diag_embed operator

* add diag_embed ops

* resolve conflict

* SW-46093: add test python script for 50 ops

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: support output compared_speedup for benchmark plot

Signed-off-by: root <chuanjiang.ma>

* SW-46093: add an optin to control not output compared speedup

Signed-off-by: root <chuanjiang.ma>

* SW-46093: cherry-pick updates from commit 54a471 to ca13b7

Signed-off-by: jiaqi.wang <[email protected]>

* SW-49945: Adapt Multiple backends code for MUSA

Signed-off-by: jiaqi.wang <[email protected]>

* SW-46093: Adjust skipped ops

Signed-off-by: jiaqi.wang <[email protected]>

* SW-49945: Update _mthreads backend config

Signed-off-by: jiaqi.wang <[email protected]>

* add pytest alluredir options for generating report

* [Operator] Add sort op (#322)

* [Operator] Add sort op

---------

Co-authored-by: zhengyang <[email protected]>
Co-authored-by: MARD1NO <[email protected]>

* [Muti_backend]part_2_device (#344)

* new feature, muti_backend

* update auto_tune_module

* update auto_tune_module

* update auto_tune_module

* update __init__

* rebase

* fix bug

* modifiy auto_tune_config

* fix bug

* fix bug

* update

* update

* update scatter&gather

* fix auto_tune

* add gen_torch_device_fn

* fix codestyle

* fix codestyle

* Modify code based on comments

* Modify gen_impl with loops instead of recursion

* Update code structure

* Polish code

* update

* Polish code

* Modify code based on comments

* modify based on comment

* Modify code based on comments

* update

* final fix

* modify_device_unify

* fix bug

* update

* remove '-> list'

* modify

* modify

* modify

* fix bug

* muti_backend_part_3 enhancement of ability (#361)

* muti_backend_part_3 enhancement of ability

* fix bug

* modify

* modify

* fix bug

* modify

* SW-49945: adjust merged code

Signed-off-by: jiaqi.wang <[email protected]>

* Reimplement slice_scatter/select_scatter using pointwise_dynamic (#358)

* merged with origin.
* updated
* renames copy_ to copy, fixes benchmark code.
* fixes benchmark.
* handle self-overlapping input in select_scatter and slice_scatter

---------

Co-authored-by: Clement Chan <[email protected]>

* modify code (#371)

* modify code

* fix bug

* fix bug

* modify

* fix

* fix

* [Polish code] Adjust some directory structures and variable names (#373)

* modify code

* fix bug

* fix bug

* modify

* fix

* fix

* polish code

* polish code

* polish code

* [Fix bugs] Variable reference dependency error (#374)

* fix benchmark bugs

* fix

* [Muti backend] add heuristics_config for muti_bacnkend (#377)

* muti_backend

* [Fix bugs] Variable reference dependency error (#374)

* fix benchmark bugs

* fix

* [bugfix] disable more shapes from blas benchmark (#375)

* add heuristics_config

* modify

* modify

* modify

* modify

---------

Co-authored-by: StrongSpoon <[email protected]>

* SW-49945: update configuration for mthreads

Signed-off-by: jiaqi.wang <[email protected]>

* [benchmark] skip perf test of cummin when triton < 3.0 (#385)

* SW-46093: fix test configuration

Signed-off-by: jiaqi.wang <[email protected]>

* SW-50265: optimize the perf of gelu, tanh op and most other pointwise ops

* Gather bwd (#382)

* add gather_backward op

* add debug log in

* perf gather backward

* rebased with master

* scatter rewrite done.

* scatter handling internally overlapping input.

* Scatter reduce now uses atomics.

* remove fp16 from scatter reduce UT.

* sets threadblock size to 128 for scatter.

* Change atomic memory order to relaxed in scatter.

---------

Co-authored-by: awayzjj <[email protected]>
Co-authored-by: StrongSpoon <[email protected]>

* [Operator]Fix metax backend bugs (#432)

1. update multi-backend code
2. fix argmin op might test failed under int types

Co-authored-by: mx-flaggems-user <[email protected]>

* [TEST] fix error in argmin UT when dtype=int16 (#431)

Co-authored-by: junjian.zhan <[email protected]>

* SW-49945: Update configurations for multi-backends

Signed-off-by: jiaqi.wang <[email protected]>

* SW-47833: Fix case resolve_conj and fill.
1. resolve_conj: ref to this link: https://jira.mthreads.com/browse/MTAI-1530
2. fill: torch_musa does not support case torch.fill(dtype=cpu, dtype=musa).

* add backward of conv2d (#365)

* add backward of conv2d

* delete useless code

* format code of tests

* modify configs for tuning

* modify autotune config

* delete test flag

* delete useless type convert

---------

Co-authored-by: Jiang Bin <[email protected]>

* SW-52470: Cherry-pick updates from commit f1ba20c to 5f31f35.

* SW-49945: rebase for multi-backends

Signed-off-by: jiaqi.wang <[email protected]>

* SW-52470: adjust skipped ops for merge

Signed-off-by: jiaqi.wang <[email protected]>

* Mthreads: adjust ops for multi-backends merge

Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>

* MTHREADS: Checking Code Format

Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>

* MTHREADS: fix bugs for multi-backends ops

Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>

* MTHREADS: fix according to review issue

Signed-off-by: chuanjiang.ma <[email protected]>

* MTHREADS: add annotation to explain the log-file-cross calc

Signed-off-by: chuanjiang.ma <[email protected]>

---------

Signed-off-by: Jian Li <[email protected]>
Signed-off-by: chuanjiang.ma <[email protected]>
Signed-off-by: jiaqi.wang <[email protected]>
Signed-off-by: machuanjiang <[email protected]>
Signed-off-by: root <chuanjiang.ma>
Signed-off-by: jiaqi.wang <jiaqi.wang @mthreads.com>
Co-authored-by: yuzhe-wu <[email protected]>
Co-authored-by: Buzz <[email protected]>
Co-authored-by: hang.zhang <[email protected]>
Co-authored-by: lingfeng.qiu <[email protected]>
Co-authored-by: jialuo.bai <[email protected]>
Co-authored-by: Jian Li <[email protected]>
Co-authored-by: Tianjie Ling <[email protected]>
Co-authored-by: jiaqi.wang <[email protected]>
Co-authored-by: zhzhcookie <[email protected]>
Co-authored-by: zhengyang <[email protected]>
Co-authored-by: zaccur <[email protected]>
Co-authored-by: TZWX-0 <[email protected]>
Co-authored-by: Hiujin Gwok <[email protected]>
Co-authored-by: Bowen12992 <[email protected]>
Co-authored-by: Tongxin Bai <[email protected]>
Co-authored-by: Clement Chan <[email protected]>
Co-authored-by: Bowen <[email protected]>
Co-authored-by: StrongSpoon <[email protected]>
Co-authored-by: kiddyjinjin <[email protected]>
Co-authored-by: haowen-han <[email protected]>
Co-authored-by: niu_he <[email protected]>
Co-authored-by: root <chuanjiang.ma>
Co-authored-by: yu.huang <[email protected]>
Co-authored-by: MARD1NO <[email protected]>
Co-authored-by: Galaxy1458 <[email protected]>
Co-authored-by: Sheng Wang <[email protected]>
Co-authored-by: awayzjj <[email protected]>
Co-authored-by: StrongSpoon <[email protected]>
Co-authored-by: mx-flaggems-user <[email protected]>
Co-authored-by: mx-flaggems-user <[email protected]>
Co-authored-by: junjian.zhan <[email protected]>
Co-authored-by: FatJhon <[email protected]>
Co-authored-by: Jiang Bin <[email protected]>
Co-authored-by: jiaqi.wang <jiaqi.wang @mthreads.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants