Preparing Text Dataset
Learn how to load textual data in JAX.
JAX doesn’t ship with data-loading utilities. This keeps JAX focused on providing a fast tool for building and training machine learning models. Loading data in JAX is done using either TensorFlow or PyTorch. This lesson will focus on how to load datasets in JAX using TensorFlow.
Let’s dive in!
Loading text data in JAX
Let’s use the Worldcup tweets dataset to illustrate how to load text datasets with JAX. We import the standard data science packages and then view a sample of the data.
import pandas as pd# Read the csv file by giving the path to the extracted filedf = pd.read_csv("/usr/local/notebooks/worldcup-tweets.csv")print(df.head())
In the code above:
Line 1: We import the
pandas
library.Line 4: We read the Worldcup tweets dataset using the
read_csv()
function from thepandas
library, and store the DataFrame indf
.Line 5: Lastly, we call the
df.head()
method to print the first five rows of the dataset.