2 분 소요

This is a brief review for “LLM‑JEPA: Large Language Models Meet Joint Embedding Predictive Architectures”. You can see the paper at this link.

Overview

Modern language models are largely trained by predicting the next token in an input sequence. Although effective for generation, this objective does not necessarily yield the most abstract or transferable representations. The paper introduces LLM‑JEPA, which augments the usual autoregressive loss with a joint embedding predictive architecture (JEPA) loss inspired by self‑supervised learning in vision. The idea is to treat pairs of inputs that express the same underlying information—such as a natural‑language description and a regular expression—as two views of the same concept. The JEPA component trains the model to predict the embedding of one view from the other, using cosine similarity as the metric. A special [PRED] token leverages the LLM’s own weights to implement the predictor, keeping the architecture unchanged[1]. To compute embeddings for each view without mixing information in attention, the model runs two separate forward passes and uses the hidden state of the last token as the representation[2]. The final loss combines the standard next‑token prediction and the JEPA loss, with a hyperparameter controlling their balance. Experiments show that LLM‑JEPA improves fine‑tuning performance on tasks such as natural language to regex (NL‑RX‑SYNTH and NL‑RX‑TURK), math problem solving (GSM8K) and text‑to‑SQL generation (Spider) across model families like Llama 3, Gemma 2, OpenELM and OLMo[3]. Preliminary pre‑training results indicate benefits for downstream sentiment analysis while preserving generative capabilities[4]. The principal limitation is the need for multiple forward passes per training example, increasing compute cost, and the reliance on datasets that naturally provide multiple views[5].

Key Ideas

  • Joint embedding prediction in NLP: Pairs of inputs representing the same concept (e.g., text and code) are used as two views. A cosine‑similarity loss is applied between the predicted and target embeddings, encouraging the model to align representations without collapsing dimensions.
  • Hybrid training objective: The overall loss is the sum of the standard autoregressive next‑token loss and the JEPA loss. A [PRED] token allows reuse of the LLM’s parameters to implement the predictor, avoiding architectural changes[1].
  • Multiple forward passes: Each view is encoded via separate forward passes through the model to avoid cross‑view interactions in self‑attention. The hidden state of the last token from the final layer serves as the embedding[2].
  • Empirical gains: Across diverse datasets and model sizes, LLM‑JEPA yields higher exact‑match accuracy in fine‑tuning experiments and shows promise in pre‑training scenarios[3].
  • Limitations: The method currently triples training cost because each pair of views requires additional forward passes[6]. It also depends on tasks that naturally provide multiple aligned views, leaving open the question of how to generalize JEPA to arbitrary language data.

Why it matters

LLM‑JEPA bridges representation learning techniques from vision and language. By jointly training models to predict embeddings across complementary views while retaining generative ability, it produces more robust representations and improves reasoning and code‑generation tasks. This work suggests that combining generative and embedding‑based objectives could be a fruitful direction for advancing language models.

댓글남기기