Solution: Train and Test a Model
Let’s review the solution for the previous challenge.
After importing all the necessary libraries, your first step is to define the network.
Define the network
Create a convolutional neural network with the Linen API by subclassing a module. Because the architecture in this case is relatively simple (you’re just stacking layers), you can define the inlined submodules directly within the __call__
method and wrap it with the @compact
decorator.
Press + to interact
class CNN(nn.Module):"""A simple CNN model."""@nn.compactdef __call__(self, x):x = nn.Conv(features=64, kernel_size=(3, 3))(x)x = nn.relu(x)x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))x = nn.Conv(features=64, kernel_size=(3, 3))(x)x = nn.relu(x)x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))x = x.reshape((x.shape[0], -1)) # flattenx = nn.Dense(features=256)(x)x = nn.relu(x)x = nn.Dense(features=2)(x)return x
Define loss
We simply use optax.softmax_cross_entropy()
. Note that this function expects both logits
and labels
to have the shape [batch, num_classes]
. Since the labels will be read from the TensorFlow dataset as integer values, we first need to convert them to one-hot encoding.
Our function returns a simple scalar value ready for ...
Access this course and 1400+ top-rated courses and projects.