This change adds support for FlashInfer. FlashInfer can be enabled using
`FLASH_INFER=1` and is currently only implemented in `FlashCausalLM`.
Since this functionality is currently only for testing, FlashInfer is
not installed anywhere yet.
The FlashInfer API is quite different from FlashAttention/vLLM in that
it requires more global bookkeeping:
* A wrapper class needs to be contstructed (which we just call *state*).
Since this is fairly expensive (due to pinned host memory allocation),
we only do this once in a FlashCausalLM instance or for each CUDA
Graph size.
* Each model forward call needs to be wrapped in `begin_forward` and
`end_forward`. This sets up data structures that can be reused for all
calls to attention for that forward call.
When calling attention, we need access to the state object. To avoid
passing an argument down the call chain (which would require changes to
all models), we use a context variable.
Each model forward call is wrapped using a context manager that does all
the bookkeeping for such a call:
* Set the context variable to the forward call's state.
* Call `begin_forward` on the state.
* Yield.
* Call `end_forward` on the state.
* Reset the context variable.
We cannot use a single shared global variable for this, since e.g. CUDA
Graphs of different sizes each have their own state.
* Using flash decoding
Conditional flashdecoding.
Fix max_q.
Working kvcache
Working version with flash decoding.
Make it work for mistral.
Fix after rebase..
Less intrusive.
REvert changes in modeling.
Speedup flashdecoding.
HHachweew
Hack to make other models work.
Fixing non flash decoding llama path.
Router logic knows about page size.
Missing 2 models.
Missing cohere.
Fixing cohere flash decoding.
Revamped all this architecture.
Fix cohere.
Fixing falcon.
Enabling custom block size schedule.
Update router/src/infer.rs
Not sending preallocated output.
* Making it work on non flash decoding.
* Fix Cohere.
* Fix non decoding paths.
* Rebased.
* No need for cache_manager anymore.
* Update?
* "ipex" -> "cpu"
* These do not belong.
* Factoring cu_seqlen_qk for better abstracting over every model.
* Fixing non flash tests/imports.
* Changing return everywhere.
* Update mistral past.
* Fixing Mi{s,x}tral (non functional in Flash Decoding mode though).
* Fixup mistral clamping (had issues with cuda graphs).
* No need to recreate anything actually.