diff --git a/src/backends/pyplot.jl b/src/backends/pyplot.jl index 1bfe4a64..082e704b 100644 --- a/src/backends/pyplot.jl +++ b/src/backends/pyplot.jl @@ -10,6 +10,7 @@ function _initialize_backend(::PyPlotBackend) const mplot3d = PyPlot.pywrap(PyPlot.pyimport("mpl_toolkits.mplot3d")) const pypatches = PyPlot.pywrap(PyPlot.pyimport("matplotlib.patches")) const pyfont = PyPlot.pywrap(PyPlot.pyimport("matplotlib.font_manager")) + const pyticker = PyPlot.pywrap(PyPlot.pyimport("matplotlib.ticker")) # const pycolorbar = PyPlot.pywrap(PyPlot.pyimport("matplotlib.colorbar")) end @@ -42,6 +43,11 @@ function getPyPlotColorMap(c::ColorGradient, α=nothing) pycolors.pymember("LinearSegmentedColormap")[:from_list]("tmp", pyvals) end +# convert vectors and ColorVectors to standard ColorGradients +# TODO: move this logic to colors.jl and keep a barebones wrapper for pyplot +getPyPlotColorMap(cv::ColorVector, α=nothing) = getPyPlotColorMap(ColorGradient(cv.v), α) +getPyPlotColorMap(v::AVec, α=nothing) = getPyPlotColorMap(ColorGradient(v), α) + # anything else just gets a bluesred gradient getPyPlotColorMap(c, α=nothing) = getPyPlotColorMap(default_gradient(), α) @@ -138,6 +144,15 @@ function getPyPlotFont(font::Font) ) end +function get_locator_and_formatter(vals::AVec) + pyticker.pymember("FixedLocator")(1:length(vals)), pyticker.pymember("FixedFormatter")(vals) +end + +function add_pyfixedformatter(cbar, vals::AVec) + cbar[:locator], cbar[:formatter] = get_locator_and_formatter(vals) + cbar[:update_ticks]() +end + # --------------------------------------------------------------------------- type PyPlotAxisWrapper @@ -310,6 +325,7 @@ function _add_series(pkg::PyPlotBackend, plt::Plot, d::KW) # holds references to any python object representing the matplotlib series handles = [] needs_colorbar = false + discrete_colorbar_values = nothing # for each plotting command, optionally build and add a series handle to the list @@ -500,6 +516,9 @@ function _add_series(pkg::PyPlotBackend, plt::Plot, d::KW) if lt == :heatmap x, y, z = heatmap_edges(x), heatmap_edges(y), z.surf' + if !(eltype(z) <: Number) + z, discrete_colorbar_values = indices_and_unique_values(z) + end handle = ax[:pcolormesh](x, y, z; label = d[:label], zorder = plt.n, @@ -531,7 +550,23 @@ function _add_series(pkg::PyPlotBackend, plt::Plot, d::KW) # add the colorbar legend if needs_colorbar && plt.plotargs[:colorbar] != :none - PyPlot.colorbar(handles[end], ax=ax) + # cbar = PyPlot.colorbar(handles[end], ax=ax) + + # do we need a discrete colorbar? + if discrete_colorbar_values == nothing + PyPlot.colorbar(handles[end], ax=ax) + else + # add_pyfixedformatter(cbar, discrete_colorbar_values) + locator, formatter = get_locator_and_formatter(discrete_colorbar_values) + vals = 1:length(discrete_colorbar_values) + PyPlot.colorbar(handles[end], + ax = ax, + ticks = locator, + format = formatter, + boundaries = vcat(0, vals + 0.5), + values = vals + ) + end end # this sets the bg color inside the grid diff --git a/src/series_args.jl b/src/series_args.jl index 8db6e392..2dceba87 100644 --- a/src/series_args.jl +++ b/src/series_args.jl @@ -302,7 +302,7 @@ function process_inputs{TX,TY}(plt::AbstractPlot, d::KW, x::AVec{TX}, y::AVec{TY end # surface-like... matrix grid -function process_inputs{TX,TY,TZ<:Number}(plt::AbstractPlot, d::KW, x::AVec{TX}, y::AVec{TY}, zmat::AMat{TZ}) +function process_inputs{TX,TY,TZ}(plt::AbstractPlot, d::KW, x::AVec{TX}, y::AVec{TY}, zmat::AMat{TZ}) @assert size(zmat) == (length(x), length(y)) if TX <: Number && !issorted(x) idx = sortperm(x) @@ -312,13 +312,7 @@ function process_inputs{TX,TY,TZ<:Number}(plt::AbstractPlot, d::KW, x::AVec{TX}, idx = sortperm(y) y, zmat = y[idx], zmat[:, idx] end - # - # !issorted(y) - # y_idx = sortperm(y) - # x, y = x[x_idx], y[y_idx] - # zmat = zmat[x_idx, y_idx] - # end - d[:x], d[:y], d[:z] = x, y, Surface{Matrix{Float64}}(zmat) + d[:x], d[:y], d[:z] = x, y, Surface{Matrix{TZ}}(zmat) if !like_surface(get(d, :linetype, :none)) d[:linetype] = :contour end diff --git a/src/utils.jl b/src/utils.jl index 30f6106d..d2f591e0 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -243,6 +243,14 @@ Base.merge(a::AbstractVector, b::AbstractVector) = sort(unique(vcat(a,b))) nanpush!(a::AbstractVector, b) = (push!(a, NaN); push!(a, b)) nanappend!(a::AbstractVector, b) = (push!(a, NaN); append!(a, b)) +# given an array of discrete values, turn it into an array of indices of the unique values +# returns the array of indices (znew) and a vector of unique values (vals) +function indices_and_unique_values(z::AbstractArray) + vals = sort(unique(z)) + vmap = Dict([(v,i) for (i,v) in enumerate(vals)]) + newz = map(zi -> vmap[zi], z) + newz, vals +end # ---------------------------------------------------------------