Skip to content

vllm.v1.attention.backends.turboquant_attn

TurboQuant attention backend for vLLM.

Standard scaled dot-product attention on uncompressed K/V,

then quantize K and store K+V into combined cache slot.

Decode: Compute TQ attention scores from compressed cache, unpack FP16 values, softmax + weighted sum.

Cache layout (no leading 2 dimension): (num_blocks, block_size, num_kv_heads, slot_size) where slot_size = key_packed_size + value_fp16_size

Per-head per-position slot layout

[key_packed (kps bytes) | value_fp16 (D*2 bytes)] For turboquant_k3v4_nc head_dim=256: [100 bytes key | 512 bytes value] = 612

TurboQuantAttentionBackend

Bases: AttentionBackend

Attention backend using TurboQuant KV-cache compression.

Source code in vllm/v1/attention/backends/turboquant_attn.py
class TurboQuantAttentionBackend(AttentionBackend):
    """Attention backend using TurboQuant KV-cache compression."""

    accept_output_buffer: bool = True
    forward_includes_kv_cache_update: bool = False

    supported_dtypes: ClassVar[list[torch.dtype]] = [
        torch.float16,
        torch.bfloat16,
    ]
    supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
        "turboquant_k8v4",
        "turboquant_4bit_nc",
        "turboquant_k3v4_nc",
        "turboquant_3bit_nc",
    ]

    @staticmethod
    def get_name() -> str:
        return "TURBOQUANT"

    @staticmethod
    def get_supported_kernel_block_sizes() -> list[int | MultipleOf]:
        return [16, 32, 64, 128]

    @classmethod
    def supports_attn_type(cls, attn_type: str) -> bool:
        return attn_type == AttentionType.DECODER

    @classmethod
    def supports_per_head_quant_scales(cls) -> bool:
        return False

    @staticmethod
    def get_impl_cls() -> type["TurboQuantAttentionImpl"]:
        return TurboQuantAttentionImpl

    @staticmethod
    def get_builder_cls() -> type["TurboQuantMetadataBuilder"]:
        return TurboQuantMetadataBuilder

    @staticmethod
    def get_kv_cache_shape(
        num_blocks: int,
        block_size: int,
        num_kv_heads: int,
        head_size: int,
        cache_dtype_str: str = "turboquant_4bit_nc",
    ) -> tuple[int, ...]:
        """Combined K+V cache shape — no leading 2 dimension.

        Standard attention backends use (2, num_blocks, block_size, num_kv_heads,
        head_dim) with a leading 2 to separate K and V. TurboQuant packs K+V
        into a single interleaved slot per head per position, so the cache is:

            (num_blocks, block_size, num_kv_heads, slot_size_aligned)

        Each slot = [key_packed | value_packed | padding].
        This is safe because TQ has its own get_kv_cache_shape override and
        never shares cache tensors with other backends. Layers that fall back
        to native dtype via kv_cache_dtype_skip_layers get their own
        standard-shaped cache allocation.

        head_size is the model's real head_dim. slot_size_aligned is computed
        from the TQ config to ensure correct cache allocation for all head dims.
        """
        from vllm.model_executor.layers.quantization.turboquant.config import (
            TurboQuantConfig,
        )

        tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype_str, head_size)
        return (num_blocks, block_size, num_kv_heads, tq_config.slot_size_aligned)

    @classmethod
    def supports_kv_cache_dtype(cls, kv_cache_dtype: CacheDType | None) -> bool:
        if kv_cache_dtype is None:
            return False
        return kv_cache_dtype.startswith("turboquant_")

    @classmethod
    def supports_head_size(cls, head_size: int) -> bool:
        # head_size from spec is effective_head_size (padded_slot//2),
        # not the model's actual head_dim. Accept any positive value.
        return head_size > 0

get_kv_cache_shape staticmethod

get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "turboquant_4bit_nc",
) -> tuple[int, ...]

Combined K+V cache shape — no leading 2 dimension.

Standard attention backends use (2, num_blocks, block_size, num_kv_heads, head_dim) with a leading 2 to separate K and V. TurboQuant packs K+V into a single interleaved slot per head per position, so the cache is:

(num_blocks, block_size, num_kv_heads, slot_size_aligned)

Each slot = [key_packed | value_packed | padding]. This is safe because TQ has its own get_kv_cache_shape override and never shares cache tensors with other backends. Layers that fall back to native dtype via kv_cache_dtype_skip_layers get their own standard-shaped cache allocation.

head_size is the model's real head_dim. slot_size_aligned is computed from the TQ config to ensure correct cache allocation for all head dims.

Source code in vllm/v1/attention/backends/turboquant_attn.py
@staticmethod
def get_kv_cache_shape(
    num_blocks: int,
    block_size: int,
    num_kv_heads: int,
    head_size: int,
    cache_dtype_str: str = "turboquant_4bit_nc",
) -> tuple[int, ...]:
    """Combined K+V cache shape — no leading 2 dimension.

    Standard attention backends use (2, num_blocks, block_size, num_kv_heads,
    head_dim) with a leading 2 to separate K and V. TurboQuant packs K+V
    into a single interleaved slot per head per position, so the cache is:

        (num_blocks, block_size, num_kv_heads, slot_size_aligned)

    Each slot = [key_packed | value_packed | padding].
    This is safe because TQ has its own get_kv_cache_shape override and
    never shares cache tensors with other backends. Layers that fall back
    to native dtype via kv_cache_dtype_skip_layers get their own
    standard-shaped cache allocation.

    head_size is the model's real head_dim. slot_size_aligned is computed
    from the TQ config to ensure correct cache allocation for all head dims.
    """
    from vllm.model_executor.layers.quantization.turboquant.config import (
        TurboQuantConfig,
    )

    tq_config = TurboQuantConfig.from_cache_dtype(cache_dtype_str, head_size)
    return (num_blocks, block_size, num_kv_heads, tq_config.slot_size_aligned)

