I built minLlama because I wanted a Llama implementation that was easy to understand and hack for KV cache compression research. There is also a PyTorch and Jax version in ~140 lines.
Would be interested in feedback from people who have written transformer implementations before, are there any implementation "tricks" that I'm missing (e.g, cleaner KV cache for PyTorch/Jax or rope tricks)?