Saved in:
Bibliographic Details
Main Authors: Yao, Jinwei, Chen, Kaiqi, Zhang, Kexun, You, Jiaxuan, Yuan, Binhang, Wang, Zeke, Lin, Tao
Format: Preprint
Published: 2024
Subjects:
Online Access:https://arxiv.org/abs/2404.00242
Tags: Add Tag
No Tags, Be the first to tag this record!
_version_ 1866910863009513472
author Yao, Jinwei
Chen, Kaiqi
Zhang, Kexun
You, Jiaxuan
Yuan, Binhang
Wang, Zeke
Lin, Tao
author_facet Yao, Jinwei
Chen, Kaiqi
Zhang, Kexun
You, Jiaxuan
Yuan, Binhang
Wang, Zeke
Lin, Tao
contents Large language models (LLMs) are increasingly employed for complex tasks that process multiple generation calls in a tree structure with shared prefixes of tokens, including few-shot prompting, multi-step reasoning, speculative decoding, etc. However, existing inference systems for tree-based applications are inefficient due to improper partitioning of queries and KV cache during attention calculation. This leads to two main issues: (1) a lack of memory access (IO) reuse for KV cache of shared prefixes, and (2) poor load balancing.As a result, there is redundant KV cache IO between GPU global memory and shared memory, along with low GPU utilization. To address these challenges, we propose DeFT(Decoding with Flash Tree-Attention), a hardware-efficient attention algorithm with prefix-aware and load-balanced KV cache partitions. DeFT reduces the number of read/write operations of KV cache during attention calculation through KV-Guided Grouping, a method that avoids repeatedly loading KV cache of shared prefixes in attention computation. Additionally, we propose Flattened Tree KV Splitting, a mechanism that ensures even distribution of the KV cache across partitions with little computation redundancy, enhancing GPU utilization during attention computations. By reducing 73-99% KV cache IO and nearly 100% IO for partial results during attention calculation, DeFT achieves up to 2.23/3.59x speedup in the end-to-end/attention latency across three practical tree-based workloads compared to state-of-the-art attention algorithms. Our code is available at https://github.com/LINs-lab/DeFT.
format Preprint
id arxiv_https___arxiv_org_abs_2404_00242
institution arXiv
publishDate 2024
record_format arxiv
spellingShingle DeFT: Decoding with Flash Tree-attention for Efficient Tree-structured LLM Inference
Yao, Jinwei
Chen, Kaiqi
Zhang, Kexun
You, Jiaxuan
Yuan, Binhang
Wang, Zeke
Lin, Tao
Computation and Language
Artificial Intelligence
Large language models (LLMs) are increasingly employed for complex tasks that process multiple generation calls in a tree structure with shared prefixes of tokens, including few-shot prompting, multi-step reasoning, speculative decoding, etc. However, existing inference systems for tree-based applications are inefficient due to improper partitioning of queries and KV cache during attention calculation. This leads to two main issues: (1) a lack of memory access (IO) reuse for KV cache of shared prefixes, and (2) poor load balancing.As a result, there is redundant KV cache IO between GPU global memory and shared memory, along with low GPU utilization. To address these challenges, we propose DeFT(Decoding with Flash Tree-Attention), a hardware-efficient attention algorithm with prefix-aware and load-balanced KV cache partitions. DeFT reduces the number of read/write operations of KV cache during attention calculation through KV-Guided Grouping, a method that avoids repeatedly loading KV cache of shared prefixes in attention computation. Additionally, we propose Flattened Tree KV Splitting, a mechanism that ensures even distribution of the KV cache across partitions with little computation redundancy, enhancing GPU utilization during attention computations. By reducing 73-99% KV cache IO and nearly 100% IO for partial results during attention calculation, DeFT achieves up to 2.23/3.59x speedup in the end-to-end/attention latency across three practical tree-based workloads compared to state-of-the-art attention algorithms. Our code is available at https://github.com/LINs-lab/DeFT.
title DeFT: Decoding with Flash Tree-attention for Efficient Tree-structured LLM Inference
topic Computation and Language
Artificial Intelligence
url https://arxiv.org/abs/2404.00242