Weekly Project News

Archives

Weekly GitHub Report for Jax: March 23, 2026 - March 30, 2026 (22:23:21)

Weekly GitHub Report for Jax

Thank you for subscribing to our weekly newsletter! Each week, we deliver a comprehensive summary of your GitHub project's latest activity right to your inbox, including an overview of your project's issues, pull requests, contributors, and commit activity.


Table of Contents

  • I. News
    • 1.1. Recent Version Releases
    • 1.2. Other Noteworthy Updates
  • II. Issues
    • 2.1. Top 5 Active Issues
    • 2.2. Top 5 Stale Issues
    • 2.3. Open Issues
    • 2.4. Closed Issues
    • 2.5. Issue Discussion Insights
  • III. Pull Requests
    • 3.1. Open Pull Requests
    • 3.2. Closed Pull Requests
    • 3.3. Pull Request Discussion Insights
  • IV. Contributors
    • 4.1. Contributors

I. News

1.1 Recent Version Releases:

The current version of this repository is jax-v0.9.0.1

1.2 Version Information:

On February 3, 2026, JAX v0.9.0.1 was released as a patch update to v0.9.0, incorporating four specific pull requests from the OpenXLA repository to address targeted improvements or fixes without introducing major changes.

Click here to view the full release notes!

II. Issues

2.1 Top 5 Active Issues:

We consider active issues to be issues that that have been commented on most frequently within the last week. Bot comments are omitted.

  1. [BUG] Pallas TPU Lowering crash: _***_lowering_rule() missing 1 required positional argument: 'accuracy': This issue reports a crash occurring during the lowering process on Pallas TPU when using jax.scipy.special.erf and similar functions, due to a missing required positional argument accuracy in the _lowering_rule function. The user provides a reproducible script and error traceback demonstrating that the problem affects multiple lowering rules that expect an accuracy argument but do not receive it.

    • The single comment simply tags a contributor for awareness, indicating that the issue has been noted but no further discussion or resolution steps have been provided yet.
    • Number of comments this week: 1
  2. [BUG] Data loss on device 1 during multi-GPU sharded device_put: This issue reports a problem where data sharded across two GPUs using JAX 0.9.1 results in the data on the second GPU being replaced with zeros when converted back to the host, specifically on a multi-GPU NVIDIA A30 setup. The user provides a minimal reproducible script demonstrating that while data on the first GPU is correct, the second GPU's data is lost, highlighting a potential bug in multi-GPU sharded device_put operations.

    • A single comment expresses interest in working on the issue and requests assignment, indicating the start of engagement but no further discussion or resolution yet.
    • Number of comments this week: 1

Since there were fewer than 5 open issues, all of the open issues have been listed above.

2.2 Top 5 Stale Issues:

We consider stale issues to be issues that has had no activity within the last 30 days. The team should work together to get these issues resolved and closed as soon as possible.

As of our latest update, there are no stale issues for the project this week.

2.3 Open Issues

This section lists, groups, and then summarizes issues that were created within the last week in the repository.

Issues Opened This Week: 2

Summarized Issues:

  • Lowering rule crashes due to missing arguments: The lowering of certain functions like jax.scipy.special.erf on Pallas TPU crashes because the required positional argument accuracy is missing in the _erf_lowering_rule() function. This issue also affects other functions that have an accuracy argument in their lowering rules, causing failures during compilation.
  • issues/36149
  • Multi-GPU data sharding errors: When sharding data across two GPUs and converting it back to the host on a multi-GPU NVIDIA A30 setup, data corresponding to the second GPU is replaced with zeros. This problem occurs specifically in JAX version 0.9.1, leading to incorrect data handling in multi-GPU environments.
  • issues/36308

2.4 Closed Issues

This section lists, groups, and then summarizes issues that were closed within the last week in the repository. This section also links the associated pull requests if applicable.

Issues Closed This Week: 3

