{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Implementing PMFs\n", "\n", "This notebook outlines the API for `Pmf` objects in the `empiricaldist` library, showing the implementations of many methods.\n", "\n", "[Click here to run this notebook on Colab](https://colab.research.google.com/github/AllenDowney/empiricaldist/blob/master/empiricaldist/pmf_demo.ipynb)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "try:\n", " import empiricaldist\n", "except ImportError:\n", " !pip install empiricaldist" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import inspect\n", "\n", "def psource(obj):\n", " \"\"\"Prints the source code for a given object.\n", "\n", " obj: function or method object\n", " \"\"\"\n", " print(inspect.getsource(obj))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Constructor\n", "\n", "For comments or questions about this section, see [this issue](https://github.com/AllenDowney/EmpyricalDistributions/issues/1).\n", "\n", "The `Pmf` class inherits its constructor from `pd.Series`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can create an empty `Pmf` and then add elements.\n", "\n", "Here's a `Pmf` that represents a six-sided die." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from empiricaldist import Pmf\n", "\n", "d6 = Pmf()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "for x in [1,2,3,4,5,6]:\n", " d6[x] = 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initially the probabilities don't add up to 1." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
11
21
31
41
51
61
\n", "
" ], "text/plain": [ "1 1\n", "2 1\n", "3 1\n", "4 1\n", "5 1\n", "6 1\n", "dtype: int64" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`normalize` adds up the probabilities and divides through. The return value is the total probability before normalizing." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " def normalize(self):\n", " \"\"\"Make the probabilities add up to 1 (modifies self).\n", "\n", " Returns: normalizing constant\n", " \"\"\"\n", " total = self.sum()\n", " self /= total\n", " return total\n", "\n" ] } ], "source": [ "psource(Pmf.normalize)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6.normalize()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now the Pmf is normalized." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.166667
20.166667
30.166667
40.166667
50.166667
60.166667
\n", "
" ], "text/plain": [ "1 0.166667\n", "2 0.166667\n", "3 0.166667\n", "4 0.166667\n", "5 0.166667\n", "6 0.166667\n", "dtype: float64" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "###Properties\n", "\n", "For comments or questions about this section, see [this issue](https://github.com/AllenDowney/EmpyricalDistributions/issues/2).\n", "\n", "In a `Pmf` the index contains the quantities (`qs`) and the values contain the probabilities (`ps`).\n", "\n", "These attributes are available as properties that return arrays (same semantics as the Pandas `values` property)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 2, 3, 4, 5, 6])" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6.qs" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.16666667, 0.16666667, 0.16666667, 0.16666667, 0.16666667,\n", " 0.16666667])" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6.ps" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sharing\n", "\n", "For comments or questions about this section, see [this issue](https://github.com/AllenDowney/EmpyricalDistributions/issues/3).\n", "\n", "Because `Pmf` is a `Series` you can initialize it with any type `Series.__init__` can handle.\n", "\n", "Here's an example with a dictionary." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
a1
b2
c3
\n", "
" ], "text/plain": [ "a 1\n", "b 2\n", "c 3\n", "dtype: int64" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d = dict(a=1, b=2, c=3)\n", "pmf = Pmf(d)\n", "pmf" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's an example with two lists." ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.25
20.25
30.25
40.25
\n", "
" ], "text/plain": [ "1 0.25\n", "2 0.25\n", "3 0.25\n", "4 0.25\n", "dtype: float64" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "qs = [1,2,3,4]\n", "ps = [0.25, 0.25, 0.25, 0.25]\n", "d4 = Pmf(ps, index=qs)\n", "d4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can copy a `Pmf` like this." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.166667
20.166667
30.166667
40.166667
50.166667
60.166667
\n", "
" ], "text/plain": [ "1 0.166667\n", "2 0.166667\n", "3 0.166667\n", "4 0.166667\n", "5 0.166667\n", "6 0.166667\n", "dtype: float64" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6_copy = Pmf(d6)\n", "d6_copy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, you have to be careful about sharing. In this example, the copies share the arrays:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6.index is d6_copy.index" ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6.ps is d6_copy.ps" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can avoid sharing with `copy=True`" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.166667
20.166667
30.166667
40.166667
50.166667
60.166667
\n", "
" ], "text/plain": [ "1 0.166667\n", "2 0.166667\n", "3 0.166667\n", "4 0.166667\n", "5 0.166667\n", "6 0.166667\n", "dtype: float64" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6_copy = Pmf(d6, copy=True)\n", "d6_copy" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6.index is d6_copy.index" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6.ps is d6_copy.ps" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Or by calling `copy` explicitly." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.25
20.25
30.25
40.25
\n", "
" ], "text/plain": [ "1 0.25\n", "2 0.25\n", "3 0.25\n", "4 0.25\n", "dtype: float64" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d4_copy = d4.copy()\n", "d4_copy" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "True" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d4.index is d4_copy.index" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "False" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d4.ps is d4_copy.ps" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Displaying PMFs\n", "\n", "For comments or questions about this section, see [this issue](https://github.com/AllenDowney/EmpyricalDistributions/issues/4).\n", "\n", "`Pmf` provides `_repr_html_`, so it looks good when displayed in a notebook." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " def _repr_html_(self):\n", " \"\"\"Returns an HTML representation of the series.\n", "\n", " Mostly used for Jupyter notebooks.\n", " \"\"\"\n", " df = pd.DataFrame(dict(probs=self))\n", " return df._repr_html_()\n", "\n" ] } ], "source": [ "psource(Pmf._repr_html_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Pmf` provides `bar`, which plots the Pmf as a bar chart." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " def bar(self, **options):\n", " \"\"\"Make a bar plot.\n", "\n", " Note: A previous version of this function used pd.Series.plot.bar,\n", " but that was a mistake, because that function treats the quantities\n", " as categorical, even if they are numerical, leading to hilariously\n", " unexpected results!\n", "\n", " Args:\n", " options: passed to plt.bar\n", " \"\"\"\n", " underride(options, label=self.name)\n", " plt.bar(self.qs, self.ps, **options)\n", "\n" ] } ], "source": [ "psource(Pmf.bar)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "def decorate_dice(title):\n", " \"\"\"Labels the axes.\n", " \n", " title: string\n", " \"\"\"\n", " plt.xlabel('Outcome')\n", " plt.ylabel('PMF')\n", " plt.title(title)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "d6.bar()\n", "decorate_dice('One die')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Pmf` inherits `plot` from `Series`." ] }, { "cell_type": "code", "execution_count": 26, "metadata": { "scrolled": true }, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "d6.plot()\n", "decorate_dice('One die')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make Pmf from sequence\n", "\n", "For comments or questions about this section, see [this issue](https://github.com/AllenDowney/EmpyricalDistributions/issues/5).\n", "\n", "\n", "The following function makes a `Pmf` object from a sequence of values." ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " @staticmethod\n", " def from_seq(\n", " seq,\n", " normalize=True,\n", " sort=True,\n", " ascending=True,\n", " dropna=True,\n", " na_position=\"last\",\n", " **options,\n", " ):\n", " \"\"\"Make a PMF from a sequence of values.\n", "\n", " Args:\n", " seq: iterable\n", " normalize: whether to normalize the Pmf, default True\n", " sort: whether to sort the Pmf by values, default True\n", " ascending: whether to sort in ascending order, default True\n", " dropna: whether to drop NaN values, default True\n", " na_position: If ‘first’ puts NaNs at the beginning,\n", " ‘last’ puts NaNs at the end.\n", " options: passed to the pd.Series constructor\n", "\n", " Returns: Pmf object\n", " \"\"\"\n", " # compute the value counts\n", " series = pd.Series(seq).value_counts(\n", " normalize=normalize, sort=False, dropna=dropna\n", " )\n", " # make the result a Pmf\n", " # (since we just made a fresh Series, there is no reason to copy it)\n", " options[\"copy\"] = False\n", " underride(options, name=\"\")\n", " pmf = Pmf(series, **options)\n", "\n", " # sort in place, if desired\n", " if sort:\n", " pmf.sort_index(\n", " inplace=True, ascending=ascending, na_position=na_position\n", " )\n", "\n", " return pmf\n", "\n" ] } ], "source": [ "psource(Pmf.from_seq)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
a0.2
e0.2
l0.4
n0.2
\n", "
" ], "text/plain": [ "a 0.2\n", "e 0.2\n", "l 0.4\n", "n 0.2\n", "Name: , dtype: float64" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pmf = Pmf.from_seq(list('allen'))\n", "pmf" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.2
20.4
30.2
50.2
\n", "
" ], "text/plain": [ "1 0.2\n", "2 0.4\n", "3 0.2\n", "5 0.2\n", "Name: , dtype: float64" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pmf = Pmf.from_seq(np.array([1, 2, 2, 3, 5]))\n", "pmf" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Selection\n", "\n", "For comments or questions about this section, see [this issue](https://github.com/AllenDowney/EmpyricalDistributions/issues/6).\n", "\n", "`Pmf` overrides `__getitem__` to return 0 for values that are not in the distribution." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " def __getitem__(self, key):\n", " check_dict_or_set_indexers(key)\n", " key = com.apply_if_callable(key, self)\n", "\n", " if key is Ellipsis:\n", " return self\n", "\n", " key_is_scalar = is_scalar(key)\n", " if isinstance(key, (list, tuple)):\n", " key = unpack_1tuple(key)\n", "\n", " if is_integer(key) and self.index._should_fallback_to_positional:\n", " warnings.warn(\n", " # GH#50617\n", " \"Series.__getitem__ treating keys as positions is deprecated. \"\n", " \"In a future version, integer keys will always be treated \"\n", " \"as labels (consistent with DataFrame behavior). To access \"\n", " \"a value by position, use `ser.iloc[pos]`\",\n", " FutureWarning,\n", " stacklevel=find_stack_level(),\n", " )\n", " return self._values[key]\n", "\n", " elif key_is_scalar:\n", " return self._get_value(key)\n", "\n", " # Convert generator to list before going through hashable part\n", " # (We will iterate through the generator there to check for slices)\n", " if is_iterator(key):\n", " key = list(key)\n", "\n", " if is_hashable(key) and not isinstance(key, slice):\n", " # Otherwise index.get_value will raise InvalidIndexError\n", " try:\n", " # For labels that don't resolve as scalars like tuples and frozensets\n", " result = self._get_value(key)\n", "\n", " return result\n", "\n", " except (KeyError, TypeError, InvalidIndexError):\n", " # InvalidIndexError for e.g. generator\n", " # see test_series_getitem_corner_generator\n", " if isinstance(key, tuple) and isinstance(self.index, MultiIndex):\n", " # We still have the corner case where a tuple is a key\n", " # in the first level of our MultiIndex\n", " return self._get_values_tuple(key)\n", "\n", " if isinstance(key, slice):\n", " # Do slice check before somewhat-costly is_bool_indexer\n", " return self._getitem_slice(key)\n", "\n", " if com.is_bool_indexer(key):\n", " key = check_bool_indexer(self.index, key)\n", " key = np.asarray(key, dtype=bool)\n", " return self._get_rows_with_mask(key)\n", "\n", " return self._get_with(key)\n", "\n" ] } ], "source": [ "psource(Pmf.__getitem__)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.16666666666666666" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6[1]" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.16666666666666666" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6[6]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you use square brackets to look up a quantity that's not in the `Pmf`, you get a `KeyError`. " ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "# d6[7]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`Pmf` objects are mutable, but in general the result is not normalized." ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "d7 = d6.copy()" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.166667
20.166667
30.166667
40.166667
50.166667
60.166667
70.166667
\n", "
" ], "text/plain": [ "1 0.166667\n", "2 0.166667\n", "3 0.166667\n", "4 0.166667\n", "5 0.166667\n", "6 0.166667\n", "7 0.166667\n", "dtype: float64" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d7[7] = 1/6\n", "d7" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.1666666666666665" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d7.sum()" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.1666666666666665" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d7.normalize()" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.0000000000000002" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d7.sum()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Statistics\n", "\n", "For comments or questions about this section, see [this issue](https://github.com/AllenDowney/EmpyricalDistributions/issues/7).\n", "\n", "`Pmf` overrides the statistics methods to compute `mean`, `median`, etc.\n", "\n", "These functions only work correctly if the `Pmf` is normalized." ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " def mean(self):\n", " \"\"\"Computes expected value.\n", "\n", " Returns: float\n", " \"\"\"\n", " if not np.allclose(1, self.sum()):\n", " raise ValueError(\"Pmf must be normalized before computing mean\")\n", "\n", " if not pd.api.types.is_numeric_dtype(self.dtype):\n", " raise ValueError(\"mean is only defined for numeric data\")\n", "\n", " return np.sum(self.ps * self.qs)\n", "\n" ] } ], "source": [ "psource(Pmf.mean)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3.5" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6.mean()" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " def var(self):\n", " \"\"\"Variance of a PMF.\n", "\n", " Returns: float\n", " \"\"\"\n", " m = self.mean()\n", " d = self.qs - m\n", " return np.sum(d**2 * self.ps)\n", "\n" ] } ], "source": [ "psource(Pmf.var)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "2.9166666666666665" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6.var()" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " def std(self):\n", " \"\"\"Standard deviation of a PMF.\n", "\n", " Returns: float\n", " \"\"\"\n", " return np.sqrt(self.var())\n", "\n" ] } ], "source": [ "psource(Pmf.std)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1.707825127659933" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6.std()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Sampling\n", "\n", "For comments or questions about this section, see [this issue](https://github.com/AllenDowney/EmpyricalDistributions/issues/8).\n", "\n", "`choice` chooses a random values from the Pmf, following the API of `np.random.choice`" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " def choice(self, *args, **kwargs):\n", " \"\"\"Makes a random sample.\n", "\n", " Uses the probabilities as weights unless `p` is provided.\n", "\n", " Args:\n", " args: same as np.random.choice\n", " kwargs: same as np.random.choice\n", "\n", " Returns: NumPy array\n", " \"\"\"\n", " underride(kwargs, p=self.ps)\n", " return np.random.choice(self.qs, *args, **kwargs)\n", "\n" ] } ], "source": [ "psource(Pmf.choice)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([3, 3, 5, 6, 1, 5, 3, 6, 6, 1])" ] }, "execution_count": 46, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6.choice(size=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`sample` chooses a random values from the `Pmf`, following the API of `pd.Series.sample`" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " def sample(self, *args, **kwargs):\n", " \"\"\"Samples with replacement using probabilities as weights.\n", "\n", " Uses the inverse CDF.\n", "\n", " Args:\n", " n: number of values\n", "\n", " Returns: NumPy array\n", " \"\"\"\n", " cdf = self.make_cdf()\n", " return cdf.sample(*args, **kwargs)\n", "\n" ] } ], "source": [ "psource(Pmf.sample)" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([3., 3., 5., 1., 1., 1., 3., 5., 3., 6.])" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6.sample(n=10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Arithmetic\n", "\n", "For comments or questions about this section, see [this issue](https://github.com/AllenDowney/EmpyricalDistributions/issues/9).\n", "\n", "`Pmf` provides `add_dist`, which computes the distribution of the sum.\n", "\n", "The implementation uses outer products to compute the convolution of the two distributions." ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " def add_dist(self, x):\n", " \"\"\"Computes the Pmf of the sum of values drawn from self and x.\n", "\n", " Args:\n", " x: Distribution, scalar, or sequence\n", "\n", " Returns: new Pmf\n", " \"\"\"\n", " if isinstance(x, Distribution):\n", " return self.convolve_dist(x, np.add.outer)\n", " else:\n", " return Pmf(self.ps.copy(), index=self.qs + x)\n", "\n" ] } ], "source": [ "psource(Pmf.add_dist)" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " def convolve_dist(self, dist, ufunc):\n", " \"\"\"Convolve two distributions.\n", "\n", " Args:\n", " dist: Distribution\n", " ufunc: elementwise function for arrays\n", "\n", " Returns: new Pmf\n", " \"\"\"\n", " dist = dist.make_pmf()\n", " qs = ufunc(self.qs, dist.qs).flatten()\n", " ps = np.multiply.outer(self.ps, dist.ps).flatten()\n", " series = pd.Series(ps).groupby(qs).sum()\n", "\n", " return Pmf(series)\n", "\n" ] } ], "source": [ "psource(Pmf.convolve_dist)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's the distribution of the sum of two dice." ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
20.027778
30.055556
40.083333
50.111111
60.138889
70.166667
80.138889
90.111111
100.083333
110.055556
120.027778
\n", "
" ], "text/plain": [ "2 0.027778\n", "3 0.055556\n", "4 0.083333\n", "5 0.111111\n", "6 0.138889\n", "7 0.166667\n", "8 0.138889\n", "9 0.111111\n", "10 0.083333\n", "11 0.055556\n", "12 0.027778\n", "dtype: float64" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6 = Pmf.from_seq([1,2,3,4,5,6])\n", "\n", "twice = d6.add_dist(d6)\n", "twice" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6.999999999999998" ] }, "execution_count": 52, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "twice.bar()\n", "decorate_dice('Two dice')\n", "twice.mean()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To add a constant to a distribution, you could construct a deterministic `Pmf`" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
20.25
30.25
40.25
50.25
\n", "
" ], "text/plain": [ "2 0.25\n", "3 0.25\n", "4 0.25\n", "5 0.25\n", "dtype: float64" ] }, "execution_count": 53, "metadata": {}, "output_type": "execute_result" } ], "source": [ "const = Pmf.from_seq([1])\n", "d4.add_dist(const)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But `add_dist` also handles constants as a special case:" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
20.25
30.25
40.25
50.25
\n", "
" ], "text/plain": [ "2 0.25\n", "3 0.25\n", "4 0.25\n", "5 0.25\n", "dtype: float64" ] }, "execution_count": 54, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d4.add_dist(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Other arithmetic operations are also implemented" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
-30.041667
-20.083333
-10.125000
00.166667
10.166667
20.166667
30.125000
40.083333
50.041667
\n", "
" ], "text/plain": [ "-3 0.041667\n", "-2 0.083333\n", "-1 0.125000\n", " 0 0.166667\n", " 1 0.166667\n", " 2 0.166667\n", " 3 0.125000\n", " 4 0.083333\n", " 5 0.041667\n", "dtype: float64" ] }, "execution_count": 55, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6.sub_dist(d4)" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.0625
20.1250
30.1250
40.1875
60.1250
80.1250
90.0625
120.1250
160.0625
\n", "
" ], "text/plain": [ "1 0.0625\n", "2 0.1250\n", "3 0.1250\n", "4 0.1875\n", "6 0.1250\n", "8 0.1250\n", "9 0.0625\n", "12 0.1250\n", "16 0.0625\n", "dtype: float64" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d4.mul_dist(d4)" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
0.2500000.0625
0.3333330.0625
0.5000000.1250
0.6666670.0625
0.7500000.0625
1.0000000.2500
1.3333330.0625
1.5000000.0625
2.0000000.1250
3.0000000.0625
4.0000000.0625
\n", "
" ], "text/plain": [ "0.250000 0.0625\n", "0.333333 0.0625\n", "0.500000 0.1250\n", "0.666667 0.0625\n", "0.750000 0.0625\n", "1.000000 0.2500\n", "1.333333 0.0625\n", "1.500000 0.0625\n", "2.000000 0.1250\n", "3.000000 0.0625\n", "4.000000 0.0625\n", "dtype: float64" ] }, "execution_count": 57, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d4.div_dist(d4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Comparison operators\n", "\n", "`Pmf` implements comparison operators that return probabilities.\n", "\n", "You can compare a `Pmf` to a scalar:" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.3333333333333333" ] }, "execution_count": 58, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6.lt_dist(3)" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.75" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d4.ge_dist(2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Or compare `Pmf` objects:" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.25" ] }, "execution_count": 60, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d4.gt_dist(d6)" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.41666666666666663" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6.le_dist(d4)" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.16666666666666666" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d4.eq_dist(d6)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Interestingly, this way of comparing distributions is [nontransitive]()." ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [], "source": [ "A = Pmf.from_seq([2, 2, 4, 4, 9, 9])\n", "B = Pmf.from_seq([1, 1, 6, 6, 8, 8])\n", "C = Pmf.from_seq([3, 3, 5, 5, 7, 7])" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.5555555555555556" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "source": [ "A.gt_dist(B)" ] }, { "cell_type": "code", "execution_count": 65, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.5555555555555556" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "B.gt_dist(C)" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.5555555555555556" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "C.gt_dist(A)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Joint distributions\n", "\n", "For comments or questions about this section, see [this issue](https://github.com/AllenDowney/EmpyricalDistributions/issues/10).\n", "\n", "`Pmf.make_joint` takes two `Pmf` objects and makes their joint distribution, assuming independence." ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " def make_joint(self, other, **options):\n", " \"\"\"Make joint distribution (assuming independence).\n", "\n", " Args:\n", " other: Pmf\n", " options: passed to Pmf constructor\n", "\n", " Returns: new Pmf\n", " \"\"\"\n", " qs = pd.MultiIndex.from_product([self.qs, other.qs])\n", " ps = np.multiply.outer(self.ps, other.ps).flatten()\n", " return Pmf(ps, index=qs, **options)\n", "\n" ] } ], "source": [ "psource(Pmf.make_joint)" ] }, { "cell_type": "code", "execution_count": 68, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.25
20.25
30.25
40.25
\n", "
" ], "text/plain": [ "1 0.25\n", "2 0.25\n", "3 0.25\n", "4 0.25\n", "Name: , dtype: float64" ] }, "execution_count": 68, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d4 = Pmf.from_seq(range(1,5))\n", "d4" ] }, { "cell_type": "code", "execution_count": 69, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.166667
20.166667
30.166667
40.166667
50.166667
60.166667
\n", "
" ], "text/plain": [ "1 0.166667\n", "2 0.166667\n", "3 0.166667\n", "4 0.166667\n", "5 0.166667\n", "6 0.166667\n", "Name: , dtype: float64" ] }, "execution_count": 69, "metadata": {}, "output_type": "execute_result" } ], "source": [ "d6 = Pmf.from_seq(range(1,7))\n", "d6" ] }, { "cell_type": "code", "execution_count": 70, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
110.041667
20.041667
30.041667
40.041667
50.041667
60.041667
210.041667
20.041667
30.041667
40.041667
50.041667
60.041667
310.041667
20.041667
30.041667
40.041667
50.041667
60.041667
410.041667
20.041667
30.041667
40.041667
50.041667
60.041667
\n", "
" ], "text/plain": [ "1 1 0.041667\n", " 2 0.041667\n", " 3 0.041667\n", " 4 0.041667\n", " 5 0.041667\n", " 6 0.041667\n", "2 1 0.041667\n", " 2 0.041667\n", " 3 0.041667\n", " 4 0.041667\n", " 5 0.041667\n", " 6 0.041667\n", "3 1 0.041667\n", " 2 0.041667\n", " 3 0.041667\n", " 4 0.041667\n", " 5 0.041667\n", " 6 0.041667\n", "4 1 0.041667\n", " 2 0.041667\n", " 3 0.041667\n", " 4 0.041667\n", " 5 0.041667\n", " 6 0.041667\n", "dtype: float64" ] }, "execution_count": 70, "metadata": {}, "output_type": "execute_result" } ], "source": [ "joint = Pmf.make_joint(d4, d6)\n", "joint" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The result is a `Pmf` object that uses a MultiIndex to represent the values." ] }, { "cell_type": "code", "execution_count": 71, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MultiIndex([(1, 1),\n", " (1, 2),\n", " (1, 3),\n", " (1, 4),\n", " (1, 5),\n", " (1, 6),\n", " (2, 1),\n", " (2, 2),\n", " (2, 3),\n", " (2, 4),\n", " (2, 5),\n", " (2, 6),\n", " (3, 1),\n", " (3, 2),\n", " (3, 3),\n", " (3, 4),\n", " (3, 5),\n", " (3, 6),\n", " (4, 1),\n", " (4, 2),\n", " (4, 3),\n", " (4, 4),\n", " (4, 5),\n", " (4, 6)],\n", " )" ] }, "execution_count": 71, "metadata": {}, "output_type": "execute_result" } ], "source": [ "joint.index" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you ask for the `qs`, you get an array of pairs:" ] }, { "cell_type": "code", "execution_count": 72, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([(1, 1), (1, 2), (1, 3), (1, 4), (1, 5), (1, 6), (2, 1), (2, 2),\n", " (2, 3), (2, 4), (2, 5), (2, 6), (3, 1), (3, 2), (3, 3), (3, 4),\n", " (3, 5), (3, 6), (4, 1), (4, 2), (4, 3), (4, 4), (4, 5), (4, 6)],\n", " dtype=object)" ] }, "execution_count": 72, "metadata": {}, "output_type": "execute_result" } ], "source": [ "joint.qs" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can select elements using tuples:" ] }, { "cell_type": "code", "execution_count": 73, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0.041666666666666664" ] }, "execution_count": 73, "metadata": {}, "output_type": "execute_result" } ], "source": [ "joint[1,1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can get unnnormalized conditional distributions by selecting on different axes:" ] }, { "cell_type": "code", "execution_count": 74, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.041667
20.041667
30.041667
40.041667
50.041667
60.041667
\n", "
" ], "text/plain": [ "1 0.041667\n", "2 0.041667\n", "3 0.041667\n", "4 0.041667\n", "5 0.041667\n", "6 0.041667\n", "dtype: float64" ] }, "execution_count": 74, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Pmf(joint[1])" ] }, { "cell_type": "code", "execution_count": 75, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.041667
20.041667
30.041667
40.041667
\n", "
" ], "text/plain": [ "1 0.041667\n", "2 0.041667\n", "3 0.041667\n", "4 0.041667\n", "dtype: float64" ] }, "execution_count": 75, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Pmf(joint.loc[:, 1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But `Pmf` also provides `conditional(i, val)` which returns the conditional distribution where the value on level `i` is `val`." ] }, { "cell_type": "code", "execution_count": 76, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " def conditional(self, i, val, name=None):\n", " \"\"\"Gets the conditional distribution of the indicated variable.\n", "\n", " Args:\n", " i: index of the variable we're conditioning on\n", " val: the value the ith variable has to have\n", " name: string\n", "\n", " Returns: Pmf\n", " \"\"\"\n", " pmf = Pmf(self.xs(key=val, level=i), copy=True, name=name)\n", " pmf.normalize()\n", " return pmf\n", "\n" ] } ], "source": [ "psource(joint.conditional)" ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.166667
20.166667
30.166667
40.166667
50.166667
60.166667
\n", "
" ], "text/plain": [ "1 0.166667\n", "2 0.166667\n", "3 0.166667\n", "4 0.166667\n", "5 0.166667\n", "6 0.166667\n", "dtype: float64" ] }, "execution_count": 77, "metadata": {}, "output_type": "execute_result" } ], "source": [ "joint.conditional(0, 1)" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.25
20.25
30.25
40.25
\n", "
" ], "text/plain": [ "1 0.25\n", "2 0.25\n", "3 0.25\n", "4 0.25\n", "dtype: float64" ] }, "execution_count": 78, "metadata": {}, "output_type": "execute_result" } ], "source": [ "joint.conditional(1, 1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It also provides `marginal(i)`, which returns the marginal distribution along axis `i`" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " def marginal(self, i, name=None):\n", " \"\"\"Gets the marginal distribution of the indicated variable.\n", "\n", " Args:\n", " i: index of the variable we want\n", " name: string\n", "\n", " Returns: Pmf\n", " \"\"\"\n", " # The following is deprecated now\n", " # return Pmf(self.sum(level=i))\n", "\n", " # here's the new version\n", " return Pmf(self.groupby(level=i).sum(), name=name)\n", "\n" ] } ], "source": [ "psource(Pmf.marginal)" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.25
20.25
30.25
40.25
\n", "
" ], "text/plain": [ "1 0.25\n", "2 0.25\n", "3 0.25\n", "4 0.25\n", "dtype: float64" ] }, "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ "joint.marginal(0)" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
probs
10.166667
20.166667
30.166667
40.166667
50.166667
60.166667
\n", "
" ], "text/plain": [ "1 0.166667\n", "2 0.166667\n", "3 0.166667\n", "4 0.166667\n", "5 0.166667\n", "6 0.166667\n", "dtype: float64" ] }, "execution_count": 81, "metadata": {}, "output_type": "execute_result" } ], "source": [ "joint.marginal(1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here are some ways of iterating through a joint distribution." ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1, 1)\n", "(1, 2)\n", "(1, 3)\n", "(1, 4)\n", "(1, 5)\n", "(1, 6)\n", "(2, 1)\n", "(2, 2)\n", "(2, 3)\n", "(2, 4)\n", "(2, 5)\n", "(2, 6)\n", "(3, 1)\n", "(3, 2)\n", "(3, 3)\n", "(3, 4)\n", "(3, 5)\n", "(3, 6)\n", "(4, 1)\n", "(4, 2)\n", "(4, 3)\n", "(4, 4)\n", "(4, 5)\n", "(4, 6)\n" ] } ], "source": [ "for q in joint.qs:\n", " print(q)" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n", "0.041666666666666664\n" ] } ], "source": [ "for p in joint.ps:\n", " print(p)" ] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(1, 1) 0.041666666666666664\n", "(1, 2) 0.041666666666666664\n", "(1, 3) 0.041666666666666664\n", "(1, 4) 0.041666666666666664\n", "(1, 5) 0.041666666666666664\n", "(1, 6) 0.041666666666666664\n", "(2, 1) 0.041666666666666664\n", "(2, 2) 0.041666666666666664\n", "(2, 3) 0.041666666666666664\n", "(2, 4) 0.041666666666666664\n", "(2, 5) 0.041666666666666664\n", "(2, 6) 0.041666666666666664\n", "(3, 1) 0.041666666666666664\n", "(3, 2) 0.041666666666666664\n", "(3, 3) 0.041666666666666664\n", "(3, 4) 0.041666666666666664\n", "(3, 5) 0.041666666666666664\n", "(3, 6) 0.041666666666666664\n", "(4, 1) 0.041666666666666664\n", "(4, 2) 0.041666666666666664\n", "(4, 3) 0.041666666666666664\n", "(4, 4) 0.041666666666666664\n", "(4, 5) 0.041666666666666664\n", "(4, 6) 0.041666666666666664\n" ] } ], "source": [ "for q, p in joint.items():\n", " print(q, p)" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "1 1 0.041666666666666664\n", "1 2 0.041666666666666664\n", "1 3 0.041666666666666664\n", "1 4 0.041666666666666664\n", "1 5 0.041666666666666664\n", "1 6 0.041666666666666664\n", "2 1 0.041666666666666664\n", "2 2 0.041666666666666664\n", "2 3 0.041666666666666664\n", "2 4 0.041666666666666664\n", "2 5 0.041666666666666664\n", "2 6 0.041666666666666664\n", "3 1 0.041666666666666664\n", "3 2 0.041666666666666664\n", "3 3 0.041666666666666664\n", "3 4 0.041666666666666664\n", "3 5 0.041666666666666664\n", "3 6 0.041666666666666664\n", "4 1 0.041666666666666664\n", "4 2 0.041666666666666664\n", "4 3 0.041666666666666664\n", "4 4 0.041666666666666664\n", "4 5 0.041666666666666664\n", "4 6 0.041666666666666664\n" ] } ], "source": [ "for (q1, q2), p in joint.items():\n", " print(q1, q2, p)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Copyright 2019 Allen Downey\n", "\n", "BSD 3-clause license: https://opensource.org/licenses/BSD-3-Clause" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 2 }