...

/

Data Generators in the Skip-Gram Algorithm with TensorFlow

Data Generators in the Skip-Gram Algorithm with TensorFlow

Learn about data generators in the skip-gram algorithm.

We’re now going to implement the algorithm from end to end. First, we’ll discuss the data we’re going to use and how TensorFlow can help us to get that data in the format the model accepts. We’ll implement the skip-gram algorithm with TensorFlow and finally train the model and evaluate it on data that was prepared.

Implementing the data generators with TensorFlow

First, we’ll investigate how data can be generated in the correct format for the model. For this exercise, we are going to use the BBC news articles dataset. It contains 2,225 news articles belonging to five topics, business, entertainment, politics, sports, and tech, which were published on the BBC website between 2004 and 2005.

We write the function download_data() below to download the data to a given folder and extract it from its compressed format:

def download_data(url, data_dir):
"""Download a file if not present, and make sure it's the right size."""
os.makedirs(data_dir, exist_ok = True)
file_path = os.path.join(data_dir, 'bbc-fulltext.zip')
if not os.path.exists(file_path):
print('Downloading file...')
filename, _ = urlretrieve(url, file_path)
else:
print("File already exists")
extract_path = os.path.join(data_dir, 'bbc')
if not os.path.exists(extract_path):
with zipfile.ZipFile(os.path.join(data_dir, 'bbc-fulltext.zip'), 'r') as zipf:
zipf.extractall(data_dir)
else:
print("bbc-fulltext.zip has already been extracted")
Downloading the data

The function first creates the data_dir if it doesn’t exist. Next, if the bbc-fulltext.zip file doesn’t exist, it will be downloaded from the provided URL. If bbc-fulltext.zip has not been extracted yet, it will be extracted to data_dir.

We can call this function as follows:

url = 'http://mlg.ucd.ie/files/datasets/bbc-fulltext.zip'
download_data(url, 'data')

Read data before preprocessing

With that, we are going to focus on reading the data contained in the news articles (in .txt format) into the memory. To do that, we’ll define the read_data() function, which takes a data directory path (data_dir) and reads the .txt files (except for the README file) found in the data directory:

def read_data(data_dir):
# This will contain the full list of stories
news_stories = []
print("Reading files")
i = 0 # Just used for printing progress
for root, dirs, files in os.walk(data_dir):
for fi, f in enumerate(files):
# We don't read the README file
if 'README' in f:
continue
# Printing progress
i += 1
print("."*i, f, end='\r')
# Open the file
with open(os.path.join(root, f), encoding = 'latin-1') as f:
story = []
# Read all the lines
for row in f:
story.append(row.strip())
# Create a single string with all the rows in the doc
story = ' '.join(story)
# Add that to the list
news_stories.append(story)
print('', end='\r')
print(f"\nDetected {len(news_stories)} stories")
return news_stories
Data read function

With the read_data() function defined, let’s use it to read in the data and print some samples as well as some statistics:

Press + to interact
news_stories = read_data(os.path.join('data', 'bbc'))
# Printing some stats and sample data
print(f"{sum([len(story.split(' ')) for story in news_stories])} words found in the total news set")
print('Example words (start): ', news_stories[0][:50])
print('Example words (end): ', news_stories[-1][-50:])

We will get the output like this:

Reading files
............. 361.txt
Detected 2225 stories
865163 words found in the total news set
Example words (start): Napster offers rented music to go Music downloadi
Example words (end): seem likely," Reuters quoted one trader as saying.

As we said at the beginning of this section, there are 2,225 stories with close to a million words.

Build a tokenizer

In the next step, we need to tokenize each story (in the form of a long string) to a list of tokens (or words). Along with that, we’ll perform some preprocessing on the text:

  • Lowercase all the characters.
  • Remove punctuation.

All of these can be achieved with the tensorflow.keras.preprocessing.text.Tokenizer object. We can define a tokenizer as follows:

from tensorflow.keras.preprocessing.text import Tokenizer
tokenizer = Tokenizer(num_words = None, filters = '!"#$%&()*+,-./:;<=>?@[\\]^_`{|}~\t\n', lower = True, split = ' ')

Here, we can see some of the most popular keyword arguments and their default values used when defining a tokenizer:

  • num_words: This defines the size of the vocabulary. It defaults to None, meaning it will consider all the words appearing in the text corpus. If set to the integer nn, it will only consider the nn most common words appearing in the corpus.
  • filters: This defines any characters that need to be omitted during preprocessing. By default, it defines a string containing most of the common punctuation marks and symbols.
  • lower: This defines whether the text needs to be converted to lowercase.
  • split: This defines the character on which the words will be tokenized.

Once the ...