joi, 23 mai 2024

GPT-2 in about 150 lines of Java

Found a very interesting post "GPT in 500 lines of SQL" (https://explainextended.com/2023/12/31/happy-new-year-15/) which was inspired from "GPT in 60 Lines of NumPy" (https://jaykmody.com/blog/gpt-from-scratch/). And of course I could not resist the temptation to do the same but in Java :-) In addition, "all I saw" in the source code were mostly matrix operations, which I like a lot.

image from https://www.mathsisfun.com/algebra/matrix-multiplying.html  

Thanks to Nd4j library which offers the same flexibility of working with matrices as NumPy, I managed to convert the source code from https://jaykmody.com/blog/gpt-from-scratch/ into about 150 lines of Java (https://github.com/Streeling/RD_Archive/tree/main/ams/gpt2-in-about-150-lines-of-nd4j).