Solution: Train and Test a Model
Let’s review the solution for the previous challenge.
After importing all the necessary libraries, your first step is to define the network.
Define the network
Create a convolutional neural network with the Linen API by subclassing a module. Because the architecture in this case is relatively simple (you’re just stacking layers), you can define the inlined submodules directly within the __call__ method and wrap it with the @compact decorator.
Define loss
We simply use optax.softmax_cross_entropy(). Note that this function expects both logits and labels to have the shape [batch, num_classes]. Since the labels will be read from the TensorFlow dataset as integer values, we first need to convert them to one-hot encoding.
Our function returns a simple scalar value ready for optimization, so we ...