P

Flash Attention : Définition et Exemples

Flash Attention est un algorithme optimisé de calcul du mécanisme d'attention dans les Transformers, qui réduit drastiquement la consommation mémoire et accélère l'entraînement et l'inférence des grands modèles de langage.

Définition complète

Flash Attention est une implémentation optimisée du mécanisme d'attention (self-attention) utilisé dans les architectures Transformer. Développé par Tri Dao et ses collaborateurs à Stanford en 2022, cet algorithme résout un problème fondamental : le calcul classique de l'attention a une complexité mémoire quadratique par rapport à la longueur de la séquence, ce qui limite fortement la taille des contextes que les modèles peuvent traiter.

L'idée centrale de Flash Attention repose sur le concept de "tiling" (découpage en blocs) et une gestion intelligente de la hiérarchie mémoire du GPU. Au lieu de matérialiser la matrice d'attention complète en mémoire haute (HBM), l'algorithme effectue les calculs par blocs directement dans la mémoire rapide du GPU (SRAM), puis fusionne les résultats. Cela évite les allers-retours coûteux entre les différents niveaux de mémoire, qui constituent le véritable goulot d'étranglement.

Flash Attention 2, publié en 2023, a encore amélioré les performances en optimisant le parallélisme et en réduisant les opérations non essentielles. Cette version atteint environ 50 à 70 % de l'utilisation théorique maximale des GPU modernes, contre seulement 25 à 40 % pour l'attention standard. Flash Attention 3, sorti en 2024, pousse ces optimisations plus loin pour les GPU Hopper (H100).

L'impact de Flash Attention sur l'écosystème IA est considérable : il a rendu possible l'entraînement de modèles avec des fenêtres de contexte de 100 000 tokens et plus, là où les implémentations classiques plafonnaient à quelques milliers. Aujourd'hui, pratiquement tous les grands modèles de langage (GPT-4, Claude, Llama, Mistral) utilisent une variante de Flash Attention.

Étymologie

Le terme "Flash" fait référence à la rapidité d'exécution de l'algorithme et à son utilisation de la mémoire SRAM (parfois appelée "flash memory" par analogie avec sa vitesse), par opposition à la mémoire HBM plus lente. Le nom évoque aussi l'idée d'un calcul "éclair" qui évite de stocker les résultats intermédiaires volumineux.

Exemples concrets

Entraînement d'un LLM avec un long contexte

Tu es un assistant qui peut traiter des documents de 200 000 tokens. Analyse ce rapport annuel complet et identifie les 5 risques stratégiques majeurs mentionnés.

Optimisation d'inférence en production

Configure le serveur d'inférence vLLM avec Flash Attention 2 activé pour servir un modèle Llama 3 70B sur 4 GPU A100 avec un throughput maximal.

Recherche et fine-tuning sur GPU limité

Fine-tune ce modèle 7B sur un seul GPU 24 Go en utilisant Flash Attention et le gradient checkpointing pour maximiser la longueur de séquence supportée.

Usage pratique

En prompt engineering, Flash Attention n'est pas un concept que l'on utilise directement dans ses prompts, mais il conditionne ce qui est possible : c'est grâce à lui que les modèles modernes acceptent des contextes très longs. Concrètement, cela signifie que vous pouvez fournir des documents entiers, de longues conversations ou de nombreux exemples few-shot dans un seul prompt sans dégradation majeure de performance. Comprendre Flash Attention aide aussi à choisir les bons paramètres d'inférence et à dimensionner l'infrastructure GPU pour déployer des modèles en production.

Concepts liés

Mécanisme d'attention (Self-Attention)TransformerFenêtre de contexteComplexité quadratiqueKV CacheMulti-Head Attention

FAQ

Flash Attention change-t-il les résultats du modèle ?
Non. Flash Attention calcule exactement le même résultat mathématique que l'attention standard. C'est une optimisation purement algorithmique et matérielle : seule la manière dont les calculs sont ordonnés et stockés en mémoire change, pas le résultat final. Les sorties sont identiques à la précision numérique près.
Pourquoi Flash Attention est-il plus rapide alors qu'il fait le même calcul ?
Le gain vient de la réduction des transferts mémoire, pas de la réduction du nombre d'opérations. L'attention classique écrit une énorme matrice intermédiaire (N×N) en mémoire lente (HBM), puis la relit pour calculer le résultat. Flash Attention effectue tout par petits blocs dans la mémoire rapide du GPU (SRAM), éliminant ces transferts coûteux. Sur les GPU modernes, c'est la bande passante mémoire — et non la puissance de calcul — qui est le facteur limitant.
Faut-il activer Flash Attention manuellement ?
Cela dépend du framework utilisé. Dans Hugging Face Transformers, on peut l'activer avec le paramètre attn_implementation='flash_attention_2' lors du chargement du modèle. Dans PyTorch 2.0+, la fonction scaled_dot_product_attention utilise automatiquement Flash Attention quand les conditions sont réunies (GPU compatible, tenseurs au bon format). Les frameworks d'inférence comme vLLM et TensorRT-LLM l'activent par défaut.

Voir aussi

Recevez de nouveaux prompts chaque semaine

Rejoignez notre newsletter.