Weekly Project News

Archives

Weekly GitHub Report for Jax: March 09, 2026 - March 16, 2026 (19:42:55)

Weekly GitHub Report for Jax

Thank you for subscribing to our weekly newsletter! Each week, we deliver a comprehensive summary of your GitHub project's latest activity right to your inbox, including an overview of your project's issues, pull requests, contributors, and commit activity.


Table of Contents

  • I. News
    • 1.1. Recent Version Releases
    • 1.2. Other Noteworthy Updates
  • II. Issues
    • 2.1. Top 5 Active Issues
    • 2.2. Top 5 Stale Issues
    • 2.3. Open Issues
    • 2.4. Closed Issues
    • 2.5. Issue Discussion Insights
  • III. Pull Requests
    • 3.1. Open Pull Requests
    • 3.2. Closed Pull Requests
    • 3.3. Pull Request Discussion Insights
  • IV. Contributors
    • 4.1. Contributors

I. News

1.1 Recent Version Releases:

The current version of this repository is jax-v0.9.0.1

1.2 Version Information:

On February 3, 2026, JAX v0.9.0.1 was released as a patch update to v0.9.0, incorporating four specific pull requests from the OpenXLA repository to address targeted improvements without introducing major changes. This release highlights a focus on incremental fixes and refinements.

Click here to view the full release notes!

II. Issues

2.1 Top 5 Active Issues:

We consider active issues to be issues that that have been commented on most frequently within the last week. Bot comments are omitted.

  1. jnp.sum of all-True boolean array returns wrong result under jax.jit on GPU: This issue reports a bug in JAX where summing an all-True boolean array captured as a closure constant under jax.jit on a GPU returns an incorrect result multiplied by 255 instead of the expected count. The problem arises because the GPU backend incorrectly interprets the i1 (boolean) splat constant as a full byte with value 0xFF during conversion to i32, causing each True element to contribute 255 rather than 1 to the sum.

    • The comments discuss the issue's relation to an upstream XLA bug, confirm its introduction in JAX 0.9.1, and share diagnostic details including HLO and StableHLO analysis; a temporary mitigation in JAX is proposed to avoid the faulty conversion, and an upstream fix is being pursued to address the root cause in GPU code generation and runtime.
    • Number of comments this week: 7
  2. [ENHANCEMENT] Please build against ILP64 so Eigh (LAPACK dsyevd) can be run on matrices larger than 32766x32766: This issue requests building JAX against the ILP64 interface so that the LAPACK dsyevd function can handle eigenvalue decomposition on matrices larger than 32766x32766, as the current ILP32 build limits matrix size due to integer overflow in workspace allocation. The user demonstrates this limitation with a 37000x37000 identity matrix, which fails with a runtime error because the workspace size exceeds the maximum representable value for the current integer type.

    • The comments discuss the current dependency on SciPy's LAPACK, noting that SciPy's ILP64 support is experimental but available, and suggest that if SciPy is built with ILP64, JAX might automatically link to it; however, there is uncertainty about the direct exposure of ILP64 LAPACK kernels through SciPy's Cython interfaces and reluctance to maintain a separate Fortran build toolchain.
    • Number of comments this week: 3
  3. [BUG] Crash with multi-device reshape + vmap + jacrev: This issue describes a crash occurring when using jnp.split combined with jnp.reshape inside jax.jacrev during differentiation, specifically when the Jacobian is multiplied by the input vector and vmapped over a merged and sharded dimension on multiple GPUs. The problem is triggered by reshaping a split slice to a scalar and results in a slice index count mismatch error, while a workaround using direct indexing instead of split and reshape avoids the crash.

    • The comments confirm the issue is reproducible and acknowledge the provided minimal repro; a fix is currently in progress as indicated by a linked pull request.
    • Number of comments this week: 2
  4. [BUG] Delay Kernel time out warning: This issue reports a warning message that appears when starting a JAX program, indicating a delay kernel timeout with sub-optimal accuracy in timing measurements, possibly due to a missing warmup execution. The user provides system details including OS, Python version, JAX library version, CUDA version, and NVIDIA driver version to help diagnose the problem.

    • The comment suggests that the warning will likely be resolved by a specific pull request in the related XLA repository, implying that a fix is already in progress.
    • Number of comments this week: 1
  5. [BUG] Empty bitcode string provided for eigen: This issue reports a warning message indicating that an empty bitcode string is provided for Eigen, which disables optimizations relying on this intermediate representation when running a simple JAX example on a Mac with an aarch64 architecture. The user is seeking clarification or assistance regarding this warning that appears during execution of a JAX-compiled function involving matrix operations and nonlinearities.

    • The comment clarifies that the warning is harmless and will be downgraded to an informational message, explaining that Eigen bitcode generation is not currently supported on Mac systems.
    • Number of comments this week: 1

