Skip to content

Conversation

@factnn
Copy link

@factnn factnn commented Oct 24, 2025

PR Category

Operator

Type of Change

New Feature

Description

This PR implements torch.nn.functional.avg_pool2d using Triton with dynamic code generation for optimal performance.

Development Tool:

  • This operator was developed with Triton-Copilot, an AI-powered tool for Triton kernel development.

Implementation highlights:

  • Dynamic code generation with adaptive loop unrolling for optimal performance
  • Smart strategy: full unroll for small kernels (≤4x4), compact code for large kernels (>4x4)
  • Full parameter support: kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
  • Comprehensive test coverage: 157 accuracy tests, 24 performance benchmarks
  • Autotune with multiple BLOCK_SIZE configurations for different input shapes

Performance characteristics:

  • Small kernels (2x2, 3x3): 1.3-1.7x speedup (float32), 1.1-1.6x (float16/bfloat16)
  • Large kernels (5x5+): Competitive performance with reduced register pressure
  • Covers CNN architectures: ResNet, VGG, MobileNet, ImageNet standard shapes

Technical approach:
The implementation uses dynamic code generation to create specialized kernels for each kernel size:

  • ≤4x4 kernels: Full loop unrolling with unique variables for maximum instruction-level parallelism
  • >4x4 kernels: Compact code with variable reuse to reduce register pressure while maintaining unrolling benefits

Issue

N/A - This is a new operator implementation to expand FlagGems' coverage of PyTorch operations.

Progress

  • Change is fully covered by a UT (157 accuracy tests, 24 performance tests).
  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.

Additional Notes

Files changed:

  • src/flag_gems/ops/avg_pool2d.py - Operator implementation (280 lines)
  • src/flag_gems/ops/__init__.py - Operator registration
  • src/flag_gems/__init__.py - PyTorch dispatcher registration
  • tests/test_reduction_ops.py - 7 test functions (~157 test cases)
  • benchmark/test_special_perf.py - Performance benchmarks (24 test cases)
  • benchmark/core_shapes.yaml - Benchmark shape configurations

Testing:
All tests pass successfully:

