bo_peng

bo_peng OP t1_j4rht4i wrote

RWKV is a RNN that also works as a linear transformer (or we may say it's a linear transformer that also works as a RNN). So it has both parallel & serial mode, and you get the best of both worlds (fast and saves VRAM).

Almost all such "linear transformers" are bad at language modeling, but RWKV is the exception. The basic idea is a bit similar to https://arxiv.org/abs/2105.14103. Then I added lots of new ideas :)

12

bo_peng OP t1_iwua2xh wrote

RWKV 7B is faster than GPT 6B, and RWKV scales great actually :)

If you check the table, RWKV is better than GPT-neo on everything at 3B (while smaller RWKV lags behind on LAMBADA).

But GPT-J is using rotary and thus quite better than GPT-neo, so I expect RWKV to surpass it at 14B.

Moreover RWKV 3B becomes stronger after trained for more tokens and I am doing it for the 7B model too.

8

bo_peng OP t1_iwts867 wrote

RWKV-3 1.5B on A40 (tf32) = always 0.015 sec/token, tested using simple pytorch code (no CUDA), GPU utilization 45%, VRAM 7823M

GPT2-XL 1.3B on A40 (tf32) = 0.032 sec/token (for ctxlen 1000), tested using HF, GPU utilization 45% too (interesting), VRAM 9655M

Moreover RWKV-4 is bf16 and faster than 16bit GPT models.

Training speed: RWKV-4 1.5B BF16 ctxlen1024 = 106K tokens/s on 8xA100 40G.

8