2.2 Top 5 Stale Issues:

We consider stale issues to be issues that has had no activity within the last 30 days. The team should work together to get these issues resolved and closed as soon as possible.

As of our latest update, there are no stale issues for the project this week.

2.3 Open Issues

This section lists, groups, and then summarizes issues that were created within the last week in the repository.

Issues Opened This Week: 10

Summarized Issues:

  • GPU computation errors and incorrect results: Several issues report incorrect numerical results or crashes related to GPU computations in JAX. These include a boolean sum bug causing incorrect multiplication by 255 during GPU code generation and a crash triggered by jnp.split and jnp.reshape inside jax.jacrev on multiple GPUs due to slice index rank mismatch errors.
  • issues/35762, issues/35815
  • Distributed and multi-GPU setup failures: Problems arise in distributed and multi-GPU environments, including overflow errors with sharded x64 arange on multi-GPU setups and hangs in jax.distributed.initialize() on TPU multislice setups caused by firewall blocking coordinator port changes.
  • issues/35896, issues/35905
  • Warnings and optimization issues during startup and compilation: Startup warnings and optimization disabling occur due to kernel timeout delays and empty bitcode strings on specific architectures, affecting JAX program initialization and Eigen optimizations.
  • issues/35751, issues/35800
  • Tracer and compilation discrepancies: There is a discrepancy in tracer leak detection behavior between direct jit(f)(*args) calls and the lower-then-compile approach, leading to different runtime errors.
  • issues/35799
  • Numerical correctness issues on TPU hardware: The jax.nn.dot_product_attention function produces silently incorrect results on TPU v6e hardware across all tested configurations, causing incorrect model training and inference without warnings.
  • issues/35916
  • Feature requests for improved functionality and performance: Requests include a native JAX implementation of linear cross-entropy loss optimized for language model training and building JAX against ILP64 to support larger matrix eigenvalue decompositions beyond current ILP32 limits.
  • issues/35906, issues/35919

2.4 Closed Issues

This section lists, groups, and then summarizes issues that were closed within the last week in the repository. This section also links the associated pull requests if applicable.

Issues Closed This Week: 3

Summarized Issues:

  • Performance regressions in Triton kernels: A significant 5-10x slowdown was observed when upgrading JAX from version 0.6.2 to 0.8.0 or later, specifically impacting Triton kernel execution using pl.pallas_call with static unrolling and boolean one-hot masks. The regression is suspected to stem from changes in MLIR/Triton scheduling or register allocation affecting heavily unrolled flat DAGs.
  • issues/35529
  • Module deprecation and import errors: The jax.ad_checkpoint module is missing or deprecated from the default jax namespace in version 0.9.1, causing an AttributeError for users expecting it to be available by default. Users must now explicitly import it from jax.ad_checkpoint to avoid this error.
  • issues/35873
  • TPU kernel crashes with experimental scheduler: Using use_experimental_scheduler=True in the Splash Attention kernel on v6e TPU hardware causes crashes during multi-host training due to invalid schedules generated by the LP-based instruction scheduler. This leads to hardware scheck failures and can be mitigated by disabling the experimental scheduler.
  • issues/35925

2.5 Issue Discussion Insights

This section will analyze the tone and sentiment of discussions within this project's open and closed issues that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.

Based on our analysis, there are no instances of toxic discussions in the project's open or closed issues from the past week.


III. Pull Requests

3.1 Open Pull Requests

This section provides a summary of pull requests that were opened in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.

Pull Requests Opened This Week: 15

Key Open Pull Requests

1. Add memory-efficient fused linear cross-entropy loss (solves issue #35906): This pull request introduces a memory-efficient implementation of fused linear cross-entropy loss in JAX that computes cross-entropy without materializing the full logits matrix, significantly reducing memory usage for large-vocabulary language models by chunking computations and using custom forward and backward passes based on the "Cut Your Losses" algorithm, while maintaining accuracy, JIT compatibility, and support for CPU, GPU, and TPU.

  • URL: pull/35915
  • Associated Commits: a9961, 9c4b6, de02f, 1bb94, 48a57, 14dde

2. [ROCm] Enable hipBLASLt in ROCm pytest CI: This pull request enables hipBLASLt support in the ROCm pytest continuous integration by adding the flag --xla_gpu_enable_cublaslt=true to XLA_FLAGS in the run_pytest_rocm.sh script, which is necessary for TF32 GEMM support and to prevent autotuner failures for TF32 dot algorithms on ROCm.

  • URL: pull/35902
  • Associated Commits: a76fa, f01e2

3. [ROCm] Fix and simplify jax rocm plugin init script: This pull request improves the JAX ROCm plugin initialization script by adding system requirements checks such as GPU availability and shared memory size, simplifying GPU detection through a new hardware utilities module, generating informative logs and warnings, and includes comprehensive unit tests to ensure robust functionality.

  • URL: pull/35785
  • Associated Commits: cf03f

Other Open Pull Requests

  • Precision and Accuracy Fixes in Primitives: Multiple pull requests address precision and accuracy issues in JAX primitives. One fixes a precision degradation bug in the Pallas reciprocal primitive by preventing incorrect downcasting of float64 inputs, while another updates the erf and atan2 primitives to properly pass the accuracy parameter, fixing a TypeError in Pallas TPU kernels and aligning with other math primitives.
  • pull/35791, pull/35809
  • Device and Array Reference Handling Improvements: Enhancements to device and array reference management are made to ensure correct propagation and avoid lifetime or ownership issues. This includes improved handling of IFRT array references in PyArray::BatchedCopyToDeviceWithSharding and the introduction of the experimental OneAPI PJRT plugin for Intel GPUs with necessary runtime and device discovery support.
  • pull/35808, pull/35849
  • Sharding Serialization and Logging Updates: Changes are made to sharding serialization and logging behavior to improve clarity and compatibility. Serialization of in_hlo_shardings and out_hlo_shardings is removed in favor of newer fields with backward compatibility maintained, and printing of sharding information during dot, broadcast, and reshard operations is suppressed.
  • pull/35826, pull/35879
  • Testing and Pipeline Enhancements: Updates to testing pipelines and test stability are introduced. The continuous wheel test pipeline for ROCm now uses previously built plugin and PJRT wheels instead of downloading released wheels, and a deadlock and timeout issue in the image_test single_gpu test is fixed by cleaning up the backend clique cache during teardown.
  • pull/35821, pull/35918
  • New Features in Lax Library: New functionality is added to the lax library to enhance its capabilities. This includes the addition of lax.associative_reduce and an unroll parameter for lax.while_loop to address specific issues reported in the project.
  • pull/35904, pull/35913
  • Bug Fixes and Cleanup: Various bug fixes and cleanup tasks are performed, including the deletion of the obsolete jax.experimental.slab submodule and conditional skipping of a rocSparse numerical precision test on ROCm devices to avoid failures while preserving other test variants.
  • pull/35824, pull/35917

3.2 Closed Pull Requests

This section provides a summary of pull requests that were closed in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.

Pull Requests Closed This Week: 37

Key Closed Pull Requests

1. [ROCm] Implement Mosaic GPU detection and Auto-Skips: This pull request implements detection of Mosaic GPU tests and automatically skips them during ROCm test runs by adding a pytest marker and classification logic based on test file paths, source usage, and default behaviors, thereby improving the reliability of ROCm testing without affecting non-ROCm environments.

  • URL: pull/35288
  • Associated Commits: be9e3, 9cd64, 76239, 1f3a2, 15c1a, df0db
  • Associated Commits: be9e3, 9cd64, 76239, 1f3a2, 15c1a, df0db

2. [ROCm] ToT ROCm Unit Test Skips: This pull request skips multiple currently failing ROCm-related unit tests in the JAX project using parameterized checks, as these tests are under triage and linked to specific issues, while also later unskipping a ROCm test for numpy function signatures.

  • URL: pull/34722
  • Associated Commits: c9084, 7bc4f, b8518, 0cd37, 5d6f8
  • Associated Commits: c9084, 7bc4f, b8518, 0cd37, 5d6f8

3. [ROCm] Add individual skips to Mosaic tests: This pull request adds targeted skips to Mosaic tests specifically for ROCm environments, aiming to minimize the number of skipped test nodes by only excluding those that utilize the Mosaic GPU.

  • URL: pull/35741
  • Associated Commits: e84aa, 0a602, 76341
  • Associated Commits: e84aa, 0a602, 76341

Other Closed Pull Requests

  • ROCm Support and Build Improvements: Multiple pull requests enhance ROCm GPU support by extending test skips, adding missing initialization scripts, improving test workflows with structured outputs and log uploads, fixing linker errors via configuration updates, and installing AWS CLI in the build process. These changes collectively ensure better ROCm integration, build reliability, and testing infrastructure.
    • pull/35533, pull/35668, pull/35283, pull/35710
  • Type Checking Enhancements: Several pull requests add or extend type checking across various modules including jax/experimental/key_reuse, multiple other experimental modules, jax/experimental/sparse, top-level experimental files, jax2tf, and jax.experimental/array_serialization plus colocated_python. These improvements aim to increase code reliability, maintainability, and developer tooling support.
    • pull/35716, pull/35727, pull/35747, pull/35750, pull/35775, pull/35779
  • Explicit Sharding and Serialization Updates: Pull requests introduce explicit sharding rules for argsort and bincount, implement a shard map without linear utility, and add support for serializing AbstractMesh.abstract_device with a new export serialization version and backward compatibility tests. These changes improve sharding functionality and serialization robustness.
    • pull/35650, pull/35745, pull/35793
  • Bug Fixes and Test Stability: A pull request fixes a test hang in the Mosaic:GPU component by reverting a previous commit and implementing explicit stream-ordered memset and synchronization, thereby re-enabling the test. This fix addresses stability issues related to memory zeroing in GPU tests.
    • pull/35716
  • Function Behavior and API Improvements: Updates include adding a warning to the pl.reciprocal function docstring about zero arguments, changing the ad.backward_pass wrapper to return a Zero object instead of None, and making the accuracy parameter in jax.lax keyword-only with added type annotations and documentation. These changes improve API clarity and consistency.
    • pull/35696, pull/35748, pull/35818
  • Triton Lowering and Hardware-Specific Fixes: A pull request updates the Triton lowering for dot_general to use MMA instructions by relaxing dimension checks for float64 tensors, fixes output accumulator dtype handling, enforces hardware-specific constraints to prevent compiler errors, and adds regression tests. This enhances performance and correctness on NVIDIA hardware.
    • pull/35654
  • Code Policy and Documentation Updates: One pull request updates the AI code policy documentation in the project, ensuring guidelines remain current and clear for contributors.
    • pull/35747
  • MLIR Extensions and Conversion Updates: Pull requests add type stubs for additional MLIR-related extensions in jaxlib and update the jax2tf conversion process to use new exported NamedShardings fields, removing deprecated fields. These changes improve type safety and modernize conversion internals.
    • pull/35823, pull/35841
  • Complex Fix and Version Bump: A highly complex fix authored by Yash Katariya and Robert Dyro was merged, and a proposal to update Pyrefly to version 0.56 includes removing suppressions and migrating contextmanager-decorated functions to return Generator instead of deprecated Iterator.
    • pull/35787, pull/35842

3.3 Pull Request Discussion Insights

This section will analyze the tone and sentiment of discussions within this project's open and closed pull requests that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.

Based on our analysis, there are no instances of toxic discussions in the project's open or closed pull requests from the past week.


IV. Contributors

4.1 Contributors

Active Contributors:

We consider an active contributor in this project to be any contributor who has made at least 1 commit, opened at least 1 issue, created at least 1 pull request, or made more than 2 comments in the last month.

If there are more than 10 active contributors, the list is truncated to the top 10 based on contribution metrics for better clarity.

Contributor Commits Pull Requests Issues Comments
jakevdp 45 13 0 9
alekstheod 44 7 0 16
benknutson-google 44 0 0 0
Arech8 3 0 0 24
superbobry 20 5 0 2
gulsumgudukbay 14 3 0 7
mattjj 18 4 0 1
magaonka-amd 21 2 0 0
Ashutosh0x 17 0 0 0
AratiGanesh 11 2 0 0

Don't miss what's next. Subscribe to Weekly Project News:
Powered by Buttondown, the easiest way to start and grow your newsletter.