Summarized Issues:

  • Batch broadcasting and shape mismatch in cho_solve: The jax.scipy.linalg.cho_solve function fails to correctly replicate SciPy's batch broadcasting behavior when solving with a batched Cholesky factor and an unbatched right-hand-side vector, resulting in a shape mismatch error. This issue arises due to improper handling of batch dimensions during the solve operation.
  • issues/36083
  • Pallas kernel incorrect results on TPU in non-interpret mode: A Pallas kernel that uses reshape and indexing operations produces incorrect results on TPU hardware when run in non-interpret mode, while the same kernel works correctly in interpret mode. This discrepancy indicates a problem with kernel execution or optimization on TPU outside of interpret mode.
  • issues/36287
  • Persistent cache warning flood in jax 0.9.1+: Enabling persistent cache in jax version 0.9.1 and later causes repeated warning messages related to pjrt_executable.cc to flood stderr on each cache hit. This issue did not occur in earlier versions, indicating a regression or new bug introduced with persistent cache support.
  • issues/36294

2.5 Issue Discussion Insights

This section will analyze the tone and sentiment of discussions within this project's open and closed issues that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.

Based on our analysis, there are no instances of toxic discussions in the project's open or closed issues from the past week.


III. Pull Requests

3.1 Open Pull Requests

This section provides a summary of pull requests that were opened in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.

Pull Requests Opened This Week: 5

Key Open Pull Requests

1. Fix cuDNN attention backward pass partitioning with sharded inputs: This pull request fixes the cuDNN attention backward pass SPMD partitioning by updating the partition function to ensure that the gradient output and forward output shardings match the query's sharding when inputs are sharded, addressing issue #25986 and adding a regression test to verify the correction.

  • URL: pull/36169
  • Associated Commits: 8ff3e, d42e1

2. Use hermetic clang to compile jax: This pull request updates the JAX build process to use the hermetic clang compiler provided by the XLA toolchains for compilation.

  • URL: pull/36283
  • Associated Commits: 1e512

3. Removed a few unused attributes: This pull request removes a few unused attributes from the codebase to improve code cleanliness and maintainability.

  • URL: pull/36303
  • Associated Commits: 72376

Other Open Pull Requests

  • ROCm Configuration Updates: This pull request updates the ROCm configuration to disable the command_buffer feature specifically for pallas tests. The change targets improving compatibility or stability for tests running on ROCm platforms.
  • pull/36304
  • New Experimental Primitive Implementation: This pull request introduces a new experimental primitive called lora_dot with a custom vector-Jacobian product (custom_vjp) implementation to the JAX library. The addition aims to extend JAX's capabilities with new differentiable operations.
  • pull/36307

3.2 Closed Pull Requests

This section provides a summary of pull requests that were closed in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.

Pull Requests Closed This Week: 41

Key Closed Pull Requests

1. [ROCm] Enable hipBLASLt in ROCm pytest CI: This pull request enables hipBLASLt support in the ROCm pytest continuous integration by adding the flag --xla_gpu_enable_cublaslt=true to XLA_FLAGS in the run_pytest_rocm.sh script, which is necessary for TF32 GEMM support and to prevent autotuner failures for TF32 dot algorithms on ROCm.

  • URL: pull/35902
  • Associated Commits: a76fa, f01e2
  • Associated Commits: a76fa, f01e2

2. feat: Add warning for large closure captures in JAX compilation: This pull request introduces a configurable warning in JAX that alerts users during function compilation when large arrays are captured in closures, which can lead to out-of-memory errors, by adding a new config flag and helper function to emit these warnings and provide guidance on avoiding such issues.

  • URL: pull/36305
  • Associated Commits: 24665, 672c0
  • Associated Commits: 24665, 672c0

3. don't print sharding/out_sharding in dot, broadcast, reshard: This pull request improves the JAX codebase by modifying the dot, broadcast, and reshard operations to exclude printing of sharding and out_sharding information, thereby streamlining debug output and enhancing readability.

  • URL: pull/35826
  • Associated Commits: a6e22
  • Associated Commits: a6e22

