Learn Before
Concept

RNN Backpropagation Formulas

Let's start with a quick recap of the basic equations in RNN: h(t)=g1(Wh(t1)Ux(t)+bh)h^{(t)}=g_1(Wh^{(t-1)} \oplus Ux^{(t)}+b_h) y^(t)=g2(Vh(t)+by)\hat y^{(t)}=g_2(Vh^{(t)}+b_y) J(t)=L(y^(t),y(t))     and     J=1Tt=1TJ(t)J^{(t)} = \mathcal{L} (\hat{y}^{(t)}, y^{(t)}) ~~~~~and~~~~~ J = \frac{1}{T} \sum_{t=1}^T J^{(t)}

We want to calculate the gradients of the loss function with parameters U,V, and W. We can sum up the gradients at each step since JW=1Tt=1TJ(t)W\frac{\partial J}{ \partial W} = \frac{1}{T} \sum_{t=1}^T \frac{\partial J^{(t)}}{\partial W} . Thus we only need to find J(t)W\frac{\partial J^{(t)}}{\partial W} , J(t)V\frac{\partial J^{(t)}}{\partial V} , and J(t)U\frac{\partial J^{(t)}}{\partial U} .

Derivative of V only depends on the current step: J(t)V=J(t)y^(t)y^(t)V\frac{\partial J^{(t)}}{\partial V} = \frac{\partial J^{(t)} }{\partial \hat{y}^{(t)}} \frac{\partial \hat{y}^{(t)}}{\partial V}

But it is not the case for derivative of U and W, take t=3 as an example: J(3)W=k=13J(3)y^(3)y^(3)h(3)h(3)h(k)h(k)W\frac{\partial J^{(3)}}{\partial W} = \sum_{k=1}^3 \frac{\partial J^{(3)} }{\partial \hat{y}^{(3)}} \frac{\partial \hat{y}^{(3)}}{\partial h^{(3)}} \frac{\partial h^{(3)}}{\partial h^{(k)}} \frac{\partial h^{(k)}}{\partial W} =k=13J(3)y^(3)y^(3)h(3)(j=k+13h(j)h(j1))h(k)W =\sum_{k=1}^3 \frac{\partial J^{(3)} }{\partial \hat{y}^{(3)}} \frac{\partial \hat{y}^{(3)}}{\partial h^{(3)}} (\prod_{j=k+1}^3 \frac{\partial h^{(j)}}{\partial h^{(j-1)}}) \frac{\partial h^{(k)}}{\partial W} and J(3)U=k=13J(3)y^(3)y^(3)h(3)h(3)h(k)h(k)U\frac{\partial J^{(3)}}{\partial U} = \sum_{k=1}^3 \frac{\partial J^{(3)} }{\partial \hat{y}^{(3)}} \frac{\partial \hat{y}^{(3)}}{\partial h^{(3)}} \frac{\partial h^{(3)}}{\partial h^{(k)}} \frac{\partial h^{(k)}}{\partial U} =k=13J(3)y^(3)y^(3)h(3)(j=k+13h(j)h(j1))h(k)U =\sum_{k=1}^3 \frac{\partial J^{(3)} }{\partial \hat{y}^{(3)}} \frac{\partial \hat{y}^{(3)}}{\partial h^{(3)}} (\prod_{j=k+1}^3 \frac{\partial h^{(j)}}{\partial h^{(j-1)}}) \frac{\partial h^{(k)}}{\partial U}

0

1

Updated 2020-10-17

Tags

Data Science