Clean TreeLSTMs implementation in PyTorch using NLTK treepositions and Easy-First Parsing
Overview
Tree LSTMs generalise LSTMs to tree-structured network topologies. Compared to sequence models which consider words as they appear in temporal order (or in the reverse direction), tree-structured models compose each phrase from its subphrase according to the given syntatic structure.
Tree LSTMs are conceptually straightforward extension of RNN-LSTMs but need a fair bit of thought to implement. Here’s what we need:
- A parser (takes in a sentence and outputs the parse tree structure).
- Stanford Core NLP/your own parser
-
Given the parse tree structure which implicitly contains how the word units should be progressively combined, convert this into a series of instructions which explicitly describes how the words should be combined.
- Write a RNN that takes in a series of instructions on how to combine a list of inputs. This strategy is inspired by stack LSTMs (Dyer et al., 2015) and Easy-first parsing(Kipperwasser and Goldberg, 2016). The key advantage of the latter approach is that we can access any element of the ‘stack’ (therefore it is no longer a stack).
1. Parser
- Download Stanford CoreNLP (2018 onwards) into your working directory.
- Write a script called
ParserDemo.java
, which reads in our text file and outputs the parses
A bash script to compile and run our Parser script.
Now, the sentence “Papa ate the caviar with a spoon” becomes
(ROOT (S (NP (NN Papa)) (VP (VBD ate) (NP (DT the) (NN caviar)) (PP (IN with) (NP (DT a) (NN spoon)))) (. .))) “
We can visualise this with the nltk package
2. Converting the parse tree into a series of instructions
We build a Tree-LSTM from our understanding of how a standard RNN works. In contrast to the standard RNN which takes in the input from the previous time step, the tree-LSTM will take inputs from the hidden states of its child cell as described by the syntatic parse.
This implementation is heavily influenced by stack-reduce-parsing. However unlike stack-reduce which processes things sequentially from the stack, our instructions allow us to combine the current representations at every position of the stack.
We need to
- Maintain a stack with word and phrase representations to process
- Instructions on what to combine
- Update the stack accordingly
To extract the instructions, we rely on the nltk tree package. The cnf tree is encoded as a series of binary branching instructions, 0
indicates branch left and 1
indicates branch right. Based on the binary branching tree above, we can work out what each position refers to.
Taking the leaves as example:
- ‘Papa’: Left(0)-Left(0)-Left(0).
- ‘Caviar’: Left(0)-Right(1)-Left(0)-Right(1)-Right(1)-Left(0).
We can encode every tree position in this fashion.
treepositions
has parameters postorder
and preorder
, which corresponds to different types
of depth-first search. We use the pre-order search order for this implementation.
The buffer is given by tree.treepositions('postorder')
. For each element in the buffer, if it
is a leaf, add it to the stack to process. If not, get the stack-positions of all its children, and replace the children on the stack with the parent element. After processing each non-leaf buffer element, increment a sequence of actions. The action sequence can either be unary(single element) or binary(tuple), which informs us which stack positions should be progressively fed as input to the neural network starting from the leaves(words).
3. Composing RNNs
Unlike standard RNN modules hich take in a sequence, we have to progressively combine inputs as specified by the action sequence, therefore all our RNN units are maintained at the cell level. The architecture consists of
- two identical LSTM cells for combining the child nodes of the tree.
- a final layer
- initialisation layers for hidden and cell units of LSTM
While looping over the action sequence, we update our stack after each action with the new input, hidden cell, lstm cell: (x, (hx, cx))
, noting that beyond the leaf nodes, x
is simply a zero tensor because we only need to consider hx
as inputs to the next cell.
We also need to maintain the stack with update and delete actions the same way we did previously when getting instructions from the parse tree.
In this example we combine child cells by simply summing them together, but there are of course more sophisticated ways of combining cells by doing clever things with the gating mechanism of the LSTM (Tai et al., 2015).
References
Tai, K. S., Socher, R., & Manning, C. D. (2015). Improved semantic representations from
tree-structured long short-term memory networks. arXiv preprint arXiv:1503.00075.
Dyer, C., Ballesteros, M., Ling, W., Matthews, A., & Smith, N. A. (2015). Transition-based dependency
parsing with stack long short-term memory. arXiv preprint arXiv:1505.08075.
Kiperwasser, E., & Goldberg, Y. (2016). Easy-first dependency parsing with hierarchical tree LSTMs. Transactions of the Association for Computational Linguistics, 4, 445-461. arXiv preprint.