nnet
This commit is contained in:
parent
62fad3724f
commit
96c66b33a2
@ -8,18 +8,9 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"ENV[\"PYTHONPATH\"] = joinpath(Pkg.dir(\"Qwt\"), \"src\", \"python\");"
|
"ENV[\"PYTHONPATH\"] = joinpath(Pkg.dir(\"Qwt\"), \"src\", \"python\");\n",
|
||||||
]
|
"\n",
|
||||||
},
|
"using Plots, Distributions; qwt()\n",
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": null,
|
|
||||||
"metadata": {
|
|
||||||
"collapsed": false
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"using Plots; qwt()\n",
|
|
||||||
"default(size=(500,300), leg=false)\n",
|
"default(size=(500,300), leg=false)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# creates x/y vectors which can define a grid in a zig-zag pattern\n",
|
"# creates x/y vectors which can define a grid in a zig-zag pattern\n",
|
||||||
@ -46,9 +37,16 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# these are the functions we want to classify\n",
|
"# these are the functions we want to classify\n",
|
||||||
"scalar = 5 # larger is harder\n",
|
"scalar = 10 # larger is harder\n",
|
||||||
"f1(x) = 0.6sin(scalar * x) + 0.1\n",
|
"noise = Distributions.Normal(0, 0.05)\n",
|
||||||
"f2(x) = f1(x) - 0.2\n",
|
"\n",
|
||||||
|
"# # problem #1... non-overlapping\n",
|
||||||
|
"f1(x) = 0.6sin(scalar * x) + 0.1 + rand(noise)\n",
|
||||||
|
"f2(x) = f1(x) - 0.3\n",
|
||||||
|
"\n",
|
||||||
|
"# problem #2... overlapping\n",
|
||||||
|
"# f1(x) = 0.6sin(scalar * x)\n",
|
||||||
|
"# f2(x) = 0.6sin(scalar * (x+0.1))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# our target function is ∈ {-1,1}\n",
|
"# our target function is ∈ {-1,1}\n",
|
||||||
"target(f) = f == f1 ? 1.0 : -1.0"
|
"target(f) = f == f1 ? 1.0 : -1.0"
|
||||||
@ -81,7 +79,7 @@
|
|||||||
" plot([gridx gridy], [gridy gridx], c=:black; kw...)\n",
|
" plot([gridx gridy], [gridy gridx], c=:black; kw...)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # show the funcs\n",
|
" # show the funcs\n",
|
||||||
" plot!(funcs, lim..., l=(4,[:blue :red]))\n",
|
" plot!(funcs, lim..., l=(4,[:royalblue :orangered]))\n",
|
||||||
"end\n",
|
"end\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# kick off an animation... we can save frames whenever we want, lets save the starting frame\n",
|
"# kick off an animation... we can save frames whenever we want, lets save the starting frame\n",
|
||||||
@ -115,18 +113,20 @@
|
|||||||
"# gradientModel = SGDModel(η=1e-4, μ=0.8, λ=0)\n",
|
"# gradientModel = SGDModel(η=1e-4, μ=0.8, λ=0)\n",
|
||||||
"# gradientModel = AdagradModel(η=1e-1)\n",
|
"# gradientModel = AdagradModel(η=1e-1)\n",
|
||||||
"# gradientModel = AdadeltaModel(η=0.1, ρ=0.99, λ=0)\n",
|
"# gradientModel = AdadeltaModel(η=0.1, ρ=0.99, λ=0)\n",
|
||||||
"gradientModel = AdamModel(η=1e-4)\n",
|
"# gradientModel = AdamModel(η=1e-4, λ=1e-8)\n",
|
||||||
|
"gradientModel = AdaMaxModel(η=1e-4, ρ1=0.9, ρ2=0.9)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"learningRateModel = AdaptiveLearningRate(gradientModel)\n",
|
"# learningRateModel = FixedLearningRate()\n",
|
||||||
|
"learningRateModel = AdaptiveLearningRate(gradientModel, 2e-2, 0.05, wgt=ExponentialWeighting(30))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"function OnlineAI.initialWeights(nin::Int, nout::Int, activation::Activation)\n",
|
"function OnlineAI.initialWeights(nin::Int, nout::Int, activation::Activation)\n",
|
||||||
" 0.5randn(nout, nin) / sqrt(nin) + eye(nout, nin)\n",
|
" 0.1randn(nout, nin) / sqrt(nin) + eye(nout, nin)\n",
|
||||||
"end\n",
|
"end\n",
|
||||||
"\n",
|
"\n",
|
||||||
"net = buildTanhClassificationNet(\n",
|
"net = buildTanhClassificationNet(\n",
|
||||||
" 2, # number of inputs\n",
|
" 2, # number of inputs\n",
|
||||||
" 1, # number of outputs\n",
|
" 1, # number of outputs\n",
|
||||||
" 2ones(Int,4), # hidden layers structure\n",
|
" [100,100,2], # hidden layers structure\n",
|
||||||
" params = NetParams(gradientModel = gradientModel)\n",
|
" params = NetParams(gradientModel = gradientModel)\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
@ -204,7 +204,7 @@
|
|||||||
"end\n",
|
"end\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# final plot to track test error\n",
|
"# final plot to track test error\n",
|
||||||
"errviz = subplot([totalCost(net, testdata) gradientModel.η], m=3, title=[\"Error\" \"η\"], n=2,nc=1)\n",
|
"errviz = subplot([totalCost(net, testdata) gradientModel.η], m=3, title=[\"Error\" \"η\"], n=2,nc=1, pos=(800,0))\n",
|
||||||
"gui(errviz)"
|
"gui(errviz)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -278,7 +278,8 @@
|
|||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"xs = xs[1:testn]\n",
|
"xs = xs[1:testn]\n",
|
||||||
"plot(xs, hcat(map(f1,xs), map(f2,xs), yhat1, yhat2), leg=true, line=([2 2 5 5], [:blue :red], [:solid :solid :dash :dash]))"
|
"plot(xs, hcat(map(f1,xs), map(f2,xs), yhat1, yhat2), leg=true,\n",
|
||||||
|
" line=([2 2 5 5], [:royalblue :orangered], [:solid :solid :dash :dash]))"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -305,7 +306,7 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"gif(anim, fps = 20)"
|
"gif(anim, fps = 10)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user