Plots.jl/examples/meetup/nnet.ipynb
Thomas Breloff a3a8cb9368 nnet
2015-10-26 01:40:12 -04:00

643 lines
14 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"ENV[\"PYTHONPATH\"] = joinpath(Pkg.dir(\"Qwt\"), \"src\", \"python\");\n",
"\n",
"using Plots, Distributions; qwt()\n",
"default(size=(500,300), leg=false)\n",
"\n",
"# creates x/y vectors which can define a grid in a zig-zag pattern\n",
"function gridxy(lim, n::Int)\n",
" xs = linspace(lim..., n)\n",
" xypairs = vec([(x,y) for x in vcat(xs,reverse(xs)), y in xs])\n",
" Plots.unzip(xypairs)\n",
"end"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# The problem... can we classify the functions?"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# these are the functions we want to classify\n",
"scalar = 5 # larger is harder\n",
"noise = Distributions.Normal(0, 0.05)\n",
"\n",
"# # problem #1... non-overlapping\n",
"f1(x) = 0.6sin(scalar * x) + 0.1 + rand(noise)\n",
"f2(x) = f1(x) - 0.3\n",
"\n",
"# problem #2... overlapping\n",
"# f1(x) = 0.6sin(scalar * x)\n",
"# f2(x) = 0.6sin(scalar * (x+0.1))\n",
"\n",
"# our target function is ∈ {-1,1}\n",
"target(f) = f == f1 ? 1.0 : -1.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# On to the fun..."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# pick the plotting limits\n",
"lim = (-1,1)\n",
"funcs = [f1, f2]\n",
"n = 20\n",
"gridx, gridy = gridxy(lim, n)\n",
"# default(xlim = lim, ylim = lim)\n",
"\n",
"function initialize_plot(funcs, lim, gridx, gridy; kw...)\n",
" # show the grid\n",
" plot([gridx gridy], [gridy gridx], c=:black; kw...)\n",
"\n",
" # show the funcs\n",
" plot!(funcs, lim..., l=(4,[:royalblue :orangered]))\n",
"end\n",
"\n",
"# kick off an animation... we can save frames whenever we want, lets save the starting frame\n",
"function initialize_animation()\n",
" anim = Animation()\n",
" frame(anim)\n",
" anim\n",
"end\n",
"\n",
"# lets see what we're dealing with...\n",
"p = initialize_plot(funcs, lim, gridx, gridy)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Lets build a neural net!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"using OnlineAI\n",
"\n",
"# gradientModel = SGDModel(η=1e-4, μ=0.8, λ=0)\n",
"# gradientModel = AdagradModel(η=1e-1)\n",
"# gradientModel = AdadeltaModel(η=0.1, ρ=0.99, λ=0)\n",
"# gradientModel = AdamModel(η=1e-4, λ=1e-8)\n",
"gradientModel = AdaMaxModel(η=1e-4, ρ1=0.9, ρ2=0.9)\n",
"\n",
"# learningRateModel = FixedLearningRate()\n",
"learningRateModel = AdaptiveLearningRate(gradientModel, 2e-2, 0.05, wgt=ExponentialWeighting(30))\n",
"\n",
"function OnlineAI.initialWeights(nin::Int, nout::Int, activation::Activation)\n",
" 0.1randn(nout, nin) / sqrt(nin) + eye(nout, nin)\n",
"end\n",
"\n",
"net = buildTanhClassificationNet(\n",
" 2, # number of inputs\n",
" 1, # number of outputs\n",
" [2,2,2,2,2,2], # hidden layers structure\n",
" params = NetParams(gradientModel = gradientModel)\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Update our model and the visualization"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# set up a visualization of the projections\n",
"layers = filter(l -> l.nout == 2, net.layers[1:end-1])\n",
"num_hidden_layers = length(layers)\n",
"plts = [initialize_plot(funcs, lim, gridx, gridy, title=\"Hidden Layer $i\") for i in 1:num_hidden_layers]\n",
"sz = round(Int, sqrt(num_hidden_layers) * 400)\n",
"projectionviz = subplot(plts..., n=num_hidden_layers, size=(sz,sz))\n",
"\n",
"# setup animation, then show the plots in a window\n",
"anim = initialize_animation()\n",
"gui()\n",
"\n",
"# create another visualization to track the internal progress of the neural net\n",
"progressviz = track_progress(net, fields=[:w,:b,:Σ,:a], size=(num_hidden_layers*300,800), m=2, w=0);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"dist = Distributions.Uniform(lim...)\n",
"# dist = Distributions.Uniform(-0.6,0.6)\n",
"progressgui = false\n",
"\n",
"function test_data(n, lim, funcs)\n",
" xs = linspace(lim..., n)\n",
" x1, x2 = [hcat(xs,map(f,xs)) for f in funcs]\n",
" y1, y2 = ones(n), -ones(n)\n",
" DataPoints(vcat(x1,x2), vcat(y1,y2))\n",
"end\n",
"\n",
"testn = 100\n",
"testdata = test_data(testn, lim, funcs)\n",
"\n",
"function activateHidden(net, layers, x, y, seriesidx, plts)\n",
" n = length(x)\n",
" p = length(layers)\n",
" projx, projy = zeros(n,p), zeros(n,p)\n",
" for i in 1:n\n",
" # feed the data through the neural net\n",
" OnlineAI.forward!(net, [x[i], y[i]])\n",
" \n",
" # grab the net's activations at each layer\n",
" for j in 1:p\n",
" projx[i,j], projy[i,j] = layers[j].Σ\n",
" end\n",
" end\n",
" \n",
" # now we can update the plots\n",
" for j in 1:p\n",
" plts[j][seriesidx] = (vec(projx[:,j]), vec(projy[:,j]))\n",
" end\n",
"end\n",
"\n",
"# final plot to track test error\n",
"errviz = subplot([totalCost(net, testdata) gradientModel.η], m=3, title=[\"Error\" \"η\"], n=2,nc=1, pos=(800,0))\n",
"gui(errviz)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"iterations_per_frame = 1000\n",
"total_frames = 100\n",
"for frm in 1:total_frames\n",
" # pick one of the functions at random, sample from the x line, then update the\n",
" # neural net with [x, f(x)] as the inputs\n",
" for i in 1:iterations_per_frame\n",
" f = sample(funcs)\n",
" x = rand(dist)\n",
" y = target(f)\n",
" update!(net, Float64[x, f(x)], [y])\n",
" end\n",
" \n",
" # update the progress visualization\n",
" update!(progressviz, true, show=progressgui)\n",
" \n",
" # update the error plot\n",
" err = totalCost(net, testdata)\n",
" push!(errviz.plts[1], err)\n",
" update!(learningRateModel, err)\n",
" push!(errviz.plts[2], gradientModel.η)\n",
" gui(errviz)\n",
"\n",
" # update the projections\n",
" x = linspace(lim..., 70)\n",
" for (seriesidx, (x,y)) in enumerate([(gridx,gridy), (gridy,gridx), (x,map(f1,x)), (x,map(f2,x))])\n",
" activateHidden(net, layers, x, y, seriesidx, projectionviz.plts)\n",
" end\n",
" \n",
" # show/update the plot\n",
" gui(projectionviz)\n",
" frame(anim)\n",
" sleep(0.001)\n",
"end\n",
"\n",
"# displays the progress if there's no gui\n",
"progressgui || progressviz.subplt"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# # show stacked and linked histograms of the predictions for each class\n",
"xs = OnlineAI.unzip(testdata)[1]\n",
"yhat = predict(net, xs)\n",
"yhat1, yhat2 = yhat[1:testn], yhat[testn+1:end]\n",
"subplot(histogram(yhat1), histogram(yhat2), nc=1, linkx=true, title=[\"f1 prediction\" \"f2 prediction\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"xs = xs[1:testn]\n",
"plot(xs, hcat(map(f1,xs), map(f2,xs), yhat1, yhat2), leg=true,\n",
" line=([2 2 5 5], [:royalblue :orangered], [:solid :solid :dash :dash]))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Animate!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"gif(anim, fps = 10)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# Network viz"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# show the network (uses Qwt, visualize isn't available unless you import it)\n",
"ENV[\"PYTHONPATH\"] = joinpath(Pkg.dir(\"Qwt\"), \"src\", \"python\");\n",
"import Qwt\n",
"viz = visualize(net);"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# update the net representation with weights, etc\n",
"update!(viz)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"metadata": {
"collapsed": true
},
"source": [
"# testing..."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"selection[3][2]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"p[4][2] |> length"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"gui(progressviz.subplt)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"histogram(yhat1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"progressviz.subplt.plts[1].seriesargs[1][:serieshandle][:get_offsets]()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"learningRateModel"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"update!(d,5)\n",
"diff(d)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"using Plots\n",
"p1 = plot(rand(20))\n",
"p2 = plot(rand(10))\n",
"p3 = scatter(rand(100))\n",
"p4 = plot(rand(1000))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"subplot(p1,p2,p3,p4, nr=1, leg=false)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# ENV[\"MPLBACKEND\"] = \"qt4agg\"\n",
"using Plots; pyplot()\n",
"p = scatter(rand(10))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"p.seriesargs[1][:serieshandle][:get_offsets]()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"PyPlot.backend"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"gui()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"append!(p,1,rand(10))\n",
"gui()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"sp = progressviz.subplt.plts[1].o.widget[:minimumSizeHint]()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"testn = 100\n",
"xs = linspace(lim..., testn)\n",
"x1, x2 = [hcat(xs,map(f,xs)) for f in funcs]\n",
"y1, y2 = ones(testn), -ones(testn)\n",
"yhat1, yhat2 = [vec(predict(net, x)) for x in (x1,x2)]\n",
"DataPoints(vcat(x1,x2), vcat(y1,y2))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 0.4.0",
"language": "julia",
"name": "julia-0.4"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "0.4.0"
}
},
"nbformat": 4,
"nbformat_minor": 0
}