Innovative training of LLMs in continuous latent spaces, by Meta AI
LLMs have made significant progress in tasks based on language processing and understanding. However, their reasoning capabilities, particularly in complex scenarios, often fall short. Traditional approaches like Chain-of-Thought (CoT) reasoning, which asks the model to reason before answering by thinking step-by-step, face inherent limitations. A recent paper introduces a novel approach called Coconut (Chain of Continuous Thought), which aims to enhance LLM reasoning by shifting from “language space” to a continuous latent space.
The limitations of language-based reasoning
Language-based reasoning, primarily done by methods like CoT, involves LLMs generating intermediate steps in natural language to solve complex problems. This method aligns with how humans articulate their thought processes but does not necessarily reflect the actual cognitive processes involved in reasoning. Neuroimaging studies have shown that brain regions responsible for language are not always active during reasoning tasks.
Consequently, language-based reasoning can be inefficient, as many tokens generated are more for textual coherence than for reasoning. This inefficiency arises because LLM architectures allocate a nearly uniform computational budget for predicting each token, regardless of its importance to the reasoning process.
Introducing Coconut: A new paradigm
Coconut proposes a shift from language tokens to "continuous thoughts" in a latent space. Instead of decoding the last hidden state of an LLM into a word token, Coconut uses this state directly as the input embedding for the next reasoning step. This method leverages the continuous nature of the hidden state, allowing for more flexible and efficient reasoning. Continuous thoughts can encode multiple potential next steps simultaneously, enabling a breadth-first search (BFS) approach to problem-solving. This differs from the deterministic path followed by CoT, which can prematurely commit to a single reasoning path.
Emergent advanced reasoning patterns
One of the key advantages of Coconut is its ability to encode multiple alternative reasoning paths within the continuous latent space. This capability allows the model to perform a BFS, maintaining several possible options and progressively eliminating incorrect paths. This approach is particularly effective in tasks that require substantial backtracking and planning. For instance, in logical reasoning tasks, Coconut has shown to outperform CoT by solving problems more efficiently and with fewer tokens.
Methodology
Training procedure
Coconut's training involves a multi-stage curriculum where the model (GPT-2) is trained on standard CoT data. As training progresses, the reasoning steps in CoT are incrementally replaced with continuous thoughts. Special tokens are used to mark the beginning and end of the latent thought mode (<bot>, <eot>). During the training process, the model optimizes the negative log-likelihood loss but masks the loss on questions and latent thoughts, encouraging the LLM to learn effective representations of reasoning steps beyond natural language.
Inference process
During inference, Coconut operates similarly to standard language models but switches to latent mode to utilize the last hidden state as the next input embedding. The model can decide when to terminate the latent reasoning either autonomously or by padding latent thoughts to a constant length. This process allows Coconut to handle complex reasoning tasks with greater flexibility and efficiency.
Experimental results
Coconut was tested on various datasets to evaluate its performance in different reasoning scenarios:
- Math reasoning: Using the GSM8k dataset, Coconut demonstrated superior performance in solving grade school-level math problems, which are diverse and open-domain.
- Logical reasoning: As the ProntoQA benchmark lacks the need for complex planning, the authors also created a ProsQA dataset which requires the model to perform substantial planning and searching over the graph to find the correct reasoning chain. On the ProntoQA and ProsQA datasets, Coconut outperformed traditional CoT and other baselines, particularly in tasks requiring extensive planning and search.
The experiments revealed that continuous thoughts enhance the model's reasoning depth and capability. By encoding a distribution of different reasoning traces, Coconut can handle more complex problem-solving scenarios effectively.
[Coconut was compared to CoT (model finetuned with reasoning chains and generated reason before answer), No-CoT (no reasoning chain), iCoT (model is trained with reasoning chains and internalizes CoT to predict an answer at inference), Pause token (special tokens are inserted between the question and answer for extra compute time)]
In-depth analysis
The findings in the paper highlight several critical aspects of Coconut's performance:
- Planning and look-ahead capability: Coconut's latent reasoning mode allows the model to evaluate multiple potential steps ahead, improving decision-making in planning-intensive tasks.
- Efficiency of continuous thoughts: Continuous thoughts provide an efficient representation of reasoning steps, reducing the number of tokens generated and enhancing the model's computational efficiency.
- Comparative performance: Coconut consistently outperforms traditional CoT and other methods across various reasoning tasks, indicating its robustness and adaptability.
A notable example shown below is a case study where CoT hallucinated a nonexistent edge, leading to incorrect reasoning. In contrast, Coconut maintained multiple possible paths, progressively eliminating incorrect ones to arrive at the correct solution. This demonstrates the practical advantages of latent space reasoning in complex scenarios.
Conclusion and future directions
Coconut represents a significant advancement in LLM reasoning capabilities by leveraging continuous latent space. This approach not only enhances the efficiency and depth of reasoning but also opens new avenues for solving complex problems that require planning and backtracking. Future research can focus on refining the training methods, integrating more advanced curriculum strategies, and exploring pretraining LLMs with continuous thoughts. By shifting the focus from language-based reasoning to continuous latent space, Coconut offers a promising direction for enhancing the cognitive capabilities of LLMs and addressing the limitations of current approaches.
Maxim is an evaluation platform for testing and evaluating LLM applications. Test your Gen AI application's performance with Maxim.