Weekly GitHub Report for Jax: February 01, 2026 - February 08, 2026 (15:57:29)
Weekly GitHub Report for Jax
Thank you for subscribing to our weekly newsletter! Each week, we deliver a comprehensive summary of your GitHub project's latest activity right to your inbox, including an overview of your project's issues, pull requests, contributors, and commit activity.
Table of Contents
I. News
1.1 Recent Version Releases:
The current version of this repository is jax-v0.9.0.1
1.2 Version Information:
Released on February 3, 2026, JAX v0.9.0.1 is a patch update to v0.9.0 that integrates four specific pull requests from the OpenXLA repository, focusing on incremental improvements and bug fixes without introducing major new features.
Click here to view the full release notes!
II. Issues
2.1 Top 5 Active Issues:
We consider active issues to be issues that that have been commented on most frequently within the last week. Bot comments are omitted.
-
[ENHANCEMENT] Logging decisions awkward to control: This issue discusses the difficulty in controlling logging behavior within the JAX library, particularly how JAX adds its own logging handler that outputs logs to STDERR, which conflicts with application-level logging configurations. The user seeks clarification on the design decisions behind JAX's logging setup and expresses interest in contributing a fix to allow more conventional and flexible logging control.
- The comments reveal frustration with the current logging design, sharing a complex monkey-patch workaround to suppress unwanted STDERR logs, and discussing potential improvements such as conditionally skipping handler setup when the logging level is NOTSET; maintainers acknowledge the issue and are open to a PR that simplifies and improves logging behavior.
- Number of comments this week: 5
-
[ENHANCEMENT] Make shard_map in_specs/out_specs definitions more flexible: This issue proposes making the shard_map's in_specs and out_specs definitions more flexible by allowing partition specifications that are agnostic to mesh axis naming conventions, enabling users to define shard mappings based on positional indices or a special sentinel value to represent all mesh axes. The goal is to improve library reusability and API ergonomics by supporting numerical partition specs like P(0) for the first mesh axis and P(-1) to shard across all mesh axes, with clear semantics and error handling for various mesh and input configurations.
- The comments discuss the usability challenges of tying PartitionSpec to mesh axis names and explore the proposed semantics of positional specs and the sentinel -1, clarifying that integers refer to mesh axes by position and -1 expands to all unmentioned axes; they agree on error cases and the idea of implementing this as a preprocessing step, with offers to help formalize the specification and develop tests or prototypes.
- Number of comments this week: 4
-
[ENHANCEMENT] Support jax.lax.ragged_all_to_all on XLA:CPU for testing purposes: This issue requests support for the
jax.lax.ragged_all_to_alloperation on CPU devices within the XLA backend to enable testing and development of Mixture of Experts (MoE) models without requiring TPU hardware. The user highlights that CPU support would facilitate fast iteration and correctness testing by allowing emulation of larger device meshes through configurable CPU devices, which is currently unsupported and results in a runtime error.- The comments discuss the rationale for CPU support focusing on correctness and testing rather than performance, consider possible fallback implementations or improved error messaging, and offer to help create minimal reproducer tests and documentation to guide development and clarify device support.
- Number of comments this week: 3
-
[BUG] [XLA] bool to int32 conversion broken / inconsistent: This issue reports a problem with the conversion of boolean arrays to int32 in JAX, where the conversion appears to incorrectly interpret the underlying memory bytes, leading to inconsistent results. The user provides a minimal code example demonstrating that converting a numpy boolean array to a JAX array and then to int32 yields unexpected values, suggesting a bug in how JAX handles this conversion.
- The comments confirm that the root cause lies upstream in the XLA compiler, specifically with its implementation of the StableHLO convert operation, and the issue has been forwarded to the XLA project for resolution.
- Number of comments this week: 2
-
[BUG] 'BatchTrace' object has no attribute 'cur_qdd': This issue reports an AttributeError encountered when using
jax.vmapwith a customBox.getmethod, where the 'BatchTrace' object lacks the attribute 'cur_qdd'. The user is seeking guidance on the correct approach to implement batching rules for this method, noting that a simplistic patch seems to work but questioning if there is an official recommended solution.- The comments discuss whether the issue should be closed, with one participant asserting it seems like a bug, and the original poster clarifying they were initially unsure if it was a bug or a knowledge gap, leading them to move the discussion to a different forum before deciding to reopen the issue.
- Number of comments this week: 2
2.2 Top 5 Stale Issues:
We consider stale issues to be issues that has had no activity within the last 30 days. The team should work together to get these issues resolved and closed as soon as possible.
-
[QUESTION] Unexpected host-to-device transfer when slicing: This issue reports that slicing an on-device array in JAX unexpectedly triggers a host-to-device data transfer, which is disallowed under certain transfer guard settings and results in an error. The user provides a minimal code example and traceback demonstrating the problem, highlighting that this behavior may be related to but distinct from previously reported issues.
-
[ENHANCEMENT] Feature Request: Direct ONNX Exporter for JAX: This issue proposes the development of a direct exporter to convert JAX models to the ONNX format, aiming to simplify the current multi-step conversion process that involves TensorFlow as an intermediate. It highlights the challenges of mapping JAX's dynamic representations and operations to ONNX's static graph structure, and suggests community collaboration to explore feasible approaches for improving model interoperability and deployment efficiency.
-
[BUG]
linalg.lstsqincorrectly returns nans: This issue reports that the functionjax.numpy.linalg.lstsqsometimes returns NaN values incorrectly when solving least squares problems, as demonstrated with a zero matrix example where the output contains NaNs instead of zeros. The issue contrasts this behavior with the correct results obtained using the pseudoinverse functionjax.numpy.linalg.pinv, highlighting a potential bug in thelstsqimplementation. -
[BUG] Documentation missing for
drop_fieldsargument oftree_util.register_dataclass: This issue highlights the absence of documentation for thedrop_fieldsargument in thetree_util.register_dataclassfunction within the JAX library. It points out the need for clear explanations to help users understand how to properly use this argument when working with dataclasses in JAX. -
[BUG]
check_tracer_leakssometimes raisesIndexErrorwhen parent is a cell from a closure.: This issue describes a problem where thecheck_tracer_leaksfunction in JAX sometimes raises anIndexErrorinstead of the expected exception about a leaked tracer, caused by theparents(parent)call returning an empty list when the parent is a cell from a closure. The reporter provides a minimal example demonstrating this behavior, along with debugging output and notes that further investigation is complicated by interactions between the debugger and the garbage collector.
2.3 Open Issues
This section lists, groups, and then summarizes issues that were created within the last week in the repository.
Issues Opened This Week: 11
Summarized Issues:
- Performance issues with data types and hardware: Several issues report performance degradations related to specific data types and hardware configurations. The "solve" function is extremely slow with bfloat16 on CUDA GPUs due to sensitivity to ill-conditioned samples, and argsort performs inconsistently worse with 64-bit integers compared to 32-bit, suggesting a need for optimized lowering strategies.
- issues/34748, issues/34912
- Attention and GPU optimization limitations: There is a request to support a head dimension of 256 in
jax.nn.dot_product_attentionoptimized for A100 GPUs, as the current cuDNN flash attention backend lacks this support while PyTorch's Flash Attention2 does. This highlights gaps in GPU-specific optimizations for attention mechanisms. - issues/34750
- Data type conversion and handling errors: Converting JAX numpy boolean arrays to int32 results in incorrect values due to improper memory handling, and treating strings as trivial pytrees is proposed to avoid TypeErrors during JIT tracing. These issues indicate challenges in consistent and correct data type conversions within JAX.
- issues/34751, issues/34873
- Sharding and partitioning flexibility and correctness: Proposals and bug reports focus on improving shard_map specs to be mesh-name-agnostic for better reusability, and fixing
custom_partitioningcallbacks that incorrectly convert mesh axis types toAxisType.Auto, causing sharding mismatches. These issues affect distributed computation correctness and usability. - issues/34752, issues/34811
- Support for CPU operations and testing: There is a request to add support for
jax.lax.ragged_all_to_allon XLA:CPU to enable testing and development of Mixture of Experts models on CPU devices, as current lack of support limits CPU-based multi-device simulation. - issues/34755
- Batching and tracing errors: An AttributeError occurs when using
jax.vmapwith a customBox.getmethod due to a missing 'cur_qdd' attribute in the 'BatchTrace' object, indicating difficulties in implementing correct batching rules for custom types. - issues/34758
- Compiler and TPU kernel failures: The Mosaic compiler fails to compile a TPU kernel on TPU v6e-8 due to an invalid vector register cast from i16 to f16, causing a JaxRuntimeError during bitcast operations. This points to hardware-specific compilation issues.
- issues/34886
- Documentation and performance tuning: Documentation updates are requested to emphasize the impact of CPU count per task on multi-GPU communication performance, especially for low-data all-to-all operations, based on observed improvements with increased CPU allocation.
- issues/34896
2.4 Closed Issues
This section lists, groups, and then summarizes issues that were closed within the last week in the repository. This section also links the associated pull requests if applicable.
Issues Closed This Week: 7
Summarized Issues:
- Indexing and Masking Behavior: This issue highlights a silent failure when indexing and updating a JAX array using a tuple of boolean values, where the operation produces no error but leaves the array unchanged. The user seeks clarification on the expected behavior and alternatives for using static boolean masks within JIT-compiled functions.
- issues/34127
- Type and Shape Handling Errors: An AttributeError occurs when using
jax.make_jaxprwithreturn_shape=Trueon functions returning a HiType, due to the returned TupTy object lacking a 'shape' attribute. This indicates a gap in handling shape information for certain return types in JAX's internal representation. - issues/34193
- Performance and Training Enhancements: A comprehensive pull request proposes a universal high-performance training suite for JAX, introducing cultural-aware precision adjustments, cost-efficient FP8 quantization, a novel SOAP-Lite optimizer, and automated sharding diagnostics. These improvements aim to enhance memory efficiency, training stability, and convergence speed for large language models and scientific AI.
- issues/34773
- Sharding and Compilation Errors: Several issues describe problems related to sharding and compilation:
lax.pcastfails with an unbound axis name error insidejax.jitbecause it is intended forjax.shard_map, and explicit sharding introduces unimplemented reshards causing lowering errors in Mosaic GPU kernels when using instructions likejnp.where. These highlight challenges in correct sharding usage and kernel lowering in JAX. - issues/34801, issues/34808
- Error Handling Improvements: The replacement of
assert Falsestatements with properValueErrorexceptions in the deserialization code improves error handling by preventing silent failures when assertions are disabled and providing more semantic error messages. This change enhances robustness and debuggability during export serialization. - issues/34828
- JIT Compilation Output Changes: Unexpected behavior in JAX 0.9.0's JIT compiler causes a nested function for array chunking to output a JitTracer object instead of a concrete array as in version 0.8.2. This change affects performance and functionality, particularly for geometric applications relying on concrete array outputs.
- issues/34866
2.5 Issue Discussion Insights
This section will analyze the tone and sentiment of discussions within this project's open and closed issues that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.
Based on our analysis, there are no instances of toxic discussions in the project's open or closed issues from the past week.
III. Pull Requests
3.1 Open Pull Requests
This section provides a summary of pull requests that were opened in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.
Pull Requests Opened This Week: 25
Key Open Pull Requests
1. Document that buffer donation is disabled under debug_nans mode: This pull request documents that buffer donation is disabled when the debug_nans mode is enabled by updating the jax_debug_nans config option help text and adding a new "Limitations" section to the buffer_donation.md documentation explaining the reasons for this behavior, its impact on compiled functions, how to restore buffer donation by clearing the function cache, and providing example code demonstrating the behavior.
- URL: pull/34789
2. Fix IndexError in raise_if_error when receiving invalid error_code: This pull request fixes an IndexError in the raise_if_error function by adding bounds checking and validation to handle invalid error_code values received from corrupted AOT serialization data, ensuring safer error handling and preventing out-of-bounds access in the _error_list.
- URL: pull/34788
3. [ROCm] Add ROCm backward compatibility test for lu_pivots_to_permutation: This pull request adds a backward compatibility test for the lu_pivots_to_permutation operation on the ROCm platform, ensuring that serialized MLIR modules from older JAX versions can still be correctly loaded and executed, and includes new test data along with removing the operation from the ignored targets list to enable coverage verification.
- URL: pull/34870
Other Open Pull Requests
- ROCm GPU Support Enhancements: Multiple pull requests enable and improve ROCm GPU support across various tests and functionalities, including deviceless AOT compile tests, LOBPCG eigenvalue solver tests, lax backend scipy sparse linear solver tests, and memory space export tests. These changes ensure that tests previously skipped on ROCm or all GPUs now run correctly on ROCm hardware, improving platform compatibility and test coverage.
- [pull/34893, pull/34768, pull/34774, pull/34802, pull/34884]
- Backward Compatibility Tests for ROCm: Several pull requests add or restore ROCm-specific backward compatibility test data and update unit tests for key linear algebra functions such as LU decomposition, GEQRF, hipsolver_gesvd, cholesky solver, and the threefry2x32 algorithm. These updates ensure stability and correctness of ROCm custom calls and MLIR modules across JAX versions without manual test exclusions.
- [pull/34822, pull/34829, pull/34862, pull/34869, pull/34875, pull/34894]
- Boolean Conversion Fix: A pull request fixes incorrect conversion of non-canonical boolean values to signed integers by canonicalizing booleans using
hlo.select(). This change includes updated documentation and comprehensive tests to verify correct behavior across conversion paths and JIT compilation. - [pull/34763]
- Convolution Function Docstring Improvements: One pull request improves the docstrings of JAX convolution functions by clarifying default dimension orderings and differences such as TensorFlow-style ordering in
conv_transpose(). These clarifications help reduce user confusion without changing any functionality. - [pull/34764]
jax.numpy.padDictionary Support: A pull request extends thejax.numpy.padfunction to accept apad_widthargument specified as a dictionary, aligning its behavior with NumPy.- [pull/34782]
jax.jitType Hint Fix: One pull request improves the type hints for thejax.jitfunction to resolve a type checker issue discussed in issue #34697, implementing a minimal change.- [pull/34783]
- Pyrefly Configuration and Error Reduction: A pull request introduces configuration changes and local fixes for pyrefly, significantly reducing the number of errors from 3151 to 1673 by ignoring many files while making notable progress.
- [pull/34794]
- Memory Space Validation Loosening: One pull request loosens the default memory space validation rule to allow operations with mixed host and device memory spaces to proceed with a warning instead of an error. This enables memory optimization patterns involving mixed memory kinds such as host-offloaded optimizer states and device parameters.
- [pull/34825]
- Test Efficiency Improvement for Single Device: A pull request improves testing efficiency by skipping multi-device tests early when only a single device is available, facilitating easier full test suite runs on single-GPU systems.
- [pull/34884]
- Numerical Precision Issue Workaround on ROCm: One pull request addresses numerical precision issues in the hipSolver library on ROCm devices by adding a decorator to skip the failing
testEighTinyNormtest intests/linalg_test.py. - [pull/34902]
- Duplicate Global Static Registry Fix: A pull request fixes the issue of duplicate global static registries in JAX by removing the use of the internal XLA
ffi_api.hheader. - [pull/34908]
- Ruff Unused-Variable Rule Enablement: One pull request enables the unused-variable rule in Ruff's configuration by inspecting all existing unused variables and either removing them or prefixing their names with an underscore.
- [pull/34911]
3.2 Closed Pull Requests
This section provides a summary of pull requests that were closed in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.
Pull Requests Closed This Week: 20
Key Closed Pull Requests
1. [ROCm] Add support to dynamically set test deps for jax as an external repo: This pull request adds the capability to dynamically inject test dependencies when including JAX as an external repository, enabling projects like ROCm to initialize and manage their specific test dependencies cleanly without modifying the core JAX codebase.
- URL: pull/34641
- Associated Commits: 1f4c3, c6ad6, 2b34a, f7894, a8b28, 87707, e615e, d02fc, a50a4, 0a2f4, b2be9, 6cb71, 40462, a58e5, d583f, 4441a, d1519, d2154, 0c856
- Associated Commits: 1f4c3, c6ad6, 2b34a, f7894, a8b28, 87707, e615e, d02fc, a50a4, 0a2f4, b2be9, 6cb71, 40462, a58e5, d583f, 4441a, d1519, d2154, 0c856
2. [ROCm] Upstream ROCm CI Nightly Wheel Testing: This pull request adds a comprehensive CI/CD infrastructure for automated nightly testing of JAX with ROCm support, including a reusable ROCm pytest workflow, multi-GPU and multi-Python version test matrices, robust wheel downloading actions, environment configuration, and integration into existing nightly release workflows to ensure thorough validation of JAX on AMD GPU platforms.
- URL: pull/34450
3. [ROCm] Add support of building jax under umbrella workspace: This pull request aims to add support for building JAX targets and executing JAX unit tests within an umbrella workspace, facilitating the creation of jaxlib and simplifying the testing and integration of ROCm plugins with JAX in projects that depend on JAX.
- URL: pull/34462
Other Closed Pull Requests
- ROCm JAX Plugin Versioning and Integration: These pull requests relax strict version requirements for ROCm JAX Plugin wheels and add support for building JAX targets and running tests within a parent workspace using Bazel. They enable easier integration and testing of ROCm plugins with JAX while allowing users to override version checks if needed.
[pull/32115, pull/34770]
- Pallas and Mosaic Backend Enhancements: Multiple pull requests improve the Pallas module by moving
jax.experimental.pallas.dstojax.dsand adding dynamic slicing support injax.numpyindexing. Additionally, initial ROCm platform support is introduced in the Mosaic backend with vendor-agnostic code changes and a new reshard rule for the mosaic GPU.
[pull/34734, pull/34759, pull/34809]
- Dynamic and Static Slice Refactoring: A pull request refactors
to_static_sliceandto_dynamic_slicefunctions by separating indexing recipe generation from execution to improve safety, efficiency, and code reuse. This change also prepares for future integration with pallas/ref indexing by removing redundant logic.
[pull/34776]
- Test and Architecture Support Adjustments: One pull request enables previously skipped tcgen05 tests by disabling NVVM verification for certain architectures and adding targeted test skips for unsupported cases on sm_103a. Another pull request titled "Test" was closed without merging and contains no commits.
[pull/34682, pull/34780]
- Error Handling and Warnings Improvements: Several pull requests address error handling by fixing an AttributeError in
jax.make_jaxprwith HiType outputs, replacingassert FalsewithValueErrorin serialization code, and adding a DeprecationWarning for unsafe integer overflow inlax.convert_element_type.
[pull/34791, pull/34827, pull/34797]
- MLIR Type and Export Schema Enhancements: One pull request proposes using the built-in
isinstancefunction with MLIR types for better type checking, while another adds a file identifier and extension to the export schema to improve file handling during export processes.
[pull/28973, pull/34075]
- Code Quality and Documentation Fixes: Pull requests fix duplicate word typos in documentation and resolve linting issues to maintain code quality at the current repository head.
[pull/34790, pull/34798]
- Mosaic GPU Collective Communication Fix: A pull request fixes the passing of correct device groups when requesting a collective clique in the mosaic GPU module to prevent inefficient resource allocation by NCCL.
[pull/34848]
3.3 Pull Request Discussion Insights
This section will analyze the tone and sentiment of discussions within this project's open and closed pull requests that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.
Based on our analysis, there are no instances of toxic discussions in the project's open or closed pull requests from the past week.
IV. Contributors
4.1 Contributors
Active Contributors:
We consider an active contributor in this project to be any contributor who has made at least 1 commit, opened at least 1 issue, created at least 1 pull request, or made more than 2 comments in the last month.
If there are more than 10 active contributors, the list is truncated to the top 10 based on contribution metrics for better clarity.
| Contributor | Commits | Pull Requests | Issues | Comments |
|---|---|---|---|---|
| jakevdp | 28 | 25 | 0 | 104 |
| alekstheod | 29 | 4 | 0 | 24 |
| JehandadKhan | 2 | 0 | 46 | 9 |
| hrideymarwah15 | 25 | 11 | 1 | 3 |
| samanklesaria | 27 | 3 | 0 | 1 |
| magaonka-amd | 14 | 12 | 0 | 4 |
| mattjj | 13 | 10 | 0 | 7 |
| tsrw2048 | 12 | 10 | 0 | 7 |
| abdulwahabahmedkhanyusufzai | 19 | 1 | 0 | 8 |
| AratiGanesh | 12 | 11 | 0 | 1 |