nnet
This commit is contained in:
parent
5a68003d16
commit
c4020080b3
@ -1,5 +1,223 @@
|
|||||||
{
|
{
|
||||||
"cells": [
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"plotgrid (generic function with 1 method)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"using Plots; qwt()\n",
|
||||||
|
"default(size=(500,300), leg=false)\n",
|
||||||
|
"\n",
|
||||||
|
"# creates x/y vectors which can define a grid in a zig-zag pattern\n",
|
||||||
|
"function gridxy(lim, n::Int)\n",
|
||||||
|
" xs = linspace(lim..., n)\n",
|
||||||
|
" xypairs = vec([(x,y) for x in vcat(xs,reverse(xs)), y in xs])\n",
|
||||||
|
" Plots.unzip(xypairs)\n",
|
||||||
|
"end\n",
|
||||||
|
"\n",
|
||||||
|
"# plot a grid from x/y vectors\n",
|
||||||
|
"function plotgrid(x, y)\n",
|
||||||
|
" plot([x y], [y x], c=:black)\n",
|
||||||
|
"end"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# The problem... can we classify the functions?"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"f2 (generic function with 1 method)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# these are the functions we want to classify\n",
|
||||||
|
"f1(x) = 0.6sin(10x) + 0.1\n",
|
||||||
|
"f2(x) = f1(x) - 0.2"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Build a neural net"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 941,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"using OnlineAI\n",
|
||||||
|
"\n",
|
||||||
|
"# first create a neural net to separate the functions\n",
|
||||||
|
"numInputs = 2\n",
|
||||||
|
"numOutputs = 1\n",
|
||||||
|
"hiddenLayerStructure = [3,3,2]\n",
|
||||||
|
"net = buildClassificationNet(numInputs, numOutputs, hiddenLayerStructure; hiddenActivation = TanhActivation())\n",
|
||||||
|
"\n",
|
||||||
|
"# show the network\n",
|
||||||
|
"viz = visualize(net);"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# On to the fun..."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 942,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# pick the plotting limits\n",
|
||||||
|
"lim = (-1,1)\n",
|
||||||
|
"default(xlim = lim, ylim = lim)\n",
|
||||||
|
"\n",
|
||||||
|
"# show the grid\n",
|
||||||
|
"n = 40\n",
|
||||||
|
"gridx, gridy = gridxy(lim, n)\n",
|
||||||
|
"p = plotgrid(gridx, gridy)\n",
|
||||||
|
"\n",
|
||||||
|
"# show the funcs\n",
|
||||||
|
"funcs = [f1, f2]\n",
|
||||||
|
"plot!(funcs, lim..., w=3)\n",
|
||||||
|
"\n",
|
||||||
|
"# kick off an animation... we can save frames whenever we want, lets save the start\n",
|
||||||
|
"anim = Animation()\n",
|
||||||
|
"frame(anim)\n",
|
||||||
|
"\n",
|
||||||
|
"# open a gui window\n",
|
||||||
|
"gui()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Yikes... that looks tricky to separate..."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 945,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"activateHidden (generic function with 1 method)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 945,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# function to sample from x's\n",
|
||||||
|
"xsample() = rand(Distributions.Uniform(lim...)) \n",
|
||||||
|
"\n",
|
||||||
|
"# pick one of the functions at random, sample from the x line, then update the\n",
|
||||||
|
"# neural net with [x, f(x)] as the inputs\n",
|
||||||
|
"function sampleAndUpdate()\n",
|
||||||
|
" f = sample(funcs)\n",
|
||||||
|
" x = xsample()\n",
|
||||||
|
" y = float(f == f1)\n",
|
||||||
|
" update!(net, Float64[x, f(x)], [y])\n",
|
||||||
|
"end\n",
|
||||||
|
"\n",
|
||||||
|
"# take x matrix and convert to the first layer's activation\n",
|
||||||
|
"function activateHidden(net, x)\n",
|
||||||
|
" input = x\n",
|
||||||
|
" for layer in net.layers[1:end-1]\n",
|
||||||
|
" proj = Array(nrows(x), layer.nout)\n",
|
||||||
|
" for i in 1:nrows(x)\n",
|
||||||
|
" OnlineAI.forward!(layer, row(proj,i), false)\n",
|
||||||
|
" row!(proj, i, layer.a)\n",
|
||||||
|
" end\n",
|
||||||
|
" input = proj\n",
|
||||||
|
" end\n",
|
||||||
|
" vec(proj[:,1]), vec(proj[:,2])\n",
|
||||||
|
"end "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 946,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"ename": "LoadError",
|
||||||
|
"evalue": "LoadError: MethodError: `convert` has no method matching convert(::Type{Array{T,N}}, ::Int64, ::Int64)\nThis may have arisen from a call to the constructor Array{T,N}(...),\nsince type constructors fall back to convert methods.\nClosest candidates are:\n convert{T,N}(::Type{Array{T,N}}, !Matched::DataArrays.DataArray{T,N}, ::Any)\n convert{T,R,N}(::Type{Array{T,N}}, !Matched::DataArrays.PooledDataArray{T,R,N}, ::Any)\n Array{T}(!Matched::Type{T}, ::Integer)\n ...\nwhile loading In[946], in expression starting on line 8",
|
||||||
|
"output_type": "error",
|
||||||
|
"traceback": [
|
||||||
|
"LoadError: MethodError: `convert` has no method matching convert(::Type{Array{T,N}}, ::Int64, ::Int64)\nThis may have arisen from a call to the constructor Array{T,N}(...),\nsince type constructors fall back to convert methods.\nClosest candidates are:\n convert{T,N}(::Type{Array{T,N}}, !Matched::DataArrays.DataArray{T,N}, ::Any)\n convert{T,R,N}(::Type{Array{T,N}}, !Matched::DataArrays.PooledDataArray{T,R,N}, ::Any)\n Array{T}(!Matched::Type{T}, ::Integer)\n ...\nwhile loading In[946], in expression starting on line 8",
|
||||||
|
"",
|
||||||
|
" in activateHidden at In[945]:17"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# update net with new samples\n",
|
||||||
|
"for i in 1:10000\n",
|
||||||
|
" sampleAndUpdate()\n",
|
||||||
|
"end\n",
|
||||||
|
"\n",
|
||||||
|
"# update the plot... project each series to the first hidden layer and reset the data\n",
|
||||||
|
"x = linspace(lim..., 100)\n",
|
||||||
|
"p[1] = activateHidden(net, hcat(gridx, gridy))\n",
|
||||||
|
"p[2] = activateHidden(net, hcat(gridy, gridx))\n",
|
||||||
|
"p[3] = activateHidden(net, hcat(x, map(f1,x)))\n",
|
||||||
|
"p[4] = activateHidden(net, hcat(x, map(f2,x)))\n",
|
||||||
|
"\n",
|
||||||
|
"# show/update the plot\n",
|
||||||
|
"gui(p)\n",
|
||||||
|
"frame(anim);"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
@ -8,30 +226,20 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"using Plots, DataFrames, OnlineStats, OnlineAI\n",
|
"# build an animated gif\n",
|
||||||
"default(size=(500,300))\n",
|
"gif(anim, fps = 10)"
|
||||||
"df = readtable(joinpath(Pkg.dir(\"Plots\"), \"examples\", \"meetup\", \"winequality-white.csv\"), separator=';');"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 940,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
"collapsed": false
|
||||||
"scrolled": false
|
|
||||||
},
|
},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"(xmeta,ymeta) = (nothing,nothing)\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"using Plots; gadfly()\n",
|
"# update the net representation with weights, etc\n",
|
||||||
"p = plot(10);"
|
"update!(viz)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -61,6 +269,51 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": []
|
"source": []
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": true
|
||||||
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": null,
|
"execution_count": null,
|
||||||
|
|||||||
@ -133,7 +133,7 @@ function image_comparison_tests(pkg::Symbol, idx::Int; debug = false, sigma = [1
|
|||||||
return true
|
return true
|
||||||
|
|
||||||
catch ex
|
catch ex
|
||||||
warn("Image did not match reference image $reffn")
|
warn("Image did not match reference image $reffn. err: $ex")
|
||||||
if isinteractive()
|
if isinteractive()
|
||||||
|
|
||||||
# if we're in interactive mode, open a popup and give us a chance to examine the images
|
# if we're in interactive mode, open a popup and give us a chance to examine the images
|
||||||
|
|||||||
BIN
test/refimg/v0.3/gadfly/ref1.png
Normal file
BIN
test/refimg/v0.3/gadfly/ref1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 55 KiB |
Loading…
x
Reference in New Issue
Block a user