What are tree recursive neural networks (Tree-RNNs)?

Overview

A Recursive Neural Network (RNN) falls under the category of a deep neural network. It is constructed in such a way that it includes applying the same set of weights recursivelyA procedure or rule that is repeating over the different tree-like structures. Recursive neural networks generalize the recurrent neural networks from a chain-like structure to a tree-like structure. RNNs are typically used for processing a natural language sentence. Therefore, RNN is a special case for Natural Language Processing (NLP).

A simple RNN structure is illustrated in the figure below:

Recurrent neural network | tree-structured RNN

A recursive neural network uses the learning algorithms, which means it predicts the output values (structured predictione.g. number, an alphabet, a word, or an object rather than real or discrete values) from the given variable-size input data. Hence, it is more powerful than the traditional feedforward neural network.

Implementation

Tree RNNs are commonly used to implement the syntactic treesRecognize the structure of English sentences and phrases that identify the meaningful phrases within the sentences and how they relate to each other. The idea is that it recursively mergescombines the pair of phrases/wordsInput to form a sentenceoutput. There is a variable named score which is calculated at each traversal of nodes that tells us which pair of phrases and words need to be combined first to form the best syntactic tree which explains a given sentence. The question arises, how do we actually calculate the score and merge the phrases/inputs? Here's an architecture of simple Tree-RNN through which we will explain the working of a recursive neural network.

Architecture-Simple Tree RNN

In our illustration, x1 and x2 are carrying the words. At the hidden layer, these words are merging in the form of vector representation p, which is calculated through the following recurrence relation.

Where,

  • f = Non-linear activation function ( Sigmoid, Tanh, and so on).

  • W[x1;x2]: Concatenation of x1 and x2 which is then multiplied by weight W.

  • b = Bias

After the merged pair calculation, we need to find the score S at each node which is calculated by the following relation:

  We will get S by multiplying particular weight W and merging vector P. This enables us to find the best pair of nodes,words / phrases which will go first in making the complete sentence.

Example

Example of Tree - RNN [Source: Socher, Lin, Ng, and Manning, 2011]

The above illustration shows that there are different words at each input, and they are combined together at the hidden layer. This procedure is done by calculating P and S. The traversing is done at each node and the score is calculated recursively. In the end, all the words are merged at the top and form a sentence which in our example is "A small crowd quietly enters the historic church."

Code

# initialize the recursive neural network and the hidden layers
my_rnn = RNN()
hidden_layers = [0, 0, 0, 0, 0, 0, 0, 0]
sequence = ["A", "small", "crowd", "quietly","enters","the","historic","church"]
# feeding the sequence as input to the network
for word in sequence:
prediction, hidden_layers = my_rnn(word, hidden_layers)
# prediction for the next word
next_word = prediction

Explanation

  • Lines 2 and 3: We initialize the recursive neural network and hidden layers.

  • Line 5: We enter sample input data in the form of verbs/nouns/phrases.

  • Lines 8 and 9: We loop through the given input data sequence which feeds both the current word and the previous internal state into the RNN. A structured prediction is then made and the internal state is updated.

  • Line 12: We return the output from the variable-size input data.

Conclusion

Therefore, a recursive neural network is a tree-based structure that uses the learning algorithms for structured predictionsupervised learning. It also has an importance in Natural Language Processing (NLP) as it uses Tree - RNNs for performing the operations such as sentiment analysis and semantic parsing.

Free Resources