TurboQuantAttentionImpl

Bases: AttentionImpl['TurboQuantMetadata']

TurboQuant attention implementation.

Vectorized PyTorch: batch quantize/store, vectorized bit-unpack decode with einsum scores and value gather.

Source code in vllm/v1/attention/backends/turboquant_attn.py
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
class TurboQuantAttentionImpl(AttentionImpl["TurboQuantMetadata"]):
    """TurboQuant attention implementation.

    Vectorized PyTorch: batch quantize/store, vectorized bit-unpack
    decode with einsum scores and value gather.
    """

    supports_quant_query_input: bool = False

    def __init__(
        self,
        num_heads: int,
        head_size: int,
        scale: float,
        num_kv_heads: int | None = None,
        alibi_slopes: list[float] | None = None,
        sliding_window: int | None = None,
        kv_cache_dtype: str = "auto",
        logits_soft_cap: float | None = None,
        attn_type: str = AttentionType.DECODER,
        kv_sharing_target_layer_name: str | None = None,
        **kwargs,
    ):
        self.num_heads = num_heads
        self.head_size = head_size
        self.scale = scale
        self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads
        self.num_kv_groups = num_heads // self.num_kv_heads
        self.kv_cache_dtype = kv_cache_dtype

        from vllm.model_executor.layers.quantization.turboquant.config import (
            TurboQuantConfig,
        )

        self.tq_config = TurboQuantConfig.from_cache_dtype(kv_cache_dtype, head_size)

        # Pre-compute kernel constants from config (avoid repeated arithmetic)
        cfg = self.tq_config
        self._mse_bytes = (
            math.ceil(head_size * cfg.key_mse_bits / 8)
            if not cfg.key_fp8
            else head_size
        )
        self._val_data_bytes = math.ceil(head_size * cfg.effective_value_quant_bits / 8)
        self._n_centroids = cfg.n_centroids if not cfg.key_fp8 else 1

        # Fixed NUM_KV_SPLITS (grid dims must be constant for cudagraph,
        # and benchmarks show no regression vs dynamic in eager mode).
        vllm_config = get_current_vllm_config()
        self.max_num_kv_splits = (
            vllm_config.attention_config.tq_max_kv_splits_for_cuda_graph
        )

    def _ensure_on_device(self, layer, device):
        """One-time derivation of TQ buffers (rotation matrices, midpoints).

        Registered buffers (_tq_signs, _tq_centroids) are already on the
        correct device via register_buffer + model.to(device).
        """
        if not hasattr(layer, "_tq_cached"):
            D = layer._tq_signs.shape[0]
            signs = layer._tq_signs.to(device=device, dtype=torch.float32)

            # WHT rotation: orthonormal + self-inverse, enabling future
            # in-kernel butterfly fusion and trivial inverse for continuation.
            H = _build_hadamard(D, str(device))
            layer._tq_PiT = (signs.unsqueeze(1) * H).contiguous()
            layer._tq_Pi = layer._tq_PiT.T.contiguous()

            c = layer._tq_centroids.to(device=device, dtype=torch.float32)
            # Precompute midpoints for threshold-based quantization
            c_sorted, _ = c.sort()
            layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2
            # Decode buffers (_tq_mid_o_buf, _tq_output_buf, _tq_lse_buf)
            # are pre-allocated via register_buffer in Attention.__init__
            # and moved to GPU by model.to(device) — no allocation needed
            # here.  The memory profiler sees them before KV cache sizing.
            layer._tq_cached = True

    def do_kv_cache_update(
        self,
        layer: torch.nn.Module,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        slot_mapping: torch.Tensor,
    ) -> None:
        """Store compressed K/V into the combined TQ cache.

        Called as a separate custom op (unified_kv_cache_update) BEFORE
        the attention forward, matching FlashAttention's split pattern.
        slot_mapping is already sliced to num_actual_tokens by the caller.
        """
        N = slot_mapping.shape[0]
        if N <= 0:
            return

        device = key.device
        self._ensure_on_device(layer, device)

        k = key[:N].view(N, self.num_kv_heads, self.head_size)
        v = value[:N].view(N, self.num_kv_heads, self.head_size)
        self._store_kv(k, v, kv_cache, slot_mapping, layer)

    def forward(
        self,
        layer: AttentionLayer,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: "TurboQuantMetadata",
        output: torch.Tensor | None = None,
        output_scale: torch.Tensor | None = None,
        output_block_scale: torch.Tensor | None = None,
    ) -> torch.Tensor:
        num_tokens = query.shape[0]

        if output is None:
            output = torch.zeros(
                num_tokens,
                self.num_heads * self.head_size,
                dtype=query.dtype,
                device=query.device,
            )

        if attn_metadata is None:
            return output.fill_(0)

        # Slice to actual tokens
        N = attn_metadata.num_actual_tokens
        if N <= 0:
            return output.fill_(0)

        q = query[:N].view(N, self.num_heads, self.head_size)

        # Get TQ buffers, ensure on device (one-time migration).
        # Use Any-typed alias for dynamic _tq_* attrs set by _ensure_on_device.
        tq_layer: Any = layer
        device = q.device
        self._ensure_on_device(tq_layer, device)
        Pi = tq_layer._tq_Pi
        PiT = tq_layer._tq_PiT
        centroids = tq_layer._tq_centroids

        # Compute attention (KV cache was already updated by do_kv_cache_update)
        # With reorder_batch_threshold=1, decodes come first in the batch.
        # num_decodes/num_decode_tokens from metadata give the split point.
        num_decodes = attn_metadata.num_decodes
        num_decode_tokens = attn_metadata.num_decode_tokens

        if not attn_metadata.is_prefill:
            # Pure decode batch — fast path
            attn_out = self._decode_attention(
                q, kv_cache, attn_metadata, Pi, centroids, PiT, layer
            )
        elif num_decodes == 0:
            # Pure prefill batch
            k = key[:N].view(N, self.num_kv_heads, self.head_size)
            v = value[:N].view(N, self.num_kv_heads, self.head_size)
            attn_out = self._prefill_attention(
                q,
                k,
                v,
                kv_cache,
                attn_metadata,
                Pi,
                centroids,
                PiT,
                layer=layer,
            )
        else:
            # Mixed batch: decodes first (guaranteed by reorder_batch).
            attn_out = torch.zeros(
                N, self.num_heads, self.head_size, device=device, dtype=q.dtype
            )

            # --- Decode portion (first num_decodes requests) ---
            # Use full-batch max_seq_len as safe upper bound (no GPU sync).
            decode_meta = TurboQuantMetadata(
                seq_lens=attn_metadata.seq_lens[:num_decodes],
                slot_mapping=attn_metadata.slot_mapping[:num_decode_tokens],
                block_table=attn_metadata.block_table[:num_decodes],
                query_start_loc=attn_metadata.query_start_loc[: num_decodes + 1],
                num_actual_tokens=num_decode_tokens,
                max_query_len=1,
                max_seq_len=attn_metadata.max_seq_len,
                is_prefill=False,
            )
            attn_out[:num_decode_tokens] = self._decode_attention(
                q[:num_decode_tokens], kv_cache, decode_meta, Pi, centroids, PiT, layer
            )

            # --- Prefill portion (remaining requests) ---
            # CRITICAL: use prefill-specific max_seq_len so flash_attn's
            # fast path (max_query_len == max_seq_len) triggers for
            # first-chunk prefills. Using full-batch max_seq_len breaks
            # this because decode requests inflate max_seq_len.
            prefill_seq_lens = attn_metadata.seq_lens[num_decodes:]
            # Use CPU-side max to avoid GPU→CPU sync from .item()
            prefill_max_seq = max(attn_metadata.seq_lens[num_decodes:].tolist())
            prefill_qsl = (
                attn_metadata.query_start_loc[num_decodes:] - num_decode_tokens
            )
            prefill_meta = TurboQuantMetadata(
                seq_lens=prefill_seq_lens,
                slot_mapping=attn_metadata.slot_mapping[num_decode_tokens:N],
                block_table=attn_metadata.block_table[num_decodes:],
                query_start_loc=prefill_qsl,
                num_actual_tokens=N - num_decode_tokens,
                max_query_len=attn_metadata.max_query_len,
                max_seq_len=prefill_max_seq,
                is_prefill=True,
            )
            k = key[:N].view(N, self.num_kv_heads, self.head_size)
            v = value[:N].view(N, self.num_kv_heads, self.head_size)
            attn_out[num_decode_tokens:] = self._prefill_attention(
                q[num_decode_tokens:],
                k[num_decode_tokens:],
                v[num_decode_tokens:],
                kv_cache,
                prefill_meta,
                Pi,
                centroids,
                PiT,
                layer=layer,
            )

        # Write into output buffer: attn_out is (N, Hq, D)
        # output may be 2D (N, Hq*D) or 3D (N, Hq, D)
        if output.ndim == 3:
            output[:N] = attn_out.to(output.dtype)
        else:
            output[:N] = attn_out.reshape(N, -1).to(output.dtype)
        return output

    # ------------------------------------------------------------------ #
    #  Store K/V into combined cache (vectorized)                         #
    # ------------------------------------------------------------------ #
    def _store_kv(
        self,
        key: torch.Tensor,  # (N, Hk, D)
        value: torch.Tensor,  # (N, Hk, D)
        kv_cache: torch.Tensor,  # (num_blocks, block_size, Hk, slot_size)
        slot_mapping: torch.Tensor,
        layer: Any,
    ):
        """Quantize + store via fused Triton kernel."""
        triton_turboquant_store(
            key,
            value,
            kv_cache,
            slot_mapping,
            layer._tq_PiT,
            layer._tq_midpoints,
            mse_bits=self.tq_config.key_mse_bits,
            key_packed_size=self.tq_config.key_packed_size,
            value_quant_bits=self.tq_config.effective_value_quant_bits,
            key_fp8=self.tq_config.key_fp8,
        )

    # ------------------------------------------------------------------ #
    #  Prefill: SDPA on raw Q/K/V with causal mask                        #
    # ------------------------------------------------------------------ #
    def _prefill_attention(
        self,
        query: torch.Tensor,  # (N, Hq, D)
        key: torch.Tensor,  # (N, Hk, D)
        value: torch.Tensor,  # (N, Hk, D)
        kv_cache: torch.Tensor,  # (num_blocks, block_size, Hk, slot_size)
        attn_metadata: TurboQuantMetadata,
        Pi: torch.Tensor,
        centroids: torch.Tensor,
        PiT: torch.Tensor | None = None,
        layer: Any = None,
    ) -> torch.Tensor:
        N, Hq, D = query.shape

        # Fast path: use flash_attn for first-chunk prefills (all K/V in batch).
        # max_query_len == max_seq_len means no request has prior cached KV.
        # Both are Python ints — no GPU sync.
        if _HAS_FLASH_ATTN and attn_metadata.max_query_len == attn_metadata.max_seq_len:
            output = torch.empty(N, Hq, D, device=query.device, dtype=query.dtype)
            flash_attn_varlen_func(
                q=query,
                k=key,
                v=value,
                cu_seqlens_q=attn_metadata.query_start_loc,
                cu_seqlens_k=attn_metadata.query_start_loc,
                max_seqlen_q=attn_metadata.max_query_len,
                max_seqlen_k=attn_metadata.max_query_len,
                softmax_scale=self.scale,
                causal=True,
                out=output,
            )
            return output

        # Continuation or no flash_attn: per-request attention.
        # For continuation chunks (seq_len > q_len), we must attend to
        # previously cached K/V from the TQ cache, not just the current
        # chunk's raw K/V.
        Hk = key.shape[1]
        use_gqa = Hk < Hq
        query_start_loc = attn_metadata.query_start_loc
        num_reqs = query_start_loc.shape[0] - 1

        output = torch.zeros(N, Hq, D, device=query.device, dtype=query.dtype)

        # Convert to Python lists once (single CPU-GPU sync) instead of
        # per-request .item() calls that each force a sync.
        qsl = query_start_loc.tolist()
        seq_lens_list = attn_metadata.seq_lens.tolist()

        # Pre-allocate cu_seqlens for single-request flash_attn calls
        # to avoid per-request host→device tensor creation.
        _cu_2 = torch.zeros(2, device=query.device, dtype=torch.int32)

        for i in range(num_reqs):
            q_start = qsl[i]
            q_end = qsl[i + 1]
            q_len = q_end - q_start
            if q_len <= 0:
                continue

            seq_len = seq_lens_list[i]
            q_seq = query[q_start:q_end]  # (q_len, Hq, D)
            k_seq = key[q_start:q_end]  # (q_len, Hk, D)
            v_seq = value[q_start:q_end]  # (q_len, Hk, D)

            if q_len == seq_len:
                # First-chunk prefill: all K/V are in the current batch.
                if _HAS_FLASH_ATTN:
                    out = torch.empty_like(q_seq)
                    _cu_2[1] = q_len
                    cu = _cu_2
                    flash_attn_varlen_func(
                        q=q_seq,
                        k=k_seq,
                        v=v_seq,
                        cu_seqlens_q=cu,
                        cu_seqlens_k=cu,
                        max_seqlen_q=q_len,
                        max_seqlen_k=q_len,
                        softmax_scale=self.scale,
                        causal=True,
                        out=out,
                    )
                else:
                    q_t = q_seq.transpose(0, 1).contiguous()
                    k_t = k_seq.transpose(0, 1).contiguous()
                    v_t = v_seq.transpose(0, 1).contiguous()
                    out = F.scaled_dot_product_attention(
                        q_t,
                        k_t,
                        v_t,
                        is_causal=True,
                        scale=self.scale,
                        enable_gqa=use_gqa,
                    ).transpose(0, 1)
                output[q_start:q_end] = out.to(query.dtype)
            else:
                # Continuation chunk: tokens already stored to TQ cache
                # by do_kv_cache_update. Use decode kernel directly to
                # avoid O(cached_len) full-dequant per continuation.
                # For large continuations, fall back to _continuation_prefill.
                cached_len = seq_len - q_len
                if q_len <= _CONTINUATION_DECODE_THRESHOLD:
                    # Fast path: treat each query as a decode request
                    # with incremental seq_lens for causal masking.
                    synth_seq_lens = torch.arange(
                        cached_len + 1,
                        seq_len + 1,
                        device=query.device,
                        dtype=attn_metadata.seq_lens.dtype,
                    )
                    synth_bt = attn_metadata.block_table[i : i + 1].expand(q_len, -1)
                    out = triton_turboquant_decode_attention(
                        query=q_seq,
                        kv_cache=kv_cache,
                        block_table=synth_bt,
                        seq_lens=synth_seq_lens,
                        Pi=Pi,
                        centroids=centroids,
                        scale=self.scale,
                        mse_bits=self.tq_config.key_mse_bits,
                        key_packed_size=self.tq_config.key_packed_size,
                        value_quant_bits=(self.tq_config.effective_value_quant_bits),
                        key_fp8=self.tq_config.key_fp8,
                        norm_correction=self.tq_config.norm_correction,
                        PiT=PiT,
                    )
                else:
                    # Large continuation: dequant cached K/V and use
                    # flash_attn for better throughput.
                    out = self._continuation_prefill(
                        layer,
                        q_seq,
                        k_seq,
                        v_seq,
                        kv_cache,
                        attn_metadata.block_table[i : i + 1],
                        cached_len,
                        seq_len,
                        Pi,
                        centroids,
                    )
                output[q_start:q_end] = out.to(query.dtype)

        return output

    def _continuation_prefill(
        self,
        layer: Any,
        query: torch.Tensor,  # (q_len, Hq, D)
        key_chunk: torch.Tensor,  # (q_len, Hk, D)
        val_chunk: torch.Tensor,  # (q_len, Hk, D)
        kv_cache: torch.Tensor,  # (num_blocks, block_size, Hk, slot_size)
        block_table: torch.Tensor,  # (1, max_num_blocks)
        cached_len: int,
        seq_len: int,
        Pi: torch.Tensor,
        centroids: torch.Tensor,
    ) -> torch.Tensor:
        """Handle continuation chunk by dequanting cached K/V from TQ cache.

        Dequants previously cached K/V, concatenates with the current
        chunk's raw K/V, then runs flash_attn with causal masking.
        """
        q_len, Hq, D = query.shape
        Hk = key_chunk.shape[1]
        device = query.device
        block_size = kv_cache.shape[1]
        BLOCK_D = triton.next_power_of_2(D)

        mse_bytes = self._mse_bytes
        val_data_bytes = self._val_data_bytes

        # Dequant cached K/V from TQ cache
        # Allocate slightly over to align to block_size for the grid.
        # Reuse cached buffers to avoid per-call allocation (~16MB at 8K).
        alloc_len = math.ceil(cached_len / block_size) * block_size
        buf_shape = (1, Hk, alloc_len, D)
        k_buf = getattr(layer, "_tq_k_dequant_buf", None)
        if k_buf is None or k_buf.shape[2] < alloc_len:
            k_buf = torch.empty(buf_shape, dtype=torch.float16, device=device)
            v_buf = torch.empty(buf_shape, dtype=torch.float16, device=device)
            layer._tq_k_dequant_buf = k_buf
            layer._tq_v_dequant_buf = v_buf
        else:
            v_buf = layer._tq_v_dequant_buf
        k_cached = k_buf[:, :, :alloc_len, :].zero_()
        v_cached = v_buf[:, :, :alloc_len, :].zero_()

        grid = (alloc_len, 1 * Hk)
        _tq_full_dequant_kv[grid](
            kv_cache,
            block_table,
            centroids,
            k_cached,
            v_cached,
            k_cached.stride(0),
            k_cached.stride(1),
            k_cached.stride(2),
            v_cached.stride(0),
            v_cached.stride(1),
            v_cached.stride(2),
            kv_cache.stride(0),
            kv_cache.stride(1),
            kv_cache.stride(2),
            block_table.stride(0),
            HEAD_DIM=D,
            BLOCK_SIZE=block_size,
            NUM_KV_HEADS=Hk,
            MSE_BYTES=mse_bytes,
            KPS=self.tq_config.key_packed_size,
            VQB=self.tq_config.effective_value_quant_bits,
            VAL_DATA_BYTES=val_data_bytes,
            MSE_BITS=self.tq_config.key_mse_bits,
            KEY_FP8=1 if self.tq_config.key_fp8 else 0,
            BLOCK_D=BLOCK_D,
            NORM_CORRECTION=1 if self.tq_config.norm_correction else 0,
            FP8_E4B15=_use_fp8_e4b15(device.index or 0),
            num_warps=4,
        )

        # Inverse-rotate MSE keys back to original space
        if not self.tq_config.key_fp8:
            k_flat = k_cached[0, :, :cached_len, :].reshape(-1, D).float()
            k_flat = k_flat @ Pi
            k_cached_trim = (
                k_flat.to(torch.float16).reshape(Hk, cached_len, D).transpose(0, 1)
            )  # (cached_len, Hk, D)
        else:
            k_cached_trim = (
                k_cached[0, :, :cached_len, :].transpose(0, 1).contiguous()
            )  # (cached_len, Hk, D)

        v_cached_trim = (
            v_cached[0, :, :cached_len, :].transpose(0, 1).contiguous()
        )  # (cached_len, Hk, D)

        # Concatenate cached + current chunk K/V (match query dtype)
        qdtype = query.dtype
        k_full = torch.cat([k_cached_trim.to(qdtype), key_chunk], dim=0)
        v_full = torch.cat([v_cached_trim.to(qdtype), val_chunk], dim=0)

        # Attention: q_len queries attending to seq_len K/V with causal mask
        if _HAS_FLASH_ATTN:
            output = torch.empty(q_len, Hq, D, device=device, dtype=query.dtype)
            cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32)
            cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32)
            flash_attn_varlen_func(
                q=query,
                k=k_full,
                v=v_full,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_k,
                max_seqlen_q=q_len,
                max_seqlen_k=seq_len,
                softmax_scale=self.scale,
                causal=True,
                out=output,
            )
            return output
        else:
            # SDPA fallback: expand KV for GQA, build causal mask
            q_t = query.transpose(0, 1).unsqueeze(0)  # (1, Hq, q_len, D)
            k_t = k_full.transpose(0, 1).unsqueeze(0)  # (1, Hk, seq_len, D)
            v_t = v_full.transpose(0, 1).unsqueeze(0)  # (1, Hk, seq_len, D)
            # Build causal mask: query position p can attend to K position j
            # where j <= cached_len + p (p is 0-indexed within chunk)
            q_pos = torch.arange(q_len, device=device).unsqueeze(1) + cached_len
            k_pos = torch.arange(seq_len, device=device).unsqueeze(0)
            mask = k_pos <= q_pos  # (q_len, seq_len)
            out = F.scaled_dot_product_attention(
                q_t,
                k_t,
                v_t,
                attn_mask=mask,
                scale=self.scale,
                enable_gqa=(Hk < Hq),
            )  # (1, Hq, q_len, D)
            return out[0].transpose(0, 1)  # (q_len, Hq, D)

    # ------------------------------------------------------------------ #
    #  Decode: Triton TQ decode attention                                 #
    # ------------------------------------------------------------------ #
    def _decode_attention(
        self,
        query: torch.Tensor,  # (B, Hq, D)
        kv_cache: torch.Tensor,  # (num_blocks, block_size, Hk, slot_size)
        attn_metadata: TurboQuantMetadata,
        Pi: torch.Tensor,
        centroids: torch.Tensor,
        PiT: torch.Tensor | None = None,
        layer: torch.nn.Module | None = None,
    ) -> torch.Tensor:
        # Grab cached decode buffers from the layer (lazily allocated).
        mid_o_buf = output_buf = lse_buf = None
        if layer is not None:
            mid_o_buf = getattr(layer, "_tq_mid_o_buf", None)
            output_buf = getattr(layer, "_tq_output_buf", None)
            lse_buf = getattr(layer, "_tq_lse_buf", None)

        result = triton_turboquant_decode_attention(
            query=query,
            kv_cache=kv_cache,
            block_table=attn_metadata.block_table,
            seq_lens=attn_metadata.seq_lens,
            Pi=Pi,
            centroids=centroids,
            scale=self.scale,
            mse_bits=self.tq_config.key_mse_bits,
            key_packed_size=self.tq_config.key_packed_size,
            value_quant_bits=self.tq_config.effective_value_quant_bits,
            key_fp8=self.tq_config.key_fp8,
            norm_correction=self.tq_config.norm_correction,
            PiT=PiT,
            mid_o_buf=mid_o_buf,
            output_buf=output_buf,
            lse_buf=lse_buf,
            buf_holder=layer,
            max_num_kv_splits=self.max_num_kv_splits,
        )
        return result

