Weekly GitHub Report for Jax: February 16, 2026 - February 23, 2026 (17:33:22)
Weekly GitHub Report for Jax
Thank you for subscribing to our weekly newsletter! Each week, we deliver a comprehensive summary of your GitHub project's latest activity right to your inbox, including an overview of your project's issues, pull requests, contributors, and commit activity.
Table of Contents
I. News
1.1 Recent Version Releases:
The current version of this repository is jax-v0.9.0.1
1.2 Version Information:
Released on February 3, 2026, JAX v0.9.0.1 is a patch update identical to v0.9.0 but includes fixes and improvements from four specific pull requests, enhancing stability and functionality without introducing major new features.
Click here to view the full release notes!
II. Issues
2.1 Top 5 Active Issues:
We consider active issues to be issues that that have been commented on most frequently within the last week. Bot comments are omitted.
-
[ENHANCEMENT] Add
lax.associative_reduce(parallel tree reduction without downsweep): This issue proposes adding a newlax.associative_reduceprimitive to perform parallel tree reduction without the downsweep phase, addressing inefficiencies in existing JAX primitives when reducing non-scalar elements like matrices. The goal is to reduce computational overhead and memory usage in operations such as multiplying chains of matrices, which is important for modern ML workloads, by implementing a more memory- and compute-efficient parallel reduction.- The comments express general support for the proposal but suggest exploring generalizing the existing
lax.reduceto handle non-scalar reductions by leveraging StableHLO capabilities; attempts to bypass current scalar-only restrictions lead to XLA errors, and there is discussion about the complexity of supporting multiple reduction dimensions for associative but non-commutative operations. - Number of comments this week: 6
- The comments express general support for the proposal but suggest exploring generalizing the existing
-
[BUG] schur (CPU) segfaults when numpy and scipy are linked to MKL: This issue reports a segmentation fault occurring when using the
schurfunction fromjax.scipy.linalgon the CPU with numpy and scipy both linked to the Intel MKL library, which does not happen when linked to OpenBLAS. The user has identified specific jax and jaxlib versions where the problem was introduced and provided detailed reproduction steps, including building numpy and scipy with MKL and running the code in a controlled environment.- The comments confirm the segfault also occurs with other linear algebra functions like
sqrtmand on different platforms including Windows, with additional test failures reported; the issue is reproducible in conda environments despite lack of official support for conda jax wheels. - Number of comments this week: 2
- The comments confirm the segfault also occurs with other linear algebra functions like
-
[BUG]
jax.numpy.linalg.multi_dotdoesn't work as expected undervmap: This issue reports that thejax.numpy.linalg.multi_dotfunction does not behave as expected when used undervmap, specifically that it always performs the matrix multiplication in the order B@C first rather than A@B first as intended. The user suggests that adding a custom batching rule tomulti_dotcould fix this behavior and seeks feedback on whether this approach is desirable before submitting a pull request.- The comments agree that a custom batching rule is likely the best interim solution, with a longer-term fix involving making
jnp.multi_dota hijax primitive; however, since hijax is still experimental and may take time to fully support, there is openness to an interim fix if the hijax timeline exceeds three months. - Number of comments this week: 2
- The comments agree that a custom batching rule is likely the best interim solution, with a longer-term fix involving making
-
[ENHANCEMENT] [TYPE:FEATURE] Proposal: Add example for implementing Principal Component Analysis (PCA) from scratch in JAX: This issue proposes adding an example for implementing Principal Component Analysis (PCA) from scratch using JAX to demonstrate its capabilities in unsupervised learning, particularly focusing on SVD decomposition, vectorization, and JIT compilation. The goal is to provide a native JAX example that helps users, especially beginners, learn how to build custom machine learning primitives without relying on external libraries like scikit-learn.
- The commenter expressed interest in self-assigning the issue and starting work on a draft pull request, aiming to submit it soon and inviting feedback or collaboration from others.
- Number of comments this week: 1
Since there were fewer than 5 open issues, all of the open issues have been listed above.
2.2 Top 5 Stale Issues:
We consider stale issues to be issues that has had no activity within the last 30 days. The team should work together to get these issues resolved and closed as soon as possible.
As of our latest update, there are no stale issues for the project this week.
2.3 Open Issues
This section lists, groups, and then summarizes issues that were created within the last week in the repository.
Issues Opened This Week: 4
Summarized Issues:
- Reduction and matrix operations inefficiencies: Several issues highlight inefficiencies and unexpected behaviors in JAX's matrix operations, such as the need for a new
lax.associative_reduceprimitive to optimize parallel tree reduction without downsweep, and the suboptimal multiplication order injax.numpy.linalg.multi_dotundervmap. These problems affect computational overhead and memory usage, particularly in handling non-scalar elements like matrix chains, and suggest improvements through custom batching rules or new primitives. - [issues/35118, issues/35308]
- Segmentation fault with MKL linkage: A segmentation fault occurs in the
schurfunction fromjax.scipy.linalgwhen both numpy and scipy are linked against Intel's MKL, reproducible across different environments and JAX versions starting from 0.4.38. This issue points to a critical stability problem related to MKL integration that impacts users relying on these libraries for linear algebra computations. - [issues/35134]
- PCA example implementation request: There is a proposal to add a comprehensive example demonstrating PCA implementation from scratch using JAX, covering SVD decomposition, JIT compilation, vectorization, dataset loading, visualization, benchmarking, testing, and documentation. This example aims to assist users in learning unsupervised dimensionality reduction techniques within the JAX framework.
- [issues/35310]
2.4 Closed Issues
This section lists, groups, and then summarizes issues that were closed within the last week in the repository. This section also links the associated pull requests if applicable.
Issues Closed This Week: 5
Summarized Issues:
- vmap and custom types errors: An AttributeError occurs when using
jax.vmapwith a custom Box type because the 'BatchTrace' object lacks the 'cur_qdd' attribute, indicating improper implementation ofBox.getwithinvmap. This issue highlights the need for guidance on correctly handling custom types in vectorized mappings to avoid such attribute errors. - issues/34758
- Memory and system configuration issues: LLVM reports out-of-memory errors during JAX compilation despite ample free system memory, which was resolved by increasing the Linux kernel parameter
/proc/sys/vm/max_map_count. This suggests that system-level limits rather than actual memory availability can cause compilation failures in JAX. - issues/35121
- Sharding and autodiff failures on singleton axes: Using explicit sharding combined with reverse-mode automatic differentiation fails on singleton axes, causing a ValueError due to mismatched cotangent types. This bug affects workflows like RMSNorm with batch size 1 and is present in JAX versions 0.8.2 and 0.9.0 but not in 0.8.1, indicating a regression.
- issues/35181
- Silent output corruption in vmap with gather and jnp.where: A silent output corruption bug occurs in
vmapwhen combininggatherandjnp.whereover batched inputs of different sizes without returning the gathered intermediate. The root cause is an XLA buffer aliasing problem, and the issue can be fixed by returning the gathered intermediate to keep its buffer alive. - issues/35252
- Device placement errors with boolean masking on CPU: Boolean masking on a CPU device unexpectedly produces a zero-sized array placed on a CUDA device instead of the selected CPU device. This is caused by incorrect device propagation logic in the
lax.full_likefunction when handling empty arrays, leading to device placement inconsistencies. - issues/35273
2.5 Issue Discussion Insights
This section will analyze the tone and sentiment of discussions within this project's open and closed issues that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.
Based on our analysis, there are no instances of toxic discussions in the project's open or closed issues from the past week.
III. Pull Requests
3.1 Open Pull Requests
This section provides a summary of pull requests that were opened in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.
Pull Requests Opened This Week: 19
Key Open Pull Requests
1. Fix jax utests with rocm: This pull request fixes the JAX unit tests to work correctly with ROCm by incorporating changes from a previous PR, adding scripts for running Bazel tests on ROCm, migrating ROCm-specific repository changes back to the upstream JAX project, injecting ROCm plugin dependencies, and refining test execution and build configurations.
- URL: pull/35315
2. scipy.linalg: add hankel special matrix: This pull request adds the jax.scipy.linalg.hankel special matrix function along with corresponding tests, includes various fixes and improvements based on code reviews, and updates the documentation to incorporate the new function.
- URL: pull/35223
3. [ROCm] Implement Mosaic GPU detection and Auto-Skips: This pull request implements detection of Mosaic GPU tests and automatically skips them when running on ROCm by adding a pytest marker and classification logic based on test file paths, source usage signals, and default Pallas behavior, thereby improving the reliability of ROCm test runs without affecting non-ROCm environments.
- URL: pull/35288
Other Open Pull Requests
- ROCm Support and Integration: Multiple pull requests focus on enhancing ROCm support within JAX, including migrating changes from the rocm-jax repository to enable building ROCm plugins and PJRT wheels directly in the main repo, adding ROCm-specific test data for eigh export backwards compatibility tests, and improving the ROCm test workflow by generating structured pytest outputs and uploading logs to S3. These changes collectively improve ROCm platform compatibility and testing infrastructure.
- Triton Autotuning and Kernel Improvements: One pull request fixes crashes in the Triton autotuning process by correctly handling input-output buffer aliases and adds tests to ensure autotuning completes properly when aliased inputs remain live. Another introduces a multi-GPU splash attention implementation for the Mosaic GPU backend, demonstrating improved kernel efficiency over CUDNN with causal masks.
- Device Management and Multicast Support: Pull requests introduce a new flag to control device assignment order based on network topology and process locality, improving compatibility with updated XLA strategies, and add a pre-launch check for multicast support on Mosaic GPUs to improve error messaging and ensure test suite compatibility on systems without multicast capabilities.
- Bug Fixes and Trace Improvements: Several pull requests address bugs and improve tracing behavior, including fixing a gradient computation bug in the ref_swap function, ensuring the LinearizeTrace.tag attribute matches tangent_trace.tag to prevent undefined tags, and fixing a bug in the Pallas module to avoid collisions between large constant dimension values and dynamic shape placeholders.
- Hijax Extension Enhancements: Pull requests add vmap support for the hijax Box type to enable broadcasting as part of vmap on other argument axes and add a failing test demonstrating the lack of an MLIR translation rule for the call_hi_primitive with core.call_p, highlighting an unimplemented lowering path on CPU.
- Miscellaneous Improvements: Other pull requests include adding support for scalar arguments in as_torch_kernel using static_argnums, making small documentation improvements to the fault tolerant mcjax blog, and bundling additional .pyi files with jaxlib to fix errors in the Pyrefly component of Pallas.
3.2 Closed Pull Requests
This section provides a summary of pull requests that were closed in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.
Pull Requests Closed This Week: 83
Key Closed Pull Requests
1. Improve ROCm pytest results handling: This pull request aims to enhance the ROCm testing workflow by generating structured pytest result outputs, capturing run-manifest information, packaging test logs into an archive, and uploading these logs to a shared S3 storage for improved test result handling and accessibility.
- URL: pull/35282
- Associated Commits: e64e7, c467a, 3a75c, 89add, 87360, 63fb9, cbf4e, bb279, e5525, 68a2a, ea5c1, d8cdc, 62486, b0f83, 9ce70, e2a8a, 041fd, 3b185, 256b4, 874a2, a3c05, 6e5ed, 16020, 2e64c, 25a1a, 6232c, cedc8, 4309a, 1b2c2, 64257, 4c12f, 951ba, e4db3, 8a7c2, 5037c, 98f76, b4696, 73f60, b75bd, 4045c, 6413b, 1b1d6
- Associated Commits: e64e7, c467a, 3a75c, 89add, 87360, 63fb9, cbf4e, bb279, e5525, 68a2a, ea5c1, d8cdc, 62486, b0f83, 9ce70, e2a8a, 041fd, 3b185, 256b4, 874a2, a3c05, 6e5ed, 16020, 2e64c, 25a1a, 6232c, cedc8, 4309a, 1b2c2, 64257, 4c12f, 951ba, e4db3, 8a7c2, 5037c, 98f76, b4696, 73f60, b75bd, 4045c, 6413b, 1b1d6
2. [ROCm] rocm CI job with a job that executes the tests: This pull request proposes adjusting the ROCm continuous integration job to execute Bazel tests under remote build execution (RBE) with ROCm plugins as a dependency, although it was not merged.
- URL: pull/35190
3. Faster jnp.trapezoid when dx is a scalar: This pull request improves the performance of the jnp.trapezoid function in the JAX library by implementing a faster computation path when the dx parameter is a scalar, resulting in speedups that align its efficiency with that of jnp.sum * dx.
- URL: pull/34943
Other Closed Pull Requests
- Bug fixes in array sharding and batching contexts: These pull requests fix critical bugs related to array sharding and batching in JAX. One addresses a bug where boolean masking with zero-sized dimensions dropped array sharding, while another fixes a crash caused by calling
Box.get()within ajax.vmapcontext by modifying attribute access to avoid anAttributeError. - [pull/35293, pull/35099]
- Documentation improvements: These pull requests enhance JAX documentation by adding warnings about batch invariance and floating-point non-associativity in
dot_general, and by clarifying scalar requirements forlax.reduceandlax.reduce_windowinitial values with examples. - [pull/35287, pull/35095]
- Testing and shape polymorphism fixes: These pull requests improve testing efficiency by skipping multi-device tests on single-device systems and fix shape polymorphism handling in the
pallas_call_batchingfunction to support symbolic batch dimensions. - [pull/34884, pull/34988]
- Pre-commit and formatting improvements: This pull request introduces enhancements to pre-commit hooks, including running end-of-line and trailing-whitespace checks on C++ files and automatic formatting of BUILD files using buildifier.
- [pull/35072]
- Hijax module enhancements: These pull requests add new functionality and tests to the hijax module, including adding shmap-of-hitypes functions, tests for vmap and hijax primitives, and adding the
cur_qddattribute to the BatchTrace component. - [pull/35083, pull/35136, pull/35137]
- Typing fixes in pyrefly component: These pull requests address additional typing fixes in the pyrefly component to improve code correctness and maintainability.
- [pull/35087, pull/35138]
- ROC random number generation compatibility: This pull request adds the
hip_threefry2x32_ffifunction to the stable custom call targets list to ensure ROCm random number generation compatibility and enable passing of the ROCm threefry2x32 backward compatibility test. - [pull/35115]
- Restoration of scan length check: This pull request restores a scan length check that was lost during a previous refactor, ensuring correctness in scan operations.
- [pull/35131]
- Fixes for LLVM integration inconsistencies: This pull request corrects inconsistencies introduced by a recent LLVM integration across multiple autodidax files by fixing the jupytext lint at the file head.
- [pull/35139]
- Automated GitHub Actions refactor: These pull requests propose an automated refactor of the project's GitHub Actions workflows to align with the latest internal standards (b/485167538), facilitating an upgrade process that may be force merged by the GHSS team if not accepted voluntarily.
- [pull/35141, pull/35142, pull/35143, pull/35144, pull/35145, pull/35146, pull/35147, pull/35148, pull/35149, pull/35150]
3.3 Pull Request Discussion Insights
This section will analyze the tone and sentiment of discussions within this project's open and closed pull requests that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.
Based on our analysis, there are no instances of toxic discussions in the project's open or closed pull requests from the past week.
IV. Contributors
4.1 Contributors
Active Contributors:
We consider an active contributor in this project to be any contributor who has made at least 1 commit, opened at least 1 issue, created at least 1 pull request, or made more than 2 comments in the last month.
If there are more than 10 active contributors, the list is truncated to the top 10 based on contribution metrics for better clarity.
| Contributor | Commits | Pull Requests | Issues | Comments |
|---|---|---|---|---|
| jakevdp | 28 | 9 | 0 | 37 |
| alekstheod | 42 | 4 | 0 | 2 |
| benknutson-google | 44 | 0 | 0 | 0 |
| google-admin | 0 | 44 | 0 | 0 |
| Ashutosh0x | 17 | 14 | 0 | 5 |
| mattjj | 19 | 8 | 0 | 4 |
| Harshadev-24 | 7 | 1 | 0 | 20 |
| magaonka-amd | 24 | 0 | 0 | 0 |
| AratiGanesh | 18 | 3 | 0 | 0 |
| gulsumgudukbay | 18 | 2 | 0 | 0 |