ICLR 2024 Roundup
TLDR; I’m sharing my personal reactions to ICLR 2024.
I’m an early(ish?) career researcher and I want a pulse on trends in the ML community. Topics that I sought out included neuroscience-inspired AI (e.g., spiking neural networks, local learning rules, etc.), mechanistic interpretability, some ML systems work, and various other topics like continual learning. Here’s a hodge-podge of comments, reactions, and takes. Feel free to skip to the paper summaries below.
- I’m paraphrasing heavily when I summarize my takeaways from the papers and talks.
- I’m a big fan of the 5-minute presentations that ICLR requested from all authors. Please encourage this at all conferences ! because it makes it so much easier to get a gist of the paper compared to reading the abstract, which is often too high level, or reading the paper, which can often take too long.
- I’ve short-listed 67 out of 2,296 accepted papers. Of those, I would like to actually read 28 of them. Of those I will probably read ~4.2 papers 😅. But even going through this little exercise of giving my one line summary of each of these papers was helpful to get a sense for the design space of paper possibilities AND to see lots of examples of research that made it through the peer-review process.
- Methods matter big time but seeing what problems people are working is really important for me right now. For instance, does learning time delays in spiking neural networks help with task performance? This is an interesting problem and if you sit down to think about it, you might come up with a dozen methods for approaching the problem. The problem is the interesting thing in this example.
- It’s possible that GPT-5 will solve all of our problems but if not, then people may start looking for the “next thing”. SNNs and neuromorphic chips may be the next battleground in AI in terms of both algorithms and hardware. I think there were enough papers on these topics at ICLR to say this is a trend.
⭐️s indicate papers that I’m particularly interested in reading.
Day 1 - Tuesday, May 7, 2024
- Predictive auxiliary objectives in deep RL mimic learning in the brain - Auxuliary predictive objectives improve RL learning. My sense is that in addition to learning from reward signals, they introduce self-supervised learning objectives (e.g., next-token prediction, etc.) to improve learning. Examining the learned representations, they show that they resemble representations in the brain.
Extra notes
- Cited previous work like Dan Yamins, and Jim DiCarlo showing representational similarity between CNNs and visual cortex.
- Authors claim deep RL can be a useful framework for thinking abount interacting brain regions (e.g., hippocampus). In particular they map the predictive network to the hippocampus and the q-learning module to the striatum (what's the striatum?).
- Predictive objectives include next-state prediction and maximizing the distance between randomly sampled different states. -
Model Tells You What to Discard: Adaptive KV Cache Compression for LLMs - This work analyzes attention patterns in transformer based language models to determine which tokens are being attended to. This can be used to determine which tokens can be discarded from the key-value cache, which is a major bottleneck for LLM inference on GPUs since they are memory bandwidth constrained. FlashAttention is another solution to the same problem - limited GPU memory bandwidth.
- ⭐️ A Stable, Fast, and Fully Automatic Learning Algorithm for Predictive Coding Networks - This work iterates on the idea of predictive coding. Predictive coding is the idea that neural representations (in the brain and in ANNs?) should be suited for predicting the next state from the current state. I see this as very similar to most self-supervised learning objectives.
Extra notes
- Authors point to the potential biological implausibility of backpropagation in the brain.
- They call for parallelization, locality, and automation in learning algorithms for non-von-Neumann architectures.
- In computational models, this is done via a minimization of the variational free energy, in this case a function of the total error of the generative model. TODO: read up on variational free energy.
-
BrainSCUBA: Fine-Grained Natural Language Captions of Visual Cortex Selectivity - The authors present a method for assigning semantically meaningful labels to functional areas of the visual cortex (e.g., this little voxel activates for faces, this one for houses, etc.). Best I can tell, they use a model that predicts from image space to voxel-wise brain activations and they search for the image that maximally activates a voxel. Then they use a vision-language multi-modal model to generate a text caption for the image that maximally activates the voxel.
-
Circuit Component Reuse Across Tasks in Transformer Language Models - This mechanistic interpretability paper provides evidence for the hypothesis that circuit components found in small transformers (e.g., GPT-2) may also be the same components used in larger models (e.g., GPT-2 Large). This would be desirable as it allows us to move towards a more general taxonomy of transformer components that are not specific to a single model and task.
-
Conformal Risk Control - The authors extend conformal prediction to control the expected value of any monotone loss function. Recall conformal prediction adjusts the size of the prediction set to guarantee coverage (e.g., expand regression prediction interval to guarantee the correct answer is in the set 90% of the time). The generalization here is to just consider other loss functions such as the false negative rate. If the loss function is the miscoverage loss, then their proposed risk control is equivalent to conformal prediction.
-
Initializing Models with Larger Ones - In this work, the authors propose to initialize neural network model weights for a smaller model using weights from larger models. This reminds me of two related ideas: knowledge distillation and Tony Zador’s genomic bottleneck idea.
-
Leveraging Generative Models for Unsupervised Alignment of Neural Time Series Data - My TLDR; this paper tries to use ideas of pretraining used for foundation models for models that learn patterns in neural data across recording sessions, tasks, and animals in contrast to the traditional approach, which is to build a generative model for each recording session, task or animal.
-
⭐️ Masks, Signs, And Learning Rate Rewinding - Not super familiar with learning rate rewinding (LRR) as a method for finding lottery ticket sparse subnetworks but they say that LRR is effective for finding sparse subnetworks and are robust to sign changes in the parameters learned in the model. 🤷♂️
-
⭐️ Sparse Autoencoders Find Highly Interpretable Features in Language Models - This is the type of work that I personally find very interesting. They learn a sparse codebook of interpretable features found in transformer models using sparse autoencoders to help with the challenge of polysemanticity of neurons. This type of decomposition can be useful for mechanistic explanations of the algorithms that deep learning models are implementing.
-
Synaptic Weight Distributions Depend on the Geometry of Plasticity - This work suggests that the Euclidean distance assumed for e.g., backpropagation results in synaptic weights that follow a different distribution from those found in the human brain, implying that other distance measures (and therefore geometry) may be more appropriate for synaptic weight learning.
-
Towards Best Practices of Activation Patching in Language Models: Metrics and Methods - This is definitely a paper of interest for me. They examine two types of activation patching used for circuit discovery in mechanistic interpretability and show that counterfactual prompting appears to work best compared to adding Gaussian noise to input tokens.
- Vision Transformers Need Registers - The authors add filler tokens to vision transformers that don’t correspond to any real input to serve as registers. This results in interpretable vision maps. Follow-up reading: Massive Activations in Large Language Models.
Extra notes
- Use [CLS] attention map in vision transformers to understand what the model is attending to.
- There's a real variety of attention patterns in vision transformers - many that seem spurious.
- High-norm tokens are aggregators that are collecting information from similar tokens.
- Register tokens take on object oriented attention behavior. - The mechanistic basis of data dependence and abrupt learning in an in-context classification task - They describe in-context learning abilities in terms of attention maps, the input context, and induction heads.
Extra notes
- Explained their results in terms of induction heads
- Talked about the data that leads to in-context learning or memorization.
- Distilled the induction head to two parameters that explain the in-context learning skills. -
Self-RAG: Learning to Retrieve, Generate, and Critique through Self-Reflection - Their framework trains a LM that adaptively retrieves passages by predicting a special token that tells the model to do retrieval. They then feed the retrieved content back into the model with the user query, score the generations, and generate more special tokens to indicate which of the responses is best. It was a bit of a frankenstein system.
-
⭐️ A differentiable brain simulator bridging brain simulation and brain-inspired computing - This paper introduces a Python framework BrainPy, which attempts to implement both differentiable spiking neural networks alongside traditional brain simulation models (e.g., Hodgkin-Huxley). It was developed using JAX and XLA.
-
BrainLM: A foundation model for brain activity recordings - This paper essentially does large scale language model style pre-training using a huge collection of fMRI recordings and shows that the resulting representations are quite good for downstream static classification tasks like age, anxiety, and PTSD. It’s less clear if the learned representations are useful for decoding brain activity for dynamic tasks like inferring what’s in the image someone is looking at.
-
Critical Learning Periods Emerge Even in Deep Linear Networks - Critical learning periods are time periods early in development where temporary sensory deficits can permanently damage the outcome of learning. This paper focuses on an analytical model of deep linear models (ignoring the complexity of non-linearities, etc.) and shows that both the depth of the neural network architecture and the data distributions under study can lead to critical learning periods. They liken this to classic Hubel and Wiesel experiments showing that early sensory deprivation can lead to permanent deficits in visual processing.
-
⭐️ Multi-modal Gaussian Process Variational Autoencoders for Neural and Behavioral Data - This paper attempts to use a VAE to learn latent variable representations of neural and behavioral data jointly. It attempts to disentangle which dimensions are representing which modality by attempting to reconstruct the data from a subset of the latent variables.
- Scaling Laws for Associative Memories - The authors define a mathematical abstraction of a transformer model using high-dimensional outer products and characterize the memory capacity of the model as the parameter count grows.
Day 2 - Wednesday, May 8, 2024
-
ReLU Strikes Back: Exploiting Activation Sparsity in Large Language Models - This is mostly an ML systems paper that takes advantage of the simple computation of ReLU as compared to more complicated activation functions like GELU. They exploit the sparsity of ReLU activations to speed up LLM inference by observing that if an output of a neuron is zero, then the corresponding incoming weights can be ignored. It’s just not obvious to me how they determine this ahead of time!
-
⭐️ Emergent mechanisms for long timescales depend on training curriculum and affect performance in memory tasks - This paper examines recurrent neural network abilities to perform a task that requires increasing memory over time. The task is $N$-parity, which looks at the trailing $N$ inputs and has to compute the parity of those inputs. They study the both the timescale of individual neurons (think time constant $\tau$ corresponding to biophysical properties of spike passing) and the network mediated timescale, which they say is the rate at which neurons in a network decorrelate their spiking activity.
-
⭐️ Forward Learning with Top-Down Feedback: Empirical and Analytical Characterization - Very cool work in the spirit of local learning rules and alternatives to backpropagation. They provide some unifying principles that link a number of forward-only learning procedures.
-
⭐️ Is This the Subspace You Are Looking for? An Interpretability Illusion for Subspace Activation Patching - This paper explores the problem of trying to determine if activation patching is truly discovering the underlying circuits in the model.
-
Local Search GFlowNets - I know very little about GFlowNets and need to do some more reading. But Yoshua Bengio is bullish on them so I’m interested.
-
⭐️ Neuron Activation Coverage: Rethinking Out-of-distribution Detection and Generalization - Studying the out-of-distribution problem from a neuron activation perspective seems like a natural thing to want to do (versus e.g., the input space or the entropy of the model’s predictions).
-
Prompt Gradient Projection for Continual Learning - This paper mashes up two existing ideas: 1) if you project gradients in the orthogonal direction to existing gradients, you can avoid catastrophic forgetting and 2) prompt-tuning, which introduces new trainable tokens to the model to help it learn new tasks.
-
⭐️ Scaling Laws for Sparsely-Connected Foundation Models - This paper introduces scaling laws for transformers as the sparsity level (the number of non-zero parameters) and the dataset size vary.
-
The Expressive Leaky Memory Neuron: an Efficient and Expressive Phenomenological Neuron Model Can Solve Long-Horizon Tasks - The abstract alone on this paper just makes you say wow. On a first read it looks like individual neurons in the brain are so incredibly complex that it requires neural networks with tens of thousands of parameters to model them.
- Mechanistic Interpretability Social Meetup
- Packed! The space is very high growth right now
- Behavioral tests insufficient because models might have, “I’m being evaluated circuits.”
- “How to avoid getting scooped?” - will have to read a lot and find time to do your own thing
- Many good 1-on-1 discussions after the panel finished
-
⭐️ Bayesian Bi-clustering of Neural Spiking Activity with Latent Structures - The author proposes a Bayesian procedure for identifying clusters of neural activity in both time and space.
-
⭐️ From Sparse to Soft Mixtures of Experts - Mixture of experts in language models is a way to route tokens to specific parameters in the model so you can avoid having to use all model parameters. Not only is this ideally more compute efficient (this is dynamic computation - closer to what the brain does) but maybe it allows for more specialization among the experts. You can do hard assignments or soft assignments (e.g., send some portion of the token to each expert). I think in this work, they’re putting a distribution over tokens and sending a little bit of each token to all experts. So a token that gets sent to an expert is a weighted combination of all tokens.
-
How connectivity structure shapes rich and lazy learning in neural circuits - This work builds on the idea that there are different learning regimes for neural networks: 1) rich learning where the network modifies its network weights quite a bit, and 2) lazy learning where the network doesn’t modify its weights much. Rich learning may be better at generalization by learning to ignore irrelevant features, for example. In this work, they look at how different initialization strategies (e.g., random or connectivity structures mimicking biological networks) affect the learning regime of the network. TLDR; the weights change much more when the network is initialized with a biological connectivity structure compared with random initialization. The learned parameters are also low-rank. Low-initial rank also seems to lead to more rich learning (as measured by the frobenius norm of the weight updates). The exception may be when evolutionary/inductive biases may lead to good initial performance.
-
Large Brain Model for Learning Generic Representations with Tremendous EEG Data in BCI - This paper is very similar in spirit to the BrainLM paper above, where instead of fMRI data, they’re working with EEG data.
-
⭐️ Learning dynamic representations of the functional connectome in neurobiological networks - This work attempts to identify which neural circuitry is working together within a short period of time, while accounting for the fact that these circuits may change over time.
-
⭐️ Mechanistically analyzing the effects of fine-tuning on procedurally defined tasks - The authors of this work showed an interesting result, which was that fine-tuning doesn’t change the internal representations of language models all that much, but rather it learns a wrapper around existing capabilities, where you might simply pre or post-process the input or output to the internal capability of the model to get the desired behavior.
-
⭐️ Pre-Training and Fine-Tuning Generative Flow Networks - This work attempts to define an unsupervised pre-training objective for GFlowNets, which typically require a reward function to be specified.
-
⭐️ TAB: Temporal Accumulated Batch Normalization in Spiking Neural Networks - Batch normalization is a technique used to stabilize training in deep neural networks by normalizing the activations of a layer over the batch dimension. It works by reducing internal covariate shift, which happens when the distribution of the activations changes after weight updates, resulting in confused downstream layers. In SNNs, BatchNorm becomes harder for a few reasons: 1) the activations are spikes, not continuous values, and 2) there’s a temporal dimension to the activations. This work introduces a new normalization technique that takes into account the temporal dimension of the activations. Best I can tell they’re computing a moving average of mean and variance statistics over the time-dimension before applying the normalization.
- ⭐️ What does the Knowledge Neuron Thesis Have to do with Knowledge? - This paper makes the claim that factual information isn’t simply stored in the MLP neurons of a language model as the paper “Locating and editing factual associations in GPT.” from 2022 seems to imply.
Day 3 - Thursday, May 9, 2024
-
⭐️ Analyzing Feed-Forward Blocks in Transformers through the Lens of Attention Maps - MLPs are hard to understand from a mechanistic interpretability perspective, which I think is the case due to not being able to linearly decompose activations because of the non-linear activation functions. I’m having a hard time getting a gist for this paper - probably need to read it in its entirety. Maybe they’re trying to measure how much outputs from MLPs affect downstream attention calculations?
-
⭐️ Dictionary Contrastive Learning for Efficient Local Supervision without Auxiliary Networks - In this work, the authors are trying to train models using only local supervision (e.g., no backprop) with a contrastive learning based approach to pushes together activations from the same class and pushes apart different classes. Their innovation was to add a bank of learned class representations that are used to compute the contrastive loss (i.e., an instance of the class should be similar to the class representation).
-
In-Context Pretraining: Language Modeling Beyond Document Boundaries - The idea of this work builds on a quirk of modern LLM training, which packs random documents together to form a single training instance. This is for efficiency, but the downside is that later documents can attend back to completely unrelated documents. This decision has always perplexed me. In this work, they find approximately related documents to pack into a single training instance and show it helps with downstream tasks.
-
⭐️ Learning Delays in Spiking Neural Networks using Dilated Convolutions with Learnable Spacings - Theoretical work has shown that spiking neural networks that have adjustable time-delays are more expressive than those that don’t. This work introduces a new method for learning time-delays in SNNs.
-
⭐️ Online Stabilization of Spiking Neural Networks - This paper appears to be similar to the paper above, TAB, which attempted to implement BatchNorm for SNNs.
-
Pre-training with Random Orthogonal Projection Image Modeling - Vision transformers can be trained through a masked modeling objective where patches are dropped from the image and the model attempts to predict the missing bits. This work takes a different approach by randomly projecting the image patches and attempting recover the original image patch. This might lead to a stronger learning signal and a smoother signal at that. They use sketching methods to implement the random projections. It’s slightly reminiscent of diffusion models because they’re predicting a sort of additive noise to the image. They show models train faster with their technique.
-
⭐️ Spatio-Temporal Approximation: A Training-Free SNN Conversion for Transformers - This paper provides a method for converting a pretrained transformer model to a spiking neural network. The challenge is that operations like self-attention and normalization pose challenges for SNNs, which have to compute through time.
-
⭐️ Spike-driven Transformer V2: Meta Spiking Neural Network Architecture Inspiring the Design of Next-generation Neuromorphic Chips - This paper focuses on transformer based SNNs (in contrast to CNN based SNNs) and how they might inspire next-generation neuromorphic chips.
-
Towards Understanding Factual Knowledge of Large Language Models - They introduce a dataset of fact related questions to probe LLM factuality.
-
⭐️ ExeDec: Execution Decomposition for Compositional Generalization in Neural Program Synthesis - They probe LLM abilities to piece together program subroutines into coherent programs. At a first glance, they might be using LLM agents to solve these compositional tasks so it could be cool to see another example of using LLM agents.
-
⭐️ A Framework for Inference Inspired by Human Memory Mechanisms - This paper uses the (potentially old?) idea of trying to model short and long-term memory systems. I think moving toward a multi-tiered memory system seems like a fun and potentially fruitful research direction for AI systems.
-
⭐️ A Progressive Training Framework for Spiking Neural Networks with Learnable Multi-hierarchical Model - The authors of this work roughly say the Leaky Integrate-and-Fire (LIF) neuron underperforms with deep-layer gradient calculation and capturing global information on the time dimension. They propose a new neuron to alleviate these issues. To draw a probably imperfect analogy, this feels like going from a vanilla RNN cell to an LSTM cell.
-
Complex priors and flexible inference in recurrent circuits with dendritic nonlinearities - The authors are interested in the question of how the brain represents priors about the world (e.g., the structure of faces, etc.). They put a probability distribution over neural circuits. They also draw inspiration from diffusion models and liken the prior to a data manifold and claim that diffusion modeling is like trying to take steps to get back onto the data manifold.
-
Implicit regularization of deep residual networks towards neural ODEs - This work investigates links between discrete (e.g., ResNets) and continuous time neural ODEs. The key result seems to be that if the network is initialized as a discretization of a neural ODE (not sure what this means), then such a discretization holds throughout training. I would have to look more closely at the paper to unpack this.
-
Manipulating dropout reveals an optimal balance of efficiency and robustness in biological and machine visual systems - This work examines what happens when you crank up the dropout level and 1) evaluate the model’s efficiency (I think measured by the dimensionality of information where higher-dimensions carries more information?), 2) evaluate robustness of the representations (robustness to data perturbations?), and 3) compare representational similarity between the model and the brain.
-
Modeling state-dependent communication between brain regions with switching nonlinear dynamical systems - This work attempts to provide a descriptive model of how brain regions communicate with each other.
-
One-hot Generalized Linear Model for Switching Brain State Discovery - This is a Bayesian treatment of identifying functional connections between disparate brain regions as the functional connectivity changes over time (e.g., when the subject is performing different tasks).
-
Parsing neural dynamics with infinite recurrent switching linear dynamical systems - This is probably a very crude approximation of what this paper is about but it seems to be investigating low-dimensional dynamics of neural activity while allowing those dynamics to vary as the brain “changes states” such as when the subject is performing different tasks.
-
Towards Energy Efficient Spiking Neural Networks: An Unstructured Pruning Framework - This is a systems paper for SNNs that attempts to sparsify networks with an eye towards energy savings on neuromorphic chips.
Day 4 - Friday, May 10, 2024
-
Neuroformer: Multimodal and Multitask Generative Pretraining for Brain Data - This paper attempts to pre-train general purpose transformers on multi-modal brain data (e.g., neural recordings, behavior data, etc.) and show that the representations learned are useful for downstream tasks. Clearly there was a trend of trying to do unsupervised pre-training of transformers on brain data at this conference.
-
⭐️ SpikePoint: An Efficient Point-based Spiking Neural Network for Event Cameras Action Recognition - This paper proposes an SNN for event detection in streams of camera data.
-
Sasha Rush Session - There were career chat sessions with tons of big-name researchers. Sasha answered all sorts of questions about career advice and where the field of NLP is heading. This was a cool idea by the conference organizers to host these sessions.
-
Provable Compositional Generalization for Object-Centric Learning - This was a bit of a theory paper that defined what it means for a model (an autoencoder) to generalize to novel compositions of objects. Compositionality is appealing because it allows us to generalize to novel objects by combining known objects in novel ways. For instance, if we know what legs of a chair look like and we know what tops of tables look like, we might be able to imagine what a table with legs might look like.
-
BTR: Binary Token Representations for Efficient Retrieval Augmented Language Models - This is a retrieval augmented generation (RAG) paper that speeds up the retrieval process by using binary token representations. Question: is it common to pass both the query and document to be retrieved through a transformer model? I thought sentence embeddings were used for retrieval.
-
Decoding Natural Images from EEG for Object Recognition - This works follows a long line of work trying to predict what someone is thinking about (e.g., an image label, etc.) based on their brain activity. This paper uses two encoders - an image encoder and an EEG encoder - with a constrastive loss to learn a shared representation between the two modalities. They also introduce some new spatial attention mechanisms for the EEG encoder.
-
Overthinking the Truth: Understanding how Language Models Process False Demonstrations - This mechanistic interpretability paper sets up careful counterfactual prompts. The correct prompt shows $k$ text examples with their corresponding correct sentiment labels and the counterfactual is to flip the labels to be incorrect. They ablate attention heads by zeroing out their outputs to observe which has the most impact on the model’s in-context learning of the incorrect labels. TLDR; models might use separate circuits for prompts that reflect reality/truth and those that are more adversarial or are incorrect.
-
Epitopological learning and Cannistraci-Hebb network shape intelligence brain-inspired theory for ultra-sparse advantage in deep learning - My gist (which may not be precise) after talking with the authors is that they have a network science method to initialize sparse connectivity structures in neural networks, which serve as good starting points for gradient descent.
-
Hebbian Learning based Orthogonal Projection for Continual Learning of Spiking Neural Networks - It’s very funny how the title of this paper contains all my favorite buzzwords. They build on the classic idea from continual learning to only update weights using gradient directions that are orthogonal to the directions of previous tasks.
-
Successor Heads: Recurring, Interpretable Attention Heads In The Wild - This paper examines the mechanisms responsible for incrementation tasks (e.g., today is Tuesday so tomorrow is Wednesday) in transformers.
-
⭐️ Traveling Waves Encode The Recent Past and Enhance Sequence Learning - This work treats encoded information as waves and shows how different information can be combined and decoded using these waves. It’s an interesting idea for short-term memory and I think it has analogs to recent state space models.
Day 5 - Saturday, May 11, 2024
- First Workshop on Representational Alignment (Re-Align) - This workshop was aimed to examining representational alignment between primarily neural networks and neural data. In particular, their stated aim was defining, evaluating, and understanding the implications of representational alignment among biological & artificial systems. The organizers released a position/survey paper that you might find interesting - https://arxiv.org/abs/2310.13018. There were lots of interesting talks and posters.
Leave a comment