Other Closed Pull Requests

  • ROCm support and optimization: Multiple pull requests enhance ROCm support by adding lowering for ScaledMatmul and ScaledDot operations, and by compressing offloaded device code to reduce PJRT wheel size by about 40% without runtime impact. These changes improve ROCm backend functionality and efficiency in the JAX project.
    • pull/35995, pull/36055
  • Batching and differentiation improvements: Pull requests introduce a new hijax primitive for jnp.nonzero to simplify batching and automatic differentiation, and implement SciPy 1.16-style auto-batching in jsp.linalg to enhance batch processing consistency. These updates provide better composability and compatibility with existing scientific computing standards.
    • pull/36053, pull/36093
  • Constant handling and simplified JAXPR constants: Several pull requests update constant argument handling in pjit by converting them to MetaTy, fix AOT compilation to account for closed over constants, and modify Jaxpr constant handling to prevent treating constants as literals during autodiff. These changes enable and support the JAX_USE_SIMPLIFIED_JAXPR_CONSTANTS feature for improved constant management.
    • pull/36131, pull/36141, pull/36186
  • Code cleanup and deprecated API removal: Multiple pull requests remove deprecated APIs and symbols, update error behaviors for removed symbols, eliminate leftover mypy references, and remove outdated functions to improve code clarity and maintainability. These removals align the codebase with recent version changes and reduce technical debt.
    • pull/36154, pull/36165, pull/36166, pull/36167, pull/36216, pull/36207
  • Documentation and build stability improvements: A pull request addresses flaky Sphinx notebook builds caused by race conditions in Jupyter kernel port assignment by introducing a dedicated portserver, ensuring conflict-free port allocation during documentation builds. Another removes examples in favor of documentation coverage for nested vmaps.
    • pull/36130, pull/36138
  • Validation and GPU lowering adjustments: One pull request disables check_vma validation inside core_map and mosaic GPU lowering to enable mosaic GPU kernel usage, adding tests to ensure correct ragged dot kernel lowering within a shard_map. This change facilitates GPU kernel compatibility and correctness.
    • pull/36180
  • Pytrees and gradient computation enhancements: Pull requests introduce compatibility of DidntWant and GradRef with pytrees, implement a skip mechanism in bilinear transpose, and enable rematerialization with reference gradient accumulators, enhancing gradient computation flexibility and correctness.
    • pull/36181, pull/36184
  • MLIR and type checking improvements: Several pull requests fix MLIR type stubs, ensure passing type checks with updated bindings, propose new helper APIs for dense i32 and i64 array attributes with clearer element type enforcement, and update Pyrefly settings to Python 3.11 with suppressions for MLIR Buffer type issues. These changes improve type safety and compatibility with tooling.
    • pull/36189, pull/36200, pull/36204
  • Lint and test suite fixes: Pull requests fix lint errors and address a lint issue in the test suite while adding functionality to track the total number of equations in a jaxpr for metrics gathering, correlating trace latency with equation count. These updates improve code quality and observability.
    • pull/36200, pull/36189
  • x64 API updates: One pull request updates the x64 context test to replace removed experimental APIs with the new jax.enable_x64 API introduced in version 0.10.0, ensuring tests reflect current API usage.
    • pull/36167

3.3 Pull Request Discussion Insights

This section will analyze the tone and sentiment of discussions within this project's open and closed pull requests that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.

Based on our analysis, there are no instances of toxic discussions in the project's open or closed pull requests from the past week.


IV. Contributors

4.1 Contributors

Active Contributors:

We consider an active contributor in this project to be any contributor who has made at least 1 commit, opened at least 1 issue, created at least 1 pull request, or made more than 2 comments in the last month.

If there are more than 10 active contributors, the list is truncated to the top 10 based on contribution metrics for better clarity.

Contributor Commits Pull Requests Issues Comments
jakevdp 40 11 0 2
superbobry 26 10 0 3
magaonka-amd 20 1 0 0
alekstheod 17 3 0 0
mattjj 14 6 0 0
gnecula 13 3 0 0
cj401-amd 12 2 0 0
gulsumgudukbay 12 1 0 0
hawkinsp 6 2 0 3
danielsuo 8 0 0 0

Don't miss what's next. Subscribe to Weekly Project News:
Powered by Buttondown, the easiest way to start and grow your newsletter.