Weekly Project News

Archives

Weekly GitHub Report for Jax: March 16, 2026 - March 23, 2026 (19:47:03)

Weekly GitHub Report for Jax

Thank you for subscribing to our weekly newsletter! Each week, we deliver a comprehensive summary of your GitHub project's latest activity right to your inbox, including an overview of your project's issues, pull requests, contributors, and commit activity.


Table of Contents

  • I. News
    • 1.1. Recent Version Releases
    • 1.2. Other Noteworthy Updates
  • II. Issues
    • 2.1. Top 5 Active Issues
    • 2.2. Top 5 Stale Issues
    • 2.3. Open Issues
    • 2.4. Closed Issues
    • 2.5. Issue Discussion Insights
  • III. Pull Requests
    • 3.1. Open Pull Requests
    • 3.2. Closed Pull Requests
    • 3.3. Pull Request Discussion Insights
  • IV. Contributors
    • 4.1. Contributors

I. News

1.1 Recent Version Releases:

The current version of this repository is jax-v0.9.0.1

1.2 Version Information:

On February 3, 2026, JAX v0.9.0.1 was released as a patch update to v0.9.0, incorporating four specific pull requests from the OpenXLA repository to address targeted improvements or fixes without introducing major changes.

Click here to view the full release notes!

II. Issues

2.1 Top 5 Active Issues:

We consider active issues to be issues that that have been commented on most frequently within the last week. Bot comments are omitted.

  1. [BUG] cho_solve does not replicate scipy batch broadcasting for batched c + unbatched vector b: This issue reports that the jax.scipy.linalg.cho_solve function does not replicate the batch broadcasting behavior of scipy.linalg.cho_solve when solving with a batched Cholesky factor and an unbatched right-hand-side vector, leading to a shape mismatch error. The user provides a minimal reproducible example, explains the root cause related to dimension checks in the underlying triangular solve, and suggests a potential fix involving explicit broadcasting of inputs to a common batch shape before solving.

    • The comments acknowledge the issue as a limitation in JAX compared to newer SciPy versions, express willingness to contribute a fix, and discuss the implications of changing broadcasting behavior due to differences in related NumPy functions, highlighting the need for careful implementation and possible deprecation strategies.
    • Number of comments this week: 3
  2. [BUG] Excessive memory usage for jacfwd(vectorize(grad(f))): This issue describes a problem with excessive memory usage when computing the Jacobian of a vectorized gradient function using JAX, where the expected memory footprint is vastly underestimated, leading to out-of-memory errors. The user suspects that the combination of grad and vectorize is not efficiently handling the scalar derivatives, causing intermediate arrays to balloon in size, and seeks advice on how to reduce memory consumption for this operation.

    • The comments analyze the JAX intermediate representation (jaxpr) to identify unexpectedly large intermediate arrays causing the memory blowup and suggest that the memory issue arises from propagating the Jacobian through a dot product. They recommend alternative approaches such as rewriting the function using scan or explicitly leveraging the diagonal structure of the vectorized function via coloring methods or vectorizing last, which can reduce memory usage by avoiding pushing large tangent vectors through the gradient computations.
    • Number of comments this week: 2
  3. [ENHANCEMENT] Tracing and pure_callback with black-box Python objects: This issue discusses the challenge of enabling JAX to trace and work with black-box Python objects that do not have a fixed-size pytree representation but expose a pure functional interface, such as variable-sized or recursive data structures. It proposes using pure_callback to handle these objects within jitted code, allowing encapsulated control flow without restructuring the program, and explores the feasibility of implementing this via reference counting and a new dtype to manage Python objects in device memory.

    • The comments highlight the difficulty of representing opaque Python objects in JAX's current array-based model without rewriting XLA's object system, suggesting that trace-time solutions like HiJax might help but still require conversion to fixed-size arrays. Another comment proposes using memory addresses as integer representations of Python objects with host callbacks to manage reference counting, indicating a potential path forward.
    • Number of comments this week: 2
  4. [BUG] Compilation time increases from seconds to 9min between 0.9.0.1 and 0.9.1: This issue reports a significant regression in compilation time for JIT code between versions 0.9.0.1 and 0.9.1, where compilation time increased from approximately 3 seconds to 9 minutes. The user suspects the problem may be related to a previous issue involving memory spikes and provides detailed system information along with example code to reproduce the behavior without requiring download of attached files.

    • The commenters discuss concerns about downloading ZIP files and request a more direct reproduction method; the original poster then shares a minimal reproducible example using their Python package to demonstrate the issue clearly.
    • Number of comments this week: 2
  5. [BUG] jax.nn.initializers.orthogonal crashes on zero-sized dimensions: This issue reports a ZeroDivisionError occurring in the jax.nn.initializers.orthogonal function when it is called with shapes that include zero-sized dimensions, which happens during the initialization of recurrent layers like LSTMCell in Flax. The problem arises because the initializer performs a division by a dimension size without checking if that dimension is zero, and the user suggests adding a guard to raise a clear error when zero-sized dimensions are present to prevent this crash.

    • The comments note that the underlying random orthogonal function supports zero-sized dimensions and argue that the initializer should as well; a related pull request has been opened to fix the issue.
    • Number of comments this week: 2