_continuation_prefill

_continuation_prefill(
    layer: Any,
    query: Tensor,
    key_chunk: Tensor,
    val_chunk: Tensor,
    kv_cache: Tensor,
    block_table: Tensor,
    cached_len: int,
    seq_len: int,
    Pi: Tensor,
    centroids: Tensor,
) -> Tensor

Handle continuation chunk by dequanting cached K/V from TQ cache.

Dequants previously cached K/V, concatenates with the current chunk's raw K/V, then runs flash_attn with causal masking.

Source code in vllm/v1/attention/backends/turboquant_attn.py
def _continuation_prefill(
    self,
    layer: Any,
    query: torch.Tensor,  # (q_len, Hq, D)
    key_chunk: torch.Tensor,  # (q_len, Hk, D)
    val_chunk: torch.Tensor,  # (q_len, Hk, D)
    kv_cache: torch.Tensor,  # (num_blocks, block_size, Hk, slot_size)
    block_table: torch.Tensor,  # (1, max_num_blocks)
    cached_len: int,
    seq_len: int,
    Pi: torch.Tensor,
    centroids: torch.Tensor,
) -> torch.Tensor:
    """Handle continuation chunk by dequanting cached K/V from TQ cache.

    Dequants previously cached K/V, concatenates with the current
    chunk's raw K/V, then runs flash_attn with causal masking.
    """
    q_len, Hq, D = query.shape
    Hk = key_chunk.shape[1]
    device = query.device
    block_size = kv_cache.shape[1]
    BLOCK_D = triton.next_power_of_2(D)

    mse_bytes = self._mse_bytes
    val_data_bytes = self._val_data_bytes

    # Dequant cached K/V from TQ cache
    # Allocate slightly over to align to block_size for the grid.
    # Reuse cached buffers to avoid per-call allocation (~16MB at 8K).
    alloc_len = math.ceil(cached_len / block_size) * block_size
    buf_shape = (1, Hk, alloc_len, D)
    k_buf = getattr(layer, "_tq_k_dequant_buf", None)
    if k_buf is None or k_buf.shape[2] < alloc_len:
        k_buf = torch.empty(buf_shape, dtype=torch.float16, device=device)
        v_buf = torch.empty(buf_shape, dtype=torch.float16, device=device)
        layer._tq_k_dequant_buf = k_buf
        layer._tq_v_dequant_buf = v_buf
    else:
        v_buf = layer._tq_v_dequant_buf
    k_cached = k_buf[:, :, :alloc_len, :].zero_()
    v_cached = v_buf[:, :, :alloc_len, :].zero_()

    grid = (alloc_len, 1 * Hk)
    _tq_full_dequant_kv[grid](
        kv_cache,
        block_table,
        centroids,
        k_cached,
        v_cached,
        k_cached.stride(0),
        k_cached.stride(1),
        k_cached.stride(2),
        v_cached.stride(0),
        v_cached.stride(1),
        v_cached.stride(2),
        kv_cache.stride(0),
        kv_cache.stride(1),
        kv_cache.stride(2),
        block_table.stride(0),
        HEAD_DIM=D,
        BLOCK_SIZE=block_size,
        NUM_KV_HEADS=Hk,
        MSE_BYTES=mse_bytes,
        KPS=self.tq_config.key_packed_size,
        VQB=self.tq_config.effective_value_quant_bits,
        VAL_DATA_BYTES=val_data_bytes,
        MSE_BITS=self.tq_config.key_mse_bits,
        KEY_FP8=1 if self.tq_config.key_fp8 else 0,
        BLOCK_D=BLOCK_D,
        NORM_CORRECTION=1 if self.tq_config.norm_correction else 0,
        FP8_E4B15=_use_fp8_e4b15(device.index or 0),
        num_warps=4,
    )

    # Inverse-rotate MSE keys back to original space
    if not self.tq_config.key_fp8:
        k_flat = k_cached[0, :, :cached_len, :].reshape(-1, D).float()
        k_flat = k_flat @ Pi
        k_cached_trim = (
            k_flat.to(torch.float16).reshape(Hk, cached_len, D).transpose(0, 1)
        )  # (cached_len, Hk, D)
    else:
        k_cached_trim = (
            k_cached[0, :, :cached_len, :].transpose(0, 1).contiguous()
        )  # (cached_len, Hk, D)

    v_cached_trim = (
        v_cached[0, :, :cached_len, :].transpose(0, 1).contiguous()
    )  # (cached_len, Hk, D)

    # Concatenate cached + current chunk K/V (match query dtype)
    qdtype = query.dtype
    k_full = torch.cat([k_cached_trim.to(qdtype), key_chunk], dim=0)
    v_full = torch.cat([v_cached_trim.to(qdtype), val_chunk], dim=0)

    # Attention: q_len queries attending to seq_len K/V with causal mask
    if _HAS_FLASH_ATTN:
        output = torch.empty(q_len, Hq, D, device=device, dtype=query.dtype)
        cu_seqlens_q = torch.tensor([0, q_len], device=device, dtype=torch.int32)
        cu_seqlens_k = torch.tensor([0, seq_len], device=device, dtype=torch.int32)
        flash_attn_varlen_func(
            q=query,
            k=k_full,
            v=v_full,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_k,
            max_seqlen_q=q_len,
            max_seqlen_k=seq_len,
            softmax_scale=self.scale,
            causal=True,
            out=output,
        )
        return output
    else:
        # SDPA fallback: expand KV for GQA, build causal mask
        q_t = query.transpose(0, 1).unsqueeze(0)  # (1, Hq, q_len, D)
        k_t = k_full.transpose(0, 1).unsqueeze(0)  # (1, Hk, seq_len, D)
        v_t = v_full.transpose(0, 1).unsqueeze(0)  # (1, Hk, seq_len, D)
        # Build causal mask: query position p can attend to K position j
        # where j <= cached_len + p (p is 0-indexed within chunk)
        q_pos = torch.arange(q_len, device=device).unsqueeze(1) + cached_len
        k_pos = torch.arange(seq_len, device=device).unsqueeze(0)
        mask = k_pos <= q_pos  # (q_len, seq_len)
        out = F.scaled_dot_product_attention(
            q_t,
            k_t,
            v_t,
            attn_mask=mask,
            scale=self.scale,
            enable_gqa=(Hk < Hq),
        )  # (1, Hq, q_len, D)
        return out[0].transpose(0, 1)  # (q_len, Hq, D)

_ensure_on_device

_ensure_on_device(layer, device)

One-time derivation of TQ buffers (rotation matrices, midpoints).

Registered buffers (_tq_signs, _tq_centroids) are already on the correct device via register_buffer + model.to(device).

Source code in vllm/v1/attention/backends/turboquant_attn.py
def _ensure_on_device(self, layer, device):
    """One-time derivation of TQ buffers (rotation matrices, midpoints).

    Registered buffers (_tq_signs, _tq_centroids) are already on the
    correct device via register_buffer + model.to(device).
    """
    if not hasattr(layer, "_tq_cached"):
        D = layer._tq_signs.shape[0]
        signs = layer._tq_signs.to(device=device, dtype=torch.float32)

        # WHT rotation: orthonormal + self-inverse, enabling future
        # in-kernel butterfly fusion and trivial inverse for continuation.
        H = _build_hadamard(D, str(device))
        layer._tq_PiT = (signs.unsqueeze(1) * H).contiguous()
        layer._tq_Pi = layer._tq_PiT.T.contiguous()

        c = layer._tq_centroids.to(device=device, dtype=torch.float32)
        # Precompute midpoints for threshold-based quantization
        c_sorted, _ = c.sort()
        layer._tq_midpoints = (c_sorted[:-1] + c_sorted[1:]) / 2
        # Decode buffers (_tq_mid_o_buf, _tq_output_buf, _tq_lse_buf)
        # are pre-allocated via register_buffer in Attention.__init__
        # and moved to GPU by model.to(device) — no allocation needed
        # here.  The memory profiler sees them before KV cache sizing.
        layer._tq_cached = True

_store_kv

_store_kv(
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    slot_mapping: Tensor,
    layer: Any,
)

Quantize + store via fused Triton kernel.

Source code in vllm/v1/attention/backends/turboquant_attn.py
def _store_kv(
    self,
    key: torch.Tensor,  # (N, Hk, D)
    value: torch.Tensor,  # (N, Hk, D)
    kv_cache: torch.Tensor,  # (num_blocks, block_size, Hk, slot_size)
    slot_mapping: torch.Tensor,
    layer: Any,
):
    """Quantize + store via fused Triton kernel."""
    triton_turboquant_store(
        key,
        value,
        kv_cache,
        slot_mapping,
        layer._tq_PiT,
        layer._tq_midpoints,
        mse_bits=self.tq_config.key_mse_bits,
        key_packed_size=self.tq_config.key_packed_size,
        value_quant_bits=self.tq_config.effective_value_quant_bits,
        key_fp8=self.tq_config.key_fp8,
    )

do_kv_cache_update

do_kv_cache_update(
    layer: Module,
    key: Tensor,
    value: Tensor,
    kv_cache: Tensor,
    slot_mapping: Tensor,
) -> None

Store compressed K/V into the combined TQ cache.

Called as a separate custom op (unified_kv_cache_update) BEFORE the attention forward, matching FlashAttention's split pattern. slot_mapping is already sliced to num_actual_tokens by the caller.

