Skip to content

Conversation

@sdh1014
Copy link
Contributor

@sdh1014 sdh1014 commented Dec 3, 2025

Background

PyTorch exposes more than 2000 operators. Implementing a custom lowering for every single operator in a backend would be prohibitively expensive. Fortunately, PyTorch provides decomposition facilities: complex high-level operators in PyTorch (such as torch.nn.functional.softmax or torch.batch_norm) can be decomposed and normalized by the AOTAutograd component into Core Aten IR.

According to the official documentation for torch.compiler_ir (https://docs.pytorch.org/docs/2.8/torch.compiler_ir.html), decomposed graphs are expressed using both Core Aten IR and Prims IR. Compared to Core Aten IR, Prims IR is a lower-level, more fine-grained IR.

On PyTorch 2.8.0, this change targets all 191 Core Aten IR operators listed in that document and provides full coverage for them in our backend.


New Operator Implementations

Frontend Core (frontend.py, operation.py)

  • Added new Op classes for Core Aten IR operators.
  • Updated operator mappings for renamed ops (e.g., LessThanOpLtTensorOp, GreaterThanOpGtTensorOp).
  • Added mapping for the _embedding_bag_forward_only operator.
  • Updated documentation for the aot_autograd_decomposition parameter.

TOSA Dialect (tosa.py)

  • Implemented 171 new operator lowering functions.
  • Math ops: abs, log, ceil, floor, log10, log2, log1p, exp2, expm1, sqrt, rsqrt, sign, digamma, lgamma, i0, etc.
  • Comparison ops: eq, ne, gt, ge, lt, le (both Tensor and Scalar variants).
  • Logical ops: logical_and, logical_or, logical_xor, logical_not.
  • Bitwise ops: bitwise_not, bitwise_and, bitwise_or, bitwise_xor.
  • Reduction ops: sum, prod, mean, std, var, amin, amax, argmax, argmin, cumsum, cumprod, all, any, norm.
  • Pooling ops: avg_pool1d/2d/3d, max_pool1d/3d, adaptive_avg_pool1d/2d/3d, adaptive_max_pool1d/2d.
  • Activation ops: gelu, elu, selu, celu, leaky_relu, hardsigmoid, hardswish, hardtanh, hardshrink, softshrink, softplus, mish, prelu, rrelu, threshold.
  • Tensor manipulation: reshape, squeeze, unsqueeze, permute, transpose, expand, slice, select, gather, scatter, flip, roll, tile, unfold, narrow, split_with_sizes.
  • Normalization: native_layer_norm, native_group_norm, and native_batch_norm variants.
  • Other: upsample, pad, clamp, where, masked_fill, sort, topk, embedding, diagonal, etc.

Linalg Dialect (linalg.py)

  • Added 25 new operator implementations.
  • Restored special broadcast pattern handling in index_op for cases like idx0: (1, 1) + idx1: (N,).
  • Fixed ops_registry mappings for LtTensorOp and SqueezeDimOp.

Math Dialect (math.py)

  • Added implementations for sinh, cosh, tan, erf, exp2, and other math operations.

Test Coverage

  • Added 198 new test files in tests/Python/AtenOps.
  • All tests run with DynamoCompiler and use FileCheck for IR validation.
  • The new tests provide comprehensive coverage for all newly implemented operators.

Known Issues

In PyTorch 2.8.0, the interaction between AOTAutograd and Aten IR still has some limitations. Under the default inductor_decomp decomposition rules, certain operators cannot be lowered correctly and require custom decomposition settings:

  • tests/Python/AtenOps/test_reflection_pad1d.py must set aot_autograd_decomposition=None.
    Otherwise, the Inductor decompositions produce a complex sequence of
    tosa.const + tosa.abs + tosa.sub + linalg.generic + tensor.extract,
    even though aten.reflection_pad1d is marked as a Core Aten IR operator in the documentation.

  • In tests/Python/AtenOps/test_max_pool3d_linalg.py, max_pool3d must use core_aten_decompositions() instead of the default inductor_decomp rules.
    The Inductor decomposition for max_pool3d triggers a tracing error, while core_aten_decompositions() keeps the operator intact so that buddy-mlir can directly use its own implementation.

These issues are limitations of the current PyTorch 2.8.0 decomposition behavior rather than the backend changes in this PR.

Add 100+ new operator implementations to enhance PyTorch model compilation:

**frontend.py:**
- Add new operator mappings for Core Aten IR ops
- Replace LessThanOp/GreaterThanOp with LtTensorOp/GtTensorOp
- Add pooling ops (MaxPool1d/3d, AvgPool1d/2d/3d, AdaptivePool variants)
- Add math ops (abs, log, ceil, floor, log10, log2, log1p, expm1, etc.)
- Add comparison ops (eq, ne, gt, ge, lt, le for Tensor)
- Add logical/bitwise ops (not, and, or, xor)
- Add reduction ops (amin, prod)
- Update documentation for aot_autograd_decomposition parameter

**operation.py:**
- Add new Op classes for all Core Aten IR operators
- Include SignOp, DiagonalOp, SplitWithSizesOp, ClampMaxTensorOp
- Add various math, comparison, logical, and pooling Op classes

**linalg.py:**
- Add new operator implementations using linalg dialect
- Fix ops_registry: add LtTensorOp and SqueezeDimOp mappings
- Enhance existing implementations for better Core Aten IR coverage

**math.py:**
- Add math dialect implementations (sinh, cosh, tan, erf, exp2, etc.)

**tosa.py:**
- Implement 100+ TOSA dialect lowering for Core Aten IR operators
- Add comprehensive tensor manipulation, math, and comparison ops
- Fix various edge cases and type handling

This enables better support for PyTorch 2.x Core Aten IR, improving
model compilation coverage and compatibility.
Add 198 new test files for Core Aten IR operators in tests/Python/AtenOps/:

- Math ops: abs, log, log2, log10, log1p, exp2, expm1, ceil, floor,
  round, trunc, sin, sinh, cos, cosh, tan, asin, asinh, acos, acosh,
  atan, atanh, atan2, erf, erfc, sqrt, sign, digamma, lgamma, i0, etc.
- Comparison ops: eq, ne, gt, ge, lt, le (Tensor and Scalar variants)
- Logical ops: logical_and, logical_or, logical_xor, logical_not
- Bitwise ops: bitwise_not, bitwise_and, bitwise_or, bitwise_xor
- Reduction ops: sum, prod, mean, std, var, amin, argmax, argmin,
  cumsum, cumprod, cummax, cummin, all, any, norm
- Pooling ops: avg_pool1d/2d/3d, max_pool3d, adaptive_avg_pool1d/2d/3d
- Activation ops: relu, silu, gelu, elu, selu, celu, leaky_relu,
  hardsigmoid, hardswish, hardtanh, hardshrink, softshrink, softplus,
  mish, prelu, rrelu, threshold
- Tensor manipulation: squeeze, stack, unbind, split_with_sizes,
  gather, scatter, index_select, flip, roll, tile, unfold, narrow
- Normalization: native_layer_norm, native_group_norm, native_batch_norm
- Other: upsample, pad, clamp, where, masked_fill, sort, topk, etc.

All tests use DynamoCompiler with FileCheck validation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant