Weekly GitHub Report for Jax: February 15, 2026 - February 22, 2026 (14:49:24)
Weekly GitHub Report for Jax
Thank you for subscribing to our weekly newsletter! Each week, we deliver a comprehensive summary of your GitHub project's latest activity right to your inbox, including an overview of your project's issues, pull requests, contributors, and commit activity.
Table of Contents
I. News
1.1 Recent Version Releases:
The current version of this repository is jax-v0.9.0.1
1.2 Version Information:
On February 3, 2026, JAX v0.9.0.1 was released as a patch update to v0.9.0, incorporating four specific pull requests from the OpenXLA repository to address targeted improvements or fixes without introducing major changes.
Click here to view the full release notes!
II. Issues
2.1 Top 5 Active Issues:
We consider active issues to be issues that that have been commented on most frequently within the last week. Bot comments are omitted.
-
[BUG] Shard map return type hint: This issue addresses the lack of a return type hint for the
shard_mapfunction in JAX, which causes problems for type checkers like pyright when building packages. The user suggests adding a return type hint similar to those inpmapandvmapto improve type inference and compatibility with static type checkers.- The comments discuss attempts to fix the issue by adding overloads and type variables to
shard_map, with contributors noting the difficulty of changing widely used APIs due to numerous downstream type errors; multiple pull requests are referenced, and the conversation highlights the complexity of maintaining type stability in a large ecosystem. - Number of comments this week: 6
- The comments discuss attempts to fix the issue by adding overloads and type variables to
-
[ENHANCEMENT] Add
lax.associative_reduce(parallel tree reduction without downsweep): This issue proposes adding a newlax.associative_reduceprimitive to perform parallel tree reduction without the downsweep phase, addressing inefficiencies in existing JAX primitives when reducing non-scalar elements like matrices. The goal is to reduce computational overhead and memory usage in operations such as multiplying chains of matrices, which is important for modern ML workloads like linear RNNs and SSMs.- The comments express general support for the proposal but suggest considering generalizing the existing
lax.reduceto handle non-scalar reductions by leveraging StableHLO capabilities; attempts to bypass current scalar-only restrictions lead to XLA errors, and there is discussion about the complexity of supporting multiple reduction dimensions for associative but non-commutative operations. - Number of comments this week: 6
- The comments express general support for the proposal but suggest considering generalizing the existing
-
[BUG] schur (CPU) segfaults when numpy and scipy are linked to MKL: This issue reports a segmentation fault occurring when using the
schurfunction fromjax.scipy.linalgon the CPU with numpy and scipy both linked to the Intel MKL library, which does not happen when linked to OpenBLAS. The user has identified specific versions of jax and jaxlib where the problem was introduced and provided detailed reproduction steps, including building numpy and scipy with MKL and running the code in a controlled environment.- The comments confirm the segfault also occurs with other linear algebra functions like
sqrtmand on different platforms including Windows, with additional test failures reported, indicating the issue is reproducible beyond the original environment and affects multiple jax versions when linked against MKL. - Number of comments this week: 2
- The comments confirm the segfault also occurs with other linear algebra functions like
-
[BUG] [TYPE:BUG] Skip CUDA PJRT init when JAX_PLATFORMS=cpu: This issue reports a bug where importing JAX with CUDA PJRT plugins installed triggers CUDA initialization and error logs even when the environment is explicitly set to use only the CPU platform, which should prevent any CUDA-related checks. The expected behavior is that the CUDA plugin should not initialize or perform CUDA validation if CUDA or GPU platforms are not requested, thereby avoiding unnecessary error logs.
- The single comment suggests a workaround by setting an environment variable to skip CUDA constraints checks, which removes the error but is noted as an unusual requirement.
- Number of comments this week: 1
Since there were fewer than 5 open issues, all of the open issues have been listed above.
2.2 Top 5 Stale Issues:
We consider stale issues to be issues that has had no activity within the last 30 days. The team should work together to get these issues resolved and closed as soon as possible.
As of our latest update, there are no stale issues for the project this week.
2.3 Open Issues
This section lists, groups, and then summarizes issues that were created within the last week in the repository.
Issues Opened This Week: 4
Summarized Issues:
- Type Annotations and Compatibility: This topic covers issues related to missing or inconsistent type hints in JAX functions, which cause problems for type checkers like pyright during package builds. The discussion includes proposed fixes to align type annotations with similar functions to improve developer experience and code reliability.
- issues/35101
- Platform Initialization and Configuration: This topic addresses the problem where JAX initializes the CUDA PJRT plugin even when configured to use only the CPU platform, leading to unnecessary CUDA checks and error logs despite explicit disabling of CUDA. This causes confusion and inefficiency during environment setup.
- issues/35105
- Performance Optimization in Reduction Operations: This topic involves the proposal of a new
lax.associative_reduceprimitive to perform parallel tree reduction without the downsweep phase, targeting inefficiencies in current JAX reduction methods for non-scalar elements like matrix chains. The goal is to reduce computational overhead and memory usage in modern machine learning workloads. - issues/35118
- Segmentation Faults with MKL-linked Libraries: This topic describes a segmentation fault occurring in the
schurfunction fromjax.scipy.linalgwhen both numpy and scipy are linked to the Intel MKL library. The issue is reproducible across different environments and versions starting fromjax==0.4.38andjaxlib==0.4.38, indicating a critical stability problem. - issues/35134
2.4 Closed Issues
This section lists, groups, and then summarizes issues that were closed within the last week in the repository. This section also links the associated pull requests if applicable.
Issues Closed This Week: 5
Summarized Issues:
- vmap and batching issues: Multiple problems arise when using
vmapin JAX, including an AttributeError related to a missing 'cur_qdd' attribute in the 'BatchTrace' object when using a custom Box type, and a silent output corruption bug caused by XLA buffer aliasing when combininggatherandjnp.whereover batched inputs of different sizes. These issues affect the correctness and stability of batched computations and require specific workarounds like returning gathered intermediates to avoid buffer aliasing. - issues/34758, issues/35252
- Memory and system configuration errors: JAX compilation can fail with an LLVM out-of-memory error despite ample free system memory, which was resolved by increasing the Linux kernel parameter
/proc/sys/vm/max_map_count. This indicates that system-level limits on memory mappings can impact JAX's ability to compile large computations. - issues/35121
- Sharding and differentiation bugs: Using explicit sharding combined with reverse-mode automatic differentiation on arrays with a singleton axis causes a ValueError due to mismatched cotangent types, affecting workflows such as RMSNorm with batch size 1. This bug appears in JAX versions 0.8.2 and 0.9.0 but not in 0.8.1, highlighting a regression in handling sharded arrays during differentiation.
- issues/35181
- Device placement and boolean masking issues: Boolean masking on CPU devices can incorrectly produce zero-sized arrays placed on CUDA devices instead of the selected CPU device, caused by faulty device propagation logic in the
lax.full_likefunction when handling empty arrays. This leads to unexpected device placement behavior that can affect computations relying on explicit device control. - issues/35273
2.5 Issue Discussion Insights
This section will analyze the tone and sentiment of discussions within this project's open and closed issues that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.
Based on our analysis, there are no instances of toxic discussions in the project's open or closed issues from the past week.
III. Pull Requests
3.1 Open Pull Requests
This section provides a summary of pull requests that were opened in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.
Pull Requests Opened This Week: 20
Key Open Pull Requests
1. scipy.linalg: add hankel special matrix: This pull request adds the jax.scipy.linalg.hankel special matrix function along with corresponding tests, includes several fixes and improvements based on code reviews, and updates the documentation to incorporate the new function.
- URL: pull/35223
2. [ROCm] Implement Mosaic GPU detection and Auto-Skips: This pull request implements detection of Mosaic GPU tests and automatically skips them when running on ROCm by adding a pytest marker and classification logic based on test file paths, source usage signals, and default Mosaic behavior, thereby improving the reliability of ROCm test runs without affecting non-ROCm environments.
- URL: pull/35288
3. [ROCm] Set release rpaths to rocm so targets: This pull request introduces support for setting release wheel rpaths based on the rpath_type configuration from XLA, specifically adding ROCm release rpaths and implementing a "link_only" option that strips out all solib and custom rpaths while adding specific wheel rpaths to the released binaries.
- URL: pull/35102
Other Open Pull Requests
- ROCm and GPU support enhancements: Multiple pull requests improve ROCm support and GPU functionality, including adding ROCm support to eigh export tests with platform-specific data, migrating rocm-jax changes upstream for building ROCm plugins, and introducing a multi-GPU splash attention implementation for Mosaic GPUs that outperforms CUDNN with causal masks. These changes enhance compatibility and performance across different GPU platforms and configurations.
- pull/35111, pull/35251, pull/35114
- Bug fixes and correctness improvements: Several pull requests address bugs and correctness issues, such as fixing crashes in Triton autotuning caused by improper buffer alias handling, correcting the gradient computation of the ref_swap function, fixing trigamma and digamma function computations for negative inputs, and resolving a bug in the Pallas backend related to shape polymorphism and dimension bounds. These fixes improve stability and mathematical correctness in JAX.
- pull/35218, pull/35217, pull/35307, pull/35243
- New features and API additions: New functionality includes the introduction of a jax.scipy.linalg.qr_multiply function and ormqr primitive for efficient least squares solutions, adding vmap support for the hijax Box type to enable broadcasting, and support for scalar arguments in as_torch_kernel via static_argnums. These additions expand JAX's capabilities and improve usability.
- pull/35104, pull/35276, pull/35116
- Tracing, tagging, and lowering improvements: Pull requests clarify and fix tracing behavior by ensuring LinearizeTrace.tag matches tangent_trace.tag, and add a failing test highlighting the missing MLIR translation rule for the call_hi_primitive in hijax on CPU. These changes improve internal consistency and expose unimplemented lowering paths.
- pull/35219, pull/35214
- Device and hardware configuration controls: A new flag jax_sort_devices_by_process_index is introduced to control device assignment order based on network topology and process locality, addressing compatibility with updated XLA strategies. Additionally, a pre-launch check for multicast support on Mosaic GPUs improves error messaging and test suite robustness on systems without multicast capabilities.
- pull/35178, pull/35184
- Testing and workflow improvements: Enhancements to the ROCm test workflow include generating structured pytest results, capturing run-manifest information, archiving test logs, and uploading them to a shared S3 location, improving test diagnostics and result accessibility.
- pull/35283
- Numerical precision improvements: The gradient precision of the jnp.sinc function near zero is improved by widening the Taylor series approximation region and replacing a constant helper with a 3-term Taylor expansion, fixing catastrophic cancellation and ensuring correct gradients for small inputs.
- pull/35305
- Documentation and blog improvements: Minor improvements to the fault tolerant mcjax blog include fixing a typo and adding a note about Pathways, enhancing clarity and accuracy of the documentation.
- pull/35133
3.2 Closed Pull Requests
This section provides a summary of pull requests that were closed in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.
Pull Requests Closed This Week: 82
Key Closed Pull Requests
1. Improve ROCm pytest results handling: This pull request aims to enhance the ROCm testing workflow by generating structured pytest result outputs, capturing run-manifest information, packaging test logs into an archive, and uploading these logs to a shared S3 storage for improved test result handling and accessibility.
- URL: pull/35282
- Associated Commits: e64e7, c467a, 3a75c, 89add, 87360, 63fb9, cbf4e, bb279, e5525, 68a2a, ea5c1, d8cdc, 62486, b0f83, 9ce70, e2a8a, 041fd, 3b185, 256b4, 874a2, a3c05, 6e5ed, 16020, 2e64c, 25a1a, 6232c, cedc8, 4309a, 1b2c2, 64257, 4c12f, 951ba, e4db3, 8a7c2, 5037c, 98f76, b4696, 73f60, b75bd, 4045c, 6413b, 1b1d6
- Associated Commits: e64e7, c467a, 3a75c, 89add, 87360, 63fb9, cbf4e, bb279, e5525, 68a2a, ea5c1, d8cdc, 62486, b0f83, 9ce70, e2a8a, 041fd, 3b185, 256b4, 874a2, a3c05, 6e5ed, 16020, 2e64c, 25a1a, 6232c, cedc8, 4309a, 1b2c2, 64257, 4c12f, 951ba, e4db3, 8a7c2, 5037c, 98f76, b4696, 73f60, b75bd, 4045c, 6413b, 1b1d6
2. [ROCm] rocm CI job with a job that executes the tests: This pull request proposes adjusting the ROCm continuous integration job to execute Bazel tests under remote build execution (RBE) with ROCm plugins as a dependency, although it was not merged.
- URL: pull/35190
3. Faster jnp.trapezoid when dx is a scalar: This pull request improves the performance of the jnp.trapezoid function in the JAX library by implementing a faster computation path when the dx parameter is a scalar, resulting in speedups that align its efficiency with that of jnp.sum * dx and optimizing the handling of broadcasting for dx.
- URL: pull/34943
Other Closed Pull Requests
- Bug fix for boolean masking and sharding in JAX: This pull request fixes a bug where boolean masking resulting in a zero-sized dimension incorrectly drops the array's sharding and device placement by passing an empty mesh sharding to
jax.lax.full_like. The fix removes the abstract slice sharding with an empty mesh before callingfull_liketo ensure the array correctly inherits its original physical device placement.
- Fix for crash in
Box.get()withinjax.vmap: This pull request resolves a crash caused by callingBox.get()inside ajax.vmapcontext by modifying the function to directly access thecur_qddattribute on concreteBoxinstances. This change fixes theAttributeErrorrelated to the missingcur_qddattribute onBatchTracewithout altering existing behavior for other JAX transformations.
- Documentation improvements for numerical behavior and function arguments: These pull requests add warnings about batch invariance in
dot_generaland the non-associativity of floating-point arithmetic, clarifying hardware-level implications and guidance for deterministic accuracy. Additionally, documentation was added tolax.reduceandlax.reduce_windowclarifying thatinit_valuesandinit_valuemust be scalars, including examples and error messages for non-scalar inputs.
- ROCm support and backward compatibility testing: Multiple pull requests add ROCm-specific backward compatibility test data and methods for LU decomposition, enhance unit tests with ROCm device support and serialization formats, and add the
hip_threefry2x32_ffifunction to the stable custom call targets list to ensure ROCm random number generation compatibility. These changes ensure stability and compatibility of ROCm features across versions.
- Testing improvements for device detection and hijax module: One pull request improves testing efficiency by skipping multi-device tests when only a single device is detected, facilitating easier testing on single-GPU systems. Another introduces tests for the
vmapfunction, hijax types, and the hijax primitive within the hijax module, expanding test coverage.
- Shape polymorphism fix in pallas_call_batching: This pull request fixes the handling of shape polymorphism in the
pallas_call_batchingfunction by correctly managing symbolic batch dimensions, which were previously unsupported. This ensures proper batching behavior with symbolic shapes.
- Pre-commit hook and formatting enhancements: This pull request improves pre-commit hooks by adding end-of-line-fixer and trailing-whitespace checks on C++ source files and automatically formatting BUILD and BUILD.bazel files using buildifier, enhancing code style consistency.
- Introduction of shmap-of-hitypes and HipSpec in hijax: This pull request adds the functions
{,un}shard_avaland the HipSpec specification to the hijax module, introducing a basic shmap-of-hitypes to the codebase.
- Typing fixes in pyrefly component: These pull requests address additional typing fixes in the pyrefly component to improve code correctness and maintainability.
- Addition of
cur_qddattribute to BatchTrace in hijax: This pull request adds thecur_qddattribute to theBatchTraceclass in the hijax module, addressing a specific issue and enabling related functionality.
- Fix for LLVM integration inconsistencies in autodidax files: This pull request fixes inconsistencies caused by LLVM integration between
autodidax.py,autodidax.md, andautodidax.ipynbby correcting the jupytext lint at the head of these files.
- Automated refactor of GitHub Actions workflows: Multiple pull requests propose automated refactors of the project's GitHub Actions to comply with the latest standards outlined in the internal directive b/485167538. These changes aim to upgrade workflow configurations and may be force merged by the GHSS team if not accepted voluntarily.
3.3 Pull Request Discussion Insights
This section will analyze the tone and sentiment of discussions within this project's open and closed pull requests that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.
Based on our analysis, there are no instances of toxic discussions in the project's open or closed pull requests from the past week.
IV. Contributors
4.1 Contributors
Active Contributors:
We consider an active contributor in this project to be any contributor who has made at least 1 commit, opened at least 1 issue, created at least 1 pull request, or made more than 2 comments in the last month.
If there are more than 10 active contributors, the list is truncated to the top 10 based on contribution metrics for better clarity.
| Contributor | Commits | Pull Requests | Issues | Comments |
|---|---|---|---|---|
| jakevdp | 29 | 9 | 0 | 36 |
| benknutson-google | 44 | 0 | 0 | 0 |
| google-admin | 0 | 44 | 0 | 0 |
| alekstheod | 34 | 3 | 0 | 4 |
| Ashutosh0x | 17 | 14 | 0 | 7 |
| mattjj | 19 | 8 | 0 | 4 |
| Harshadev-24 | 7 | 1 | 0 | 20 |
| magaonka-amd | 24 | 0 | 0 | 0 |
| AratiGanesh | 18 | 3 | 0 | 0 |
| gulsumgudukbay | 18 | 2 | 0 | 0 |
Access Last Week's Newsletter: