Weekly GitHub Report for Jax: April 06, 2026 - April 13, 2026 (19:27:16)
Weekly GitHub Report for Jax
Thank you for subscribing to our weekly newsletter! Each week, we deliver a comprehensive summary of your GitHub project's latest activity right to your inbox, including an overview of your project's issues, pull requests, contributors, and commit activity.
Table of Contents
I. News
1.1 Recent Version Releases:
The current version of this repository is jax-v0.9.0.1
1.2 Version Information:
On February 3, 2026, JAX v0.9.0.1 was released as a patch update to v0.9.0, incorporating four specific pull requests from the OpenXLA repository to address targeted improvements without introducing major changes. This release highlights a focus on incremental fixes and refinements.
Click here to view the full release notes!
II. Issues
2.1 Top 5 Active Issues:
We consider active issues to be issues that that have been commented on most frequently within the last week. Bot comments are omitted.
-
[BUG]
ffi_callwith empty array insidelax.scancrashes during lowering whenJAX_USE_SIMPLIFIED_JAXPR_CONSTANTS=True: This issue describes a crash occurring during the lowering phase of a JAX computation when usingjax.ffi.ffi_callwith an empty array insidejax.lax.scanwhile the environment variableJAX_USE_SIMPLIFIED_JAXPR_CONSTANTSis set toTrue. The error arises because aTypedNdArrayobject lacks ashardingattribute, causing anAttributeErrorin the constant hoisting logic during compilation.- The first comment states that the issue appears fixed in JAX version 0.9.2, but a subsequent comment contradicts this by reproducing the same error with the exact reproducer on JAX 0.9.2, providing detailed system information to support the report.
- Number of comments this week: 2
-
[BUG] jax.jit produces different results compared to eager execution: This issue reports that using jax.jit to compile a model produces different numerical results compared to running the model in eager execution mode, with a relative L2 error of about 4%. The discrepancy arises in a model involving normalization, negation operations, LayerNorm, and clamp operations, suggesting a potential problem in XLA's optimization passes that may alter the numerical semantics of the computation.
- The single comment references another related issue for further context, indicating that the discussion is being continued elsewhere rather than resolved within this thread.
- Number of comments this week: 1
-
[BUG] jax.jit produces incorrect results vs eager execution on CPU: This issue reports that using jax.jit on CPU produces significantly incorrect results compared to eager execution, with a relative L2 error of approximately 48%. The problem appears to stem from XLA optimizations such as fusion or constant propagation altering the computation semantics in a model involving constant folding, fused matmul with bias and relu, and layernorm with relu, and it is reproducible with fixed inputs.
- The single comment references another issue for further discussion, indicating that the conversation has been redirected rather than resolved within this thread.
- Number of comments this week: 1
-
[BUG] Silent numerical inconsistency under compiler optimization: This issue reports a silent numerical inconsistency observed under compiler optimization in version 0.9.2, where outputs that appear identical at default precision actually exhibit significant numerical divergence with a relative L2 norm of approximately 0.17. The user provides a Python script attachment to demonstrate the problem and highlights that this discrepancy is not immediately apparent without detailed analysis.
- The single comment directs the user to a related discussion in another issue, suggesting that the problem may be connected or previously addressed elsewhere.
- Number of comments this week: 1
-
[BUG] jax.jit (XLA) produces wrong output: This issue reports that the jax.jit (XLA) compilation produces incorrect numerical output compared to the expected results from PyTorch eager execution, with a significant relative L2 difference. The user provided an external Python script to demonstrate the discrepancy but did not include the code directly in the issue description.
- The comment questions the usefulness of opening multiple issues with only external attachments and requests that the reproducing code be included directly in the issue description for easier review. It also notes that some numerical differences are expected in floating point computations and that determining if the reported discrepancy is abnormal will require further investigation.
- Number of comments this week: 1
2.2 Top 5 Stale Issues:
We consider stale issues to be issues that has had no activity within the last 30 days. The team should work together to get these issues resolved and closed as soon as possible.
As of our latest update, there are no stale issues for the project this week.
2.3 Open Issues
This section lists, groups, and then summarizes issues that were created within the last week in the repository.
Issues Opened This Week: 12
Summarized Issues:
- Numerical inconsistencies and non-determinism in JAX/JIT/XLA outputs: Multiple issues report that JAX's JIT compilation and XLA optimizations cause numerical inconsistencies and non-deterministic outputs compared to eager execution or expected results from PyTorch. These problems manifest in normalization, LayerNorm, fused operations, and near-zero input processing, leading to significant relative errors and reproducibility challenges.
- issues/36570, issues/36618, issues/36619, issues/36620, issues/36622, issues/36623, issues/36624, issues/36625
- Crashes and errors during JAX lowering and compilation phases: Some issues describe crashes and internal errors occurring during the lowering phase or TPU lowering in JAX, triggered by specific inputs or mesh configurations. These errors include AttributeErrors related to missing sharding information and unexpected Call operations that should have been eliminated earlier in compilation.
- issues/36607, issues/36691
- Performance bottlenecks due to device-to-host memory copies: One issue highlights a significant performance bottleneck where each device-to-host memcpy operation is followed by a 16 ms stall waiting for host callback execution, severely limiting throughput and scaling with memcpy size.
- issues/36651
- XLA verifier errors triggered by specific combined operations in parallel mapping contexts: An issue reports that combining
jax.lax.top_kresults inside ajnp.wherewithin apmap(vmap(...))context causes an XLA HLO verifier error after the topk-decomposer pass, which does not occur in earlier JAX versions. - issues/36703
2.4 Closed Issues
This section lists, groups, and then summarizes issues that were closed within the last week in the repository. This section also links the associated pull requests if applicable.
Issues Closed This Week: 4
Summarized Issues:
- Use-after-free bug in asynchronous compilation: Two issues report a use-after-free bug occurring when moving a cloned ModuleOp during asynchronous compilation due to improper ownership transfer. The suggested fix involves setting
allow_in_place_mlir_modificationto true and transferring ownership toHloProgramusingOwningOpRefto prevent the cloned module from going out of scope prematurely. - issues/36516, issues/36517
- vmap failure with sharded PRNG keys in mesh context: One issue describes a bug where using vmap with
jax.random.categoricalon sharded PRNG keys and logits inside an explicit mesh context fails because the abstract mesh context is empty inside vmap. The only current workaround is replicating inputs, which is memory inefficient. - issues/36562
- Data corruption from FFI handlers with cudaMalloc in jax.lax.scan: One issue reports a bug where FFI handlers calling
cudaMallocorcudaMallocAsynccause closure-captured float64 constant buffers withinjax.lax.scanto be zeroed out due to uninitialized device buffers. This leads to data corruption under specific conditions involving wrapped FFI functions and JIT compilation on CUDA devices. - issues/36580
2.5 Issue Discussion Insights
This section will analyze the tone and sentiment of discussions within this project's open and closed issues that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.
Based on our analysis, there are no instances of toxic discussions in the project's open or closed issues from the past week.
III. Pull Requests
3.1 Open Pull Requests
This section provides a summary of pull requests that were opened in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.
Pull Requests Opened This Week: 15
Key Open Pull Requests
1. nn: add implementation='stable' to dot_product_attention: This pull request introduces a new implementation='stable' option to the dot_product_attention function in JAX that replaces variable-length softmax reductions with a fixed-shape tiled scan using an online softmax recurrence, ensuring bit-exact deterministic outputs across different key-value cache lengths and thereby fixing floating-point non-determinism that previously caused variability in model predictions.
- URL: pull/36571
2. lru_cache: skip full dir scan in _evict_if_needed when cache is under budget: This pull request optimizes the _evict_if_needed function in the lru_cache by introducing a fast-path that skips a full directory scan when the cache is under budget, thereby avoiding costly atime reads on every put() call, and also adds error handling for missing atime files to prevent crashes during eviction.
- URL: pull/36708
3. Add explicit sharding support to cuDNN attention abstract eval: This pull request adds explicit sharding support to the cuDNN attention abstract evaluation in the JAX project.
- URL: pull/36539
- Associated Commits: 768a9
Other Open Pull Requests
- ROCm and Cross-Platform Compatibility: Multiple pull requests improve ROCm support and cross-platform functionality by fixing lowering bugs, aligning LSTM execution with cuDNN-style weight packing, and updating plugin download processes to simplify CI setup. These changes ensure better handling of ROCm-specific cases and improve performance and usability across different platforms.
- Performance Improvements in Linear Algebra: A pull request replaces the cusolver ormqr implementation with a faster pure JAX block Householder algorithm, achieving up to 10x speedups while maintaining accuracy. This includes detailed tuning and benchmarking against existing CUDA kernels.
- Bug Fixes and Correctness Enhancements: Several pull requests fix bugs such as preventing system hangs on multiple devices, correcting the cuDNN attention backward batcher for unbatched residuals, and fixing docstring errors for
jax.lax.clz. These fixes improve stability and correctness in various JAX components.
- Error Handling and Type Checking Improvements: Pull requests introduce a new
ProcessFailureErrorfor better failure identification inlive_devices, ensure cache verification errors are always raised, and replace broadtype: ignoredirectives with specific Pyrefly suppressions to improve type checking precision.
- CI Security Enhancements: One pull request improves CI workflow security by adding guard conditions to prevent self-hosted runner jobs from running on pull requests from forked repositories, restricting these jobs to PRs from the same repository only.
3.2 Closed Pull Requests
This section provides a summary of pull requests that were closed in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.
Pull Requests Closed This Week: 27
Key Closed Pull Requests
1. Add complex-input support to jax.scipy.special.gamma: This pull request adds support for complex-valued inputs to the jax.scipy.special.gamma function by implementing a Lanczos approximation-based complex gamma function with reflection formula handling, safe masking to prevent NaN gradient contamination, and comprehensive tests to ensure parity with SciPy’s behavior, including correct handling of poles and branch cuts.
- URL: pull/36521
2. [ROCm] fix: wire up clone_main_xla for ROCm builds and tests: This pull request fixes the missing integration of the clone_main_xla workflow input in the ROCm CI pipeline by adding necessary environment variable mappings, sourcing required environment scripts, creating a dedicated build script for ROCm artifacts, and switching the Bazel ROCm CI to use the updated jax-dev container with dynamic digest resolution to ensure builds and tests correctly use the latest XLA code.
- URL: pull/36355
- Associated Commits: 661d3
- Associated Commits: 661d3
3. Fix/remove JAX-Toolbox links: This pull request fixes and removes outdated JAX-Toolbox links, updating one that was relocated and removing another that became obsolete following the retirement of paxml-related files.
- URL: pull/36360
- Associated Commits: 845b1
- Associated Commits: 845b1
Other Closed Pull Requests
- Type specificity improvements: Multiple pull requests improve type clarity by refining type aliases and adding missing type annotations. These changes replace broad
Anyaliases with more specific types and add forgotten annotations to enhance code safety and readability.
[pull/36512, pull/36601]
- Pyrefly dependency updates: Several pull requests upgrade the Pyrefly library to newer versions and fix related errors. These updates include version bumps to 0.59.1, 0.60.0, and 0.60.2, along with fixes or silencing of errors related to
jnp.where.
[pull/36549, pull/36559, pull/36694]
- Array dtype conversion deprecation: Pull requests introduce a named deprecation for array dtype conversion and later update it to raise a TypeError. This prepares the codebase for upcoming NumPy 2.4 requirements by aligning behavior with future NumPy errors.
[pull/36535, pull/36596]
- Bazel ROCm CI and wheel management: Updates to the Bazel ROCm continuous integration workflows improve wheel usage by downloading from S3 or PyPi, adjusting permissions, and prioritizing these wheels. Additionally, auditwheel validation is updated to accept
manylinux_2_28wheels to fix ROCm plugin compatibility without changing build environments.
[pull/36522, pull/36621]
- JAX compilation and constants handling: A pull request fixes Ahead-Of-Time (AOT) compilation issues by extending simplified handling of closed-over constants to cover dead inputs. This fix only applies when a specific JAX flag is enabled, improving compilation correctness.
[pull/36525]
- API refactoring for IR constants and types: The
mlir.ir_constantandaval_to_ir_typeAPIs are refactored to separate singular and plural forms, returning specific types instead of unions. This change enhances type checking and allows explicit declaration of expected return types at call sites.
[pull/36542]
- Documentation improvements: Multiple pull requests update documentation, including autodiff functionality for the
eig()function and adding references tojax.lax.emptyin thejax.numpy.emptydocstring. Another fixes rendering issues and removes incorrect links in method docs.
[pull/36551, pull/36657, pull/36631]
- Key reuse and GPU allocator enhancements: One pull request adds key reuse rules for sharding primitives to improve key management, while another adds support for a new CUDA-only virtual memory management GPU allocator. The latter enables an environment variable option and documents it as experimental.
[pull/36569, pull/36579]
- Codebase cleanup and export fixes: A pull request fixes unintentional exports of private names by relocating files and adding shims to reexport only specific names. Another removes suppressions for
unused-ignoredirectives previously needed for mypy.
[pull/36643, pull/36713]
- ROCm platform test adjustments: To address issues with the hipsparseSgtsv2 function on ROCm, tests are skipped when certain conditions cause NaN returns. A TODO notes re-enabling these tests once the underlying problem is fixed.
[pull/36644]
- Hijax module improvements: Several pull requests enhance the hijax module by adding support for effects in VJPHiPrimitive, integrating flattrees data structures into lowering paths, and adding logging functionality restricted to top-level access.
[pull/36657, pull/36681, pull/36687]
3.3 Pull Request Discussion Insights
This section will analyze the tone and sentiment of discussions within this project's open and closed pull requests that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.
Based on our analysis, there are no instances of toxic discussions in the project's open or closed pull requests from the past week.
IV. Contributors
4.1 Contributors
Active Contributors:
We consider an active contributor in this project to be any contributor who has made at least 1 commit, opened at least 1 issue, created at least 1 pull request, or made more than 2 comments in the last month.
If there are more than 10 active contributors, the list is truncated to the top 10 based on contribution metrics for better clarity.
| Contributor | Commits | Pull Requests | Issues | Comments |
|---|---|---|---|---|
| jakevdp | 29 | 9 | 0 | 13 |
| superbobry | 22 | 6 | 0 | 2 |
| mattjj | 20 | 4 | 0 | 4 |
| kanglant | 12 | 0 | 0 | 0 |
| gnecula | 10 | 1 | 0 | 0 |
| hawkinsp | 5 | 1 | 0 | 3 |
| olupton | 5 | 2 | 0 | 0 |
| yashk2810 | 5 | 0 | 0 | 2 |
| tsrw2048 | 4 | 2 | 0 | 1 |
| psanal35 | 2 | 2 | 0 | 3 |