XOR Training Exercise using TensorFlow.js
In this example, we will use TensorFlow.js to train a neural network for the XOR (exclusive OR) problem. The XOR problem is a classic problem in machine learning where the model needs to learn a non-linear decision boundary to correctly classify inputs.
Model Description
The neural network used for this XOR problem consists of an input layer with two neurons, a hidden layer with two neurons, and an output layer with one neuron. The activation function used in the hidden layer is ReLU (Rectified Linear Unit), and the output layer uses the sigmoid activation function to produce values between 0 and 1.
Dataset
The XOR dataset consists of four samples, each with two features (inputs) and one label (output). The dataset is as follows:
Input 1 | Input 2 | Output |
---|---|---|
0 | 0 | 0 |
0 | 1 | 1 |
1 | 0 | 1 |
1 | 1 | 0 |
Training Procedure
The training is performed using TensorFlow.js with the following hyperparameters:
- Learning Rate: 0.1
- Number of Epochs: 10000
- Loss Function: Mean Squared Error (MSE)
- Optimizer: Stochastic Gradient Descent (SGD)
Code Implementation
Below is the code implementation for training the XOR neural network using TensorFlow.js:
// Import TensorFlow.js library
import * as tf from '@tensorflow/tfjs-node-gpu';
this.model = await tf.loadLayersModel(`file://${this.model_path}/model.json`);
this.model.compile({
optimizer: tf.train.sgd(0.1),
loss: 'binaryCrossentropy', // Binary classification loss
metrics: ['accuracy'],
});
this.model.summary();
const x = tf.tensor2d([[1,1]]);
const prediction = this.model.predict(x) as tf.Tensor;
Resulting Prediction
The resulting predictions for the XOR dataset after training the model are as follows:
Tensor
[[0.0020679],
[0.9994502],
[0.9994048],
[0.0002599]]
Intended Uses & Limitations
- The trained XOR neural network is suitable for solving the XOR problem and may not be suitable for more complex tasks without modifications.
- TensorFlow.js allows running machine learning models directly in the browser, making it useful for web applications with client-side inference.
Framework Versions Used
- TensorFlow.js
Please note that this example assumes you have set up your project with the necessary dependencies and have a basic understanding of JavaScript and TensorFlow.js.