GitHub - lucidrains/compressive-transformer-pytorch: Pytorch implementation of Compressive Transformers, from Deepmind (original) (raw)
Pytorch implementation of Compressive Transformers, a variant of Transformer-XL with compressed memory for long-range language modelling. I will also combine this with an idea from another paper that adds gating at the residual intersection. The memory and the gating may be synergistic, and lead to further improvements in both language modeling as well as reinforcement learning.
Install
$ pip install compressive_transformer_pytorch
Usage
import torch from compressive_transformer_pytorch import CompressiveTransformer
model = CompressiveTransformer( num_tokens = 20000, emb_dim = 128, # embedding dimensions, embedding factorization from Albert paper dim = 512, depth = 12, seq_len = 1024, mem_len = 1024, # memory length cmem_len = 1024 // 4, # compressed memory buffer length cmem_ratio = 4, # compressed memory ratio, 4 was recommended in paper reconstruction_loss_weight = 1,# weight to place on compressed memory reconstruction loss attn_dropout = 0.1, # dropout post-attention ff_dropout = 0.1, # dropout in feedforward attn_layer_dropout = 0.1, # dropout for attention layer output gru_gated_residual = True, # whether to gate the residual intersection, from 'Stabilizing Transformer for RL' paper mogrify_gru = False, # experimental feature that adds a mogrifier for the update and residual before gating by the GRU memory_layers = range(6, 13), # specify which layers to use long-range memory, from 'Do Transformers Need LR Memory' paper ff_glu = True # use GLU variant for feedforward )
inputs = torch.randint(0, 256, (1, 2048)) masks = torch.ones_like(inputs).bool()
segments = inputs.reshape(1, -1, 1024).transpose(0, 1) masks = masks.reshape(1, -1, 1024).transpose(0, 1)
logits, memories, aux_loss = model(segments[0], mask = masks[0]) logits, _, aux_loss = model(segments[1], mask = masks[1], memories = memories)
memories is a named tuple that contains the memory (mem) and the compressed memory (cmem)
When training, you can use the AutoregressiveWrapper to have memory management across segments taken care of for you. As easy as it gets.
import torch from compressive_transformer_pytorch import CompressiveTransformer from compressive_transformer_pytorch import AutoregressiveWrapper
model = CompressiveTransformer( num_tokens = 20000, dim = 512, depth = 6, seq_len = 1024, mem_len = 1024, cmem_len = 256, cmem_ratio = 4, memory_layers = [5,6] ).cuda()
model = AutoregressiveWrapper(model)
inputs = torch.randint(0, 20000, (1, 2048 + 1)).cuda()
for loss, aux_loss, _ in model(inputs, return_loss = True): (loss + aux_loss).backward() # optimizer step and zero grad
... after much training ...
generation is also greatly simplified and automated away
just pass in the prime, which can be 1 start token or any length
all is taken care of for you
prime = torch.ones(1, 1).cuda() # assume 1 is start token sample = model.generate(prime, 4096)
Citations
@misc{rae2019compressive, title = {Compressive Transformers for Long-Range Sequence Modelling}, author = {Jack W. Rae and Anna Potapenko and Siddhant M. Jayakumar and Timothy P. Lillicrap}, year = {2019}, eprint = {1911.05507}, archivePrefix = {arXiv}, primaryClass = {cs.LG} }
@misc{parisotto2019stabilizing, title = {Stabilizing Transformers for Reinforcement Learning}, author = {Emilio Parisotto and H. Francis Song and Jack W. Rae and Razvan Pascanu and Caglar Gulcehre and Siddhant M. Jayakumar and Max Jaderberg and Raphael Lopez Kaufman and Aidan Clark and Seb Noury and Matthew M. Botvinick and Nicolas Heess and Raia Hadsell}, year = {2019}, eprint = {1910.06764}, archivePrefix = {arXiv}, primaryClass = {cs.LG} }
@inproceedings{rae-razavi-2020-transformers, title = "Do Transformers Need Deep Long-Range Memory?", author = "Rae, Jack and Razavi, Ali", booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics", month = jul, year = "2020", address = "Online", publisher = "Association for Computational Linguistics", url = "https://www.aclweb.org/anthology/2020.acl-main.672" }
@article{Shazeer2019FastTD, title = {Fast Transformer Decoding: One Write-Head is All You Need}, author = {Noam Shazeer}, journal = {ArXiv}, year = {2019}, volume = {abs/1911.02150} }
@misc{shazeer2020glu, title = {GLU Variants Improve Transformer}, author = {Noam Shazeer}, year = {2020}, url = {https://arxiv.org/abs/2002.05202} }
@misc{lan2019albert, title = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations}, author = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut}, year = {2019}, url = {https://arxiv.org/abs/1909.11942} }
@misc{ding2021erniedoc, title = {ERNIE-Doc: A Retrospective Long-Document Modeling Transformer}, author = {Siyu Ding and Junyuan Shang and Shuohuan Wang and Yu Sun and Hao Tian and Hua Wu and Haifeng Wang}, year = {2021}, eprint = {2012.15688}, archivePrefix = {arXiv}, primaryClass = {cs.CL} }
