Weekly Project News

Archives

Weekly GitHub Report for Jax: February 15, 2026 - February 22, 2026 (14:49:24)

Weekly GitHub Report for Jax

Thank you for subscribing to our weekly newsletter! Each week, we deliver a comprehensive summary of your GitHub project's latest activity right to your inbox, including an overview of your project's issues, pull requests, contributors, and commit activity.


Table of Contents

  • I. News
    • 1.1. Recent Version Releases
    • 1.2. Other Noteworthy Updates
  • II. Issues
    • 2.1. Top 5 Active Issues
    • 2.2. Top 5 Stale Issues
    • 2.3. Open Issues
    • 2.4. Closed Issues
    • 2.5. Issue Discussion Insights
  • III. Pull Requests
    • 3.1. Open Pull Requests
    • 3.2. Closed Pull Requests
    • 3.3. Pull Request Discussion Insights
  • IV. Contributors
    • 4.1. Contributors

I. News

1.1 Recent Version Releases:

The current version of this repository is jax-v0.9.0.1

1.2 Version Information:

On February 3, 2026, JAX v0.9.0.1 was released as a patch update to v0.9.0, incorporating four specific pull requests from the OpenXLA repository to address targeted improvements or fixes without introducing major changes.

Click here to view the full release notes!

II. Issues

2.1 Top 5 Active Issues:

We consider active issues to be issues that that have been commented on most frequently within the last week. Bot comments are omitted.

  1. [BUG] Shard map return type hint: This issue addresses the lack of a return type hint for the shard_map function in JAX, which causes problems for type checkers like pyright when building packages. The user suggests adding a return type hint similar to those in pmap and vmap to improve type inference and compatibility with static type checkers.

    • The comments discuss attempts to fix the issue by adding overloads and type variables to shard_map, with contributors noting the difficulty of changing widely used APIs due to numerous downstream type errors; multiple pull requests are referenced, and the conversation highlights the complexity of maintaining type stability in a large ecosystem.
    • Number of comments this week: 6
  2. [ENHANCEMENT] Add lax.associative_reduce (parallel tree reduction without downsweep): This issue proposes adding a new lax.associative_reduce primitive to perform parallel tree reduction without the downsweep phase, addressing inefficiencies in existing JAX primitives when reducing non-scalar elements like matrices. The goal is to reduce computational overhead and memory usage in operations such as multiplying chains of matrices, which is important for modern ML workloads like linear RNNs and SSMs.

    • The comments express general support for the proposal but suggest considering generalizing the existing lax.reduce to handle non-scalar reductions by leveraging StableHLO capabilities; attempts to bypass current scalar-only restrictions lead to XLA errors, and there is discussion about the complexity of supporting multiple reduction dimensions for associative but non-commutative operations.
    • Number of comments this week: 6
  3. [BUG] schur (CPU) segfaults when numpy and scipy are linked to MKL: This issue reports a segmentation fault occurring when using the schur function from jax.scipy.linalg on the CPU with numpy and scipy both linked to the Intel MKL library, which does not happen when linked to OpenBLAS. The user has identified specific versions of jax and jaxlib where the problem was introduced and provided detailed reproduction steps, including building numpy and scipy with MKL and running the code in a controlled environment.

    • The comments confirm the segfault also occurs with other linear algebra functions like sqrtm and on different platforms including Windows, with additional test failures reported, indicating the issue is reproducible beyond the original environment and affects multiple jax versions when linked against MKL.
    • Number of comments this week: 2
  4. [BUG] [TYPE:BUG] Skip CUDA PJRT init when JAX_PLATFORMS=cpu: This issue reports a bug where importing JAX with CUDA PJRT plugins installed triggers CUDA initialization and error logs even when the environment is explicitly set to use only the CPU platform, which should prevent any CUDA-related checks. The expected behavior is that the CUDA plugin should not initialize or perform CUDA validation if CUDA or GPU platforms are not requested, thereby avoiding unnecessary error logs.

    • The single comment suggests a workaround by setting an environment variable to skip CUDA constraints checks, which removes the error but is noted as an unusual requirement.
    • Number of comments this week: 1

Since there were fewer than 5 open issues, all of the open issues have been listed above.

2.2 Top 5 Stale Issues:

We consider stale issues to be issues that has had no activity within the last 30 days. The team should work together to get these issues resolved and closed as soon as possible.

As of our latest update, there are no stale issues for the project this week.

2.3 Open Issues

This section lists, groups, and then summarizes issues that were created within the last week in the repository.

Issues Opened This Week: 4

