Canvas JavaScript

Training a 2D Perceptron

This JavaScript Neural Network program demonstrates how to train a single perceptron using the perceptron algorithm and a linear classifier.

PerceptronLearning2D.html

<!DOCTYPE html>
<html>
  <head>
    <title>XoaX.net's Javascript Perceptron Learning</title>
    <script type="text/javascript" src="PerceptronLearning2D.js"></script>
  </head>
  <body onload="Initialize()">
     <canvas id="idCanvas" width="640" height ="640" style="background-color: #F0F0F0;"></canvas>
     <hr />
     <button style="clear:both;" onclick="Reset()">Reset</button>
     <button id="idPlay" onclick="PlayPause()">Play</button>
  </body>
</html>

PerceptronLearning2D.js

var gqCanvas = null;
var gqData = null;
var gqPerceptron = null;
var giPointIndex = 0;
var giPointCount = 0;
var gidTimeOutID = null;
var gbDoneLearning = true;

function Initialize() {
	// Create the canvas
	gqCanvas = new CCanvas();
	// Create a new dataset and redraw the graph
	Reset();
}

function Reset() {
	// Create the data set
	gqData = new CClassifiedDataSet();
	giPointCount = gqData.GetPointCount();
	gqPerceptron = new CPerceptron2D();
	gbDoneLearning = true;
	giPointIndex = 0;
	Pause();
	// Draw the grid, axes, data points, and the initial classification regions
	Redraw();
}

function PlayPause() {
	if (gidTimeOutID != null) {
		Pause();
	} else {
		Play();
	}
}

function Pause() {
	var qPlayButton = document.getElementById("idPlay");
	clearTimeout(gidTimeOutID);
	gidTimeOutID = null;
	qPlayButton.innerHTML = "Play";
}

function Play() {
	var qPlayButton = document.getElementById("idPlay");
	// Start updating the percetron
	gidTimeOutID = setTimeout(Update, 10);
	qPlayButton.innerHTML = "Pause";
}

function Update() {
	var daP = gqData.GetPoint(giPointIndex);
	// If the perceptron learned from this point, note that we need another loop
	if (gqPerceptron.Learn(daP)) {
		gbDoneLearning = false;
	}
	Redraw();
	// Draw the current point that is under consideration
	gqCanvas.DrawPoint(daP, 5, "black");
	++giPointIndex;
	if (giPointIndex == giPointCount) {
		giPointIndex = 0;
		// If we are done, stop updating
		if (gbDoneLearning) {
			Pause();
			Redraw();
			return;
		} else {
			// Otherwise reset the flag for the next round
			gbDoneLearning = true;
		}
	}
	// Draw the current point
	gidTimeOutID = setTimeout(Update, 10);
}

function Redraw() {
	gqCanvas.Clear();
	gqCanvas.DrawGrid();
	gqCanvas.DrawAxes();
	gqData.DrawPoints(gqCanvas, "red", "lime");
	gqPerceptron.DrawRegions(gqCanvas, "red", "lime");
}

class CCanvas {
	constructor() {
		var qCanvas = document.getElementById("idCanvas");
		this.mqContext = qCanvas.getContext("2d");
		// This should be determined from the canvas size
		this.mdaSize = [640, 640];
		this.mdaGraph = [600, 600];
	}
	
