Performance optimized interleaved mode JetStream server#122
Performance optimized interleaved mode JetStream server#122JoeZijunZhou wants to merge 1 commit intomainfrom
Conversation
Optimized TTFT and Optimized output token throughput are conflicted with each. Can we expose some parameter to tuning the two part? |
| for _ in self._generate_engines | ||
| ] | ||
| self._prefill_detokenize_backlogs = [ | ||
| # We don't let detokenization accumulate more than 8 steps to avoid |
There was a problem hiding this comment.
Can you elaborate more on why there is synchronization issue after 8 steps?
There was a problem hiding this comment.
Set it to 8 as the detokenize thread. Too large or too small will cause performance issue.
There was a problem hiding this comment.
Thanks, is this PR ready to submit?
| slot, active_request = data | ||
| my_live_requests[slot] = active_request | ||
|
|
||
| def _prefill_detokenize_thread(self, idx: int): |
There was a problem hiding this comment.
I though we already had prefill detokenize thread. Do current jetstream (before this pr) always return prefill token (fist token) after first decode step?
There was a problem hiding this comment.
We only had detokenize thread that combined prefill detokenize and decode detokenize. The problem is that we have jax.block_until_ready() blocking the thread waiting for the prefill token or decode token copy to host async, so putting them in 1 thread would make the TTFT slow. JetStream returns prefill token in prefill thread (after prefill step generating the first token).
There was a problem hiding this comment.
sounds good, thanks for sharing insights!
Currently, prioritize prefills in interleaved mode, and apply correct JAX blocking for copy to host async to reduce wasted wait time. 1 more optimization to do is to ensure the result returns immediately when the return channel has the result (from orchestrator). |
Uh oh!
There was an error while loading. Please reload this page.