Weekly Project News

Archives

Weekly GitHub Report for Jax: March 30, 2026 - April 06, 2026 (18:22:34)

Weekly GitHub Report for Jax

Thank you for subscribing to our weekly newsletter! Each week, we deliver a comprehensive summary of your GitHub project's latest activity right to your inbox, including an overview of your project's issues, pull requests, contributors, and commit activity.


Table of Contents

  • I. News
    • 1.1. Recent Version Releases
    • 1.2. Other Noteworthy Updates
  • II. Issues
    • 2.1. Top 5 Active Issues
    • 2.2. Top 5 Stale Issues
    • 2.3. Open Issues
    • 2.4. Closed Issues
    • 2.5. Issue Discussion Insights
  • III. Pull Requests
    • 3.1. Open Pull Requests
    • 3.2. Closed Pull Requests
    • 3.3. Pull Request Discussion Insights
  • IV. Contributors
    • 4.1. Contributors

I. News

1.1 Recent Version Releases:

The current version of this repository is jax-v0.9.0.1

1.2 Version Information:

On February 3, 2026, JAX v0.9.0.1 was released as a patch update to v0.9.0, incorporating four specific pull requests from the OpenXLA repository to address targeted improvements without introducing major changes. This release highlights a focus on incremental fixes and refinements.

Click here to view the full release notes!

II. Issues

2.1 Top 5 Active Issues:

We consider active issues to be issues that that have been commented on most frequently within the last week. Bot comments are omitted.

  1. [BUG] Concatenating a primal and tangent: This issue discusses a problem in JAX where concatenating primal and tangent values inside a custom_jvp function leads to an uninformative error, complicating the use of vectorized operations that combine these two types of values. The user requests either support for such concatenation to simplify code and improve performance or at least a clearer error message explaining why this operation fails, as the current error message is vague and does not clarify the underlying autodiff limitations.

    • The comments clarify that concatenating primal and tangent values is problematic due to JAX's autodiff heuristic that operations on tangent values yield tangent values, making nonlinear operations on tangents unsupported; a minimal reproducible example is provided, and while a better error message is agreed upon as helpful, fundamentally changing JAX to support this use case would require significant redesign and added complexity.
    • Number of comments this week: 11
  2. [BUG] JIT segfault: conv + reshape + matmul backward on gfx1100 (ROCm 7.2): This issue reports a segmentation fault occurring during the backward pass compilation of a JIT-compiled function in JAX that involves a convolution followed by a reshape and a matrix multiplication on AMD Radeon RX 7900 XTX GPUs using ROCm 7.2. The crash happens specifically during XLA compilation of the gradient computation and does not occur in eager mode or forward-only JIT, indicating a problem with the fusion of operations in the backward pass on the gfx1100 architecture.

    • The comments reference related issues in the ROCm JAX repository that may be connected to this problem, providing additional context and cross-links for further investigation.
    • Number of comments this week: 2
  3. Mosaic TPU compile failure in SplashAttention: failed to legalize arith.cmpi on vector i8 compare: This issue reports a compilation failure when running the SplashAttention kernel test on a TPU using Mosaic, specifically due to an inability to legalize an 'arith.cmpi' operation comparing vector i8 types during kernel compilation. The user is seeking guidance on how to resolve this internal compiler error that arises from the JAX Mosaic backend when handling int8 vector comparisons.

    • The comment requests the user's JAX version and suggests enabling mosaic dumps for further diagnosis, hypothesizing that the error stems from an int8 comparison operation and recommending tracing and possibly casting the comparison to int32 to avoid the compilation failure.
    • Number of comments this week: 1
  4. [BUG] The RegularGridInterpolator requires that the points value be strictly increasing, not merely monotonic, not specified in docs and differs from scipy: This issue highlights that the RegularGridInterpolator function in JAX requires its input points to be strictly increasing rather than just monotonic, which is a deviation from SciPy's implementation and is not documented. The user reports that reversing the order of axis values causes the interpolator to return NaNs, and suggests either updating the documentation or modifying the function to handle monotonic inputs, offering to contribute a fix.

    • A single comment expresses interest in working on the issue and requests to be assigned to it.
    • Number of comments this week: 1

Since there were fewer than 5 open issues, all of the open issues have been listed above.

2.2 Top 5 Stale Issues:

We consider stale issues to be issues that has had no activity within the last 30 days. The team should work together to get these issues resolved and closed as soon as possible.

As of our latest update, there are no stale issues for the project this week.

2.3 Open Issues

This section lists, groups, and then summarizes issues that were created within the last week in the repository.

Issues Opened This Week: 7

Summarized Issues:

  • Compilation and Runtime Failures on Specialized Hardware: Several issues describe failures related to compilation or runtime on specific hardware platforms. These include a TPU compilation failure due to the Mosaic compiler's inability to legalize an arithmetic comparison on 8-bit integer vectors, and a segmentation fault during backward pass compilation on AMD Radeon RX 7900 XTX GPUs triggered by XLA's fusion of operations on the gfx1100 architecture.
  • issues/36324, issues/36490
  • Ownership and Use-After-Free Bugs in Asynchronous Compilation: Two issues report use-after-free bugs caused by improper ownership transfer when moving cloned ModuleOps during asynchronous compilation. Both suggest enabling allow_in_place_mlir_modification and transferring ownership with OwningOpRef to prevent premature module scope loss.
  • issues/36516, issues/36517
  • Performance and Dispatch Strategy Concerns: One issue highlights unexpected performance degradation when using vmap for multiplying tall skinny matrices on NVIDIA A100 GPUs, hypothesizing that CuBLAS's cublasDgemmStridedBatched may be suboptimal. It suggests that XLA could benefit from heuristics to select alternative dispatch strategies or batched GEMM implementations.
  • issues/36374
  • Functionality and Documentation Gaps in Interpolator Behavior: An issue reports that JAX's RegularGridInterpolator requires strictly increasing input points rather than merely monotonic, differing from SciPy's behavior and causing NaNs when axis values are reversed. This undocumented requirement suggests a need for internal handling improvements or documentation updates.
  • issues/36499
  • Enhancements for Fault-Tolerant Training Diagnostics: One issue proposes improving the live_devices function to include failed process IDs in exceptions, enabling users to better identify and respond to specific process failures during fault-tolerant training loops.
  • issues/36372

2.4 Closed Issues

This section lists, groups, and then summarizes issues that were closed within the last week in the repository. This section also links the associated pull requests if applicable.

Issues Closed This Week: 4

Summarized Issues:

  • TPU-specific numerical errors in attention kernels: The jax.nn.dot_product_attention function produces silently incorrect numerical results on TPU v6e hardware across all tested data types, shapes, and implementation modes, causing models to train or infer incorrectly without any errors or warnings. The root cause is a fused attention kernel in the XLA compilation path that fails specifically on this TPU version.
  • issues/35916
  • GEMM fusion crashes due to dimension mismatches: JAX crashes when the NestGemmFusion pass fails because of a symbolic map dimension mismatch (3 vs. 4) during fusion of Dense GEMMs with 4D einsum operations involving reshapes. This issue is triggered by models using multiple Dense layers and einsum contractions, with a known workaround being to disable Triton GEMM fusion.
  • issues/36095
  • Errors with varying inputs in jnp.full_like and shard_map: Using jnp.full_like combined with shard_map fails when both the input array and fill value are varying, causing an error due to an incorrect attempt to re-cast the result as varying. The issue might be resolved by removing the final cast operation.
  • issues/36365
  • Runtime errors in jax.numpy.diff with append on sharded arrays: Calling jax.numpy.diff with the append parameter on sharded arrays causes a runtime error due to an internal NCCL operation failure during multi-GPU execution. This indicates a problem with handling appended values in distributed array computations.
  • issues/36439

2.5 Issue Discussion Insights

This section will analyze the tone and sentiment of discussions within this project's open and closed issues that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.

Based on our analysis, there are no instances of toxic discussions in the project's open or closed issues from the past week.


III. Pull Requests

3.1 Open Pull Requests

This section provides a summary of pull requests that were opened in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.

Pull Requests Opened This Week: 15

Key Open Pull Requests

1. bisect an oss bazel cuda failure: This pull request aims to bisect and identify the cause of a Bazel CUDA failure in the open-source JAX project by testing various commits and adjusting the checkout process to isolate the problematic change.

  • URL: pull/36348
  • Associated Commits: 425f2, 216c1, 242c0, a489e, cbb82, e4a81, 386b6, a7aab, d539c, b73c2

2. Add distributed MNIST training example for Kubernetes: This pull request adds a comprehensive end-to-end example demonstrating distributed MNIST training using JAX on TPUs within a Kubernetes environment, featuring data-parallel training with gradient averaging via jax.lax.pmean, zero-dependency data loading from GCS, and orchestration through a Kubernetes JobSet, along with supporting files such as a Dockerfile, RBAC configurations, and detailed setup documentation.

  • URL: pull/36523
  • Associated Commits: 1aa76, 1174f, 7caa1, 0c239, 23c5e

3. Add examples/pca_from_scratch.py demonstrating SVD, jit, and vmap: This pull request adds a standalone example script implementing Principal Component Analysis (PCA) from scratch in JAX, showcasing efficient use of jnp.linalg.svd for matrix decomposition, jax.jit with static argument handling to compile the PCA function, and jax.vmap for vectorizing batch processing, along with structured synthetic data generation and optional plotting to demonstrate variance capture.

  • URL: pull/36381
  • Associated Commits: 3f7f1, df2d9

Other Open Pull Requests

  • Profiling and Logging Enhancements: These pull requests improve profiling and internal logging capabilities in JAX. One enables profiling logs to be generated directly from the command line using the JAX_PROFILE environment variable, while the other adds internal logging within the hijax module to facilitate value extraction from JAX functions, especially those involving scan operations on the backward pass.
    • pull/36398, pull/36494
  • Hardware and CI Pipeline Fixes: These pull requests address hardware-specific issues and improve continuous integration workflows. One fixes timeout errors on H100 hardware by bisecting and resolving conditional logic issues, while others fix missing integration in the ROCm CI pipeline and update Bazel ROCm workflows to properly use JAX and Jaxlib wheels with improved environment setup and build scripts.
    • pull/36501, pull/36355, pull/36522
  • Mathematical Function and Operation Improvements: These pull requests enhance mathematical functionality and performance in JAX. One adds support for complex-valued inputs to jax.scipy.special.gamma with a Lanczos approximation and safe masking, and another introduces a new Jacobian-vector product rule for qr_multiply to improve performance and numerical stability without explicitly materializing the Q matrix.
    • pull/36521, pull/36357
  • Data Handling and Type Improvements: These pull requests fix data loss issues and improve type definitions. One resolves data loss on secondary devices during device_put operations of sharded arrays, and another updates the codebase to remove IrValues and IrTypes aliases of the Any type, enhancing type specificity and clarity.
    • pull/36524, pull/36512
  • Documentation and Link Maintenance: These pull requests update documentation and external links. One adds a new tutorial on writing high-performance GPU kernels with CuTe DSL and JAX, including a notebook and helper module, while another updates JAX-Toolbox links by fixing relocated links and removing obsolete ones after retiring paxml-related files.
    • pull/36496, pull/36360

3.2 Closed Pull Requests

This section provides a summary of pull requests that were closed in the repository over the past week. The top three pull requests with the highest number of commits are highlighted as 'key' pull requests. Other pull requests are grouped based on similar characteristics for easier analysis. Up to 25 pull requests are displayed in this section, while any remaining pull requests beyond this limit are omitted for brevity.

Pull Requests Closed This Week: 40

Key Closed Pull Requests

1. [ROCm] bring gesdd for computing SVD on ROCm: This pull request introduces the use of the gesdd algorithm for computing singular value decomposition (SVD) on ROCm platforms, replacing the previous default gesvdj with gesdd due to its significantly improved performance and lower latency across various data types and matrix sizes, as demonstrated by benchmarking on MI250 hardware with ROCm 7.1.0.

  • URL: pull/35534
  • Associated Commits: 6cc09, cc4fc, 4c744, 4af6f, 6aee7, 23f7d, ee706, 40e9c, ea749, c514f, c02b5
  • Associated Commits: 6cc09, cc4fc, 4c744, 4af6f, 6aee7, 23f7d, ee706, 40e9c, ea749, c514f, c02b5

2. Fix cuDNN attention backward pass partitioning with sharded inputs: This pull request fixes the cuDNN attention backward pass by updating the SPMD partitioning logic to ensure that the gradient output and forward output tensors have matching sharding with the query input, thereby resolving a runtime error when using sharded inputs and adding a regression test to verify the fix.

  • URL: pull/36169
  • Associated Commits: 8ff3e, d42e1
  • Associated Commits: 8ff3e, d42e1

3. [shard_map] Raise TypeError for eager new_ref tracing: This pull request aims to add a safety check in ShardMapTrace that raises a clear TypeError when stateful operations like new_ref are attempted in eager mode, replacing a previously confusing NotImplementedError and guiding users to use @jax.jit instead.

  • URL: pull/36338
  • Associated Commits: 14a24, d3c9c
  • Associated Commits: 14a24, d3c9c

Other Closed Pull Requests

  • Author and Citation Updates: This topic includes pull requests that update the author list and citation files to reflect recent changes in the JAX project. These updates ensure that the project's metadata remains current and accurate.
    [pull/36506]
  • Primitive and Function Enhancements: Multiple pull requests introduce new arguments and simplify existing functions such as adding a preferred_element_type argument to lax.mul and simplifying lax.scan by removing parameters and changing its implementation. These changes improve functionality and maintainability of core JAX primitives.
    [pull/36092, pull/36429]
  • Deprecation and Cleanup: Several pull requests focus on cleaning up deprecated features and obsolete parameters, including removing deprecated mlir.custom_call symbols and obsolete parameters like linear and split_transpose from the scan function. These efforts help maintain a clean and modern codebase.
    [pull/36330, pull/36344, pull/36346]
  • Serialization and Compatibility: This group of pull requests adds serialization version 10 and related tests to optimize export serialization by avoiding duplicate serialization of abstract meshes and shardings. They also ensure backward compatibility and prepare for forward compatibility with new serialization formats.
    [pull/36354, pull/36375]
  • Sharding Documentation and Fixes: Pull requests in this topic add new documentation for sharding and fix dangling references and documentation errors related to sharding in the JAX project. These improvements enhance the clarity and accuracy of project documentation.
    [pull/36402, pull/36409, pull/36428]
  • Bug Fixes and Debugging: This topic covers pull requests that fix bugs such as the interaction between dropvars and remat-of-shmap in shard-map functionality and restore original SVD algorithms on ROCm due to test failures. It also includes an unmerged attempt to debug ReadTheDocs build issues.
    [pull/36419, pull/36422, pull/36426]
  • Testing and Validation Improvements: Pull requests here update tests to validate collective semantics according to HLO v3 standards and add backwards compatibility tests for multiple meshes. These changes improve test robustness and compatibility with evolving standards.
    [pull/36353, pull/36362]
  • Codebase Consistency and Refactoring: This includes pull requests that replace identifiers for clarity, remove unnecessary suppressions, and update dependencies like protobuf and uv pin versions. These changes contribute to code consistency and dependency management.
    [pull/36304, pull/36311, pull/36320, pull/36330, pull/36350, pull/36373, pull/36384]
  • Cross-compilation and Auto-tuning: A pull request enables the use of a real local client with a compatible backend during cross-compilation to facilitate auto-tuning at compile time, linking to related changes in the XLA project. This enhances performance optimization workflows.
    [pull/36341]
  • CI and Lint Workflow Updates: This pull request updates the continuous integration configuration to break the lint environment cache daily, ensuring that non-pinned dependencies are refreshed regularly to pick up new versions. This helps maintain up-to-date linting environments.
    [pull/36378]
  • Feature Disabling for Testing: One pull request proposes disabling the command_buffer feature on ROCm specifically for pallas tests to address issues or improve test performance. This targeted disabling helps stabilize testing environments.
    [pull/36304]

3.3 Pull Request Discussion Insights

This section will analyze the tone and sentiment of discussions within this project's open and closed pull requests that occurred within the past week. It aims to identify potentially heated exchanges and to maintain a constructive project environment.

Based on our analysis, there are no instances of toxic discussions in the project's open or closed pull requests from the past week.


IV. Contributors

4.1 Contributors

Active Contributors:

We consider an active contributor in this project to be any contributor who has made at least 1 commit, opened at least 1 issue, created at least 1 pull request, or made more than 2 comments in the last month.

If there are more than 10 active contributors, the list is truncated to the top 10 based on contribution metrics for better clarity.

Contributor Commits Pull Requests Issues Comments
jakevdp 33 8 0 27
superbobry 22 5 0 0
mattjj 21 6 0 0
magaonka-amd 18 1 0 3
gnecula 13 3 0 5
ahmedtaha100 3 3 0 13
yashk2810 5 5 0 5
kanglant 12 2 0 0
gulsumgudukbay 12 0 0 0
mminutoli 8 1 0 2

Don't miss what's next. Subscribe to Weekly Project News:
Powered by Buttondown, the easiest way to start and grow your newsletter.