2.2 Top 5 Stale Issues:

We consider stale issues to be issues that has had no activity within the last 30 days. The team should work together to get these issues resolved and closed as soon as possible.

As of our latest update, there are no stale issues for the project this week.

2.3 Open Issues

This section lists, groups, and then summarizes issues that were created within the last week in the repository.

Issues Opened This Week: 11

Summarized Issues:

  • Memory usage and out-of-memory errors: Excessive memory usage and out-of-memory errors occur when computing the Jacobian of a vectorized gradient function due to large intermediate arrays generated by the composition of grad and vectorize. Additionally, GPU allocation OOM errors are not properly propagated as Python exceptions, causing misleading success status codes despite error messages.
  • issues/35936, issues/35994
  • JIT compilation and performance regressions: A significant regression in JIT compilation time was observed between versions 0.9.0.1 and 0.9.1, with compilation time increasing from about 3 seconds to 9 minutes, potentially linked to memory spikes. This regression is supported by attached XLA dump files for further investigation.
  • issues/35958
  • Handling of references and memory in shard_map: Using jax.new_ref with jax.shard_map on an 8-TPU mesh causes errors related to memory space specification and limitations in handling references without the jit decorator. This indicates challenges in managing memory and references within eager shard_map executions.
  • issues/36000
  • Gradient and sharding propagation errors: The backward pass of jax.nn.dot_product_attention with implementation="cudnn" inside a jax.shard_map fails to propagate manual sharding axis annotations on gradient outputs, resulting in a ValueError due to a type mismatch at the custom_vjp boundary. This highlights issues in gradient computation and sharding metadata handling.
  • issues/36008
  • Under-specification of reduction operations: The axis traversal order in multi-dimensional lax.reduce operations is under-specified when using non-commutative reduction functions, causing inconsistent results across CPU, GPU, and TPU hardware platforms. This inconsistency affects reproducibility and correctness of reductions.
  • issues/36011
  • Broadcasting and tree mapping enhancements: A proposal suggests modifying jax.tree.map to allow broadcasting of additional input trees into the main tree when possible, enhancing versatility while preserving current behavior for valid inputs. This change aims to improve usability in tree-structured data transformations.
  • issues/36037
  • Batch broadcasting inconsistencies in linear algebra: jax.scipy.linalg.cho_solve does not replicate the batch broadcasting behavior of scipy.linalg.cho_solve when solving with a batched Cholesky factor and an unbatched right-hand-side vector, leading to shape mismatch errors. A fix involving explicit broadcasting is proposed to align batch dimensions correctly.
  • issues/36083
  • Crashes due to fusion pass failures: JAX crashes when the NestGemmFusion pass fails because of a mismatch in symbolic map dimensions while fusing 3D Dense GEMMs with 4D einsum operations across reshapes. This issue is triggered by combining Dense layers and einsum contractions in GPU-accelerated models.
  • issues/36095
  • Initializer failures with zero-sized dimensions: The jax.nn.initializers.orthogonal function crashes with a ZeroDivisionError when called with shapes containing zero-sized dimensions, as it attempts to divide by a dimension size that can be zero. This failure affects scenarios like initializing recurrent layers with dynamic shapes.
  • issues/35993
  • Tracing and pure_callback support for opaque Python objects: A proposal aims to enable tracing and use of pure_callback with black-box Python objects that have variable-sized or recursive data structures by treating them as opaque entities in jitted code. This would manage their lifetimes and reference counting through a new dtype or mechanism holding memory addresses and deferring operations to callbacks.
  • issues/35950

