Weekly GitHub Report for Jax: February 08, 2026 - February 15, 2026 (15:14:24)
Weekly GitHub Report for Jax
Thank you for subscribing to our weekly newsletter! Each week, we deliver a comprehensive summary of your GitHub project's latest activity right to your inbox, including an overview of your project's issues, pull requests, contributors, and commit activity.
Table of Contents
I. News
1.1 Recent Version Releases:
The current version of this repository is jax-v0.9.0.1
1.2 Version Information:
On February 3, 2026, JAX v0.9.0.1 was released as a patch update to v0.9.0, incorporating four specific pull requests from the OpenXLA repository to address targeted improvements without introducing major changes. This release highlights a focus on incremental fixes and refinements.
Click here to view the full release notes!
II. Issues
2.1 Top 5 Active Issues:
We consider active issues to be issues that that have been commented on most frequently within the last week. Bot comments are omitted.
-
[ENHANCEMENT] Faster
jnp.trapezoidwhendxis a scalar: This issue proposes optimizing thejnp.trapezoidfunction for the case whendxis a scalar by using a more efficient computation method that leverages the uniform grid property, resulting in significant performance improvements. The discussion includes benchmarking results demonstrating the speedup and a suggestion to generalize the optimization for N-dimensional cases, with the original poster preparing a pull request incorporating the feedback.- The comments include benchmarking evidence supporting the optimization, positive feedback encouraging a pull request, a question about extending the optimization to N-dimensional cases, and updates indicating the pull request is ready after addressing review comments.
- Number of comments this week: 5
-
[ENHANCEMENT] Clarify differences in methods for moving data to host in documentation: This issue requests clarification in the documentation regarding the differences between three methods for moving data from a GPU to the host in JAX:
jax.device_putwith a CPU device,jax.copy_to_host_async, andjax.device_get. The user seeks to understand the distinctions in blocking behavior and use cases for these methods, particularly why multiple approaches exist and how they differ in terms of synchronization and performance implications.- The first comment provides a detailed explanation distinguishing the three APIs by their intended use cases and blocking behavior, suggesting clearer documentation and a comparison table; the second comment asks for further clarification on the subtle differences between non-blocking asynchronous copies and lazy data movement, indicating ongoing discussion to refine understanding.
- Number of comments this week: 2
-
[QUESTION] Nested jit changes behavior of static arguments: This issue describes a problem where nested JIT compilation in JAX causes static arguments to be treated as traced values, leading to a
ConcretizationTypeErrorwhen a static argument is passed through an inner JIT-compiled function. The user is confused because they expected the static vs traced determination to happen only at the outermost JIT boundary, but the inner JIT compilation appears to change the behavior of static arguments.- The first comment explains that this behavior is expected because the output of JIT-compiled functions cannot be static, and suggests not JIT-compiling the inner function if static behavior is needed. The second comment is a newcomer expressing interest in contributing to the issue.
- Number of comments this week: 2
-
[BUG] jax.lax.scan crashes on GPU when jax_default_matmul_precision is set to any non-default value: This issue reports that setting the
jax_default_matmul_precisionconfiguration to any non-default value causes thejax.lax.scanfunction to crash on GPU with an internal error, despite the same operations working fine outside ofscanandscanworking fine without the precision configuration. The problem occurs specifically on a system using JAX version 0.9.0 with an NVIDIA L4 GPU and CUDA 13.0, resulting in an unsupported parameter error during execution.- A contributor has submitted a pull request that resolves the issue and is awaiting review or feedback before it can be merged and the issue closed.
- Number of comments this week: 1
-
[BUG] jax parallelism and tensorflow: This issue describes a runtime error encountered when attempting to use JAX's parallelism features alongside TensorFlow, specifically resulting in an NCCL operation failure during a replicated computation on multiple NVIDIA A100 GPUs. The user provides a minimal reproducible example and detailed system information, noting that a previously suggested fix from a related issue does not resolve the problem.
- A single comment requests the user to rerun the code with the environment variable
NCCL_DEBUG=WARNenabled to gather more detailed logs for diagnosing the NCCL error. - Number of comments this week: 1
- A single comment requests the user to rerun the code with the environment variable
2.2 Top 5 Stale Issues:
We consider stale issues to be issues that has had no activity within the last 30 days. The team should work together to get these issues resolved and closed as soon as possible.
As of our latest update, there are no stale issues for the project this week.
2.3 Open Issues
This section lists, groups, and then summarizes issues that were created within the last week in the repository.
Issues Opened This Week: 5
Summarized Issues:
- Performance optimization in numerical functions: This issue proposes optimizing the
jnp.trapezoidfunction to compute more efficiently when thedxparameter is a scalar by leveraging a simplified cumulative trapezoid rule. This optimization reduces computation time significantly on both CPU and GPU. - issues/34915
- GPU runtime errors and configuration issues: Several issues describe runtime errors occurring on GPU related to configuration and backend compatibility. One issue involves a crash in
jax.lax.scanwhenjax_default_matmul_precisionis set to a non-default value, causing an internal error, while another issue reports an NCCL operation failure during replicated computation on multiple NVIDIA A100 GPUs using TensorFlow as the backend. - issues/34917, issues/34918
- Documentation clarity for data transfer methods: There is a request for clearer documentation explaining the differences, use cases, and blocking behavior of three methods for moving data from GPU to host in JAX:
jax.device_putwith a CPU device,jax.copy_to_host_async, andjax.device_get. This aims to help users understand when and why to use each method effectively. - issues/35054
- Static argument handling in nested JIT compilation: An issue describes a problem where nesting JIT-compiled functions causes static arguments to be treated as traced values, resulting in a ConcretizationTypeError when a static argument is expected inside a nested JIT context. This highlights challenges in managing static arguments within nested JITs.
- issues/35090
2.4 Closed Issues
This section lists, groups, and then summarizes issues that were closed within the last week in the repository. This section also links the associated pull requests if applicable.
Issues Closed This Week: 7
Summarized Issues:
- Error handling and validation issues: Several issues describe crashes and errors caused by improper validation or handling of error codes and indexing. These include an IndexError in
raise_if_error()due to out-of-bounds error codes during serialization in AOT compilation, and unintuitive shard_map errors caused by out-of-bounds indexing behaving inconsistently across devices. - issues/34370, issues/35013
- Duplicate request checks and user guidance: Multiple issues request implementing checks for duplicate requests and emphasize the need for users to provide clear goals and motivating examples or code snippets. These requests aim to improve issue triaging and clarity for the ModelOdi feature and other user inquiries.
- issues/34928, issues/34961
- Compilation and backend crashes: There are reports of compilation failures and backend crashes related to specific operations and hardware. These include a Mosaic compiler failure to legalize a
func.funcoperation when loading a JAX function calling a Pallas TPU kernel, and a crash in the MPS backend caused by unsupported gather patterns generated by Flax's scan transform during parameter broadcasting. - issues/34936, issues/35063
- Serialization compatibility issues: One issue highlights a serialization problem where pickling a JAX
PRNGKeyArrayfails with aTypeErroron Python 3.11 but works on Python 3.13, indicating a version-specific compatibility problem with the standard pickle module. - issues/35065
2.5 Issue Discussion Insights
This section will analyze the tone and sentiment of discussions within this project's open and closed issues that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.
Based on our analysis, there are no instances of toxic discussions in the project's open or closed issues from the past week.
III. Pull Requests
3.1 Open Pull Requests
This section provides a summary of pull requests that were opened in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.
Pull Requests Opened This Week: 19
Key Open Pull Requests
1. [ROCm] fix the performance issue when n=1 or 2: This pull request addresses a performance issue in the slogdet function on ROCm (and CUDA) platforms when the input matrix size n is 1 or 2 by ensuring that jnp.linalg.slogdet(x) uses an analytic path exclusively, with gradients computed via a custom JVP to avoid invoking slower backend algorithms.
- URL: pull/34971
2. Faster jnp.trapezoid when dx is a scalar: This pull request improves the performance of the jnp.trapezoid function in the JAX library by implementing a faster computation path when the dx parameter is a scalar, resulting in speedups comparable to jnp.sum * dx and optimizing the handling of broadcasting and dimensionality for dx.
- URL: pull/34943
3. fixed scan crash on GPU with non-default matmul precision: This pull request fixes a runtime crash in JAX's GPU backend caused by non-default matrix multiplication precision settings inside jax.lax.scan by normalizing precision attributes during lowering specifically for GPU, ensuring compatibility without altering CPU or TPU behavior.
- URL: pull/35091
Other Open Pull Requests
- Bug fixes in tensor operations and lowering: Several pull requests address critical bugs including a crash in the cuDNN attention vmap batcher caused by incorrect reshaping of broadcasted bias or mask tensors, and an infinite recursion issue during lowering when converting boolean arrays to int32 in the pallas experimental module. These fixes improve stability and correctness in tensor handling and compilation processes.
- [pull/34920, pull/35083]
- Documentation improvements: Multiple pull requests enhance documentation by adding detailed guides and clarifications. This includes a new guide for Pallas SparseCore, expanded documentation on host CPU allocation impact on multi-GPU collective communication with Slurm examples, and clarifications on scalar requirements for
lax.reduceandlax.reduce_windowinitial values with examples and error messages. - [pull/34962, pull/34962, pull/35015, pull/35092]
- New features and API additions: New functionalities are introduced such as five special matrix construction functions added to
jax.scipy.linalgwith JIT compatibility and thorough testing, support for numeric PartitionSpec in shard_map with positional indices and wildcard support, and a basic shmap-of-hitypes implementation in the hijax module. These additions expand the library's capabilities and usability. - [pull/35030, pull/35075, pull/35072]
- Code refactoring and testing enhancements: Refactoring efforts include moving core indexing functionality to a shared file to enable reuse and improvements in pre-commit hooks that add formatting and whitespace checks for C++ and Bazel files. Additionally, unit tests are extended to support ROCm devices by verifying memory space serialization with ROCm-specific data.
- [pull/34929, pull/35087, pull/35072]
- Performance and optimization proposals: A new dead code elimination (DCE) rule for while_loop constructs is proposed to optimize code execution. This aims to improve runtime efficiency by removing unnecessary code in loops.
- [pull/34945]
- Bug fixes in shape polymorphism and initializers: Fixes include correct handling of symbolic batch dimensions in
pallas_call_batchingfor shape polymorphism and a fix for a division by zero error in thevariance_scalinginitializer to support zero-dimension input shapes. These changes enhance robustness in shape handling and initialization. - [pull/34988, pull/35096]
- Typing fixes: Additional typing corrections are applied to the pyrefly component to improve type safety and code quality.
- [pull/35086]
3.2 Closed Pull Requests
This section provides a summary of pull requests that were closed in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.
Pull Requests Closed This Week: 72
Key Closed Pull Requests
1. [ROCm] Added support for GESVDJ on ROCm devices: This pull request adds support for the GESVDJ FFI call on ROCm devices by defining the hipSolver version and helper functions, refactoring existing handlers to be vendor-independent outside CUDA guards, and fixing type usage, thereby enabling previously failing unit tests for SVD on ROCm to pass.
- URL: pull/34474
2. [ROCm] Enable cuda array interface test on ROCm: This pull request enables the cuda_array_interface test to run on ROCm devices by adding ROCm device support in the relevant buffer interface, updating the test decorator to include both CUDA and ROCm GPUs, and correcting the skip condition to properly check the jaxlib extension version.
- URL: pull/34501
3. [ROCm] Enabled RNN unit test "test_no_workspace_overflow" for ROCm.: This pull request enables the previously disabled RNN unit test test_no_workspace_overflow to run on ROCm devices by modifying its test decorator, confirming that the test passes successfully on these devices.
- URL: pull/34577
Other Closed Pull Requests
- ROCm Platform Support and Test Enablement: Multiple pull requests focus on enabling tests and operations on the ROCm platform that were previously skipped or limited to CUDA devices. These changes include modifying skip decorators, adding ROCm support checks, and fixing platform-specific issues to ensure tests such as SVD, ToeplitzSymmetricConstruction, reduce_window, and LOBPCG eigenvalue solver run successfully on ROCm GPUs.
- [pull/34470, pull/34494, pull/34500, pull/34534, pull/34560, pull/34561, pull/34567, pull/34571, pull/34574, pull/34575, pull/34600, pull/34603, pull/34662, pull/34674, pull/34689, pull/34774]
- ROCm GPU Architecture and Backend Handling: Some pull requests address ROCm GPU architecture detection and backend routing issues by redirecting ROCm devices to the Triton backend instead of the NVIDIA-specific Mosaic GPU backend. These fixes include adding safety checks for ROCm architecture strings and updating test logic to prevent errors related to ROCm GPU identifiers.
- [pull/34674, pull/31768]
- ROCm Support in Collective Operations and Lowering: Several pull requests add or improve ROCm support in collective operations and lowering mechanisms, such as enabling
psendandprecvoperations on ROCm GPUs and adding support for lowering Pallas Triton calls to HSACO through the PJRT_Triton_Extension for ROCm. - [pull/34304, pull/31768]
- Test Fixes and Improvements for ROCm Compatibility: Various pull requests fix test errors and improve compatibility on ROCm by addressing issues like KeyErrors in memory statistics, IndexErrors in error handling, and undefined behavior in bitshift unit tests. These fixes ensure safer error handling and consistent test results across platforms.
- [pull/34788, pull/33157, pull/34500]
- Dynamic Platform Detection and Test Adaptation: One pull request enables the deviceless ahead-of-time (AOT) compile test to run on ROCm GPUs by dynamically detecting the GPU platform instead of hardcoding "cuda," allowing the test to support both CUDA and ROCm environments.
- [pull/34893]
- Hijax Library Enhancements for Compatibility: A pull request improves the hijax library by adding support for vectorized mapping (vmap) and scanning (scan) operations on hi types, making hijax vmappable again and enhancing its compatibility with JAX transformations.
- [pull/34977]
- Skip Reason Message Updates: A pull request updates skip reason messages for CPU-only and TPU-only tests to more accurately reflect device specificity, replacing generic messages to improve clarity and categorization of skipped tests.
- [pull/34675]
3.3 Pull Request Discussion Insights
This section will analyze the tone and sentiment of discussions within this project's open and closed pull requests that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.
Based on our analysis, there are no instances of toxic discussions in the project's open or closed pull requests from the past week.
IV. Contributors
4.1 Contributors
Active Contributors:
We consider an active contributor in this project to be any contributor who has made at least 1 commit, opened at least 1 issue, created at least 1 pull request, or made more than 2 comments in the last month.
If there are more than 10 active contributors, the list is truncated to the top 10 based on contribution metrics for better clarity.
| Contributor | Commits | Pull Requests | Issues | Comments |
|---|---|---|---|---|
| jakevdp | 30 | 14 | 0 | 40 |
| tsrw2048 | 14 | 11 | 0 | 5 |
| alekstheod | 29 | 0 | 0 | 0 |
| superbobry | 3 | 3 | 0 | 22 |
| hrideymarwah15 | 25 | 3 | 0 | 0 |
| mattjj | 15 | 8 | 0 | 5 |
| magaonka-amd | 14 | 11 | 0 | 0 |
| AratiGanesh | 13 | 9 | 0 | 0 |
| IvyZX | 2 | 1 | 0 | 17 |
| samanklesaria | 20 | 0 | 0 | 0 |
Access Last Week's Newsletter: