Flash Attention Transformer LLM GPU Inference Bellek Optimizasyonu

Flash Attention Nedir? Transformer Bellek Optimizasyonu

Orta
person Yapay Zeka Uzmanı

Flash Attention Nedir? Transformer Bellek Optimizasyonu kapak görseli

Büyük bir dil modeline yüzlerce satırlık kod ya da uzun bir belge yapıştırdınızda yanıt gelmeden önce kısa bir bekleme yaşanır. Bu beklemenin önemli bir kısmı transformer mimarisinin içindeki dikkat hesabından kaynaklanır.

Attention mekanizması her yeni token üretmek için girişteki tüm önceki tokenlara bakar. Bu ilişki hesabı GPU’da yapılırken belirleyici bir darboğaz ortaya çıkar: hesaplama birimi olan SRAM küçüktür, depolama katmanı olan HBM büyük ama yavaştır. Klasik implementasyonlar büyük matris sonuçlarını sürekli HBM’e yazıp okuyarak bu yavaş köprüde bant genişliği tüketir.

Flash Attention bu problemi 2022’de Tri Dao ve ekibinin yayımladığı çalışmayla adresler. Algoritma dikkat hesabını GPU bellek hiyerarşisine göre yeniden düzenleyerek hem eğitim hem çıkarım hızını gözle görülür biçimde artırır. GPT-4, Claude 3, Llama 3 ve Gemma 2 dahil neredeyse tüm modern LLM’ler bu algoritmayı üretimde çalıştırıyor.

Standart Attention’ın Sorunu: GPU Bellek Hiyerarşisi

Transformer modellerinin attention katmanı üç matris üretir: Query (Q), Key (K) ve Value (V). Formül tanıdıktır:

Attention(Q, K, V) = softmax(QKᵀ / √d_k) · V

Ama bu hesabın GPU’da nasıl gerçekleştiğine bakılınca ciddi bir darboğaz görülür.

Modern bir NVIDIA GPU’sunda iki temel bellek bölgesi vardır. HBM (High Bandwidth Memory) büyüktür; A100’de 80 GB kapasiteye ulaşır. Hesaplama birimine bitişik olan SRAM ise yaklaşık 20 MB ile kısıtlıdır ama birkaç kat daha hızlı çalışır.

Klasik attention implementasyonu şöyle ilerler: tüm QKᵀ çarpım matrisini önce HBM’e yazar, softmax için geri okur, ardından V ile çarpmak için tekrar HBM’e gönderir. Bu süreçte N × N boyutlu matrisler defalarca yavaş HBM köprüsünden geçer. Dizi uzadıkça bu trafik karesiyle büyür; 1000 token için 1 milyon eleman taşınırken 8000 token için 64 milyon eleman taşınır.

Bu yüzden standart attention bellek karmaşıklığı açısından O(N²) davranır. 1 milyon token bağlam penceresiyle çalışmak klasik implementasyonla fiilen imkânsızlaşır; bellek gereksinimi milyarlarca elemanı içeren matrislerle patlar.

Flash Attention’ın Çözümü: Tiling ve Recomputation

Flash Attention bu darboğazı iki temel teknikle çözer.

Tiling (Blok Bölümleme)

Büyük N × N dikkat matrisini tek seferde hesaplamak yerine Flash Attention girişi küçük bloklara, yani tile’lara böler. Her blok SRAM’e sığacak büyüklükte tutulur. Blok SRAM içinde hesaplanır, geçici sonuç birleştirilir, ardından HBM’e yazılır. Sonraki blok için SRAM temizlenip yeniden kullanılır.

Bu yaklaşımın getirdiği kritik kazanım şudur: dikkat matrisinin büyük ara versiyonları HBM’e hiç yazılmaz. Hesaplama doğrudan hızlı SRAM içinde tamamlanır.

Online Softmax

Normal softmax hesaplamak için satırın tamamını görmek gerekir; oysa tiling’de matrisin yalnızca bir bölümü görülür. Flash Attention, bloğu tamamlamadan önce doğru softmax değerini üretebilmek için online softmax adı verilen numerik tekniği kullanır. Bu teknik log-sum-exp numarasıyla işler: her yeni blok işlendiğinde önceki bloğun normalizasyon faktörü güncellenir ve nihai sonuç tutarlı kalır.

Recomputation (Yeniden Hesaplama)

Backward pass sırasında, yani eğitimde, ara aktivasyonların saklanması gerekir. Klasik implementasyon bu değerleri HBM’de tutar. Flash Attention onları depolamak yerine gerektiğinde SRAM’de yeniden hesaplar. Bu kulağa verimsiz gelir; ama HBM’de depolamanın ve okumanın maliyeti SRAM’de yeniden hesaplamanın maliyetini geçtiğinde bellek tasarrufu baskın çıkar.

Net etki: bellek karmaşıklığı O(N²) düzeyinden O(N) düzeyine iner. HBM I/O trafiği ciddi biçimde azalır. Hesaplama süresi artsa da toplam duvar saati süresi düşer çünkü asıl darboğaz HBM bant genişliğiydi.

FlashAttention-1, 2 ve 3: Sürüm Farkları

Flash Attention tek bir çalışmayla bitmedi; her sürüm belirgin bir kazanım getirdi.

FlashAttention v1 (2022)

Tri Dao, Daniel Y. Fu ve arkadaşlarının NeurIPS 2022’de yayımladığı orijinal makale. A100 GPU’sunda standart attention’a göre 3x hız artışı ve 10x bellek tasarrufu gösterdi. Temel inovasyon tiling ile recomputation kombinasyonuydu.

FlashAttention v2 (2023)

İlk sürüm warp düzeyinde hesaplamanın önemli bir bölümünü boşa harcıyordu. v2 paralel hesap düzenini yeniden organize ederek thread bloklarındaki koordinasyonu düzeltti. A100’de v1’e kıyasla yaklaşık 2x ek hız kazandı. Multi-Head Attention’ın yanı sıra Grouped Query Attention (GQA) desteği de eklendi. Bu sürüm Llama 2’den itibaren neredeyse tüm açık modellerin varsayılan attention implementasyonu haline geldi.

FlashAttention v3 (2024-2025)

H100 GPU’sunun yeni donanım özelliklerine göre baştan yazıldı. Asenkron yürütme, producer-consumer pipeline ve FP8 hassasiyet desteği getirdi. H100 Tensor Core’larını tam verimde kullanan bu sürüm, A100’de v2’ye kıyasla yaklaşık 2x ek kazanım sunar. FP8 pipeline ile H100’de teorik tepe throughput’un yüzde yetmişbeşine ulaşmak mümkün.

Hangi Modeller Flash Attention Kullanıyor?

2022’den bu yana Flash Attention fiilen sektör standardı haline geldi.

GPT-4 serisi, Claude 3 ve sonrası (Haiku, Sonnet, Opus), Meta’nın Llama 2, Llama 3 ve Llama 3.1 modelleri, Google’ın Gemma 2, Mistral ve Mixtral, TII’ın Falcon 2 ile Qwen2 serisi Flash Attention veya türevi bir implementasyon çalıştırıyor.

PyTorch 2.0’dan itibaren torch.nn.functional.scaled_dot_product_attention fonksiyonu Flash Attention’ı otomatik devreye alıyor. CUDA ortamında float16 ya da bfloat16 tensor kullanıldığında ayrı bir paket kurmaya gerek kalmıyor.

import torch
import torch.nn.functional as F

# PyTorch 2.0+ — Flash Attention, CUDA + float16 varsa otomatik devreye girer
q = torch.randn(1, 8, 512, 64, device="cuda", dtype=torch.float16)
k = torch.randn(1, 8, 512, 64, device="cuda", dtype=torch.float16)
v = torch.randn(1, 8, 512, 64, device="cuda", dtype=torch.float16)

with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False):
    out = F.scaled_dot_product_attention(q, k, v)

Bağımsız flash_attn paketi ise daha düşük seviyeli kontrol, GQA veya sliding window attention gibi özellikler için tercih edilir. Paket doğrudan pip install flash-attn ile kurulur; derleme süresi uzundur.

Flash Attention, KV Cache ve Uzun Bağlam

Flash Attention ile KV Cache farklı sorunları ele alır ama birbirini tamamlar.

KV Cache, decode aşamasında token’ların Key ve Value vektörlerini GPU belleğinde saklayarak yeniden hesaplamayı önler. Flash Attention ise dikkat matrisinin kendisini hesaplarken HBM trafiğini azaltır. Biri matris hesabını optimize ederken diğeri vektör depolamayı optimize eder; çatışmak yerine yan yana çalışırlar.

Long Context LLM modellerinin var olmasında Flash Attention’ın rolü belirleyicidir. Standart attention 1 milyon token için petabayt düzeyinde HBM trafiği üretir; Flash Attention bu trafiği O(N) düzeyine çekerek uzun bağlamı çalıştırılabilir kılar. 100k+ token pencereli Gemini Ultra ve Claude 3 bu altyapı olmadan var olamazdı.

Transformer mimarisini derinlemesine anlıyorsanız Flash Attention’ın tam olarak nerede devreye girdiğini görmek kolaylaşır: her attention katmanının forward ve backward geçişinde, her token-token ilişki hesabında etkin.

Sınırlılıklar ve Alternatifler

Flash Attention evrensel bir çözüm değil; belirli kısıtları var.

CUDA bağımlılığı: Orijinal implementasyon NVIDIA GPU’larına özgü CUDA kernel’ları üzerine kurulu. Apple Silicon (Metal) ve AMD (ROCm) desteği sonradan eklendi ya da topluluk tarafından sürdürülüyor; kararlılık ve performans ana implementasyona henüz yetişmedi.

Mask karmaşıklığı: Causal masking autoregressive modeller için doğal biçimde desteklenir. Ancak belirli cross-attention konfigürasyonları ya da özel sparse attention mask’ler tile hesabını karmaşıklaştırır.

Alternatifler: Meta’nın xFormers kütüphanesi Flash Attention’a yakın performans sunar ve bazı özel mimari konfigürasyonlarında tercih edilir. PyTorch SDPA standart deployment için genellikle yeterlidir. Triton tabanlı custom kernel’lar araştırma ortamlarında esneklik gerektiğinde kullanılır.

vLLM ve PagedAttention bu ekosisteme entegre çalışır. Flash Attention dikkat matrisini optimize ederken PagedAttention KV cache bellek bloklarını yönetir; ikisi birlikte kullanılır.

Speculative Decoding ile birlikte kullanıldığında taslak ve hedef modelin ikisi de Flash Attention çalıştırır. Ayrı cache yönetimi ve senkronizasyon gerektirse de iki teknik birbirini dışlamaz.

Karşılaştırma Tablosu

ÖzellikStandart AttentionFlash Attention v2Flash Attention v3
Bellek karmaşıklığıO(N²)O(N)O(N)
Hız (A100’e göre)1x~3-4x~5-6x (H100)
HBM I/OYüksekDüşükÇok düşük
FP8 desteğiHayırHayırEvet
Uzun bağlamKısıtlıİyiÇok iyi
GQA desteğiHayırEvetEvet

Eğitim Maliyetine ve Throughput’a Etkisi

Flash Attention’ın eğitim döngüsüne katkısı yalnızca hız değil, doğrudan mali tasarruf anlamına gelir. GPT-3 ölçeğinde bir model eğitirken attention hesabı toplam GPU süresinin yüzde otuz ila kırkını tüketir. FlashAttention-2 ile bu oranı yarıya indirmek, milyonlarca dolarlık veri merkezi faturasında doğrudan bir düşüş demektir.

Pratik karşılaştırmalar bu tabloyu netleştirir. A100 80 GB GPU’da 2048 token uzunluğunda standart PyTorch dikkat hesabı yaklaşık 18 ms sürerken Flash Attention v2 aynı işi 5-6 ms’de bitirir. Bağlam uzadıkça fark katlanarak büyür: 8192 token’da standart attention OOM hatası verirken Flash Attention v2 aynı donanımda sorunsuz çalışır.

Eğitim throughput’una yansıması da ölçülebilir. Llama 2’nin eğitim sürecini belgeleyen Meta raporları, Flash Attention’ın etkin olduğu konfigürasyonlarda token/saniye değerinin yüzde altmış ila seksen artığını ortaya koyar. Bu artış hem pratikte daha kısa eğitim döngüleri hem de aynı donanımla daha büyük batch boyutları çalıştırma imkânı sağlar.

Araştırma ortamlarında da fark belirgin. Yüz bin tokenın üzerinde bağlam deneyleri yapan takımlar, Flash Attention olmadan bu deneylerin kurulabilir olmadığını belirtir: ya donanım yetersiz kalır ya da model daha kısa dizilerle eğitilmek zorunda kalır. Bu yüzden uzun bağlam araştırmalarının büyük çoğunluğu Flash Attention’ı temel varsayım olarak alır.

Flash Attention Neden Bu Kadar Yaygınlaştı?

Algoritma hem eğitim hem üretim maliyetini doğrudan etkiliyor. Model eğitirken döngünün büyük kısmı attention hesabında geçer; Flash Attention bu maliyeti yukarıdan aşağıya indirger. Üretimde çalıştırırken throughput artar, gecikme düşer, aynı donanımla daha uzun bağlamlar işlenebilir.

API kullanıcısı olarak Flash Attention’ı doğrudan kontrol etmezsiniz; sağlayıcının altyapısında çalışan bir katmandır. Ama hangi sürümün etkin olduğu, özellikle uzun bağlam isteklerinin hızını ve fiyatını etkiler.

Flash Attention’ı daha derinlemesine incelemek isteyenler için başlangıç noktaları: Tri Dao’nun orijinal makalesi, PyTorch SDPA dokümantasyonu ve Dao-AILab/flash-attention GitHub deposu.