Summarized Issues:

  • Type Annotations and Compatibility: This topic covers issues related to missing or inconsistent type hints in JAX functions, which cause problems for type checkers like pyright during package builds. The discussion includes proposed fixes to align type annotations with similar functions to improve developer experience and code reliability.
  • issues/35101
  • Platform Initialization and Configuration: This topic addresses the problem where JAX initializes the CUDA PJRT plugin even when configured to use only the CPU platform, leading to unnecessary CUDA checks and error logs despite explicit disabling of CUDA. This causes confusion and inefficiency during environment setup.
  • issues/35105
  • Performance Optimization in Reduction Operations: This topic involves the proposal of a new lax.associative_reduce primitive to perform parallel tree reduction without the downsweep phase, targeting inefficiencies in current JAX reduction methods for non-scalar elements like matrix chains. The goal is to reduce computational overhead and memory usage in modern machine learning workloads.
  • issues/35118
  • Segmentation Faults with MKL-linked Libraries: This topic describes a segmentation fault occurring in the schur function from jax.scipy.linalg when both numpy and scipy are linked to the Intel MKL library. The issue is reproducible across different environments and versions starting from jax==0.4.38 and jaxlib==0.4.38, indicating a critical stability problem.
  • issues/35134

2.4 Closed Issues

This section lists, groups, and then summarizes issues that were closed within the last week in the repository. This section also links the associated pull requests if applicable.

Issues Closed This Week: 5

Summarized Issues:

  • vmap and batching issues: Multiple problems arise when using vmap in JAX, including an AttributeError related to a missing 'cur_qdd' attribute in the 'BatchTrace' object when using a custom Box type, and a silent output corruption bug caused by XLA buffer aliasing when combining gather and jnp.where over batched inputs of different sizes. These issues affect the correctness and stability of batched computations and require specific workarounds like returning gathered intermediates to avoid buffer aliasing.
  • issues/34758, issues/35252
  • Memory and system configuration errors: JAX compilation can fail with an LLVM out-of-memory error despite ample free system memory, which was resolved by increasing the Linux kernel parameter /proc/sys/vm/max_map_count. This indicates that system-level limits on memory mappings can impact JAX's ability to compile large computations.
  • issues/35121
  • Sharding and differentiation bugs: Using explicit sharding combined with reverse-mode automatic differentiation on arrays with a singleton axis causes a ValueError due to mismatched cotangent types, affecting workflows such as RMSNorm with batch size 1. This bug appears in JAX versions 0.8.2 and 0.9.0 but not in 0.8.1, highlighting a regression in handling sharded arrays during differentiation.
  • issues/35181
  • Device placement and boolean masking issues: Boolean masking on CPU devices can incorrectly produce zero-sized arrays placed on CUDA devices instead of the selected CPU device, caused by faulty device propagation logic in the lax.full_like function when handling empty arrays. This leads to unexpected device placement behavior that can affect computations relying on explicit device control.
  • issues/35273

2.5 Issue Discussion Insights

This section will analyze the tone and sentiment of discussions within this project's open and closed issues that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.

Based on our analysis, there are no instances of toxic discussions in the project's open or closed issues from the past week.


III. Pull Requests

3.1 Open Pull Requests

This section provides a summary of pull requests that were opened in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.

Pull Requests Opened This Week: 20

Key Open Pull Requests

1. scipy.linalg: add hankel special matrix: This pull request adds the jax.scipy.linalg.hankel special matrix function along with corresponding tests, includes several fixes and improvements based on code reviews, and updates the documentation to incorporate the new function.

  • URL: pull/35223
  • Associated Commits: eda5c, 8c8c2, 3d6bb, 3a68f, 70728, 2ddfb, c2f80

2. [ROCm] Implement Mosaic GPU detection and Auto-Skips: This pull request implements detection of Mosaic GPU tests and automatically skips them when running on ROCm by adding a pytest marker and classification logic based on test file paths, source usage signals, and default Mosaic behavior, thereby improving the reliability of ROCm test runs without affecting non-ROCm environments.

  • URL: pull/35288
  • Associated Commits: be9e3, 9cd64, 76239, 1f3a2, 15c1a, df0db

3. [ROCm] Set release rpaths to rocm so targets: This pull request introduces support for setting release wheel rpaths based on the rpath_type configuration from XLA, specifically adding ROCm release rpaths and implementing a "link_only" option that strips out all solib and custom rpaths while adding specific wheel rpaths to the released binaries.

  • URL: pull/35102
  • Associated Commits: 511e7, a35f5