2.4 Closed Issues

This section lists, groups, and then summarizes issues that were closed within the last week in the repository. This section also links the associated pull requests if applicable.

Issues Closed This Week: 6

Summarized Issues:

  • Documentation and User Guidance Issues: Several issues highlight missing or unclear documentation and unexpected function behaviors that confuse users. One issue points out the lack of training data and limitations in the README, while another notes unexpected behavior in jnp.arange due to hardcoded compile-time constants instead of expected XLA lowering, both of which hinder user understanding and proper usage.
  • issues/35930, issues/35953
  • Function Behavior and Output Inconsistencies: There are problems with function implementations producing incorrect or inconsistent results. For example, jax.scipy.fft.dct ignores the imaginary part of complex inputs, leading to outputs that differ from SciPy’s behavior, and jax.tree.map does not broadcast arguments as documented, causing type errors during execution.
  • issues/35973, issues/35996
  • Runtime Errors and Crashes in Specific Environments: Some operations cause crashes or errors depending on the hardware or operation sequence. A notable case is a SIGABRT crash on GPU and TPU when using lax.reduce with lax.cond, while the same code runs fine on CPU, and another involves an LLVM error in Shardy triggered by resharded results with reduced axes feeding into a shard_map during backward passes.
  • issues/35934, issues/36009

2.5 Issue Discussion Insights

This section will analyze the tone and sentiment of discussions within this project's open and closed issues that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.

Based on our analysis, there are no instances of toxic discussions in the project's open or closed issues from the past week.


III. Pull Requests

3.1 Open Pull Requests

This section provides a summary of pull requests that were opened in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.

Pull Requests Opened This Week: 9

Key Open Pull Requests

1. Check for escaped tracers during jit.lower: This pull request addresses issue #35799 by implementing a check for escaped tracers during the jit.lower process in the JAX project.

  • URL: pull/35981
  • Associated Commits: 5d228

2. [ROCm] Add lowering for ScaledMatmul, ScaledDot: This pull request adds ROCm support for lowering the ScaledMatmul and ScaledDot operations by implementing a dedicated translation that delegates to lax.scaled_dot, fixing failing tests caused by the missing MLIR translation rule on ROCm while preserving existing CUDA behavior.

  • URL: pull/35995
  • Associated Commits: 24685

3. [hijax] add a hijax primitive for jnp.nonzero: This pull request introduces a new hijax primitive for the jnp.nonzero function that simplifies batching and automatic differentiation rules, aiming to provide a well-defined, composable implementation with clearer semantics compared to the existing nonzero implementation in JAX.

  • URL: pull/36053
  • Associated Commits: 1f624

Other Open Pull Requests

  • ROCm Bazel configuration improvements: These pull requests enhance the ROCm Bazel setup by adding a flag to compress offloaded device code, significantly reducing the PJRT wheel size without affecting runtime performance. Additionally, they limit the number of concurrent jobs on the ROCm RBE cluster to prevent overload.
    [pull/36055, pull/36061]
  • Bug fix in nn.initializers.orthogonal: This pull request addresses a ZeroDivisionError in the nn.initializers.orthogonal function, resolving the issue reported in the related GitHub issue. The fix ensures stability and correctness in the orthogonal initializer.
    [pull/36062]
  • Enhancements to hijax scanning and QDD tracing: This pull request enables scanning over Box objects by allowing traversal of their contents like pytrees, adds helper functions inc_rank and dec_rank for tracing jaxprs on AvalQDDs, and includes minor fixes. It also temporarily disables scan residual hoisting optimization due to partial evaluation limitations with QDD.
    [pull/36090]
  • Extension of lax.mul with preferred element type: This pull request introduces an optional preferred_element_type argument to the lax.mul primitive, allowing behavior similar to dot_general for specifying element types while maintaining default functionality. This enhances flexibility in element type handling during multiplication operations.
    [pull/36092]
  • SciPy-style auto-batching support in jsp.linalg: This pull request adds support for SciPy-style auto-batching in the jsp.linalg module by implementing a broadcasting approach for maximal batched dimensions. It improves batch dimension handling in linear algebra functions through iterative vmap application and helper functions, with tests updated to match SciPy 1.16's auto-batching features.
    [pull/36093]