Source code in vllm/v1/attention/backends/turboquant_attn.py
def do_kv_cache_update(
    self,
    layer: torch.nn.Module,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: torch.Tensor,
    slot_mapping: torch.Tensor,
) -> None:
    """Store compressed K/V into the combined TQ cache.

    Called as a separate custom op (unified_kv_cache_update) BEFORE
    the attention forward, matching FlashAttention's split pattern.
    slot_mapping is already sliced to num_actual_tokens by the caller.
    """
    N = slot_mapping.shape[0]
    if N <= 0:
        return

    device = key.device
    self._ensure_on_device(layer, device)

    k = key[:N].view(N, self.num_kv_heads, self.head_size)
    v = value[:N].view(N, self.num_kv_heads, self.head_size)
    self._store_kv(k, v, kv_cache, slot_mapping, layer)

TurboQuantMetadata dataclass

Bases: AttentionMetadata

Metadata for TurboQuant attention.

Source code in vllm/v1/attention/backends/turboquant_attn.py
@dataclass
class TurboQuantMetadata(AttentionMetadata):
    """Metadata for TurboQuant attention."""

    seq_lens: torch.Tensor  # (num_reqs,) — total context length per request
    slot_mapping: torch.Tensor  # (num_tokens,) — cache slot for each token
    block_table: torch.Tensor  # (num_reqs, max_num_blocks)
    query_start_loc: torch.Tensor  # (num_reqs + 1,) — cu_seqlens for queries
    num_actual_tokens: int = 0  # actual tokens (excluding padding)
    max_query_len: int = 0  # longest query in batch
    max_seq_len: int = 0  # longest context in batch
    is_prefill: bool = False
    num_decodes: int = 0  # number of decode requests (first in batch)
    num_decode_tokens: int = 0  # tokens from decode requests

