651 lines
53 KiB
Plaintext
651 lines
53 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"gridxy (generic function with 1 method)"
|
||
]
|
||
},
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"using Plots; qwt()\n",
|
||
"default(size=(500,300), leg=false)\n",
|
||
"\n",
|
||
"# creates x/y vectors which can define a grid in a zig-zag pattern\n",
|
||
"function gridxy(lim, n::Int)\n",
|
||
" xs = linspace(lim..., n)\n",
|
||
" xypairs = vec([(x,y) for x in vcat(xs,reverse(xs)), y in xs])\n",
|
||
" Plots.unzip(xypairs)\n",
|
||
"end"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# The problem... can we classify the functions?"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"target (generic function with 1 method)"
|
||
]
|
||
},
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# these are the functions we want to classify\n",
|
||
"scalar = 8 # larger is harder... start with 3\n",
|
||
"f1(x) = 0.6sin(scalar * x) + 0.1\n",
|
||
"f2(x) = f1(x) - 0.2\n",
|
||
"\n",
|
||
"# our target function is ∈ {-1,1}\n",
|
||
"target(f) = f == f1 ? 1.0 : -1.0"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# On to the fun..."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[Plots.jl] Initializing backend: qwt"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"Plot{Plots.QwtPackage() n=4}"
|
||
]
|
||
},
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# pick the plotting limits\n",
|
||
"lim = (-1,1)\n",
|
||
"funcs = [f1, f2]\n",
|
||
"n = 40\n",
|
||
"gridx, gridy = gridxy(lim, n)\n",
|
||
"default(xlim = lim, ylim = lim)\n",
|
||
"\n",
|
||
"function initialize_plot(funcs, lim, gridx, gridy; kw...)\n",
|
||
" # show the grid\n",
|
||
" plot([gridx gridy], [gridy gridx], c=:black; kw...)\n",
|
||
"\n",
|
||
" # show the funcs\n",
|
||
" plot!(funcs, lim..., l=(4,[:blue :red]))\n",
|
||
"end\n",
|
||
"\n",
|
||
"# kick off an animation... we can save frames whenever we want, lets save the starting frame\n",
|
||
"function initialize_animation()\n",
|
||
" anim = Animation()\n",
|
||
" frame(anim)\n",
|
||
" anim\n",
|
||
"end\n",
|
||
"\n",
|
||
"# lets see what we're dealing with...\n",
|
||
"p = initialize_plot(funcs, lim, gridx, gridy)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# That looks tricky... lets build a neural net!"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 33,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"NeuralNet{\n",
|
||
" params: NetParams{OnlineAI.AdadeltaModel(1.0e-6,0.3,0.96,1.0e-5) NoDropout OnlineAI.L2CostModel()}\n",
|
||
" solverParams: OnlineAI.SolverParams(1000,1000,10000,-1,[:x,:xhat,:y,:Σ,:a],100,1.0e-5,OnlineAI.donothing)\n",
|
||
" layers:\n",
|
||
" NormalizedLayer{2=>2 OnlineAI.TanhActivation() p=1.0 ‖δΣ‖₁=0.0 ‖δy‖₁=0.0 }\n",
|
||
" NormalizedLayer{2=>2 OnlineAI.TanhActivation() p=1.0 ‖δΣ‖₁=0.0 ‖δy‖₁=0.0 }\n",
|
||
" NormalizedLayer{2=>5 OnlineAI.TanhActivation() p=1.0 ‖δΣ‖₁=0.0 ‖δy‖₁=0.0 }\n",
|
||
" NormalizedLayer{5=>100 OnlineAI.TanhActivation() p=1.0 ‖δΣ‖₁=0.0 ‖δy‖₁=0.0 }\n",
|
||
" NormalizedLayer{100=>5 OnlineAI.TanhActivation() p=1.0 ‖δΣ‖₁=0.0 ‖δy‖₁=0.0 }\n",
|
||
" NormalizedLayer{5=>2 OnlineAI.TanhActivation() p=1.0 ‖δΣ‖₁=0.0 ‖δy‖₁=0.0 }\n",
|
||
" NormalizedLayer{2=>2 OnlineAI.TanhActivation() p=1.0 ‖δΣ‖₁=0.0 ‖δy‖₁=0.0 }\n",
|
||
" NormalizedLayer{2=>1 OnlineAI.TanhActivation() p=1.0 ‖δΣ‖₁=0.0 ‖δy‖₁=0.0 }\n",
|
||
"}\n"
|
||
]
|
||
},
|
||
"execution_count": 33,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"using OnlineAI\n",
|
||
"\n",
|
||
"# gradientModel = SGDModel(η=1e-4, μ=0.5)\n",
|
||
"# gradientModel = AdagradModel(η=1e-2)\n",
|
||
"gradientModel = AdadeltaModel(η=3e-1, ρ=0.96)\n",
|
||
"\n",
|
||
"net = buildTanhClassificationNet(\n",
|
||
" 2, # number of inputs\n",
|
||
" 1, # number of outputs\n",
|
||
" [2,2,5,100,5,2,2], # hidden layers structure\n",
|
||
" params = NetParams(gradientModel = gradientModel)\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Update our model and the visualization"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 34,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# set up a visualization of the projections\n",
|
||
"layers = filter(l -> l.nout == 2, net.layers[1:end-1])\n",
|
||
"num_hidden_layers = length(layers)\n",
|
||
"plts = [initialize_plot(funcs, lim, gridx, gridy, title=\"Hidden Layer $i\") for i in 1:num_hidden_layers]\n",
|
||
"sz = round(Int, sqrt(num_hidden_layers) * 600)\n",
|
||
"projectionviz = subplot(plts..., n=num_hidden_layers, size=(sz,sz))\n",
|
||
"\n",
|
||
"# setup animation, then show the plots in a window\n",
|
||
"anim = initialize_animation()\n",
|
||
"gui()\n",
|
||
"\n",
|
||
"# create another visualization to track the internal progress of the neural net\n",
|
||
"progressviz = track_progress(net, fields=[:w,:b,:Σ,:a], size=(num_hidden_layers*300,800), m=2, w=0);"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 35,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"activateHidden (generic function with 1 method)"
|
||
]
|
||
},
|
||
"execution_count": 35,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"dist = Distributions.Uniform(lim...)\n",
|
||
"progressgui = false\n",
|
||
"\n",
|
||
"function activateHidden(net, layers, x, y, seriesidx, plts)\n",
|
||
" n = length(x)\n",
|
||
" p = length(plts)\n",
|
||
" projx, projy = zeros(n,p), zeros(n,p)\n",
|
||
" for i in 1:n\n",
|
||
" # feed the data through the neural net\n",
|
||
" OnlineAI.forward!(net, [x[i], y[i]])\n",
|
||
" \n",
|
||
" # grab the net's activations at each layer\n",
|
||
" for j in 1:p\n",
|
||
" projx[i,j], projy[i,j] = layers[j].a\n",
|
||
" end\n",
|
||
" end\n",
|
||
" \n",
|
||
" # now we can update the plots\n",
|
||
" for j in 1:p\n",
|
||
" plts[j][seriesidx] = (vec(projx[:,j]), vec(projy[:,j]))\n",
|
||
" end\n",
|
||
"end"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"iterations_per_frame = 10000\n",
|
||
"total_frames = 100\n",
|
||
"for frm in 1:total_frames\n",
|
||
" # pick one of the functions at random, sample from the x line, then update the\n",
|
||
" # neural net with [x, f(x)] as the inputs\n",
|
||
" for i in 1:iterations_per_frame\n",
|
||
" f = sample(funcs)\n",
|
||
" x = rand(dist)\n",
|
||
" y = target(f)\n",
|
||
" update!(net, Float64[x, f(x)], [y])\n",
|
||
" end\n",
|
||
" \n",
|
||
" # update the progress visualization\n",
|
||
" update!(progressviz, true, show=progressgui)\n",
|
||
"\n",
|
||
" # update the projections\n",
|
||
" x = linspace(lim..., 100)\n",
|
||
" for (seriesidx, (x,y)) in enumerate([(gridx,gridy), (gridy,gridx), (x,map(f1,x)), (x,map(f2,x))])\n",
|
||
" activateHidden(net, layers, x, y, seriesidx, projectionviz.plts)\n",
|
||
" end\n",
|
||
" \n",
|
||
" # show/update the plot\n",
|
||
" gui(projectionviz)\n",
|
||
" frame(anim)\n",
|
||
" sleep(0.001)\n",
|
||
"end\n",
|
||
"\n",
|
||
"# displays the progress if there's no gui\n",
|
||
"progressgui || progressviz.subplt"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 31,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"WARNING: handleLinkInner isn't implemented for qwt\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"Subplot{Plots.QwtPackage() p=2 n=2}"
|
||
]
|
||
},
|
||
"execution_count": 31,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"# show stacked and linked histograms of the preditions for each class\n",
|
||
"testn = 100\n",
|
||
"xs = linspace(lim..., testn)\n",
|
||
"x1, x2 = [hcat(xs,map(f,xs)) for f in funcs]\n",
|
||
"y1, y2 = ones(testn), -ones(testn)\n",
|
||
"yhat1, yhat2 = [vec(predict(net, x)) for x in (x1,x2)]\n",
|
||
"subplot(histogram(yhat1), histogram(yhat2), nc=1, linkx=true, title=[\"f1 prediction\", \"f2 prediction\"])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 32,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [
|
||
{
|
||
"data": {
|
||
"image/png": "",
|
||
"text/plain": [
|
||
"Plot{Plots.QwtPackage() n=4}"
|
||
]
|
||
},
|
||
"execution_count": 32,
|
||
"metadata": {},
|
||
"output_type": "execute_result"
|
||
}
|
||
],
|
||
"source": [
|
||
"plot(xs, hcat(map(f1,xs), map(f2,xs), yhat1, yhat2), leg=true, line=([2 2 5 5], [:blue :red], [:solid :solid :dot :dot]))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Animate!"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"gif(anim, fps = 20)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"source": [
|
||
"# Network viz"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# show the network (uses Qwt, visualize isn't available unless you import it)\n",
|
||
"import Qwt\n",
|
||
"viz = visualize(net);"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# update the net representation with weights, etc\n",
|
||
"update!(viz)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"source": [
|
||
"# testing..."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"selection[3][2]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"p[4][2] |> length"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"gui(progressviz.subplt)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"histogram(yhat1)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"using Plots\n",
|
||
"p1 = plot(rand(20))\n",
|
||
"p2 = plot(rand(10))\n",
|
||
"p3 = scatter(rand(100))\n",
|
||
"p4 = plot(rand(1000))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"subplot(p1,p2,p3,p4, nr=1, leg=false)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"using Plots; immerse()\n",
|
||
"p = plot(rand(10))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"gui()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"append!(p,1,rand(10))\n",
|
||
"gui()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"sp = progressviz.subplt.plts[1].o.widget[:minimumSizeHint]()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Julia 0.4.0",
|
||
"language": "julia",
|
||
"name": "julia-0.4"
|
||
},
|
||
"language_info": {
|
||
"file_extension": ".jl",
|
||
"mimetype": "application/julia",
|
||
"name": "julia",
|
||
"version": "0.4.0"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 0
|
||
}
|