Class: LSTM_NETWORK
- Inherits:
-
Object
- Object
- LSTM_NETWORK
- Defined in:
- lib/NETWORK.rb
Overview
██╗░░░░░░██████╗████████╗███╗░░░███╗ ███╗░░██╗███████╗████████╗░██╗░░░░░░░██╗░█████╗░██████╗░██╗░░██╗██║░░░░░██╔════╝╚══██╔══╝████╗░████║ ████╗░██║██╔════╝╚══██╔══╝░██║░░██╗░░██║██╔══██╗██╔══██╗██║░██╔╝██║░░░░░╚█████╗░░░░██║░░░██╔████╔██║ ██╔██╗██║█████╗░░░░░██║░░░░╚██╗████╗██╔╝██║░░██║██████╔╝█████═╝░██║░░░░░░╚═══██╗░░░██║░░░██║╚██╔╝██║ ██║╚████║██╔══╝░░░░░██║░░░░░████╔═████║░██║░░██║██╔══██╗██╔═██╗░███████╗██████╔╝░░░██║░░░██║░╚═╝░██║ ██║░╚███║███████╗░░░██║░░░░░╚██╔╝░╚██╔╝░╚█████╔╝██║░░██║██║░╚██╗╚══════╝╚═════╝░░░░╚═╝░░░╚═╝░░░░░╚═╝ ╚═╝░░╚══╝╚══════╝░░░╚═╝░░░░░░╚═╝░░░╚═╝░░░╚════╝░╚═╝░░╚═╝╚═╝░░╚═╝This class handles the single-height LSTM network. The LSTM network comprises of an array of LSTM cell objects, an input matrix, and a target matrix. Forward propagation for the network, and backward network propagation is implemented here. This is seperate from the cellular-level forward and back propagation in the LSTM_CELL class.
Instance Method Summary collapse
- #applyWeightChange ⇒ Object
- #backwardPropagate ⇒ Object
- #forwardPropagate(initialH = DFloat.zeros(1, @sz), initialC = DFloat.zeros(1, @sz)) ⇒ Object
- #getInput(mode = "word_mode") ⇒ Object
- #getLSTMNodes ⇒ Object
- #getOutput(mode = "word_mode") ⇒ Object
- #getTarget(mode = "word_mode") ⇒ Object
- #init(nodes, x_dim, alpha, terminal_output = nil) ⇒ Object
- #setDictionary(dictionary) ⇒ Object
- #setInput(input, mode = nil) ⇒ Object
- #setTarget(target, mode = nil) ⇒ Object
Instance Method Details
#applyWeightChange ⇒ Object
158 159 160 161 162 |
# File 'lib/NETWORK.rb', line 158 def applyWeightChange() for i in 0...@lstm_nodes.column_count() @lstm_nodes[0,i].applyWeightChange() end end |
#backwardPropagate ⇒ Object
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# File 'lib/NETWORK.rb', line 140 def backwardPropagate() i = 0 delta_h_init = DFloat.zeros(@lstm_nodes[0, i].getHt().shape()) delta_h_init = 2 * (@lstm_nodes[0, i].getYt() - @lstm_nodes[0, i].getHt()) delta_c_init = DFloat.zeros(1, @sz) @lstm_nodes[0,i].backwardPropagation(delta_h_init, delta_c_init) i += 1 while i <= @lstm_nodes.column_count()-1 do #puts "BP cell: " + i.to_s delta_h = DFloat.zeros(@lstm_nodes[0, i].getHt().shape()) delta_h = 2 * (@lstm_nodes[0, i].getYt() - @lstm_nodes[0, i].getHt()) delta_h += @lstm_nodes[0, i-1].getBottomDeltaHt() delta_c = @lstm_nodes[0, i-1].getBottomDeltaCt() @lstm_nodes[0,i].backwardPropagation(delta_h, delta_c) i += 1 end end |
#forwardPropagate(initialH = DFloat.zeros(1, @sz), initialC = DFloat.zeros(1, @sz)) ⇒ Object
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
# File 'lib/NETWORK.rb', line 108 def forwardPropagate(initialH=DFloat.zeros(1, @sz), initialC=DFloat.zeros(1, @sz)) @Output = DFloat.zeros(@Nodes, @sz) # Start from front of network and work forwards if initialH != nil && initialC != nil @lstm_nodes[0,0].setHprev(initialH) @lstm_nodes[0,0].setCprev(initialC) else @lstm_nodes[0,0].setHprev(DFloat.zeros(1, @sz)) @lstm_nodes[0,0].setCprev(DFloat.zeros(1, @sz)) end node_input = DFloat[*[@Input[0, true].to_a]] node_target = DFloat[*[@Target[0, true].to_a]] @lstm_nodes[0,0].setXt(node_input) @lstm_nodes[0,0].setYt(node_target) @lstm_nodes[0,0].forwardPropagation() @Output[0, true] = @lstm_nodes[0,0].getHt() # Nodes 1 to end for i in 1...@lstm_nodes.column_count() # Indexing is working now node_input = DFloat[*[@Input[i, true].to_a]] node_target = DFloat[*[@Target[i, true].to_a]] @lstm_nodes[0,i].setHprev(@lstm_nodes[0,i-1].getHt()) @lstm_nodes[0,i].setCprev(@lstm_nodes[0,i-1].getCt()) @lstm_nodes[0,i].setXt(node_input) @lstm_nodes[0,i].setYt(node_target) @lstm_nodes[0,i].forwardPropagation() @Output[i, true] = @lstm_nodes[0,i].getHt() end end |
#getInput(mode = "word_mode") ⇒ Object
90 91 92 93 94 95 96 97 98 |
# File 'lib/NETWORK.rb', line 90 def getInput(mode="word_mode") if mode == "word_mode" return @dict.decodeArray(@Input) elsif mode == "char_mode" return @encoder.hotDecodeSentance(@encoder.nArrayToMatrix(@Input)) elsif mode == "array_mode" return @Input end end |
#getLSTMNodes ⇒ Object
53 54 55 |
# File 'lib/NETWORK.rb', line 53 def getLSTMNodes() return @lstm_nodes end |
#getOutput(mode = "word_mode") ⇒ Object
99 100 101 102 103 104 105 106 107 |
# File 'lib/NETWORK.rb', line 99 def getOutput(mode="word_mode") if mode == "word_mode" return @dict.decodeArrayByMaximum(@Output) elsif mode == "char_mode" return @encoder.hotDecodeSentance(@encoder.nArrayToMatrix(@Output)) elsif mode == "array_mode" return @Output end end |
#getTarget(mode = "word_mode") ⇒ Object
81 82 83 84 85 86 87 88 89 |
# File 'lib/NETWORK.rb', line 81 def getTarget(mode="word_mode") if mode == "word_mode" return @dict.decodeArray(@Target) elsif mode == "char_mode" return @encoder.hotDecodeSentance(@encoder.nArrayToMatrix(@Target)) elsif mode == "array_mode" return @Target end end |
#init(nodes, x_dim, alpha, terminal_output = nil) ⇒ Object
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
# File 'lib/NETWORK.rb', line 36 def init(nodes, x_dim, alpha, terminal_output=nil) @dict = DICTIONARY. new @encoder = ENCODER. new @encoder.init() @sz = x_dim @@Alpha = alpha @Nodes = nodes ### Creating LSTM network (matrix of nodes) @lstm_nodes = Matrix.build(1, @Nodes) { LSTM_CELL } for i in 0...@lstm_nodes.column_count() @lstm_nodes[0,i] = LSTM_CELL. new @lstm_nodes[0,i].init(@@Alpha, @sz, terminal_output) end ## Input and target @Input = DFloat.zeros(nodes, @sz) @Target = DFloat.zeros(nodes, @sz) end |
#setDictionary(dictionary) ⇒ Object
56 57 58 |
# File 'lib/NETWORK.rb', line 56 def setDictionary(dictionary) @dict = dictionary end |
#setInput(input, mode = nil) ⇒ Object
70 71 72 73 74 75 76 77 78 79 80 |
# File 'lib/NETWORK.rb', line 70 def setInput(input, mode=nil) if mode == "encoded" @Input = input elsif input.instance_of? String @Input = @encoder.matrixToNArray(@encoder.hotEncodeSentance(input)) elsif input.instance_of? Array @Input = @dict.encodeArray(input) else raise "Target input is not of type Array or String" end end |
#setTarget(target, mode = nil) ⇒ Object
59 60 61 62 63 64 65 66 67 68 69 |
# File 'lib/NETWORK.rb', line 59 def setTarget(target, mode=nil) if mode == "encoded" @Target = target elsif target.instance_of? String @Target = @encoder.matrixToNArray(@encoder.hotEncodeSentance(target)) elsif target.instance_of? Array @Target = @dict.encodeArray(target) else raise "Target input is not of type Array or String" end end |