TurboQuantMetadataBuilder

Bases: AttentionMetadataBuilder[TurboQuantMetadata]

Builds TurboQuantMetadata from scheduler output.

Source code in vllm/v1/attention/backends/turboquant_attn.py
class TurboQuantMetadataBuilder(AttentionMetadataBuilder[TurboQuantMetadata]):
    """Builds TurboQuantMetadata from scheduler output."""

    _cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH

    def __init__(self, kv_cache_spec, layer_names, vllm_config, device):
        super().__init__(kv_cache_spec, layer_names, vllm_config, device)
        self._init_reorder_batch_threshold(1, supports_spec_as_decode=False)

    def build_for_cudagraph_capture(
        self, common_attn_metadata: CommonAttentionMetadata
    ) -> TurboQuantMetadata:
        attn_metadata = self.build(0, common_attn_metadata)
        # Set seq_lens to 1 so CUDA graph capture is fast
        # (real seq_lens are filled at replay time).
        attn_metadata.seq_lens.fill_(1)
        return attn_metadata

    def build(self, common_prefix_len, common_attn_metadata, fast_build=False):
        """Build TurboQuantMetadata from common attention metadata."""
        cam = common_attn_metadata

        # With reorder_batch_threshold=1, the model runner guarantees
        # decodes come first in the batch. split_decodes_and_prefills
        # finds the boundary (operates on CPU tensors — no GPU sync).
        assert self.reorder_batch_threshold is not None
        num_decodes, num_prefills, num_decode_tokens, _ = split_decodes_and_prefills(
            cam, decode_threshold=self.reorder_batch_threshold
        )

        return TurboQuantMetadata(
            seq_lens=cam.seq_lens,
            slot_mapping=cam.slot_mapping,
            block_table=cam.block_table_tensor,
            query_start_loc=cam.query_start_loc,
            num_actual_tokens=cam.num_actual_tokens,
            max_query_len=cam.max_query_len,
            max_seq_len=cam.max_seq_len,
            is_prefill=(cam.max_query_len > 1),
            num_decodes=num_decodes,
            num_decode_tokens=num_decode_tokens,
        )

