Speculative Decoding
Last updated
Last updated
Faster inference for LLMs
Parallel token computation
For predicting 10 token in a sentence, 10 runs/decoding steps of LLMs has to be done (in series).
Slow
Key observations about generation with LLMs
Some token are easier to predict, some need more compute/remembrance to predict.
"what's the square root of 7? Square root of 7 is 2.blah blah"
In above sentence, 7 is simply copying previous token, where as for 2.blah blah we will gave to do some computation.
Memory is the bottleneck with GPUs, because parallelisation is very efficient for computation.
Memory read/write is about 100 times slower than compute. i.e for every byte read in to memory, we can do about 100 compute operations in the meanwhile we read
Perform some operation in speculation of it happing in the future. ex, you are uploading an instagram image -> Already upload the image on the server when you editing in the instagram app locally, now if you actually move head with uploading, image will be uploaded instantly as it was uploading in the background, if you don't move ahead and discard your edits, then remove the uploaded image.
Use smaller network as proxy true model. Keep predicting with it and use larger model as verifier (true) model.
The prediction is done with a smaller model here