524 lines
11 KiB
Plaintext
524 lines
11 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"using Plots; immerse()\n",
|
||
"default(size=(500,300), leg=false)\n",
|
||
"\n",
|
||
"# creates x/y vectors which can define a grid in a zig-zag pattern\n",
|
||
"function gridxy(lim, n::Int)\n",
|
||
" xs = linspace(lim..., n)\n",
|
||
" xypairs = vec([(x,y) for x in vcat(xs,reverse(xs)), y in xs])\n",
|
||
" Plots.unzip(xypairs)\n",
|
||
"end"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# The problem... can we classify the functions?"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# these are the functions we want to classify\n",
|
||
"scalar = 5 # larger is harder... start with 3\n",
|
||
"f1(x) = 0.6sin(scalar * x) + 0.1\n",
|
||
"f2(x) = f1(x) - 0.2\n",
|
||
"\n",
|
||
"# our target function is ∈ {-1,1}\n",
|
||
"target(f) = f == f1 ? 1.0 : -1.0"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# On to the fun..."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# pick the plotting limits\n",
|
||
"lim = (-1,1)\n",
|
||
"funcs = [f1, f2]\n",
|
||
"n = 40\n",
|
||
"gridx, gridy = gridxy(lim, n)\n",
|
||
"default(xlim = lim, ylim = lim)\n",
|
||
"\n",
|
||
"function initialize_plot(funcs, lim, gridx, gridy; kw...)\n",
|
||
" # show the grid\n",
|
||
" plot([gridx gridy], [gridy gridx], c=:black, kw...)\n",
|
||
"\n",
|
||
" # show the funcs\n",
|
||
" plot!(funcs, lim..., l=(4,[:blue :red]))\n",
|
||
"end\n",
|
||
"\n",
|
||
"# kick off an animation... we can save frames whenever we want, lets save the starting frame\n",
|
||
"function initialize_animation()\n",
|
||
" anim = Animation()\n",
|
||
" frame(anim)\n",
|
||
" anim\n",
|
||
"end\n",
|
||
"\n",
|
||
"# lets see what we're dealing with...\n",
|
||
"p = initialize_plot(funcs, lim, gridx, gridy)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# That looks tricky... lets build a neural net!"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"using OnlineAI\n",
|
||
"net = buildTanhClassificationNet(\n",
|
||
" 2, # number of inputs\n",
|
||
" 1, # number of outputs\n",
|
||
" [2], # hidden layers structure\n",
|
||
"# params = NetParams(gradientModel = SGDModel(η=1e-5))\n",
|
||
"params = NetParams(gradientModel = AdadeltaModel(η=1e-3, ρ=0.98))\n",
|
||
"# params = NetParams(gradientModel = AdagradModel(η=1e-1))\n",
|
||
")\n",
|
||
"\n",
|
||
"# take x matrix and convert to the first layer's activation\n",
|
||
"function activateHidden(net, x)\n",
|
||
" @assert net.layers[end].nin == 2\n",
|
||
" proj = zeros(nrows(x), 2)\n",
|
||
" for i in 1:nrows(x)\n",
|
||
" data = row(x,i)\n",
|
||
" for layer in net.layers[1:end-1]\n",
|
||
" OnlineAI.forward!(layer, data, false)\n",
|
||
" data = layer.a\n",
|
||
" end\n",
|
||
" row!(proj, i, data)\n",
|
||
" end\n",
|
||
" vec(proj[:,1]), vec(proj[:,2])\n",
|
||
"end "
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Update our model and the visualization"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"p = initialize_plot(funcs, lim, gridx, gridy)\n",
|
||
"anim = initialize_animation()\n",
|
||
"gui()\n",
|
||
"\n",
|
||
"progressviz = track_progress(net, fields=[:x,:Σ,:a], size=(800,800), m=2, w=0);"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"iterations_per_frame = 1000\n",
|
||
"total_frames = 200\n",
|
||
"dist = Distributions.Uniform(lim...)\n",
|
||
"\n",
|
||
"for frm in 1:total_frames\n",
|
||
" # pick one of the functions at random, sample from the x line, then update the\n",
|
||
" # neural net with [x, f(x)] as the inputsn = 1000\n",
|
||
" for i in 1:iterations_per_frame\n",
|
||
" f = sample(funcs)\n",
|
||
" x = rand(dist)\n",
|
||
" y = target(f)\n",
|
||
" update!(net, Float64[x, f(x)], [y])\n",
|
||
" end\n",
|
||
" \n",
|
||
" # update the progress visualization\n",
|
||
" update!(progressviz, true, show=false)\n",
|
||
"\n",
|
||
" # update the plot... project each series to the first hidden layer and reset the data\n",
|
||
" # NOTE: this works because `getindex` and `setindex` are overloaded to get/set the underlying plot series data\n",
|
||
" x = linspace(lim..., 50)\n",
|
||
" p[1] = activateHidden(net, hcat(gridx, gridy))\n",
|
||
" p[2] = activateHidden(net, hcat(gridy, gridx))\n",
|
||
" p[3] = activateHidden(net, hcat(x, map(f1,x)))\n",
|
||
" p[4] = activateHidden(net, hcat(x, map(f2,x)))\n",
|
||
"\n",
|
||
" # show/update the plot\n",
|
||
" gui(p)\n",
|
||
" frame(anim)\n",
|
||
"end\n",
|
||
"\n",
|
||
"# displays the progress\n",
|
||
"progressviz.subplt"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# show stacked and linked histograms of the preditions for each class\n",
|
||
"testn = 100\n",
|
||
"xs = linspace(lim..., testn)\n",
|
||
"x1, x2 = [hcat(xs,map(f,xs)) for f in funcs]\n",
|
||
"# testx = vcat(hcat(xs,map(f1,xs)), hcat(xs,map(f2,xs)))\n",
|
||
"# testy = vcat(ones(testn), -ones(testn))\n",
|
||
"y1, y2 = ones(testn), -ones(testn)\n",
|
||
"yhat1, yhat2 = [vec(predict(net, x)) for x in (x1,x2)]\n",
|
||
"subplot(histogram(yhat1), histogram(yhat2), nc=1, linkx=true, title=[\"f1 prediction\", \"f2 prediction\"], xlim=lim)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"plot()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"plot(xs, hcat(map(f1,xs), map(f2,xs), reshape(yhat,testn,2)), leg=true, w=[2 2 5 5])"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Animate!"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"gif(anim, fps = 20)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"source": [
|
||
"# Network viz"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# show the network (uses Qwt, visualize isn't available unless you import it)\n",
|
||
"import Qwt\n",
|
||
"viz = visualize(net);"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"# update the net representation with weights, etc\n",
|
||
"update!(viz)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"source": [
|
||
"# testing..."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"selection[3][2]"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"p[4][2] |> length"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"gui(progressviz.subplt)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"using Plots\n",
|
||
"p1 = plot(rand(20))\n",
|
||
"p2 = plot(rand(10))\n",
|
||
"p3 = scatter(rand(100))\n",
|
||
"p4 = plot(rand(1000))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"subplot(p1,p2,p3,p4, nr=1, leg=false)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"using Plots; immerse()\n",
|
||
"p = plot(rand(10))"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"gui()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"append!(p,1,rand(10))\n",
|
||
"gui()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": false
|
||
},
|
||
"outputs": [],
|
||
"source": [
|
||
"sp = progressviz.subplt.plts[1].o.widget[:minimumSizeHint]()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {
|
||
"collapsed": true
|
||
},
|
||
"outputs": [],
|
||
"source": []
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Julia 0.4.0",
|
||
"language": "julia",
|
||
"name": "julia-0.4"
|
||
},
|
||
"language_info": {
|
||
"file_extension": ".jl",
|
||
"mimetype": "application/julia",
|
||
"name": "julia",
|
||
"version": "0.4.0"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 0
|
||
}
|