build

build(
    common_prefix_len,
    common_attn_metadata,
    fast_build=False,
)

Build TurboQuantMetadata from common attention metadata.

Source code in vllm/v1/attention/backends/turboquant_attn.py
def build(self, common_prefix_len, common_attn_metadata, fast_build=False):
    """Build TurboQuantMetadata from common attention metadata."""
    cam = common_attn_metadata

    # With reorder_batch_threshold=1, the model runner guarantees
    # decodes come first in the batch. split_decodes_and_prefills
    # finds the boundary (operates on CPU tensors — no GPU sync).
    assert self.reorder_batch_threshold is not None
    num_decodes, num_prefills, num_decode_tokens, _ = split_decodes_and_prefills(
        cam, decode_threshold=self.reorder_batch_threshold
    )

    return TurboQuantMetadata(
        seq_lens=cam.seq_lens,
        slot_mapping=cam.slot_mapping,
        block_table=cam.block_table_tensor,
        query_start_loc=cam.query_start_loc,
        num_actual_tokens=cam.num_actual_tokens,
        max_query_len=cam.max_query_len,
        max_seq_len=cam.max_seq_len,
        is_prefill=(cam.max_query_len > 1),
        num_decodes=num_decodes,
        num_decode_tokens=num_decode_tokens,
    )

_build_hadamard

_build_hadamard(d: int, device_str: str) -> Tensor

Orthonormal Hadamard matrix (Sylvester construction), cached per (d, device).

Precomputed D×D matrix enables matmul-based WHT — single cuBLAS GEMM instead of log2(D) butterfly kernel launches. 64KB for D=128.

Source code in vllm/v1/attention/backends/turboquant_attn.py
def _build_hadamard(d: int, device_str: str) -> torch.Tensor:
    """Orthonormal Hadamard matrix (Sylvester construction), cached per (d, device).

    Precomputed D×D matrix enables matmul-based WHT — single cuBLAS GEMM
    instead of log2(D) butterfly kernel launches. 64KB for D=128.
    """
    # Normalize device string so "cuda" and "cuda:0" hit the same cache entry.
    return _build_hadamard_cached(d, str(torch.device(device_str)))