...

/

Solution: Train and Test a Model

Solution: Train and Test a Model

Let’s review the solution for the previous challenge.

After importing all the necessary libraries, your first step is to define the network.

Define the network

Create a convolutional neural network with the Linen API by subclassing a module. Because the architecture in this case is relatively simple (you’re just stacking layers), you can define the inlined submodules directly within the __call__ method and wrap it with the @compact decorator.

Press + to interact
class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=2)(x)
return x

Define loss

We simply use optax.softmax_cross_entropy(). Note that this function expects both logits and labels to have the shape [batch, num_classes]. Since the labels will be read from the TensorFlow dataset as integer values, we first need to convert them to one-hot encoding.

Our function returns a simple scalar value ready for ...

Access this course and 1400+ top-rated courses and projects.