Other Open Pull Requests

  • ROCm and GPU support enhancements: Multiple pull requests improve ROCm support and GPU functionality, including adding ROCm support to eigh export tests with platform-specific data, migrating rocm-jax changes upstream for building ROCm plugins, and introducing a multi-GPU splash attention implementation for Mosaic GPUs that outperforms CUDNN with causal masks. These changes enhance compatibility and performance across different GPU platforms and configurations.
  • pull/35111, pull/35251, pull/35114
  • Bug fixes and correctness improvements: Several pull requests address bugs and correctness issues, such as fixing crashes in Triton autotuning caused by improper buffer alias handling, correcting the gradient computation of the ref_swap function, fixing trigamma and digamma function computations for negative inputs, and resolving a bug in the Pallas backend related to shape polymorphism and dimension bounds. These fixes improve stability and mathematical correctness in JAX.
  • pull/35218, pull/35217, pull/35307, pull/35243
  • New features and API additions: New functionality includes the introduction of a jax.scipy.linalg.qr_multiply function and ormqr primitive for efficient least squares solutions, adding vmap support for the hijax Box type to enable broadcasting, and support for scalar arguments in as_torch_kernel via static_argnums. These additions expand JAX's capabilities and improve usability.
  • pull/35104, pull/35276, pull/35116
  • Tracing, tagging, and lowering improvements: Pull requests clarify and fix tracing behavior by ensuring LinearizeTrace.tag matches tangent_trace.tag, and add a failing test highlighting the missing MLIR translation rule for the call_hi_primitive in hijax on CPU. These changes improve internal consistency and expose unimplemented lowering paths.
  • pull/35219, pull/35214
  • Device and hardware configuration controls: A new flag jax_sort_devices_by_process_index is introduced to control device assignment order based on network topology and process locality, addressing compatibility with updated XLA strategies. Additionally, a pre-launch check for multicast support on Mosaic GPUs improves error messaging and test suite robustness on systems without multicast capabilities.
  • pull/35178, pull/35184
  • Testing and workflow improvements: Enhancements to the ROCm test workflow include generating structured pytest results, capturing run-manifest information, archiving test logs, and uploading them to a shared S3 location, improving test diagnostics and result accessibility.
  • pull/35283
  • Numerical precision improvements: The gradient precision of the jnp.sinc function near zero is improved by widening the Taylor series approximation region and replacing a constant helper with a 3-term Taylor expansion, fixing catastrophic cancellation and ensuring correct gradients for small inputs.
  • pull/35305
  • Documentation and blog improvements: Minor improvements to the fault tolerant mcjax blog include fixing a typo and adding a note about Pathways, enhancing clarity and accuracy of the documentation.
  • pull/35133

3.2 Closed Pull Requests

This section provides a summary of pull requests that were closed in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.

Pull Requests Closed This Week: 82

Key Closed Pull Requests

1. Improve ROCm pytest results handling: This pull request aims to enhance the ROCm testing workflow by generating structured pytest result outputs, capturing run-manifest information, packaging test logs into an archive, and uploading these logs to a shared S3 storage for improved test result handling and accessibility.

  • URL: pull/35282
  • Associated Commits: e64e7, c467a, 3a75c, 89add, 87360, 63fb9, cbf4e, bb279, e5525, 68a2a, ea5c1, d8cdc, 62486, b0f83, 9ce70, e2a8a, 041fd, 3b185, 256b4, 874a2, a3c05, 6e5ed, 16020, 2e64c, 25a1a, 6232c, cedc8, 4309a, 1b2c2, 64257, 4c12f, 951ba, e4db3, 8a7c2, 5037c, 98f76, b4696, 73f60, b75bd, 4045c, 6413b, 1b1d6
  • Associated Commits: e64e7, c467a, 3a75c, 89add, 87360, 63fb9, cbf4e, bb279, e5525, 68a2a, ea5c1, d8cdc, 62486, b0f83, 9ce70, e2a8a, 041fd, 3b185, 256b4, 874a2, a3c05, 6e5ed, 16020, 2e64c, 25a1a, 6232c, cedc8, 4309a, 1b2c2, 64257, 4c12f, 951ba, e4db3, 8a7c2, 5037c, 98f76, b4696, 73f60, b75bd, 4045c, 6413b, 1b1d6

2. [ROCm] rocm CI job with a job that executes the tests: This pull request proposes adjusting the ROCm continuous integration job to execute Bazel tests under remote build execution (RBE) with ROCm plugins as a dependency, although it was not merged.

  • URL: pull/35190
  • Associated Commits: c2168, 66893, 72b93, 85b3f
  • Associated Commits: c2168, 66893, 72b93, 85b3f

3. Faster jnp.trapezoid when dx is a scalar: This pull request improves the performance of the jnp.trapezoid function in the JAX library by implementing a faster computation path when the dx parameter is a scalar, resulting in speedups that align its efficiency with that of jnp.sum * dx and optimizing the handling of broadcasting for dx.

  • URL: pull/34943
  • Associated Commits: 5323b, b7229, e985c
  • Associated Commits: 5323b, b7229, e985c

