Submitted by Chemont t3_109z8om in MachineLearning

I recently came across " Confident Adaptive Language Modeling " which allows Transformers to exit early during inference and not use all model layers if a token is easy to predict. Is there any research on basically doing the opposite and allowing Transformers to spent more compute on tokens that are very hard to predict?

19

Comments

You must log in or register to comment.

amrit_za t1_j418a4l wrote

It sounds like what you're considering the "opposite" is just a reframing of original task i.e. if a token is difficult to predict, then more layers (and therefore compute) would be used used. If it's easy, fewer layers. Am I missing something from what you're asking?

18

Chemont OP t1_j41eamz wrote

I should have been clearer with my question. What I was wondering was, if there are any extensions to the Transformer architecture that allow it to, in theory, spent indefinite amounts of compute on one token. I suppose one could train a very deep Transformer, use CALM during inference and only use all of the layers for tokens which are difficult to predict, but this would still arbitrarily limit the maximum amount of compute per token.

6

tdgros t1_j41f1nz wrote

You'll still pay the full price at train time, right? Early decoding works by using decoders on earlier levels at train time. Conversely, if you want to spend more on some tokens, at train time, you will need to have more layers, so at some point you will hit your memory/complexity limits.

4

visarga t1_j46b2po wrote

No but if you use a decoder model (autoregressive) you can generate more tokens for the same task, depending on its difficulty. Chain-of-thought makes use of this trick.

2

PassingTumbleweed t1_j41pibv wrote

Yes. This thread made me think of Universal Transformers which has dynamic halting and has been around for a while now: https://openreview.net/forum?id=HyzdRiR9Y7

14

Raphaelll_ t1_j45u38j wrote

Did this ever get any traction?

1

PassingTumbleweed t1_j46sco1 wrote

That depends on what you mean. I don't think any of the LLMs use it, but it has some citations and follow-up literature.

1

Professor_Entropy t1_j41xfok wrote

Chain-of-thought prompting does this for the LM transformers. It can complete the harder objectives by using more computations.

I know it doesn't solve the general case, but you may take inspiration from it in other domains.

7

FutureIsMine t1_j41l2ck wrote

There's albert which reuses the same layers throughout, I can see a case where albert is used, and a decoder thats a few neuros is also present, where at each step it will use a token in the input to determine if its time to stop, similarly reso net did this for Q&A

3

rehrev t1_j414ibv wrote

What does early stopping inference mean tho

1

tdgros t1_j41fn3f wrote

At train time, you plug decoders at many levels with the same objective, you can find out if some things can be decoded earlier, using an additional network that outputs a sort of confidence. At inference time, you run the layers one by one, and stop when the confidence is high. which allows you to skip some computations. (It's probably a simplistic description, feel free to correct me)

4

cfoster0 t1_j4alveu wrote

FWIW in certain sense this goes against the design philosophy of transformers, which is to jointly compute all representations within a layer at once, to maximize the degree of parallelism on GPUs and other accelerators.

1

icecubeinanicecube t1_j415q2o wrote

How would you even do that? Once you have run inference through all layers, you can not just randomly pull additional layers out of thin air, can you?

0