Submitted by Avelina9X t3_109yuvi in MachineLearning
I'm a machine learning PhD student and I'm doing research on LMs and how to reduce their memory footprint.
One idea I've been toying with is Vector Quantized LMs. I'm not talking about VQ as a technique to speed up compute using int8 activations etc etc, but by using a codebook.
The idea is based on an uni-directional RNN that reconstructs the source sequence after quantization. Unlike MLM where the corruption is based on masking and replacing tokens we instead quantize the token vectors and try to the predict the original token based on the quantized version of the token and the unquantized short/long term memory states produced at the previous timestep.
The reason I'm interested in such a convoluted idea is to effectively create a metric to measure entropy of tokens in sequence; if the VQ-LM can reconstruct the correct token with high likelihood then that token is unimportant, but if the VQ-LM fails to predict a token it is likely that this token is of great importance because it is a rare word and this carries higher entropy in the sequence. And the motivation behind wanting to learn to measure such a phenomenon is so we can use this to guide the memory of a transformer: models like the Transformer-XL operate on longer sequences by keeping memory around for keys and values, and the Compressive Transformer takes it a step further by compressing older tokens... Well... what if we used the reconstruction loss from the VQ-LM along with an 'age' metric to guide the memory bank of such a transformer architecture, discarding easily predicted tokens early while keeping higher entropy tokens around for longer?
Has anyone considered such a system before? If done a lot of searching and I've come up blank so far.
dojoteef t1_j41m2hd wrote
See Fast Decoding in Sequence Models using Discrete Latent Variables