Datasets with Scala Case Class and Java Bean Class

Learn how Scala's case classes and Java's bean classes can be used with Datasets.

Generating data using SparkSession

We can also create a Dataset using a SparkSession object as demonstrated below.

## define  class
case class MovieDetailShort(imdbID: String, rating: Int)

## define random number generator
scala> val rnd = new scala.util.Random(9)

## create some data
scala> val data = for(i <- 0 to 100) yield (MovieDetailShort("movie-"+i, rnd.nextInt(10)))

## use spark session to generate a Dataset consisting of objects created in the previous step 
scala> val datasetMovies = spark.createDataset(data)

## display three rows from the Dataset
scala> datasetMovies.show(3)
+-------+------+
| imdbID|rating|
+-------+------+
|movie-0|     0|
|movie-1|     3|
|movie-2|     8|
+-------+------+
only showing top 3 rows

When working with Scala, we didn’t have to explicitly specify the encoder since Spark implicitly handles it for us. This is not the case for Java, where we have to specify the encoder. The equivalent Java bean class for MovieDetailShort is listed below:

public class MovieDetailShort implements Serializable {
    String imdbID;
    int rating;

    public MovieDetailShort() {

    }

    public MovieDetailShort(String imdbID, int rating) {
        this.imdbID = imdbID;
        this.rating = rating;
    }

    // JavaBean getters and setters
    public String getImdbID() { return imdbID; }
    public void setImdbID(String imdbID) { this.imdbID = imdbID; }
    public int getRating() { return rating; }
    public void setRating(int rating) { this.rating = rating; }
}

Creating sample data using the Java bean class looks like this:

// create an explicit Encoder
Encoder<MovieDetailShort> encoder = Encoders.bean(MovieDetailShort.class);

// create random number generator
Random rand = new Random();
rand.setSeed(5);

// create a list of randomly generated objects
List<MovieDetailShort> data = new ArrayList<MovieDetailShort>();

for(inti=0;i<1000;i++) { 
    data.add(new MovieDetailShort("movie-"+i, rand.nextInt(10));
}

// create a Dataset of MovieDetailShort typed data
Dataset<MovieDetailShort> movies = spark.createDataset(data, encoder);

Note: The above listing we have to specify the encoder in the spark.createDataset() method as the second argument.

Filter

We can use higher order functions such as filter() with Datasets:

scala> datasetMovies.filter(mov => mov.rating > 5).show(3)
+-------+------+
| imdbID|rating|
+-------+------+
|movie-0|     9|
|movie-1|     6|
|movie-2|     8|
+-------+------+
only showing top 3 rows

Here, we are using the dot notation to access the fields of the movie object in the anonymous lambda function. Instead of the anonymous lambda function, we can also define a function and pass it into the filter() function as follows:

scala> def highRatedMovies(mov : MovieDetailShort) = mov.rating > 5
highRatedMovies: (mov: MovieDetailShort)Boolean

scala> datasetMovies.filter(highRatedMovies(_)).show(3)
+-------+------+
| imdbID|rating|
+-------+------+
|movie-0|     9|
|movie-1|     6|
|movie-2|     8|
+-------+------+
only showing top 3 rows

The equivalent operation in Java is much more verbose than in Scala. The filter() function takes in an instance of type FilterFunction<T> where T is the type we want to filter on.

// define the named function
FilterFunction<MovieDetailShort> func = new FilterFunction<MovieDetailShort>() {
    public boolean call(MovieDetailShort mov) {
        return (mov.rating > 5);
    }
};

// passing the named function to the filter() function
movies.filter(func).show(3);

Map

We can use the map() function to return a computed value. For example, consider if we want to assign grades “A,” “B,” and “C” to movies that have a hitFlop rating of above 5, equal to 5, and less than 5, respectively. We could write a function in Scala and then pass it to the ``map() function. We also create a new case class, MovieGrade , which will be returned from our named function for each entry of the data set.

scala> case class MovieGrade(imdbID: String, grade: String)
defined class MovieGrade

scala> def movieGrade(mov: MovieDetailShort):MovieGrade = { val grade = if (mov.rating == 5) "B" else if (mov.rating < 5) "C" else "A"; MovieGrade(mov.imdbID, grade) }
movieGrade: (mov: MovieDetailShort)MovieGrade

scala> datasetMovies.map(movieGrade).show(3)
+-------+-----+
| imdbID|grade|
+-------+-----+
|movie-0|    A|
|movie-1|    A|
|movie-2|    A|
+-------+-----+
only showing top 3 rows

To achieve the equivalent functionality in Java, we’ll need to pass in an instance of type MapFunction<T> to the map() function.

movies.map((MapFunction<MovieDetailShort, String>) mov -> {
    if (mov.rating == 5)
        return "B";
    else if (mov.rating < 5)
        return "C";
    else
        return "A";
}, Encoders.STRING())

Note that we have to pass the encoder explicitly to the map() function in case of Java.

All the queries and commands used in this lesson are reproduced in the widget below for easy copy and pasting into the terminal.

Get hands-on with 1400+ tech skills courses.