AI Basics, approximate a sine wave with a NN

Today we are going to look at how to train a neural network that can approximate the same output as a sine wave function for example y = model(x) should be pretty close to predicting y = sin(x).

The classic “hello world” of machine learning is generally using the MNIST dataset to train a model to translate hand written characters into text, but i was looking for something even simpler than that.

I find very basic examples like this really help me to get my head around the underlying math and what is really going on when you train a model.

First we will use PyTorch to define model with some Linear layers seperated by activation functions, The intial input size is 1 to represent the value of x you would usually give to a sin() function and the final output is also 1 to represent the number returned by the sin() function.

				
					import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

class Model(nn.Module):
    def __init__(self, input_dim, middle_dim, output_dim):
        super(Model, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, middle_dim),
            nn.ReLU(),
            nn.Linear(middle_dim, middle_dim),
            nn.ReLU(),
            nn.Linear(middle_dim, output_dim),
        )
    
    def forward(self, x):
        out = self.model(x)
        return out

# Create our model
model = Model(1, 512, 1)

# Define loss function and optimizer
criterion = nn.MSELoss(reduction='mean')
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
				
			

Our model now looks like this

Now we have a model that we can train we need to define some training data, to do this we will use numpy to generate some random numbers and also the equivilent outputs those numbers would produce if given to a sin() function.

				
					X = np.random.rand(10**5) * 2 * np.pi
y = np.sin(X).ravel()

# These are dataloaders which are responsible for splitting the test and train data into batches
# that can be fed into the model for training
X_train, X_test, y_train, y_test = map(torch.tensor, train_test_split(X, y, test_size=0.2))
train_dataloader = DataLoader(TensorDataset(X_train.unsqueeze(1), y_train.unsqueeze(1)), batch_size=64, pin_memory=True, shuffle=True)
val_dataloader = DataLoader(TensorDataset(X_test.unsqueeze(1), y_test.unsqueeze(1)), batch_size=64, pin_memory=True, shuffle=True)
				
			

For fun lets now run our untrained model on some test data and see how well it does on a simple linear array. 

				
					lin_test = np.arange(0.0, 2*np.pi, 0.01)[:, np.newaxis]
with torch.no_grad():
    y_1 = model(torch.from_numpy(lin_test).float())
				
			

Not very well…..

The blue line represents our prediction, and the orange dots are the correct datapoints it should be estimating.

Maybe we can do better if we train our model on the above data before we run it, here is how we do that

				
					for epoch in range(10):
     for train, expected in train_dataloader:
        train = train.type(torch.float32)
        expected = expected.type(torch.float32)

        optimizer.zero_grad()

        # Feed the data into our model
        y_pred = model(train)

        # Calculate the loss (how far off the model is from the expected result)
        # then backpropagate the error to adjust the model's weights
        loss = criterion(y_pred, expected)
        loss.backward()
        optimizer.step()
				
			

Now lets run it again and see what we get!

				
					lin_test = np.arange(0.0, 2*np.pi, 0.01)[:, np.newaxis]
with torch.no_grad():
    y_1 = model(torch.from_numpy(lin_test).float())
				
			

Much better, the model is able to predict the sine wave.

Altho this is a very basic example the steps defined here are very similar to what is required for much more advanced use cases like image recognition or other classification type tasks.

BLOG

VMWare vs Proxmox in enterprise

What is proxmox? Proxmox is an open-source virtualization platform that integrates virtual machines (VMs) and containers into a unified solution. It is Debian/Linux based, uses

Delve Deeper »

CONTACT US

We’re all about enterprise apps.  Assessment, modernisation, maintenance, migration and even new builds.

Reach out to use and we’ll work out how we can help.