python to use LSTM for predicting words in Txt

description: use LSTM to predict words in Text in python

import packages

1
2
3
4
5
6
7
import numpy
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM
from keras.callbacks import ModelCheckpoint
from keras.utils import np_utils

load ascii text and covert to lowercase

1
2
3
filename = "wonderland.txt"
raw_text = open(filename).read()
raw_text = raw_text.lower()

create mapping of unique chars to integers

1
2
chars = sorted(list(set(raw_text)))
char_to_int = dict((c, i) for i, c in enumerate(chars))

summarize the loaded data

1
2
3
4
5
6
# Total Characters:  147674
# Total Vocab: 47
n_chars = len(raw_text)
n_vocab = len(chars)
print("Total Characters: ", n_chars)
print("Total Vocab: ", n_vocab)

split the book text up into subsequences with a fixed length of 100 characters

1
2
3
4
5
6
7
8
9
10
11
seq_length = 100
dataX = []
dataY = []
for i in range(0, n_chars - seq_length, 1):
seq_in = raw_text[i: i+seq_length]
seq_out = raw_text[ i+seq_length ]
dataX.append([char_to_int[char] for char in seq_in])
dataY.append(char_to_int[seq_out])
n_patterns = len(dataX)
# Total Patterns:  147574
print( "Total Patterns: ", n_patterns)

reshape X to be [samples, time steps, features]

1
X = numpy.reshape(dataX, (n_patterns, seq_length, 1))

normalize

1
X = X / float(n_vocab)

one hot encode the output variable

1
y = np_utils.to_categorical(dataY)

define the LSTM model

1
2
3
4
5
6
model = Sequential()
model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2])))
model.add(Dropout(0.2))
# y.shape (144312, 47)
model.add(Dense(y.shape[1], activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')

define the checkpoint

1
2
3
filepath="weights-improvement-{epoch:02d}-{loss:.4f}.hdf5"
checkpoint = ModelCheckpoint(filepath, monitor='loss', verbose=1, save_best_only=True, mode='min')
callbacks_list = [checkpoint]

fit the model

1
model.fit(X, y, epochs=20, batch_size=128, callbacks=callbacks_list)