	PointToPixel(daP) {
		return [320 + daP[0], 320 - daP[1]];
	}
	Clear() {
		this.mqContext.clearRect(0, 0, 640, 640);
	}
	DrawArrow(daP1, daP2, sColor, dAlpha = 1) {
		this.mqContext.globalAlpha = dAlpha;
		var daPix1 = this.PointToPixel(daP1);
		var daPix2 = this.PointToPixel(daP2);
		this.mqContext.strokeStyle = sColor;
		this.mqContext.beginPath();
		this.mqContext.moveTo(daPix1[0], daPix1[1]);
		this.mqContext.lineTo(daPix2[0], daPix2[1]);
		this.mqContext.stroke();
		// Draw the arrowhead
		this.DrawArrowheadPixel(daPix1, daPix2, 5, sColor);
	}
	DrawArrowheadPixel(daPix1, daPix2, dSize, sColor) {
		this.mqContext.fillStyle = sColor;
		this.mqContext.beginPath();
		// The direction of the arrow is P1 to P2 in pixels
		var dInitAngle = Math.atan2(daPix2[1] - daPix1[1], daPix2[0] - daPix1[0]);
		this.mqContext.moveTo(daPix2[0] + dSize*Math.cos(dInitAngle), daPix2[1] + dSize*Math.sin(dInitAngle));
		for (var i = 1; i < 3; ++i) {
			var dAngle = dInitAngle + (2*Math.PI*i)/3;
			this.mqContext.lineTo(daPix2[0] + dSize*Math.cos(dAngle), daPix2[1] + dSize*Math.sin(dAngle));
		}
		this.mqContext.closePath();
		this.mqContext.fill();
	}
	DrawAxes() {
		this.DrawArrow([0, -300], [0, 300], "black");
		this.DrawArrow([-300, 0], [300,0], "black");
	}
	DrawGrid() {
		var sColor = "gray";
		var dAlpha = .25;
		var dDelta = 600/20;
		for (var i = 0; i < 21; ++i) {
			this.DrawLine([-300,-300+i*dDelta], [300,-300+i*dDelta], sColor, dAlpha);
			this.DrawLine([-300+i*dDelta,-300], [-300+i*dDelta,300], sColor, dAlpha);
		}
	}
	DrawLine(daP1, daP2, sColor, dAlpha = 1) {
		this.mqContext.globalAlpha = dAlpha;
		var daPix1 = this.PointToPixel(daP1);
		var daPix2 = this.PointToPixel(daP2);
		this.mqContext.strokeStyle = sColor;
		this.mqContext.beginPath();
		this.mqContext.moveTo(daPix1[0], daPix1[1]);
		this.mqContext.lineTo(daPix2[0], daPix2[1]);
		this.mqContext.stroke();
	}
	DrawPoint(daP, dRadius, sColor, dAlpha = 1) {
		this.mqContext.globalAlpha = dAlpha;
		this.mqContext.fillStyle = sColor;
		this.mqContext.beginPath();
		var daPixel = this.PointToPixel(daP);
		this.mqContext.arc(daPixel[0], daPixel[1], dRadius, 0, 2.0*Math.PI, true);
		this.mqContext.fill();
	}
	DrawPolygon(daPts, sColor, dFillAlpha = 1, dStrokeAlpha = 1) {
		var daPixel = this.PointToPixel(daPts[0]);
		this.mqContext.beginPath();
		this.mqContext.moveTo(daPixel[0], daPixel[1]);
		for (var i = 1; i < daPts.length; ++i) {
			daPixel = this.PointToPixel(daPts[i]);
			this.mqContext.lineTo(daPixel[0], daPixel[1]);
		}
		this.mqContext.closePath();
		this.mqContext.globalAlpha = dFillAlpha;
		this.mqContext.fillStyle = sColor;
		this.mqContext.fill();
		this.mqContext.globalAlpha = dStrokeAlpha;
		this.mqContext.strokeStyle = sColor;
		this.mqContext.lineWidth = 1;
		this.mqContext.stroke();
		this.mqContext.globalAlpha = 1;
	}
}

class CPerceptron2D {
	constructor() {
		this.mdaLine = [1.0, 1.0, 1.0];
	}
	// Return true if the perceptron is still learning from this point
	Learn(daP) {
		// Get the calculated class of the point, using the current classifier line
		var dClassValue = this.mdaLine[0]*daP[0] + this.mdaLine[1]*daP[1] + this.mdaLine[2];
		// If the calculated class and the actual class daP[2] do not have the same sign, the point is classified correctly.
		var bIsGood = ((dClassValue*daP[2] >= 0) ? true : false);
		// Check whether it is in the correct class. If not, adjust the classifier
		if (!bIsGood) {
			// Adjust the line with the misclassified point
			for (var i = 0; i < 2; ++i) {
				this.mdaLine[i] += daP[2]*daP[i];
			}
			this.mdaLine[2] += daP[2];
			return true;
		}
		return false;
	}
	NormalizeLine() {
		var dMag = Math.sqrt(Math.pow(this.mdaLine[0],2)+Math.pow(this.mdaLine[1],2));
		this.mdaLine[0] /= dMag;
		this.mdaLine[1] /= dMag;
		this.mdaLine[2] /= dMag;
	}
	
