Learn Before
  • Objective Function for Student Model Training via Knowledge Distillation

  • Sequence-Level Loss

  • Definition of Teacher's Probability Distribution (Pt) in Knowledge Distillation

Cross-Entropy Loss for Knowledge Distillation

A frequently used loss function in knowledge distillation is the sequence-level loss, which often takes the form of cross-entropy. This loss measures the dissimilarity between the teacher model's output distribution, Prt(yc,z)\text{Pr}^t(\mathbf{y}|\mathbf{c}, \mathbf{z}), and the student model's distribution, Prθs(yc,z)\text{Pr}_{\theta}^s(\mathbf{y}|\mathbf{c}', \mathbf{z}). The total loss is calculated by summing the log probability of the student's predictions over all possible output sequences y\mathbf{y}, weighted by the teacher's probability for each sequence. The formula is expressed as: Loss=yPrt(yc,z)logPrθs(yc,z)\text{Loss} = \sum_{\mathbf{y}} \text{Pr}^t(\mathbf{y}|\mathbf{c}, \mathbf{z}) \log \text{Pr}_{\theta}^s(\mathbf{y}|\mathbf{c}', \mathbf{z})

Image 0

0

1

7 days ago

Contributors are:

Who are from:

Tags

Ch.3 Prompting - Foundations of Large Language Models

Foundations of Large Language Models

Foundations of Large Language Models Course

Computing Sciences

Related
  • Cross-Entropy Loss for Knowledge Distillation

  • Using KL Divergence for Knowledge Distillation Loss

  • A research team is training a small, efficient 'student' model to replicate the behavior of a large, powerful 'teacher' model. The team's goal is to find the optimal parameters for the student model (θ^\hat{\theta}) by minimizing a loss function over a dataset of simplified inputs (D\mathcal{D}'), as defined by the following objective:

    θ^=argminθxDLoss(Prt(),Prθs(),x)\hat{\theta} = \arg\min_{\theta} \sum_{\mathbf{x}' \in \mathcal{D}'} \text{Loss}(\text{Pr}^t(\cdot|\cdot), \text{Pr}_{θ}^s(\cdot|\cdot), \mathbf{x}')

    Where Prt\text{Pr}^t is the teacher's output probability distribution and Prθs\text{Pr}_{θ}^s is the student's.

    If the team mistakenly configures the training process to use the teacher's original, complex dataset instead of the intended simplified dataset D\mathcal{D}', which of the following outcomes is the most direct and likely consequence for the student model?

  • Loss Function for RNN

  • Sample-wise Negative Log-Likelihood Loss for a Sub-sequence

  • Cross-Entropy Loss for Knowledge Distillation

  • A language model is being trained to generate the four-word sentence 'The quick brown fox'. The model generates one word at a time, and the error (loss) is calculated at each step:

    • Loss for 'The' = 0.1
    • Loss for 'quick' = 0.3
    • Loss for 'brown' = 0.2
    • Loss for 'fox' = 0.4

    To update the model's parameters, the training process computes a single, overall loss value for the entire sentence. Which statement best analyzes this method of calculating the overall loss?

  • Total Loss Calculation for a Token Sequence

  • Calculating Average Sequence-Level Loss

  • KL Divergence Loss for Knowledge Distillation

  • Cross-Entropy Loss for Knowledge Distillation

  • A large, complex language model is used to generate target probabilities for training a smaller, more efficient model. For the input sentence 'The cat sat on the ___', the large model could produce different probability distributions for the next word. Which of the following distributions, representing P(extoutputextcontext)P( ext{output} | ext{context}), would provide the most informative and nuanced training signal for the smaller model?

Learn After
  • Computational Infeasibility of Full Output Summation in Distillation Loss

  • A student model is trained to mimic a teacher model by minimizing the following loss function, which measures the dissimilarity between their output probability distributions for a given input:

    Loss=yPrt(y)logPrθs(y)\text{Loss} = -\sum_{\mathbf{y}} \text{Pr}^t(\mathbf{y}) \log \text{Pr}_{\theta}^s(\mathbf{y})

    In this formula, Prt(y)\text{Pr}^t(\mathbf{y}) is the teacher's probability for an output sequence y\mathbf{y}, Prθs(y)\text{Pr}_{\theta}^s(\mathbf{y}) is the student's probability, and the summation is over all possible output sequences. What is the primary function of the summation (y\sum_{\mathbf{y}}) over the entire space of possible outputs?