Plots.jl/examples/meetup/nnet.ipynb
2015-10-23 13:53:13 -04:00

524 lines
11 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"using Plots; immerse()\n",
"default(size=(500,300), leg=false)\n",
"\n",
"# creates x/y vectors which can define a grid in a zig-zag pattern\n",
"function gridxy(lim, n::Int)\n",
" xs = linspace(lim..., n)\n",
" xypairs = vec([(x,y) for x in vcat(xs,reverse(xs)), y in xs])\n",
" Plots.unzip(xypairs)\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# The problem... can we classify the functions?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# these are the functions we want to classify\n",
"scalar = 5 # larger is harder... start with 3\n",
"f1(x) = 0.6sin(scalar * x) + 0.1\n",
"f2(x) = f1(x) - 0.2\n",
"\n",
"# our target function is ∈ {-1,1}\n",
"target(f) = f == f1 ? 1.0 : -1.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# On to the fun..."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# pick the plotting limits\n",
"lim = (-1,1)\n",
"funcs = [f1, f2]\n",
"n = 40\n",
"gridx, gridy = gridxy(lim, n)\n",
"default(xlim = lim, ylim = lim)\n",
"\n",
"function initialize_plot(funcs, lim, gridx, gridy; kw...)\n",
" # show the grid\n",
" plot([gridx gridy], [gridy gridx], c=:black, kw...)\n",
"\n",
" # show the funcs\n",
" plot!(funcs, lim..., l=(4,[:blue :red]))\n",
"end\n",
"\n",
"# kick off an animation... we can save frames whenever we want, lets save the starting frame\n",
"function initialize_animation()\n",
" anim = Animation()\n",
" frame(anim)\n",
" anim\n",
"end\n",
"\n",
"# lets see what we're dealing with...\n",
"p = initialize_plot(funcs, lim, gridx, gridy)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# That looks tricky... lets build a neural net!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"using OnlineAI\n",
"net = buildTanhClassificationNet(\n",
" 2, # number of inputs\n",
" 1, # number of outputs\n",
" [2], # hidden layers structure\n",
"# params = NetParams(gradientModel = SGDModel(η=1e-5))\n",
"params = NetParams(gradientModel = AdadeltaModel(η=1e-3, ρ=0.98))\n",
"# params = NetParams(gradientModel = AdagradModel(η=1e-1))\n",
")\n",
"\n",
"# take x matrix and convert to the first layer's activation\n",
"function activateHidden(net, x)\n",
" @assert net.layers[end].nin == 2\n",
" proj = zeros(nrows(x), 2)\n",
" for i in 1:nrows(x)\n",
" data = row(x,i)\n",
" for layer in net.layers[1:end-1]\n",
" OnlineAI.forward!(layer, data, false)\n",
" data = layer.a\n",
" end\n",
" row!(proj, i, data)\n",
" end\n",
" vec(proj[:,1]), vec(proj[:,2])\n",
"end "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Update our model and the visualization"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"p = initialize_plot(funcs, lim, gridx, gridy)\n",
"anim = initialize_animation()\n",
"gui()\n",
"\n",
"progressviz = track_progress(net, fields=[:x,:Σ,:a], size=(800,800), m=2, w=0);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"iterations_per_frame = 1000\n",
"total_frames = 200\n",
"dist = Distributions.Uniform(lim...)\n",
"\n",
"for frm in 1:total_frames\n",
" # pick one of the functions at random, sample from the x line, then update the\n",
" # neural net with [x, f(x)] as the inputsn = 1000\n",
" for i in 1:iterations_per_frame\n",
" f = sample(funcs)\n",
" x = rand(dist)\n",
" y = target(f)\n",
" update!(net, Float64[x, f(x)], [y])\n",
" end\n",
" \n",
" # update the progress visualization\n",
" update!(progressviz, true, show=false)\n",
"\n",
" # update the plot... project each series to the first hidden layer and reset the data\n",
" # NOTE: this works because `getindex` and `setindex` are overloaded to get/set the underlying plot series data\n",
" x = linspace(lim..., 50)\n",
" p[1] = activateHidden(net, hcat(gridx, gridy))\n",
" p[2] = activateHidden(net, hcat(gridy, gridx))\n",
" p[3] = activateHidden(net, hcat(x, map(f1,x)))\n",
" p[4] = activateHidden(net, hcat(x, map(f2,x)))\n",
"\n",
" # show/update the plot\n",
" gui(p)\n",
" frame(anim)\n",
"end\n",
"\n",
"# displays the progress\n",
"progressviz.subplt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# show stacked and linked histograms of the preditions for each class\n",
"testn = 100\n",
"xs = linspace(lim..., testn)\n",
"x1, x2 = [hcat(xs,map(f,xs)) for f in funcs]\n",
"# testx = vcat(hcat(xs,map(f1,xs)), hcat(xs,map(f2,xs)))\n",
"# testy = vcat(ones(testn), -ones(testn))\n",
"y1, y2 = ones(testn), -ones(testn)\n",
"yhat1, yhat2 = [vec(predict(net, x)) for x in (x1,x2)]\n",
"subplot(histogram(yhat1), histogram(yhat2), nc=1, linkx=true, title=[\"f1 prediction\", \"f2 prediction\"], xlim=lim)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"plot()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"plot(xs, hcat(map(f1,xs), map(f2,xs), reshape(yhat,testn,2)), leg=true, w=[2 2 5 5])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Animate!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"gif(anim, fps = 20)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# Network viz"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# show the network (uses Qwt, visualize isn't available unless you import it)\n",
"import Qwt\n",
"viz = visualize(net);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# update the net representation with weights, etc\n",
"update!(viz)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# testing..."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"selection[3][2]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"p[4][2] |> length"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"gui(progressviz.subplt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"using Plots\n",
"p1 = plot(rand(20))\n",
"p2 = plot(rand(10))\n",
"p3 = scatter(rand(100))\n",
"p4 = plot(rand(1000))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"subplot(p1,p2,p3,p4, nr=1, leg=false)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"using Plots; immerse()\n",
"p = plot(rand(10))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"gui()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"append!(p,1,rand(10))\n",
"gui()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"sp = progressviz.subplt.plts[1].o.widget[:minimumSizeHint]()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 0.4.0",
"language": "julia",
"name": "julia-0.4"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "0.4.0"
}
},
"nbformat": 4,
"nbformat_minor": 0
}