3.2 Closed Pull Requests

This section provides a summary of pull requests that were closed in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.

Pull Requests Closed This Week: 32

Key Closed Pull Requests

1. Yanywang/issue 3627 fix jax: This pull request addresses the issue where JAX ROCm wheels fail to initialize the GPU backend by adding the rocm_sysdeps/lib subdirectory to the _WHEEL_RPATHS in jaxlib/rocm/rocm_rpath.bzl, ensuring that the dynamic linker can correctly resolve all required librocm_sysdeps_*.so libraries without needing LD_LIBRARY_PATH, thereby fixing the library path resolution problem described in ROCm TheRock issue 3627.

  • URL: pull/35971
  • Associated Commits: e8a00, aeb5a, 6c61f, fc688, 73077, dd38b, 9f144, a409d, 383a5, 6e99b, dc69c, 2ec72, ced26, 2c542, 83c22, 9f458, 0001f, 03708, af4d1, 28108, 3ddfe, 3d2c7, 5af4a, 49474, 31a8e, 98e03, 0f2cb, cb7f8, 9c1e2, 54ae6, 00496, 9d1d6, 7957d, 6dba3, 766e8, f97e8, ee3d2, 85e78, eec41, a0cd2, 9b089
  • Associated Commits: e8a00, aeb5a, 6c61f, fc688, 73077, dd38b, 9f144, a409d, 383a5, 6e99b, dc69c, 2ec72, ced26, 2c542, 83c22, 9f458, 0001f, 03708, af4d1, 28108, 3ddfe, 3d2c7, 5af4a, 49474, 31a8e, 98e03, 0f2cb, cb7f8, 9c1e2, 54ae6, 00496, 9d1d6, 7957d, 6dba3, 766e8, f97e8, ee3d2, 85e78, eec41, a0cd2, 9b089

2. Rocm/s3 wheel downloads: This pull request replaces the GitHub CLI-based download of ROCm plugin and PJRT wheels from GitHub Releases with direct S3 downloads from the jax-ci-amd bucket, introduces a LATEST pointer file to simplify locating the most recent build, expands the ROCm build matrix to cover multiple Python versions, removes the dependency on GH_TOKEN, and updates related workflows to create a self-contained, streamlined build-and-test pipeline.

  • URL: pull/35932
  • Associated Commits: 51b81, 65331, 595b4, af6ae, eb600
  • Associated Commits: 51b81, 65331, 595b4, af6ae, eb600

3. Postrelease JAX v0.9.2.: This pull request finalizes the JAX v0.9.2 release by incorporating critical bug fixes, test guards for TPU library compatibility, upstream XLA and Shardy patches applied via Bazel, and preparatory changes to ensure a stable post-release state.

  • URL: pull/36023
  • Associated Commits: 6ca53, 43195, e0bb8, a6597, a7d12
  • Associated Commits: 6ca53, 43195, e0bb8, a6597, a7d12