Other Closed Pull Requests

  • Bug fix for boolean masking and sharding in JAX: This pull request fixes a bug where boolean masking resulting in a zero-sized dimension incorrectly drops the array's sharding and device placement by passing an empty mesh sharding to jax.lax.full_like. The fix removes the abstract slice sharding with an empty mesh before calling full_like to ensure the array correctly inherits its original physical device placement.
    • pull/35293
  • Fix for crash in Box.get() within jax.vmap: This pull request resolves a crash caused by calling Box.get() inside a jax.vmap context by modifying the function to directly access the cur_qdd attribute on concrete Box instances. This change fixes the AttributeError related to the missing cur_qdd attribute on BatchTrace without altering existing behavior for other JAX transformations.
    • pull/35099
  • Documentation improvements for numerical behavior and function arguments: These pull requests add warnings about batch invariance in dot_general and the non-associativity of floating-point arithmetic, clarifying hardware-level implications and guidance for deterministic accuracy. Additionally, documentation was added to lax.reduce and lax.reduce_window clarifying that init_values and init_value must be scalars, including examples and error messages for non-scalar inputs.
    • pull/35287, pull/35095
  • ROCm support and backward compatibility testing: Multiple pull requests add ROCm-specific backward compatibility test data and methods for LU decomposition, enhance unit tests with ROCm device support and serialization formats, and add the hip_threefry2x32_ffi function to the stable custom call targets list to ensure ROCm random number generation compatibility. These changes ensure stability and compatibility of ROCm features across versions.
    • pull/34829, pull/34929, pull/35115
  • Testing improvements for device detection and hijax module: One pull request improves testing efficiency by skipping multi-device tests when only a single device is detected, facilitating easier testing on single-GPU systems. Another introduces tests for the vmap function, hijax types, and the hijax primitive within the hijax module, expanding test coverage.
    • pull/34884, pull/35136
  • Shape polymorphism fix in pallas_call_batching: This pull request fixes the handling of shape polymorphism in the pallas_call_batching function by correctly managing symbolic batch dimensions, which were previously unsupported. This ensures proper batching behavior with symbolic shapes.
    • pull/34988
  • Pre-commit hook and formatting enhancements: This pull request improves pre-commit hooks by adding end-of-line-fixer and trailing-whitespace checks on C++ source files and automatically formatting BUILD and BUILD.bazel files using buildifier, enhancing code style consistency.
    • pull/35072
  • Introduction of shmap-of-hitypes and HipSpec in hijax: This pull request adds the functions {,un}shard_aval and the HipSpec specification to the hijax module, introducing a basic shmap-of-hitypes to the codebase.
    • pull/35083
  • Typing fixes in pyrefly component: These pull requests address additional typing fixes in the pyrefly component to improve code correctness and maintainability.
    • pull/35087, pull/35138
  • Addition of cur_qdd attribute to BatchTrace in hijax: This pull request adds the cur_qdd attribute to the BatchTrace class in the hijax module, addressing a specific issue and enabling related functionality.
    • pull/35137
  • Fix for LLVM integration inconsistencies in autodidax files: This pull request fixes inconsistencies caused by LLVM integration between autodidax.py, autodidax.md, and autodidax.ipynb by correcting the jupytext lint at the head of these files.
    • pull/35139
  • Automated refactor of GitHub Actions workflows: Multiple pull requests propose automated refactors of the project's GitHub Actions to comply with the latest standards outlined in the internal directive b/485167538. These changes aim to upgrade workflow configurations and may be force merged by the GHSS team if not accepted voluntarily.
    • pull/35141, pull/35142, pull/35143, pull/35144, pull/35145, pull/35146, pull/35147, pull/35148

3.3 Pull Request Discussion Insights

This section will analyze the tone and sentiment of discussions within this project's open and closed pull requests that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.

Based on our analysis, there are no instances of toxic discussions in the project's open or closed pull requests from the past week.


IV. Contributors

4.1 Contributors

Active Contributors:

We consider an active contributor in this project to be any contributor who has made at least 1 commit, opened at least 1 issue, created at least 1 pull request, or made more than 2 comments in the last month.

If there are more than 10 active contributors, the list is truncated to the top 10 based on contribution metrics for better clarity.

Contributor Commits Pull Requests Issues Comments
jakevdp 29 9 0 36
benknutson-google 44 0 0 0
google-admin 0 44 0 0
alekstheod 34 3 0 4
Ashutosh0x 17 14 0 7
mattjj 19 8 0 4
Harshadev-24 7 1 0 20
magaonka-amd 24 0 0 0
AratiGanesh 18 3 0 0
gulsumgudukbay 18 2 0 0

Access Last Week's Newsletter:

  • Link
Don't miss what's next. Subscribe to Weekly Project News:
Powered by Buttondown, the easiest way to start and grow your newsletter.