Configuration for TurboQuant KV-cache quantization.
Uses PolarQuant (WHT rotation + Lloyd-Max scalar quantization) for keys and uniform quantization for values. QJL is intentionally omitted — community consensus (5+ independent groups) found it hurts attention quality by amplifying variance through softmax.
Named presets (use via --kv-cache-dtype): turboquant_k8v4: FP8 keys + 4-bit values, 2.6x, +1.17% PPL turboquant_4bit_nc: 4-bit MSE keys + 4-bit values + NC, 3.8x, +2.71% turboquant_k3v4_nc: 3-bit MSE keys + 4-bit values + NC, ~3.5x, +10.63% turboquant_3bit_nc: 3-bit MSE keys + 3-bit values + NC, 4.9x, +20.59%
Parameters:
| Name | Type | Description | Default |
head_dim | int | Attention head dimension (e.g. 64, 96, 128). | 128 |
key_quant_bits | int | Bits for key quantization. 8 = FP8 keys (no rotation/MSE). 3-4 = Lloyd-Max MSE quantized keys. | 3 |
value_quant_bits | int | Bits per value dimension for uniform quantization. 3 = 8 levels, 4 = 16 levels (default). | 4 |
seed | int | Base seed for deterministic random matrix generation. Actual seed per layer = seed + layer_idx * 1337. | 42 |
norm_correction | bool | Re-normalize centroid vectors to unit norm before inverse rotation during dequant. Fixes quantization-induced norm distortion, improving PPL by ~0.8% at 4-bit. | False |
Source code in vllm/model_executor/layers/quantization/turboquant/config.py
| @dataclass
class TurboQuantConfig:
"""Configuration for TurboQuant KV-cache quantization.
Uses PolarQuant (WHT rotation + Lloyd-Max scalar quantization) for keys
and uniform quantization for values. QJL is intentionally omitted —
community consensus (5+ independent groups) found it hurts attention
quality by amplifying variance through softmax.
Named presets (use via --kv-cache-dtype):
turboquant_k8v4: FP8 keys + 4-bit values, 2.6x, +1.17% PPL
turboquant_4bit_nc: 4-bit MSE keys + 4-bit values + NC, 3.8x, +2.71%
turboquant_k3v4_nc: 3-bit MSE keys + 4-bit values + NC, ~3.5x, +10.63%
turboquant_3bit_nc: 3-bit MSE keys + 3-bit values + NC, 4.9x, +20.59%
Args:
head_dim: Attention head dimension (e.g. 64, 96, 128).
key_quant_bits: Bits for key quantization. 8 = FP8 keys (no
rotation/MSE). 3-4 = Lloyd-Max MSE quantized keys.
value_quant_bits: Bits per value dimension for uniform quantization.
3 = 8 levels, 4 = 16 levels (default).
seed: Base seed for deterministic random matrix generation.
Actual seed per layer = seed + layer_idx * 1337.
norm_correction: Re-normalize centroid vectors to unit norm before
inverse rotation during dequant. Fixes quantization-induced norm
distortion, improving PPL by ~0.8% at 4-bit.
"""
head_dim: int = 128
key_quant_bits: int = 3 # 3-4 = MSE keys, 8 = FP8 keys
value_quant_bits: int = 4 # 3-4 = uniform quantized values
seed: int = 42
norm_correction: bool = False
@property
def key_fp8(self) -> bool:
"""Whether keys are stored as FP8 — no rotation/quantization needed."""
return self.key_quant_bits == 8
@property
def mse_bits(self) -> int:
"""MSE quantizer bit-width (determines centroid count: 2^mse_bits).
For MSE key modes, equals key_quant_bits.
For FP8 key mode, falls back to value_quant_bits (centroids are still
needed for continuation-prefill dequant and decode kernel params).
"""
if self.key_fp8:
return self.value_quant_bits
return self.key_quant_bits
@property
def key_mse_bits(self) -> int:
"""MSE bits actually used for key quantization (0 if FP8 keys)."""
if self.key_fp8:
return 0
return self.key_quant_bits
@property
def centroid_bits(self) -> int:
"""Bits for centroid generation — always non-zero."""
return self.mse_bits
@property
def n_centroids(self) -> int:
return 2**self.mse_bits
@property
def key_packed_size(self) -> int:
"""Packed bytes for a single KEY vector.
FP8 mode (key_quant_bits=8):
head_dim bytes (1 byte per element, no overhead).
TQ mode:
- MSE indices: ceil(head_dim * key_mse_bits / 8) bytes
- vec_norm: 2 bytes (float16)
"""
if self.key_fp8:
return self.head_dim # 1 byte per element
mse_bytes = math.ceil(self.head_dim * self.key_mse_bits / 8)
norm_bytes = 2 # vec_norm fp16
return mse_bytes + norm_bytes
@property
def effective_value_quant_bits(self) -> int:
"""Actual bits used for value storage."""
return self.value_quant_bits
@property
def value_packed_size(self) -> int:
"""Packed bytes for a single VALUE vector.
Uniform quantization: ceil(head_dim * bits / 8) + 4 bytes (scale + zero fp16).
"""
data_bytes = math.ceil(self.head_dim * self.value_quant_bits / 8)
return data_bytes + 4 # +2 scale(fp16) +2 zero(fp16)
@property
def slot_size(self) -> int:
"""Total packed bytes per head per position (key + value combined).
Layout: [key_packed | value_packed]
"""
return self.key_packed_size + self.value_packed_size
@property
def slot_size_aligned(self) -> int:
"""Slot size rounded up to next even number.
Even-number is required so effective_head_size = slot_size_aligned // 2
is integral.
"""
s = self.slot_size
return s + (s % 2) # round up to even
@staticmethod
def get_boundary_skip_layers(num_layers: int, n: int = 2) -> list[str]:
"""Get layer indices to skip TQ compression (boundary protection).
Returns first N and last N layer indices as strings, suitable for
kv_cache_dtype_skip_layers.
"""
if n <= 0 or num_layers <= 0:
return []
n = min(n, num_layers // 2) # don't skip more than half
first = list(range(n))
last = list(range(num_layers - n, num_layers))
# Deduplicate (if num_layers <= 2*n)
indices = sorted(set(first + last))
return [str(i) for i in indices]
@staticmethod
def from_cache_dtype(cache_dtype: str, head_dim: int) -> "TurboQuantConfig":
"""Create config from a named preset.
Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc.
"""
if cache_dtype not in TQ_PRESETS:
valid = ", ".join(TQ_PRESETS.keys())
raise ValueError(
f"Unknown TurboQuant cache dtype: {cache_dtype!r}. "
f"Valid presets: {valid}"
)
preset = TQ_PRESETS[cache_dtype]
return TurboQuantConfig(
head_dim=head_dim,
key_quant_bits=preset["key_quant_bits"],
value_quant_bits=preset["value_quant_bits"],
norm_correction=preset["norm_correction"],
)
|
centroid_bits property
Bits for centroid generation — always non-zero.
effective_value_quant_bits property
effective_value_quant_bits: int
Actual bits used for value storage.
key_fp8 property
Whether keys are stored as FP8 — no rotation/quantization needed.
key_mse_bits property
MSE bits actually used for key quantization (0 if FP8 keys).
key_packed_size property
Packed bytes for a single KEY vector.
FP8 mode (key_quant_bits=8): head_dim bytes (1 byte per element, no overhead).
TQ mode
- MSE indices: ceil(head_dim * key_mse_bits / 8) bytes
- vec_norm: 2 bytes (float16)
mse_bits property
MSE quantizer bit-width (determines centroid count: 2^mse_bits).
For MSE key modes, equals key_quant_bits. For FP8 key mode, falls back to value_quant_bits (centroids are still needed for continuation-prefill dequant and decode kernel params).
slot_size property
Total packed bytes per head per position (key + value combined).
Layout: [key_packed | value_packed]
slot_size_aligned property
Slot size rounded up to next even number.
Even-number is required so effective_head_size = slot_size_aligned // 2 is integral.
value_packed_size property
Packed bytes for a single VALUE vector.
Uniform quantization: ceil(head_dim * bits / 8) + 4 bytes (scale + zero fp16).
from_cache_dtype staticmethod
Create config from a named preset.
Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc.
Source code in vllm/model_executor/layers/quantization/turboquant/config.py
| @staticmethod
def from_cache_dtype(cache_dtype: str, head_dim: int) -> "TurboQuantConfig":
"""Create config from a named preset.
Valid presets: turboquant_k8v4, turboquant_4bit_nc, etc.
"""
if cache_dtype not in TQ_PRESETS:
valid = ", ".join(TQ_PRESETS.keys())
raise ValueError(
f"Unknown TurboQuant cache dtype: {cache_dtype!r}. "
f"Valid presets: {valid}"
)
preset = TQ_PRESETS[cache_dtype]
return TurboQuantConfig(
head_dim=head_dim,
key_quant_bits=preset["key_quant_bits"],
value_quant_bits=preset["value_quant_bits"],
norm_correction=preset["norm_correction"],
)
|
get_boundary_skip_layers staticmethod
get_boundary_skip_layers(
num_layers: int, n: int = 2
) -> list[str]
Get layer indices to skip TQ compression (boundary protection).
Returns first N and last N layer indices as strings, suitable for kv_cache_dtype_skip_layers.
Source code in vllm/model_executor/layers/quantization/turboquant/config.py
| @staticmethod
def get_boundary_skip_layers(num_layers: int, n: int = 2) -> list[str]:
"""Get layer indices to skip TQ compression (boundary protection).
Returns first N and last N layer indices as strings, suitable for
kv_cache_dtype_skip_layers.
"""
if n <= 0 or num_layers <= 0:
return []
n = min(n, num_layers // 2) # don't skip more than half
first = list(range(n))
last = list(range(num_layers - n, num_layers))
# Deduplicate (if num_layers <= 2*n)
indices = sorted(set(first + last))
return [str(i) for i in indices]
|