pytest tests/test_reduction_ops.py -m avg_pool2d  # 157 passed
pytest benchmark/test_special_perf.py::test_perf_avg_pool2d  # 1 passed (24 benchmarks)
### Performance
<!-- Please describe any performance tests you have added or the results of any benchmarks. -->
Operator: avg_pool2d  Performance Test (dtype=torch.float16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.032800            0.022720               1.444          ([torch.Size([4, 64, 112, 112])], {'kernel_size': 3, 'stride': 2, 'padding': 1})
SUCCESS               0.027888            0.022592               1.234          ([torch.Size([8, 128, 56, 56])], {'kernel_size': 3, 'stride': 2, 'padding': 1})
SUCCESS               0.023840            0.017600               1.355          ([torch.Size([16, 256, 28, 28])], {'kernel_size': 2, 'stride': 2, 'padding': 0})
SUCCESS               0.039104            0.024352               1.606          ([torch.Size([2, 64, 224, 224])], {'kernel_size': 2, 'stride': 2, 'padding': 0})
SUCCESS               0.010240            0.008896               1.151          ([torch.Size([4, 512, 14, 14])], {'kernel_size': 2, 'stride': 2, 'padding': 0})
SUCCESS               0.049088            0.069952               0.702          ([torch.Size([8, 32, 112, 112])], {'kernel_size': 7, 'stride': 2, 'padding': 3})
SUCCESS               0.010720            0.008352               1.284          ([torch.Size([1, 3, 224, 224])], {'kernel_size': 3, 'stride': 2, 'padding': 1})
SUCCESS               0.039872            0.026928               1.481          ([torch.Size([32, 64, 56, 56])], {'kernel_size': 2, 'stride': 2, 'padding': 0})
Operator: avg_pool2d  Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.039232            0.024288               1.615          ([torch.Size([4, 64, 112, 112])], {'kernel_size': 3, 'stride': 2, 'padding': 1})
SUCCESS               0.040816            0.024256               1.683          ([torch.Size([8, 128, 56, 56])], {'kernel_size': 3, 'stride': 2, 'padding': 1})
SUCCESS               0.037568            0.023968               1.567          ([torch.Size([16, 256, 28, 28])], {'kernel_size': 2, 'stride': 2, 'padding': 0})
SUCCESS               0.062400            0.036608               1.705          ([torch.Size([2, 64, 224, 224])], {'kernel_size': 2, 'stride': 2, 'padding': 0})
SUCCESS               0.011824            0.010704               1.105          ([torch.Size([4, 512, 14, 14])], {'kernel_size': 2, 'stride': 2, 'padding': 0})
SUCCESS               0.064608            0.064736               0.998          ([torch.Size([8, 32, 112, 112])], {'kernel_size': 7, 'stride': 2, 'padding': 3})
SUCCESS               0.010816            0.009552               1.132          ([torch.Size([1, 3, 224, 224])], {'kernel_size': 3, 'stride': 2, 'padding': 1})
SUCCESS               0.063936            0.038432               1.664          ([torch.Size([32, 64, 56, 56])], {'kernel_size': 2, 'stride': 2, 'padding': 0})
Operator: avg_pool2d  Performance Test (dtype=torch.bfloat16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.026688            0.021376               1.249          ([torch.Size([4, 64, 112, 112])], {'kernel_size': 3, 'stride': 2, 'padding': 1})
SUCCESS               0.028096            0.021440               1.310          ([torch.Size([8, 128, 56, 56])], {'kernel_size': 3, 'stride': 2, 'padding': 1})
SUCCESS               0.023808            0.016704               1.425          ([torch.Size([16, 256, 28, 28])], {'kernel_size': 2, 'stride': 2, 'padding': 0})
SUCCESS               0.039168            0.023776               1.647          ([torch.Size([2, 64, 224, 224])], {'kernel_size': 2, 'stride': 2, 'padding': 0})
SUCCESS               0.009952            0.008896               1.119          ([torch.Size([4, 512, 14, 14])], {'kernel_size': 2, 'stride': 2, 'padding': 0})
SUCCESS               0.048832            0.069216               0.706          ([torch.Size([8, 32, 112, 112])], {'kernel_size': 7, 'stride': 2, 'padding': 3})
SUCCESS               0.011072            0.009344               1.185          ([torch.Size([1, 3, 224, 224])], {'kernel_size': 3, 'stride': 2, 'padding': 1})
SUCCESS               0.040160            0.027168               1.478          ([torch.Size([32, 64, 56, 56])], {'kernel_size': 2, 'stride': 2, 'padding': 0})

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @factnn, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new, highly optimized implementation of the torch.nn.functional.avg_pool2d operator within the FlagGems framework. By leveraging Triton with dynamic code generation and compile-time loop unrolling, this change aims to significantly boost the performance of average pooling operations, particularly for common kernel sizes, while maintaining full functional parity with PyTorch's native implementation. This expands FlagGems' coverage of PyTorch operations with a performant new feature.

Highlights

  • Dynamic Code Generation: Implemented avg_pool2d using Triton with dynamic code generation, featuring compile-time loop unrolling for all kernel sizes to optimize performance.
  • Adaptive Strategy for Large Kernels: Introduced an adaptive strategy for larger kernels (5x5+) that generates compact code with variable reuse to reduce register pressure while retaining the benefits of loop unrolling.
  • Full Parameter Support: The new avg_pool2d operator fully supports all parameters, including kernel_size, stride, padding, ceil_mode, count_include_pad, and divisor_override, ensuring functional parity with torch.nn.functional.avg_pool2d.
  • Performance Improvements: Achieved significant speedups: 1.3-1.7x for Float32 and 1.1-1.6x for Float16/BFloat16 on common kernel sizes (2x2, 3x3), covering typical CNN architectures like ResNet, VGG, and MobileNet.
  • Comprehensive Test Coverage: Added extensive test coverage, including 157 accuracy tests and 24 performance benchmarks, validating basic operations, edge cases, various input sizes (small, medium, large, non-square), and all float data types (float16, float32, bfloat16).
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a avg_pool2d operator using dynamic Triton code generation, which is a great approach for optimizing performance on specific kernel sizes. The implementation is well-structured, and the accompanying accuracy and performance tests are very comprehensive, covering a wide range of parameters and edge cases.

I have a few suggestions to further improve the implementation:

  • The adaptive strategy for large kernels mentioned in the description seems to be missing, which impacts performance for larger kernel sizes like 7x7.
  • The temporary files generated for kernels are not cleaned up, which could be an issue in long-running processes.
  • There is some redundancy in the benchmark shape definitions between the Python test file and the YAML configuration.

Details are in the specific comments. Overall, this is a solid contribution.

Comment on lines 22 to 32
def generate_avg_pool2d_kernel_code(kernel_h, kernel_w, count_include_pad):
"""
Generate Triton kernel source code with unrolled loops for specific kernel size.
Args:
kernel_h: Kernel height
kernel_w: Kernel width
count_include_pad: Whether to include padding in count
Returns:
str: Generated kernel source code
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The PR description mentions an "adaptive strategy" with "compact code for large kernels (5x5+) to reduce register pressure". However, this implementation uses the same full unrolling strategy for all kernel sizes. This approach, which creates unique variables for each position in the kernel window (e.g., h_in_0_0, h_in_0_1), can lead to excessive register usage and poor performance for larger kernels, as confirmed by the benchmark results for 7x7 kernels.

Please implement the adaptive strategy. For kernels larger than a certain threshold (e.g., 4x4), you should generate code that reuses temporary variables inside the unrolled loop to reduce register pressure.

For example, for large kernels, you could generate code like this:

# In the generated Triton kernel for large kernels
# ...
h_in = h_start + {kh}
w_in = w_start + {kw}
valid = (h_in >= 0) & (h_in < H) & (w_in >= 0) & (w_in < W) & mask
# ...

This avoids creating unique variables for each kernel element.

Comment on lines 174 to 209
# Create temporary file for the generated kernel
code_hash = hashlib.md5(kernel_code.encode()).hexdigest()[:8]
temp_dir = tempfile.gettempdir()
kernel_file = os.path.join(
temp_dir,
f"avg_pool2d_{kernel_h}x{kernel_w}_pad{int(count_include_pad)}_{code_hash}.py",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function get_or_create_codegen_kernel creates temporary Python files for the generated Triton kernels in the system's temporary directory. These files are not deleted, which can lead to an accumulation of files over time, especially in long-running applications or during extensive testing.

To ensure these temporary files are cleaned up properly, you could register a cleanup function using the atexit module. This function would be called upon script exit to remove all generated files.

Here's a conceptual example:

import atexit
import os

_temp_files_to_clean = []

def _cleanup_temp_files():
    for f in _temp_files_to_clean:
        try:
            os.remove(f)
        except OSError:
            pass

atexit.register(_cleanup_temp_files)

# In get_or_create_codegen_kernel:
# ...
# after creating kernel_file path
_temp_files_to_clean.append(kernel_file)
with open(kernel_file, "w") as f:
    f.write(kernel_code)
# ...

Comment on lines 586 to 601
# Common shapes from popular CNN architectures
return [
# ResNet-like feature maps
(4, 64, 112, 112, 3, 2, 1), # Early layer pooling
(8, 128, 56, 56, 3, 2, 1), # Mid layer pooling
(16, 256, 28, 28, 2, 2, 0), # Late layer pooling
# VGG-like
(2, 64, 224, 224, 2, 2, 0), # Max pooling equivalent size
(4, 512, 14, 14, 2, 2, 0), # Deep layer
# MobileNet-like
(8, 32, 112, 112, 7, 2, 3), # Large kernel
# ImageNet standard input
(1, 3, 224, 224, 3, 2, 1), # Single image
# Large batch
(32, 64, 56, 56, 3, 2, 1), # Batch training
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The shapes for the AvgPool2dBenchmark are defined in both benchmark/core_shapes.yaml and here in the set_more_shapes method. This leads to redundancy, as 7 out of 8 shapes are identical in both places. The GenericBenchmark class already loads shapes from the YAML file, and set_more_shapes is intended for adding extra shapes.

To improve maintainability, you should only include shapes in this method that are not already present in core_shapes.yaml. This avoids confusion and makes it easier to manage benchmark configurations.

        # Most shapes are defined in benchmark/core_shapes.yaml.
        # This method adds shapes not present in the YAML file.
        return [
            # Large batch with 3x3 kernel, stride 2, padding 1
            (32, 64, 56, 56, 3, 2, 1),
        ]

@factnn factnn force-pushed the add-avg-pool2d-operator branch from 7b6c4b5 to ba29ff2 Compare October 24, 2025 09:52
@factnn
Copy link
Author

factnn commented Oct 24, 2025

@ gemini-code-assist check the latest PR please.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new avg_pool2d operator, implemented with dynamic Triton code generation for optimized performance. The implementation is impressive, featuring an adaptive strategy for different kernel sizes and comprehensive support for all avg_pool2d parameters. The performance gains are well-documented, and the operator is thoroughly covered by 157 accuracy tests and 24 performance benchmarks. The code is clean, well-structured, and the changes across the repository are correctly handled. I have one suggestion to improve the code generation mechanism by using exec() instead of writing to temporary files, which would make it more robust and self-contained.

Comment on lines 203 to 222
# Create temporary file for the generated kernel
code_hash = hashlib.md5(kernel_code.encode()).hexdigest()[:8]
temp_dir = tempfile.gettempdir()
kernel_file = os.path.join(
temp_dir,
f"avg_pool2d_{kernel_h}x{kernel_w}_pad{int(count_include_pad)}_{code_hash}.py",
)

# Write kernel code to file
with open(kernel_file, "w") as f:
f.write(kernel_code)

# Dynamically import the kernel
spec = importlib.util.spec_from_file_location(f"kernel_{code_hash}", kernel_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

# Get the kernel function
kernel_name = f"avg_pool2d_kernel_{kernel_h}x{kernel_w}_pad{int(count_include_pad)}"
kernel_func = getattr(module, kernel_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation for dynamic kernel generation involves writing the generated code to a temporary file and then importing it using importlib. This approach has a few drawbacks:

  • Filesystem I/O: It introduces filesystem operations, which can be slower and a source of errors (e.g., permissions, disk full).
  • Cleanup: Temporary files might not always be cleaned up properly, especially if the process crashes.
  • Complexity: It adds complexity with tempfile, os.path, and importlib.

A cleaner and more robust approach is to use exec() to compile the generated code in memory. This avoids filesystem interaction entirely, making the code generation process more efficient and self-contained. The kernel_code string already includes the necessary imports, so they will be available in the new namespace after exec().

Suggested change
# Create temporary file for the generated kernel
code_hash = hashlib.md5(kernel_code.encode()).hexdigest()[:8]
temp_dir = tempfile.gettempdir()
kernel_file = os.path.join(
temp_dir,
f"avg_pool2d_{kernel_h}x{kernel_w}_pad{int(count_include_pad)}_{code_hash}.py",
)
# Write kernel code to file
with open(kernel_file, "w") as f:
f.write(kernel_code)
# Dynamically import the kernel
spec = importlib.util.spec_from_file_location(f"kernel_{code_hash}", kernel_file)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
# Get the kernel function
kernel_name = f"avg_pool2d_kernel_{kernel_h}x{kernel_w}_pad{int(count_include_pad)}"
kernel_func = getattr(module, kernel_name)
# Dynamically compile the kernel in memory using exec(). This avoids file I/O,
# making it cleaner and more robust.
kernel_namespace = {}
exec(kernel_code, kernel_namespace)
# Get the kernel function from the namespace.
kernel_name = f"avg_pool2d_kernel_{kernel_h}x{kernel_w}_pad{int(count_include_pad)}"
kernel_func = kernel_namespace[kernel_name]

@factnn factnn force-pushed the add-avg-pool2d-operator branch from ba29ff2 to 2062e76 Compare October 27, 2025 07:02
@factnn
Copy link
Author

factnn commented Oct 27, 2025

@gemini-code-assist I have revised the model as you suggested before, please check the code now

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new avg_pool2d operator implemented with Triton, using dynamic code generation to achieve significant performance speedups. The implementation is impressive, featuring an adaptive strategy for different kernel sizes and comprehensive support for all avg_pool2d parameters. The test coverage is also excellent, with numerous accuracy and performance tests.

My review focuses on improving the clarity and maintainability of the new implementation. I've suggested updating the docstrings to be more accurate, refactoring the main functions for simplicity, and a potential optimization related to the divisor_override parameter. Overall, this is a high-quality contribution.

Comment on lines 1 to 4
"""
Dynamic code generation version of avg_pool2d for small kernels (2x2, 3x3)
This version unrolls the pooling loops at compile time for better performance.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring is misleading. It states that this implementation is for "small kernels (2x2, 3x3)", but the code actually implements an adaptive strategy that works for all kernel sizes. The docstring should be updated to accurately describe the implementation.

Suggested change
"""
Dynamic code generation version of avg_pool2d for small kernels (2x2, 3x3)
This version unrolls the pooling loops at compile time for better performance.
"""
"""
Implementation of avg_pool2d using dynamic Triton code generation.
This version unrolls the pooling loops at compile time and uses an adaptive strategy
for both small and large kernels to optimize performance.
"""

Comment on lines 17 to 32
def generate_avg_pool2d_kernel_code(kernel_h, kernel_w, count_include_pad):
"""
Generate Triton kernel source code with unrolled loops for specific kernel size.
Adaptive strategy:
- Small kernels (<=4x4): Full unroll with unique variables for better ILP
- Large kernels (>4x4): Compact code with reused variables to reduce register pressure
Args:
kernel_h: Kernel height
kernel_w: Kernel width
count_include_pad: Whether to include padding in count
Returns:
str: Generated kernel source code
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation calculates count inside the Triton kernel even when divisor_override is provided. This is unnecessary work. Since divisor_override is known before kernel generation, you can optimize this by generating different kernels based on whether divisor_override is used.

Consider adding a use_divisor_override flag to generate_avg_pool2d_kernel_code and the kernel cache key. If True, the generated kernel can skip all count related calculations and directly use divisor_override as the divisor. This will result in a more efficient kernel for that specific case.

Comment on lines 223 to 334
def avg_pool2d_codegen(
input,
kernel_size,
stride=None,
padding=0,
ceil_mode=False,
count_include_pad=True,
divisor_override=None,
):
"""
Code-generated version of avg_pool2d for small kernels.
Args:
input: Input tensor [N, C, H, W]
kernel_size: Size of pooling kernel
stride: Stride of pooling operation
padding: Padding to add to input
ceil_mode: When True, use ceil instead of floor for output size
count_include_pad: When True, include padding in average calculation
divisor_override: If specified, override divisor for averaging
Returns:
Output tensor after avg pooling
"""
logger.debug("GEMS AVG_POOL2D (CodeGen)")

# Handle kernel_size
if isinstance(kernel_size, int):
kernel_h = kernel_w = kernel_size
else:
kernel_h, kernel_w = kernel_size

# Handle stride
if stride is None:
stride_h = kernel_h
stride_w = kernel_w
elif isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride

# Handle padding
if isinstance(padding, int):
pad_h = pad_w = padding
else:
pad_h, pad_w = padding

# Get input dimensions
N, C, H, W = input.shape

# Calculate output dimensions
if ceil_mode:
H_out = (H + 2 * pad_h - kernel_h + stride_h - 1) // stride_h + 1
W_out = (W + 2 * pad_w - kernel_w + stride_w - 1) // stride_w + 1
else:
H_out = (H + 2 * pad_h - kernel_h) // stride_h + 1
W_out = (W + 2 * pad_w - kernel_w) // stride_w + 1

# Create output tensor
output = torch.empty((N, C, H_out, W_out), device=input.device, dtype=input.dtype)

# Handle divisor_override
if divisor_override is None:
divisor_override = -1

# Get or generate kernel
kernel_func = get_or_create_codegen_kernel(kernel_h, kernel_w, count_include_pad)

# Launch kernel with autotune
grid = lambda meta: (triton.cdiv(N * C * H_out * W_out, meta["BLOCK_SIZE"]),)

kernel_func[grid](
input,
output,
N,
C,
H,
W,
H_out,
W_out,
stride_h,
stride_w,
pad_h,
pad_w,
divisor_override,
)

return output


def avg_pool2d(
input,
kernel_size,
stride=None,
padding=0,
ceil_mode=False,
count_include_pad=True,
divisor_override=None,
):
"""
Average pooling using pure code generation with unrolled loops.
Works for all kernel sizes.
"""
return avg_pool2d_codegen(
input,
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function avg_pool2d is a simple wrapper around avg_pool2d_codegen, which adds an unnecessary layer of indirection. Additionally, the docstring for avg_pool2d_codegen is misleading, as it states the implementation is only for "small kernels", while the code supports all kernel sizes.

To improve clarity and simplify the code, I suggest:

  1. Removing the avg_pool2d wrapper function.
  2. Renaming avg_pool2d_codegen to avg_pool2d.
  3. Updating the docstring of the main function to be accurate.

@factnn factnn force-pushed the add-avg-pool2d-operator branch from 2062e76 to 0fe009e Compare October 27, 2025 07:28
@factnn
Copy link
Author

factnn commented Oct 27, 2025

@ gemini-code-assist check

@gemini-code-assist
Copy link
Contributor

Hello @factnn, I've received your request to check the pull request. I am now initiating a new review of the current changes. You will receive my feedback shortly. For future reviews, you can use the /gemini review command.

@factnn factnn force-pushed the add-avg-pool2d-operator branch from 0fe009e to 91c47b2 Compare October 28, 2025 03:04
This PR implements torch.nn.functional.avg_pool2d using Triton with dynamic
code generation for optimal performance.

Implementation highlights:
- Dynamic code generation with loop unrolling for all kernel sizes
- Adaptive strategy: compact code for large kernels (5x5+) to reduce register pressure
- Full parameter support: kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override
- Comprehensive test coverage: 157 accuracy tests, 24 performance benchmarks

Performance results:
- Float32: 1.3-1.7x speedup for common kernel sizes (2x2, 3x3)
- Float16/BFloat16: 1.1-1.6x speedup
- Covers CNN architectures: ResNet, VGG, MobileNet, ImageNet

Test coverage:
- Basic operations (kernel_size, stride)
- Edge cases (padding, ceil_mode, count_include_pad, divisor_override)
- Various input sizes (small, medium, large, non-square)
- All float dtypes (float16, float32, bfloat16)
@factnn factnn force-pushed the add-avg-pool2d-operator branch from 91c47b2 to 636526d Compare October 29, 2025 03:43
Following FlagGems architectural standards for code-generated operators:

- Add runtime import to generated kernel code
- Use runtime.get_tuned_config("avg_pool2d") instead of hardcoded configs
- Add avg_pool2d entry to tune_configs.yaml with 4 configurations
- Maintains key=["N", "C", "H_out", "W_out"] for optimal performance

This follows the same pattern as index_put operator which also uses
dynamic code generation with runtime-loaded autotune configurations.

Changes:
- Simplified autotune decorator generation (9 lines → 4 lines)
- Added avg_pool2d configs to runtime/backend/_nvidia/tune_configs.yaml
- Preserved all existing functionality and performance characteristics
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant