加载中...
Skip to content

Conversation

@sputti-czi
Copy link

@sputti-czi sputti-czi commented Oct 19, 2025

Diffusion Bottlenecks:

This commit addresses a major performance bottleneck in the diffusion function, which was iterating over each graph in the batch (B) using a Python for loop. This approach failed to leverage the GPU's parallelism and incurred significant overhead by launching many small, sequential operations.

The Chebyshev approximation path (else block in diffusion) has been refactored to process the entire batch at once. This is achieved through a new "unpad-batch-diffuse-repad" strategy:

Unpad: The padded embedding tensor E (shape [B, S, H, D]) is converted into an unpadded tensor unpadded_E (shape [total_nodes, H, D]) using a boolean mask derived from the num_nodes_list.

Batch: All edge_index_list tensors are collated into a single, large, disconnected graph using manual offset-based batching (avoiding PyG Batch.from_data_list overhead). Each graph's edge indices are offset by the cumulative sum of previous node counts, then concatenated into a single batched_edge_index.

Diffuse: The function chebyshev_diffusion_per_sample has been renamed to _chebyshev_diffusion_batch and is now called only once on the entire batched graph and the unpadded_E tensor.

Repad: The resulting diffused_unpadded_E tensor is "scattered" back into a new, zeroed tensor of the original [B, S, H, D] shape using the same boolean mask.

Additional Optimizations:.

CSR Sparse Format: The Laplacian matrix is converted from COO (Coordinate) to CSR (Compressed Sparse Row) format, providing 2-5x speedup in sparse matrix multiplications due to better memory access patterns and cache locality.
Chebyshev Coefficient Caching: Coefficients are computed once and cached per (K, beta, device) combination, eliminating redundant computation across forward passes.

Separated Recurrence Loop: The Chebyshev recurrence is extracted into a separate function for potential torch.compile optimization (currently disabled but infrastructure in place).

This change replaces B small chebyshev_diffusion calls with a single, large batched operation, allowing for massive parallelization on the GPU.

Impact: This optimization provides a ~ 7x speedup on H100 GPUs and ~1.4x speedup on L40 GPUs compared to the baseline. The performance difference between GPU architectures is due to H100's superior sparse tensor cores, better CSR format support, and higher memory bandwidth.

GPU utilization is fluctuating between 75-88% before parallelization and constantly maxed out ~99-MAX after parallelization.

Note: Please make sure the functionality is still relevant.

Inference Bottlenecks:

Dataset Batch Size Seq Length Old Time New Time Speedup
Small (overhead) 16 128 1.15 ms 0.34 ms 3.40x faster ⚡
Typical inference 64 512 52.77 ms 0.67 ms 78.31x faster 🚀
Large batch 256 1,024 436.45 ms 1.60 ms 273.61x faster 🚀
Very large 512 2,048 2,330.72 ms 6.54 ms 356.14x faster 🚀
Production scale 1,024 4,096 11,881.96 ms 37.92 ms 313.33x faster 🚀

Optimized get_gene_embeddings() by replacing nested Python loops with GPU-vectorized grouping. The new version uses torch.unique with return_inverse=True, index_add_ for segment summation, and torch.bincount for counts all on GPU then transfers results to CPU once. This reduces transfers from O(tokens) to O(unique genes) and leverages GPU parallelism. Fixes performance bottlenecks in gene embedding aggregation, achieving 78-356x speedup on production workloads (256+ batch sizes), with production-scale inference (1024×4096) improving from 11.9 seconds to 38 milliseconds (~313x faster), while maintaining numerical equivalence within floating-point precision.

@sputti-czi sputti-czi requested a review from mingkz October 20, 2025 20:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants