Showing change in a treemap in matplotlib

79 Views Asked by At

I am trying to create this:

Treemap in matplotlib

The data for the chart is:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


data = {
    "year": [2004, 2022, 2004, 2022, 2004, 2022],
    "countries" : [ "Denmark", "Denmark", "Norway", "Norway","Sweden", "Sweden",],
    "sites": [4,10,5,8,13,15]
}
df= pd.DataFrame(data)
df['diff'] = df.groupby(['countries'])['sites'].diff()
df['diff'].fillna(df.sites, inplace=True)

df

I am aware that there are packages that do treemaps, (squarify and plotly, to name some), but I have not figured out how to do the one above where the values of the years are added to each other. (or the difference to be exact) and it would be fantastic to learn how to do it in pure matplotlib, if it is not too complex.

Anyone has any pointers? I havent found a lot of info on treemaps on google.

1

There are 1 best solutions below

4
Paul Brodersen On BEST ANSWER

There are two parts to this task.

  1. Computing a layout for the rectangles.
  2. Drawing the rectangles.

The first part can get quite involved: people publish scientific papers on the topic. It's not advisable to re-invent the wheel here. However, the second part is quite straightforward and can be done in matplotlib.

The solution below uses squarify to compute a layout using the larger value for each value pair, and then matplotlib to draw two rectangles on top of each other.

enter image description here

import numpy as np
import matplotlib.pyplot as plt
import squarify

from matplotlib import colormaps
from matplotlib.colors import to_rgba

DEFAULT_COLORS = list(zip(colormaps["tab20"].colors[::2],
                          colormaps["tab20"].colors[1::2]))


def color_to_grayscale(color):
    # Adapted from: https://stackoverflow.com/a/689547/2912349
    r, g, b, a = to_rgba(color)
    return (0.299 * r + 0.587 * g + 0.114 * b) * a


class PairedTreeMap:

    def __init__(self, values, colors=DEFAULT_COLORS, labels=None, ax=None, bbox=(0, 0, 200, 100)):
        """
        Draw a treemap of value pairs.

        values : list[tuple[float, float]]
            A list of value pairs.

        colors : list[tuple[RGBA, RGBA]]
            The corresponding color pairs. Defaults to light/dark tab20 matplotlib color pairs.

        labels : list[str]
            The labels, one for each pair.

        ax : matplotlib.axes._axes.Axes
            The matplotlib axis instance to draw on.

        bbox : tuple[float, float, float, float]
            The (x, y) origin and (width, height) extent of the treemap.

        """

        self.ax = self.initialize_axis(ax)
        self.rects = self.get_layout(values, bbox)
        self.artists = list(self.draw(self.rects, values, colors, self.ax))

        if labels:
            self.labels = list(self.add_labels(self.rects, labels, values, colors, self.ax))


    def get_layout(self, values, bbox):
        maxima = np.max(values, axis=1)
        order = np.argsort(maxima)[::-1]
        normalized_maxima = squarify.normalize_sizes(maxima[order], *bbox[2:])
        rects = squarify.padded_squarify(normalized_maxima, *bbox)
        reorder = np.argsort(order)
        return [rects[ii] for ii in reorder]


    def initialize_axis(self, ax=None):
        if ax is None:
            fig, ax = plt.subplots()
        ax.set_aspect("equal")
        ax.axis("off")
        return ax


    def _get_artist_pair(self, rect, value_pair, color_pair):
        x, y, w, h = rect["x"], rect["y"], rect["dx"], rect["dy"]
        (small, large), (color_small, color_large) = zip(*sorted(zip(value_pair, color_pair)))
        ratio = np.sqrt(small / large)
        return (plt.Rectangle((x, y), w,         h,         color=color_large, zorder=1),
                plt.Rectangle((x, y), w * ratio, h * ratio, color=color_small, zorder=2))


    def draw(self, rects, values, colors, ax):
        for rect, value_pair, color_pair in zip(rects, values, colors):
            large_patch, small_patch = self._get_artist_pair(rect, value_pair, color_pair)
            ax.add_patch(large_patch)
            ax.add_patch(small_patch)
            yield(large_patch, small_patch)
        ax.autoscale_view()


    def add_labels(self, rects, labels, values, colors, ax):
        for rect, label, value_pair, color_pair in zip(rects, labels, values, colors):
            x, y, w, h = rect["x"], rect["y"], rect["dx"], rect["dy"]
            # decide a fontcolor based on background brightness
            (small, large), (color_small, color_large) = zip(*sorted(zip(value_pair, color_pair)))
            ratio = small / large
            background_brightness = color_to_grayscale(color_large) if ratio < 0.33 else color_to_grayscale(color_small) # i.e. 0.25 + some fudge
            fontcolor = "white" if background_brightness < 0.5 else "black"
            yield ax.text(x + w/2, y + h/2, label, va="center", ha="center", color=fontcolor)


if __name__ == "__main__":

    values = [
        (4, 10),
        (13, 15),
        (5, 8),
    ]

    colors = [
        ("red", "coral"),
        ("royalblue", "cornflowerblue"),
        ("darkslategrey", "gray"),
    ]

    labels = [
        "Denmark",
        "Sweden",
        "Norway"
    ]

    PairedTreeMap(values, colors=colors, labels=labels, bbox=(0, 0, 100, 100))
    plt.show()