DMelt:AI/4 Recurrent NN and LSTM

From HandWiki
Revision as of 12:24, 14 February 2021 by imported>Jworkorg
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)
Member


Recurrent Neural Networks (RNN) and LSTM

A recurrent neural network (RNN) is an artificial neural network where connections between nodes form a directed graph along a sequence., which allows for dynamic temporal behavior for a time sequence. See Recurrent neural network article. Long_short-term memory networks (LSTM) are a building unit for layers of a recurrent neural network (RNN).

Recurrent Neural Networks and Long Short-Term Memory Networks (LSTM) are included using recunn/ recunn/ Java library under the MIT license. Recurrent networks and their popular type, LSTM, are artificial neural network designed to recognize patterns in sequences of data, such as text, sequences, handwriting, the spoken words, or numerical times series.

Here is a demo that reads sentences from Paul Graham's essays, encoding Paul Graham's knowledge into the weights of the Recurrent Networks. The long-term goal of the project then is to generate startup wisdom at will. Feel free to train on whatever data you wish, and to experiment with the parameters. If you want more impressive models you have to increase the sizes of hidden layers, and maybe slightly the letter vectors.

# see http://cs.stanford.edu/people/karpathy/recurrentjs/
# This demo shows usage of the recurrentjs library that allows you to train deep Recurrent Neural Networks (RNN) and Long Short-Term Memory Networks (LSTM)
# The demo is pre-filled with sentences from Paul Graham's essays, in an attempt to encode Paul Graham's knowledge into the weights of the Recurrent Networks. The long-term goal of the project then is to generate startup wisdom at will. Feel free to train on whatever data you wish, and to experiment with the parameters. If you want more impressive models you have to increase the sizes of hidden layers, and maybe slightly the letter vectors. However, this will take longer to train.


from java.util import Random

from recunn.datasets import TextGeneration
from recunn.datastructs import DataSet
from recunn.model import Model
from recunn.trainer import Trainer
from recunn.util import NeuralNetworkHelper
from jhplot import *


textSource = "PaulGraham"
http="http://datamelt.org/examples/data/"
print Web.get(http+"/text/"+textSource+".txt")
data = TextGeneration(textSource+".txt")
savePath = textSource+".ser"

initFromSaved = True # set this to false to start with a fresh model
overwriteSaved = True
TextGeneration.reportSequenceLength = 100
TextGeneration.singleWordAutocorrect = False # set this to true to constrain generated sentences to contain only words observed in the training data.

bottleneckSize = 10 # one-hot input is squeezed through this
hiddenDimension = 200
hiddenLayers = 1
learningRate = 0.001
initParamsStdDev = 0.08

rng = Random()
lstm = NeuralNetworkHelper.makeLstmWithInputBottleneck(
       data.inputDimension, bottleneckSize,
       hiddenDimension, hiddenLayers,
       data.outputDimension, data.getModelOutputUnitToUse(),
       initParamsStdDev, rng)

reportEveryNthEpoch = 10
trainingEpochs = 1000
Trainer.train(trainingEpochs, learningRate, lstm, data, reportEveryNthEpoch, initFromSaved, overwriteSaved, savePath, rng)

The above example is based on JavaScript library http://cs.stanford.edu/people/karpathy/recurrentjs/