Other Closed Pull Requests

  • ROCm Testing and Compatibility Fixes: Multiple pull requests improve ROCm support by adding prebuilt ROCm plugin dependencies for testing, skipping failing tests on ROCm devices, updating continuous wheel test pipelines to use ROCm plugin wheels, and adding ROCm xdist device pinning for pytest workers. These changes enhance test reliability and device management specifically for ROCm hardware.
    • pull/35516, pull/35611, pull/35821, pull/35917, pull/35944
  • Export Feature Documentation and Serialization Updates: Pull requests clarify the compatibility guarantees of the jax.Export feature and reformat export.md for line length consistency. Additionally, serialization of old shardings fields is removed in favor of newer fields with backward compatibility maintained through updated tests.
    • pull/35967, pull/35965, pull/35879
  • Bug Fixes and Code Cleanup: Several pull requests fix bugs such as the drop_fields parameter in tree_util.register_dataclass, the equality check for primal_tangent_dtype, and nested jit compilation in hijax boxes. Code clarity is improved by removing unnecessary calls like lax.asarray.
    • pull/35938, pull/35939, pull/35972, pull/35984
  • ROCm Wheel Build and Runtime Path Fixes: Two pull requests add the rocm_sysdeps/lib subdirectory to the RUNPATH entries in the JAX wheel build configuration to ensure dynamic linker can locate ROCm system dependency libraries without requiring LD_LIBRARY_PATH. This fixes linking issues with multiple librocm_sysdeps_*.so libraries.
    • pull/35977, pull/35978
  • Triton Autotuning Crash Fix: A pull request fixes crashes in the Triton autotuning process caused by improper handling of input-output buffer aliases by modifying the restore loop to iterate only over actually shared buffers at runtime. It also adds a test to ensure autotuning completes correctly when aliased inputs remain live after kernel calls.
    • pull/35218
  • Test Skips for Hardware-Specific Failures: Some pull requests add conditional skips for tests failing on specific hardware, such as skipping .cta_group::2 tests on non-tcgen05 hardware and skipping a failing tridiagonal solve gradient test on ROCm devices. These prevent test failures on unsupported or problematic hardware configurations.
    • pull/35959, pull/35917
  • JIT and Callback Enhancements: A pull request introduces support for PyObjectType in pure_callback and adds the py_traced_argnums parameter to jit, allowing Python objects to be passed as JIT arguments without recompilation. This uses a global registry keyed by a uint32 counter to safely track Python objects through XLA, with comprehensive tests ensuring correct behavior and garbage collection safety.
    • pull/35968
  • Build and CI Improvements: Pull requests improve the build process by avoiding repeated Bazel flags, updating lockfiles for version 0.9.2 compatibility, adding a wheel-version-suffix input for ROCm artifact builds to support post-release versions, and implementing a dynamic spawn strategy for ROCm RBE CI tests to run locally if the RBE pool is busy.
    • pull/36027, pull/36044, pull/36046, pull/36056
  • Pallas Module and FFT Improvements: One pull request improves the pl.loop functionality in the Pallas module by preserving concrete bounds, while another enhances the jax.scipy.fft module by adding support for complex inputs in the discrete cosine transform and related functions.
    • pull/36012, pull/35974
  • Repository Cleanup: A pull request removes the obsolete submodule jax.experimental.slab from the JAX project repository, cleaning up the codebase.
    • pull/35824

3.3 Pull Request Discussion Insights

This section will analyze the tone and sentiment of discussions within this project's open and closed pull requests that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.

Based on our analysis, there are no instances of toxic discussions in the project's open or closed pull requests from the past week.


IV. Contributors

4.1 Contributors

Active Contributors:

We consider an active contributor in this project to be any contributor who has made at least 1 commit, opened at least 1 issue, created at least 1 pull request, or made more than 2 comments in the last month.

If there are more than 10 active contributors, the list is truncated to the top 10 based on contribution metrics for better clarity.

Contributor Commits Pull Requests Issues Comments
jakevdp 44 8 2 20
alekstheod 38 6 0 4
superbobry 21 2 0 11
mattjj 18 4 0 3
magaonka-amd 20 1 0 0
danielsuo 15 2 0 1
Ashutosh0x 15 0 0 0
gulsumgudukbay 12 3 0 0
gnecula 10 4 0 0
cj401-amd 11 1 0 0

Don't miss what's next. Subscribe to Weekly Project News:
Powered by Buttondown, the easiest way to start and grow your newsletter.