Weekly GitHub Report for Jax: March 16, 2026 - March 23, 2026 (19:47:03)
Weekly GitHub Report for Jax
Thank you for subscribing to our weekly newsletter! Each week, we deliver a comprehensive summary of your GitHub project's latest activity right to your inbox, including an overview of your project's issues, pull requests, contributors, and commit activity.
Table of Contents
I. News
1.1 Recent Version Releases:
The current version of this repository is jax-v0.9.0.1
1.2 Version Information:
On February 3, 2026, JAX v0.9.0.1 was released as a patch update to v0.9.0, incorporating four specific pull requests from the OpenXLA repository to address targeted improvements or fixes without introducing major changes.
Click here to view the full release notes!
II. Issues
2.1 Top 5 Active Issues:
We consider active issues to be issues that that have been commented on most frequently within the last week. Bot comments are omitted.
-
[BUG]
cho_solvedoes not replicate scipy batch broadcasting for batched c + unbatched vector b: This issue reports that thejax.scipy.linalg.cho_solvefunction does not replicate the batch broadcasting behavior ofscipy.linalg.cho_solvewhen solving with a batched Cholesky factor and an unbatched right-hand-side vector, leading to a shape mismatch error. The user provides a minimal reproducible example, explains the root cause related to dimension checks in the underlying triangular solve, and suggests a potential fix involving explicit broadcasting of inputs to a common batch shape before solving.- The comments acknowledge the issue as a limitation in JAX compared to newer SciPy versions, express willingness to contribute a fix, and discuss the implications of changing broadcasting behavior due to differences in related NumPy functions, highlighting the need for careful implementation and possible deprecation strategies.
- Number of comments this week: 3
-
[BUG] Excessive memory usage for
jacfwd(vectorize(grad(f))): This issue describes a problem with excessive memory usage when computing the Jacobian of a vectorized gradient function using JAX, where the expected memory footprint is vastly underestimated, leading to out-of-memory errors. The user suspects that the combination ofgradandvectorizeis not efficiently handling the scalar derivatives, causing intermediate arrays to balloon in size, and seeks advice on how to reduce memory consumption for this operation.- The comments analyze the JAX intermediate representation (jaxpr) to identify unexpectedly large intermediate arrays causing the memory blowup and suggest that the memory issue arises from propagating the Jacobian through a dot product. They recommend alternative approaches such as rewriting the function using
scanor explicitly leveraging the diagonal structure of the vectorized function via coloring methods or vectorizing last, which can reduce memory usage by avoiding pushing large tangent vectors through the gradient computations. - Number of comments this week: 2
- The comments analyze the JAX intermediate representation (jaxpr) to identify unexpectedly large intermediate arrays causing the memory blowup and suggest that the memory issue arises from propagating the Jacobian through a dot product. They recommend alternative approaches such as rewriting the function using
-
[ENHANCEMENT] Tracing and
pure_callbackwith black-box Python objects: This issue discusses the challenge of enabling JAX to trace and work with black-box Python objects that do not have a fixed-size pytree representation but expose a pure functional interface, such as variable-sized or recursive data structures. It proposes usingpure_callbackto handle these objects within jitted code, allowing encapsulated control flow without restructuring the program, and explores the feasibility of implementing this via reference counting and a new dtype to manage Python objects in device memory.- The comments highlight the difficulty of representing opaque Python objects in JAX's current array-based model without rewriting XLA's object system, suggesting that trace-time solutions like HiJax might help but still require conversion to fixed-size arrays. Another comment proposes using memory addresses as integer representations of Python objects with host callbacks to manage reference counting, indicating a potential path forward.
- Number of comments this week: 2
-
[BUG] Compilation time increases from seconds to 9min between 0.9.0.1 and 0.9.1: This issue reports a significant regression in compilation time for JIT code between versions 0.9.0.1 and 0.9.1, where compilation time increased from approximately 3 seconds to 9 minutes. The user suspects the problem may be related to a previous issue involving memory spikes and provides detailed system information along with example code to reproduce the behavior without requiring download of attached files.
- The commenters discuss concerns about downloading ZIP files and request a more direct reproduction method; the original poster then shares a minimal reproducible example using their Python package to demonstrate the issue clearly.
- Number of comments this week: 2
-
[BUG]
jax.nn.initializers.orthogonalcrashes on zero-sized dimensions: This issue reports aZeroDivisionErroroccurring in thejax.nn.initializers.orthogonalfunction when it is called with shapes that include zero-sized dimensions, which happens during the initialization of recurrent layers likeLSTMCellin Flax. The problem arises because the initializer performs a division by a dimension size without checking if that dimension is zero, and the user suggests adding a guard to raise a clear error when zero-sized dimensions are present to prevent this crash.- The comments note that the underlying random orthogonal function supports zero-sized dimensions and argue that the initializer should as well; a related pull request has been opened to fix the issue.
- Number of comments this week: 2
2.2 Top 5 Stale Issues:
We consider stale issues to be issues that has had no activity within the last 30 days. The team should work together to get these issues resolved and closed as soon as possible.
As of our latest update, there are no stale issues for the project this week.
2.3 Open Issues
This section lists, groups, and then summarizes issues that were created within the last week in the repository.
Issues Opened This Week: 11
Summarized Issues:
- Memory usage and out-of-memory errors: Excessive memory usage and out-of-memory errors occur when computing the Jacobian of a vectorized gradient function due to large intermediate arrays generated by the composition of
gradandvectorize. Additionally, GPU allocation OOM errors are not properly propagated as Python exceptions, causing misleading success status codes despite error messages. - issues/35936, issues/35994
- JIT compilation and performance regressions: A significant regression in JIT compilation time was observed between versions 0.9.0.1 and 0.9.1, with compilation time increasing from about 3 seconds to 9 minutes, potentially linked to memory spikes. This regression is supported by attached XLA dump files for further investigation.
- issues/35958
- Handling of references and memory in shard_map: Using
jax.new_refwithjax.shard_mapon an 8-TPU mesh causes errors related to memory space specification and limitations in handling references without thejitdecorator. This indicates challenges in managing memory and references within eager shard_map executions. - issues/36000
- Gradient and sharding propagation errors: The backward pass of
jax.nn.dot_product_attentionwithimplementation="cudnn"inside ajax.shard_mapfails to propagate manual sharding axis annotations on gradient outputs, resulting in a ValueError due to a type mismatch at the custom_vjp boundary. This highlights issues in gradient computation and sharding metadata handling. - issues/36008
- Under-specification of reduction operations: The axis traversal order in multi-dimensional
lax.reduceoperations is under-specified when using non-commutative reduction functions, causing inconsistent results across CPU, GPU, and TPU hardware platforms. This inconsistency affects reproducibility and correctness of reductions. - issues/36011
- Broadcasting and tree mapping enhancements: A proposal suggests modifying
jax.tree.mapto allow broadcasting of additional input trees into the main tree when possible, enhancing versatility while preserving current behavior for valid inputs. This change aims to improve usability in tree-structured data transformations. - issues/36037
- Batch broadcasting inconsistencies in linear algebra:
jax.scipy.linalg.cho_solvedoes not replicate the batch broadcasting behavior ofscipy.linalg.cho_solvewhen solving with a batched Cholesky factor and an unbatched right-hand-side vector, leading to shape mismatch errors. A fix involving explicit broadcasting is proposed to align batch dimensions correctly. - issues/36083
- Crashes due to fusion pass failures: JAX crashes when the NestGemmFusion pass fails because of a mismatch in symbolic map dimensions while fusing 3D Dense GEMMs with 4D einsum operations across reshapes. This issue is triggered by combining Dense layers and einsum contractions in GPU-accelerated models.
- issues/36095
- Initializer failures with zero-sized dimensions: The
jax.nn.initializers.orthogonalfunction crashes with aZeroDivisionErrorwhen called with shapes containing zero-sized dimensions, as it attempts to divide by a dimension size that can be zero. This failure affects scenarios like initializing recurrent layers with dynamic shapes. - issues/35993
- Tracing and pure_callback support for opaque Python objects: A proposal aims to enable tracing and use of
pure_callbackwith black-box Python objects that have variable-sized or recursive data structures by treating them as opaque entities in jitted code. This would manage their lifetimes and reference counting through a new dtype or mechanism holding memory addresses and deferring operations to callbacks. - issues/35950
2.4 Closed Issues
This section lists, groups, and then summarizes issues that were closed within the last week in the repository. This section also links the associated pull requests if applicable.
Issues Closed This Week: 6
Summarized Issues:
- Documentation and User Guidance Issues: Several issues highlight missing or unclear documentation and unexpected function behaviors that confuse users. One issue points out the lack of training data and limitations in the README, while another notes unexpected behavior in
jnp.arangedue to hardcoded compile-time constants instead of expected XLA lowering, both of which hinder user understanding and proper usage. - issues/35930, issues/35953
- Function Behavior and Output Inconsistencies: There are problems with function implementations producing incorrect or inconsistent results. For example,
jax.scipy.fft.dctignores the imaginary part of complex inputs, leading to outputs that differ from SciPy’s behavior, andjax.tree.mapdoes not broadcast arguments as documented, causing type errors during execution. - issues/35973, issues/35996
- Runtime Errors and Crashes in Specific Environments: Some operations cause crashes or errors depending on the hardware or operation sequence. A notable case is a
SIGABRTcrash on GPU and TPU when usinglax.reducewithlax.cond, while the same code runs fine on CPU, and another involves an LLVM error in Shardy triggered by resharded results with reduced axes feeding into ashard_mapduring backward passes. - issues/35934, issues/36009
2.5 Issue Discussion Insights
This section will analyze the tone and sentiment of discussions within this project's open and closed issues that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.
Based on our analysis, there are no instances of toxic discussions in the project's open or closed issues from the past week.
III. Pull Requests
3.1 Open Pull Requests
This section provides a summary of pull requests that were opened in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.
Pull Requests Opened This Week: 9
Key Open Pull Requests
1. Check for escaped tracers during jit.lower: This pull request addresses issue #35799 by implementing a check for escaped tracers during the jit.lower process in the JAX project.
- URL: pull/35981
- Associated Commits: 5d228
2. [ROCm] Add lowering for ScaledMatmul, ScaledDot: This pull request adds ROCm support for lowering the ScaledMatmul and ScaledDot operations by implementing a dedicated translation that delegates to lax.scaled_dot, fixing failing tests caused by the missing MLIR translation rule on ROCm while preserving existing CUDA behavior.
- URL: pull/35995
- Associated Commits: 24685
3. [hijax] add a hijax primitive for jnp.nonzero: This pull request introduces a new hijax primitive for the jnp.nonzero function that simplifies batching and automatic differentiation rules, aiming to provide a well-defined, composable implementation with clearer semantics compared to the existing nonzero implementation in JAX.
- URL: pull/36053
- Associated Commits: 1f624
Other Open Pull Requests
- ROCm Bazel configuration improvements: These pull requests enhance the ROCm Bazel setup by adding a flag to compress offloaded device code, significantly reducing the PJRT wheel size without affecting runtime performance. Additionally, they limit the number of concurrent jobs on the ROCm RBE cluster to prevent overload.
[pull/36055, pull/36061]
- Bug fix in nn.initializers.orthogonal: This pull request addresses a ZeroDivisionError in the
nn.initializers.orthogonalfunction, resolving the issue reported in the related GitHub issue. The fix ensures stability and correctness in the orthogonal initializer.
[pull/36062]
- Enhancements to hijax scanning and QDD tracing: This pull request enables scanning over Box objects by allowing traversal of their contents like pytrees, adds helper functions
inc_rankanddec_rankfor tracing jaxprs on AvalQDDs, and includes minor fixes. It also temporarily disables scan residual hoisting optimization due to partial evaluation limitations with QDD.
[pull/36090]
- Extension of lax.mul with preferred element type: This pull request introduces an optional
preferred_element_typeargument to thelax.mulprimitive, allowing behavior similar todot_generalfor specifying element types while maintaining default functionality. This enhances flexibility in element type handling during multiplication operations.
[pull/36092]
- SciPy-style auto-batching support in jsp.linalg: This pull request adds support for SciPy-style auto-batching in the
jsp.linalgmodule by implementing a broadcasting approach for maximal batched dimensions. It improves batch dimension handling in linear algebra functions through iterativevmapapplication and helper functions, with tests updated to match SciPy 1.16's auto-batching features.
[pull/36093]
3.2 Closed Pull Requests
This section provides a summary of pull requests that were closed in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.
Pull Requests Closed This Week: 32
Key Closed Pull Requests
1. Yanywang/issue 3627 fix jax: This pull request addresses the issue where JAX ROCm wheels fail to initialize the GPU backend by adding the rocm_sysdeps/lib subdirectory to the _WHEEL_RPATHS in jaxlib/rocm/rocm_rpath.bzl, ensuring that the dynamic linker can correctly resolve all required librocm_sysdeps_*.so libraries without needing LD_LIBRARY_PATH, thereby fixing the library path resolution problem described in ROCm TheRock issue 3627.
- URL: pull/35971
- Associated Commits: e8a00, aeb5a, 6c61f, fc688, 73077, dd38b, 9f144, a409d, 383a5, 6e99b, dc69c, 2ec72, ced26, 2c542, 83c22, 9f458, 0001f, 03708, af4d1, 28108, 3ddfe, 3d2c7, 5af4a, 49474, 31a8e, 98e03, 0f2cb, cb7f8, 9c1e2, 54ae6, 00496, 9d1d6, 7957d, 6dba3, 766e8, f97e8, ee3d2, 85e78, eec41, a0cd2, 9b089
- Associated Commits: e8a00, aeb5a, 6c61f, fc688, 73077, dd38b, 9f144, a409d, 383a5, 6e99b, dc69c, 2ec72, ced26, 2c542, 83c22, 9f458, 0001f, 03708, af4d1, 28108, 3ddfe, 3d2c7, 5af4a, 49474, 31a8e, 98e03, 0f2cb, cb7f8, 9c1e2, 54ae6, 00496, 9d1d6, 7957d, 6dba3, 766e8, f97e8, ee3d2, 85e78, eec41, a0cd2, 9b089
2. Rocm/s3 wheel downloads: This pull request replaces the GitHub CLI-based download of ROCm plugin and PJRT wheels from GitHub Releases with direct S3 downloads from the jax-ci-amd bucket, introduces a LATEST pointer file to simplify locating the most recent build, expands the ROCm build matrix to cover multiple Python versions, removes the dependency on GH_TOKEN, and updates related workflows to create a self-contained, streamlined build-and-test pipeline.
- URL: pull/35932
3. Postrelease JAX v0.9.2.: This pull request finalizes the JAX v0.9.2 release by incorporating critical bug fixes, test guards for TPU library compatibility, upstream XLA and Shardy patches applied via Bazel, and preparatory changes to ensure a stable post-release state.
- URL: pull/36023
Other Closed Pull Requests
- ROCm Testing and Compatibility Fixes: Multiple pull requests improve ROCm support by adding prebuilt ROCm plugin dependencies for testing, skipping failing tests on ROCm devices, updating continuous wheel test pipelines to use ROCm plugin wheels, and adding ROCm xdist device pinning for pytest workers. These changes enhance test reliability and device management specifically for ROCm hardware.
- Export Feature Documentation and Serialization Updates: Pull requests clarify the compatibility guarantees of the jax.Export feature and reformat export.md for line length consistency. Additionally, serialization of old shardings fields is removed in favor of newer fields with backward compatibility maintained through updated tests.
- Bug Fixes and Code Cleanup: Several pull requests fix bugs such as the
drop_fieldsparameter intree_util.register_dataclass, the equality check forprimal_tangent_dtype, and nested jit compilation in hijax boxes. Code clarity is improved by removing unnecessary calls likelax.asarray.
- ROCm Wheel Build and Runtime Path Fixes: Two pull requests add the
rocm_sysdeps/libsubdirectory to the RUNPATH entries in the JAX wheel build configuration to ensure dynamic linker can locate ROCm system dependency libraries without requiringLD_LIBRARY_PATH. This fixes linking issues with multiplelibrocm_sysdeps_*.solibraries.
- Triton Autotuning Crash Fix: A pull request fixes crashes in the Triton autotuning process caused by improper handling of input-output buffer aliases by modifying the restore loop to iterate only over actually shared buffers at runtime. It also adds a test to ensure autotuning completes correctly when aliased inputs remain live after kernel calls.
- Test Skips for Hardware-Specific Failures: Some pull requests add conditional skips for tests failing on specific hardware, such as skipping
.cta_group::2tests on non-tcgen05 hardware and skipping a failing tridiagonal solve gradient test on ROCm devices. These prevent test failures on unsupported or problematic hardware configurations.
- JIT and Callback Enhancements: A pull request introduces support for
PyObjectTypeinpure_callbackand adds thepy_traced_argnumsparameter tojit, allowing Python objects to be passed as JIT arguments without recompilation. This uses a global registry keyed by auint32counter to safely track Python objects through XLA, with comprehensive tests ensuring correct behavior and garbage collection safety.
- Build and CI Improvements: Pull requests improve the build process by avoiding repeated Bazel flags, updating lockfiles for version 0.9.2 compatibility, adding a wheel-version-suffix input for ROCm artifact builds to support post-release versions, and implementing a dynamic spawn strategy for ROCm RBE CI tests to run locally if the RBE pool is busy.
- Pallas Module and FFT Improvements: One pull request improves the
pl.loopfunctionality in the Pallas module by preserving concrete bounds, while another enhances thejax.scipy.fftmodule by adding support for complex inputs in the discrete cosine transform and related functions.
- Repository Cleanup: A pull request removes the obsolete submodule
jax.experimental.slabfrom the JAX project repository, cleaning up the codebase.
3.3 Pull Request Discussion Insights
This section will analyze the tone and sentiment of discussions within this project's open and closed pull requests that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.
Based on our analysis, there are no instances of toxic discussions in the project's open or closed pull requests from the past week.
IV. Contributors
4.1 Contributors
Active Contributors:
We consider an active contributor in this project to be any contributor who has made at least 1 commit, opened at least 1 issue, created at least 1 pull request, or made more than 2 comments in the last month.
If there are more than 10 active contributors, the list is truncated to the top 10 based on contribution metrics for better clarity.
| Contributor | Commits | Pull Requests | Issues | Comments |
|---|---|---|---|---|
| jakevdp | 44 | 8 | 2 | 20 |
| alekstheod | 38 | 6 | 0 | 4 |
| superbobry | 21 | 2 | 0 | 11 |
| mattjj | 18 | 4 | 0 | 3 |
| magaonka-amd | 20 | 1 | 0 | 0 |
| danielsuo | 15 | 2 | 0 | 1 |
| Ashutosh0x | 15 | 0 | 0 | 0 |
| gulsumgudukbay | 12 | 3 | 0 | 0 |
| gnecula | 10 | 4 | 0 | 0 |
| cj401-amd | 11 | 1 | 0 | 0 |