Solving context length constraints by distillation
Berkeley works to give models infinite context
If you’ve existed around language model circles, you certainly would have heard of “context length” — how many tokens the model can attend to in one go. A limited context length is the reason a model can’t properly summarize an entire book — it’s much longer than the context length the model can hold.
But it’s not like humans remember an entire book word for word when they read it. Yet, we can summarize it. Can we replicate this behavior in neural models?
Introduction and Motivation
Goal: Allow for infinitely long (or at least, substantially longer) context length through “distilling” context. This will allow for longer context, and also allow us to fit more training examples in context, continually improving performance.
Method: Make the model “internalize” step-by-step reasoning by using a teacher/student interface, and fine-tuning the student on the teacher’s output.
An example should clarify what this represents.
The idea here is that the teacher’s output is much more likely to be correct, since it uses commonly trusted context-amendment methods like few shot learning and chain of thought prompting.
When we train the student to directly predict the final answer, it no longer needs to explicitly generate the scratchpad that the teacher did, but learns to do so internally because of the fine-tuning. The goal is to transfer the useful context information into the parameters.
Evaluation
I quite like how they arranged their evaluation section — they come up with a set of hypotheses that demonstrate the effectiveness of context distillation. Some of the salient ones are as follows:
Internalizes abstract task instructions: When the teacher model was given a detailed task instruction and the student was just given a task, fine-tuning with context distillation improved performance from 9% to 34.7%
Learn from explanations: When the teacher model was given an explanation for a task (like sentiment categorization), on fine-tuning student performance, there was a correlation between how useful the explanation was for the teacher and how well the student learned. Note, this does not state that context distillation with explanations increased student model performance.
Internalize step-by-step reasoning: After distilling, the student model, even without CoT improved its ability to perform direct addition from 0% to 94.7%. If the question was changed to an indirect addition (“I have 2 sandwiches. Anton has 3. He gives me his sandwiches. How many do I have?”), the model retains its new mathematics ability, demonstrating transferability to related tasks.
It is worth mentioning that such a technique drastically reduces inference time tokens (by about 10x), which could be a massive cost saver for production-time models.
Limitations and Future Work
One can imagine that such fine-tuning would lead to something like “overfitting”, where the model gets particularly good at a handful of tasks, but worse at more general downstream tasks. The paper does not highlight such a degradation of model capabilities, but I would have liked to see them confirm it.
The paper ran a lot of its evaluation on InCoder, an edit-based programming model equivalent to Codex (more commonly used as Github Copilot). I would have liked to see this nature on testing on more common models, like fine-tuning GPT3 or Flan 540B.
Their techniques do not directly compare to other techniques they draw inspiration from, like chain-of-thought reasoning or few shot learning. Given the same context length, does distillation outperform these?
To summarize — I think context distillation is interesting work, but I suspect there are side-effects to this technique that aren’t investigated by the paper. It also does not really live up to its promise of universally increasing context length, since the student/teacher model comes with its own drawbacks. Never-the-less, for specializing LLMs and saving context space, distillation demonstrates that it is an effective strategy.