Artificial Neural Networks (ANNs) are powerful inference tools. They can be trained to fit complex functions and then used to predict new (unseen) data outside their training set. Fitting the training data is relatively easy for ANNs because of their Universal Approximation capability. However, that does not mean ANNs can learn the rules as we humans do. Here we aim to show how well a trained ANN, which fits its training data accurately, can generalize to new and unseen data. We categorize the unseen data into two types:
- Data points within the training range (that can be interpolated).
- Data points outside the training range (that can be extrapolated).
Take this example: | 1->1 | 2->4 | 3->? | 4->16 | 5->25 | 6->? | …
If we ask a human to predict 3->? and 6->? from the sequence above, most humans can answer correctly once they discover the rule that fits the set, in this example: y = X2. The process of predicting 3->9 is called interpolation while the process of predicting 6->36 is called extrapolation. In this specific example, the distinction between interpolation and extrapolation is not important for humans because as humans find the power rule, they can apply it for any other number in the sequence. Even if we jump to 100->?, the answer is straight forward (1002). Nevertheless, the distinction between interpolation and extrapolation is quite important for ANNs and this sheds some insight into the difference between learning in humans versus learning in ANNs.
Here we demonstrate this difference by implementing a simple yet powerful ANN architecture with a single hidden layer and demonstrate its generalization capability for various parameter settings. We randomly assign and freeze the weights in the first layer and only train the weights in the final layer using a closed-form solution. This method is popularized under the name Extreme Learning Machine (ELM) and has some controversial origins. From our experience, this method trains quickly and gives very accurate results for low dimensional datasets that have shallow structures. We use Matlab to showcase the demonstration because its syntax is similar to algebra and we can implement the ANN from scratch in just a few lines of code. First, we start from the hello world of Machine Learning which is training ANN to solve the XOR problem. Below is the complete Matlab code:
X = [ 0, 0; 1, 1; 0, 1; 1, 0 ]; % Xor data y = [ 0, 0, 1, 1 ]; % targets input = 2; neurons = 5; % parameters. Wx = randn(input, neurons)*0.01; % input-hidden weights (range ~ -0.01 to 0.01) z = tanh(X * Wx); % 1st-Layer forward activation (tanh) Wo = y * pinv(z'); % Training output weights (closed form solution) predictions = tanh(X * Wx) * Wo'; % Feedforward propagation | inference disp(predictions) % display the predicted data
Believe it or not, the above is all you need to construct and train a single hidden layer ANN (no external libraries are needed). You can achieve comparable conciseness with python+numpy ( https://github.com/hunar4321/Simple_ANN/blob/master/ELM.py ). If you run this code in your IDE, it will output very close predictions to the targets i.e [0, 0, 1, 1]. What is more, the same lines of code can be used to approximate any function given you have enough neurons in the hidden layer. Though it should be noted that for very large inputs the inverse operation of the closed-form solution is computationally expensive. And in the case of very noisy datasets, over-fitting is the usual enemy.
Coming back to our quest which was to examine the ability of such trained networks to interpolate and extrapolate the unseen data, the XOR example is not useful for our quest because all the input data is used for training. A better example is to use these networks to solve the power rule: y = X2 (our first example). Following is the complete Matlab code for our next example:
step = 2; % step size i.e the gap between the training sequence X = [2:step:100]'; % training data (even numbers) % 2, 4, 6,... y = X.^2; % train targets inp = 1; neurons = 100; Wx = randn(inp, neurons)*0.01; z = tanh(X * Wx); Wo = y' * pinv(z'); yhat = tanh(X * Wx) * Wo'; Xt = [1:step:121]'; % testing data (odd numbers) % 1, 3, 5,... yt = Xt.^2; % test targets prediction = tanh(Xt * Wx) * Wo'; % inference % visualizations figure; hold on; plot(X, y,'og'); plot(X, yhat, '*r'); hold off; legend('target', 'prediction'); title('y = X^2') figure; hold on; plot(Xt, yt,'og'); plot(Xt, prediction, '*r'); hold off; legend('target', 'prediction'); title('y = X^2') xticks = 1:10:121; set(gca,'XTick',xticks)
The inputs for training are even numbers from 0 to 100 and their squares are the target outputs. As you can see from Figure 1, the network learned to fit and approximate the training data very well.
Figure 1: Fitting y = X2 | training-set
We tested the same network, after training on the even numbers, to predict the unseen odd numbers from 1 to 121. Figure 2 shows the network’s ability to generalize to the unseen data. As you can note, ANN can predict the odd numbers that are within the training range very well, but cannot extrapolate beyond the training range i.e. When it reaches 101 it starts to give wrong predictions.
Figure 2: Fitting y = X2 | test-set
If we increase the number of the hidden neurons from 100 to 10000 we can see a steady improvement in ANN’s ability to predict the unseen odd numbers.
Although increasing the number of hidden neurons improves the ANN’s capability for both interpolation and extrapolation, the network still fails to extrapolate values that are far from the training range. This suggests that ‘power rule’ cannot be modeled by ANN no matter how much is the size of our hidden layer. This is understandable because the weighted summation of the inputs (i.e the dot product) cannot model the multiplication process among the inputs, it can only approximate it for a specific range.
It is also worth noting that the interpolation ability of ANNs remains very good, even if we increase the step size (i.e the gap) between the training set numbers (Figure 4).
We can draw the following conclusions from our results above:
- ANNs can fit the training set of our non-linear function, y = X2, very well.
- ANNs can fit the testing set of the function above provided that the test data are within the range of the training set, i.e ANNs are good at interpolation.
- ANNs can interpolate the unseen data well even if the gaps between the training data are big.
- ANNs are bad at fitting the test data that are far outside the range of the training set.
- Increasing the number of neurons in the hidden layer improves both interpolation and extrapolation capabilities of the ANNs, however, the power rule cannot be modeled no matter how large the hidden layer is. The power rule can only be approximated within the specific range.
The last point is an important distinction between how ANNs learn to generalize and how we humans learn to generalize. However, the way we present the data to the ANNs might be important. For example, in deep learning, the data are usually presented as one-hot encoding which treats the samples as discrete categories instead of continuous values. Some new deep learning architectures like Transformers (e.g GPT-3) are shown to produce coherent text and generate compelling answers to questions suggesting there might be some rule learning capabilities. To our knowledge, GPT-3 was also bad at learning ‘multiplication’ (GPT-3 paper / Figure 3.10), but good at learning ‘addition’ and this sort of behavior resembles very much the behavior of a typical feed-forward ANN such as the one we showed in our example.
Another interesting factor about ANNs is that even though we increased the number of neurons beyond the number of the samples, our network did not suffer from the bad effects of over-fitting. Over-fitting is usually problematic when we have a non-representative training dataset. This usually happens when there is a lot of variance in the data due to the external noise. By external noise, we mean the non-interesting variance that is not part of the data structure itself. People who work with time-series signals that have a low signal to noise ratio (SNR), e.g. EGG and fMRI, usually prefer to use simpler Machine Learning methods such as SVMs and Ridge Regression because over-parameterized ANNs are usually driven toward fitting the external noise instead of the signal (Which is a very bad side effect of over-fitting).
In contrast, in the deep learning world, over-parameterized ANNs seems to be the norm, especially for those working in vision and NLP fields because their training datasets are usually large, representative, and clean. Since the training datasets in those fields are large and include a wide range of structural variation (i.e training datasets are representative of testing datasets), it is no surprise that deep learning networks like CNNs, GANs, and Transformers are capable of learning to classify and generate new variations of interpolate-able unseen data. This is similar to our example where the ANN was able to successfully predict the unseen odd numbers that were within the range of the training set.
In conclusion, we attribute the success of ANNs in vision and NLP fields to their good interpolation capability, especially when they are fed with large and representative training datasets. We also argue that rule learning and extrapolation beyond the training range are ANNs’ weak points.
At Brainxyz, our focus is on learning algorithms that are capable to learn and generalize in a controllable manner. We also aim for biological plausibility and efficiency. In the future articles, we will test and compare PHUN, our in-progress novel ML algorithm, with ANNs and other ML types. Please stay tuned with us.
- ML: Machine Learning
- ANNs: Artificial Neural Networks
- CNN: Convolution al Neural Networks
- SVMs: Support Vector Machine
- GANs: Generative Adversarial Networks
- GPT-3: Generative Pre-Training
- NLP: Natural Language Processing
- PHUN: Predictive Hebbian Unified Neurons.
- fMRI: Functional Magnetic Resonance
- EEG: ElectroEncephaloGram