	DrawRegions(qCanvas, sNegColor, sPosColor) {
		var dX = 0;
		var dY = 0;
		var daLine = this.mdaLine;
		// Find the intersections with each side (x = -+300, y = -+300)
		var daIntersections = new Array();
		// Make sure that B != 0
		if (daLine[1] != 0) {
			// x = -300, 300, y = (-Ax - C)/B, if B != 0
			dX = -300;
			dY = (-daLine[0]*dX - daLine[2])/daLine[1];
			if (dY >= -300 && dY <= 300) {
				daIntersections.push([dX, dY]);
			}
			dX = 300;
			dY = (-daLine[0]*dX - daLine[2])/daLine[1];
			if (dY >= -300 && dY <= 300) {
				daIntersections.push([dX, dY]);
			}
		}
		// Make sure that A != 0
		if (daLine[0] != 0) {
			// y = -300, 300, x = (-By - C)/A, if A != 0
			dY = -300;
			dX = (-daLine[1]*dY - daLine[2])/daLine[0];
			if (dX >= -300 && dX <= 300) {
				daIntersections.push([dX, dY]);
			}
			dY = 300;
			dX = (-daLine[1]*dY - daLine[2])/daLine[0];
			if (dX >= -300 && dX <= 300) {
				daIntersections.push([dX, dY]);
			}
		}
		// Remove any points that are near enough to be considered the same
		for (var i = 0; i < daIntersections.length -1; ++i) {
			var j = i + 1;
			while (j < daIntersections.length) {
				// If the points are close enough remove the jth one, because it is probably duplicate corner vertex.
				var dDist = Math.sqrt(Math.pow(daIntersections[i][0]-daIntersections[j][0], 2)+Math.pow(daIntersections[i][1]-daIntersections[j][1], 2));
				if (dDist < 1.0e-3) {
					daIntersections.splice(j, 1);
				} else { // Otherwise, increment the index
					++j;
				}
			}
		}
		// Make sure that we have at least two intersections, before drawing the regions. Otherwise, it will all be one color
		// There should never be more than 2
		if (daIntersections.length >= 2) {
			// Order the point by counter-clockwise orientation
			var daPos = new Array();
			var daNeg = new Array();
			var daVertices = [[300, 300],[-300,300],[-300,-300],[300,-300]];
			// Insert the intersections into the vertices array in the correct order
			var iInsertions = 0;
			while (iInsertions < 2) {
				var bNotInserted = true;
				var iFirst = 0;
				var dInsertAngle = Math.atan2(daIntersections[iInsertions][1], daIntersections[iInsertions][0]);
				while (bNotInserted) {
					// Get the angles for each vertex
					var dFirstAngle = Math.atan2(daVertices[iFirst][1], daVertices[iFirst][0]);
					var iSecond = ((iFirst+1) % daVertices.length);
					var dSecondAngle = Math.atan2(daVertices[iSecond][1], daVertices[iSecond][0]);
					dSecondAngle = (dFirstAngle < dSecondAngle) ? dSecondAngle : (dSecondAngle + 2*Math.PI);
					if ((dInsertAngle > dFirstAngle && dInsertAngle <= dSecondAngle) || 
							(dInsertAngle + 2*Math.PI > dFirstAngle && dInsertAngle + 2*Math.PI <= dSecondAngle)) {
						daVertices.splice(iSecond, 0, daIntersections[iInsertions]);
						++iInsertions;
						bNotInserted = false;
					} else {
						++iFirst;
					}
				}
			}
			// Run through the vertices and put them into the positive and negative arrays
			var dLineMag = Math.sqrt(Math.pow(daLine[0], 2)+Math.pow(daLine[1], 2));
			var dNormalized = [daLine[0]/dLineMag, daLine[1]/dLineMag, daLine[2]/dLineMag];
			for (var i = 0; i < daVertices.length; ++i) {
				var dDistToLine = dNormalized[0]*daVertices[i][0] + dNormalized[1]*daVertices[i][1] + dNormalized[2];
				if (Math.abs(dDistToLine) < 1.0e-5) { // The points on the line
					daPos.push(daVertices[i]);
					daNeg.push(daVertices[i]);
				} else if (dDistToLine > 0.0) { // Positive values
					daPos.push(daVertices[i]);
				} else { // Negative values
					daNeg.push(daVertices[i]);
				}
			}
			qCanvas.DrawPolygon(daNeg, sNegColor, .05, 1);
			qCanvas.DrawPolygon(daPos, sPosColor,.05, 1);
		}
	}
}

class CClassifiedDataSet {
	constructor() {
		// This defines a linearly-separable classifier function f(x,y) = Ax + By + C
		var dAngle = 2*Math.PI*Math.random();
		var daClassifier = [Math.cos(dAngle), Math.sin(dAngle), Math.random()*400 - 200];
		// Generate random points for the dataset
		this.mdaaPoints = new Array();
		for (var i = 0; i< 100; ++i) {
			// The third component will be the classification value -1 or 1.
			var daPoint = [0,0,0];
			for (var j = 0; j < 2; ++j) {
				// Generate coordinates in the range -300 to 300
				daPoint[j] = Math.random()*600 - 300;
			}
			if (daClassifier[0]*daPoint[0]+daClassifier[1]*daPoint[1]+daClassifier[2] < 0.0) {
				daPoint[2] = -1.0;
			} else {
				daPoint[2] = 1.0;
			}
			this.mdaaPoints.push(daPoint);
		}
	}
	
	// This could take a range or a transformation function
	DrawPoints(qCanvas, sNegColor, sPosColor) {
		for (var i = 0; i< this.mdaaPoints.length; ++i) {
			var daPixel = this.mdaaPoints[i];
			if (this.mdaaPoints[i][2] > 0) {
				qCanvas.DrawPoint(this.mdaaPoints[i], 2, sPosColor);
			} else {
				qCanvas.DrawPoint(this.mdaaPoints[i], 2, sNegColor);
			}
		}
	}
	GetPointCount() {
		return this.mdaaPoints.length;
	}
	GetPoint(i) {
		return this.mdaaPoints[i];
	}
}
 

Output

 
 

© 2007–2025 XoaX.net LLC. All rights reserved.