Jekyll2021-09-27T18:55:00+00:00/feed.xmlevdokimovn’s blogA place for me to share my thoughs about programming with you Nikita EvdokimovELI5 numpy axes2021-09-27T18:23:51+00:002021-09-27T18:23:51+00:00/2021/09/27/eli5-numpy-axis<p>Axes in numpy can be a little tricky for beginners. Usually there is no problem with axes when they are used for indexing. Trouble hits when we start working with numpy methods. After this post you should build an inuiation which will allow you to effectively use axes in numpy operations.</p>
<p>First of all, what is numpy axis? Axis is nothing more than another term for an array dimension. As with any coordinate system number of axis equals dimensionality.</p>
<p>In Cartesian coordinate systems axis are usually referred by letters. \(X\), \(Y\), \(Z\) are three axes of a three dimensional space. Though numpy axes are not referred by letters, but rather by numbers, with the first axes being \(0\) (no surprises here).</p>
<p>It’s not enough to know how many axes there are, we also must know to which direction they correspond to. As an example let’s imagine an arbitrary 2D matrix. Since it has two dimensions there must be two axes. These dimensions are usually referred to as rows and columns. Naturally, axes are closely related to them. Axis \(0\) corresponds to rows and points in the direction of row increase - downwards, while axis \(1\) corresponds to columns and points in the direction which they increase - to the right.</p>
<p><img src="/assets/posts/numpy-axes/MatrixWithAxes.png" alt="MatrixWithAxes" /></p>
<p>A question you might ask: “What if we were to add another dimension, in what direction axis \(2\) will point?”. Well, the answer is very intuitive — it will point in the direction of depth’s increase. And if one more? This one is a bit harder to tell. However, it might be useful to remember that since numpy axis are exactly like coordinate system axes it means they form a <a href="https://en.wikipedia.org/wiki/Right-hand_rule">right-handed coordinate system</a>. Knowing this and direction of axis \(0\) is enough to deduce directions of all other axes.</p>
<h2 id="working-with-axes">Working with axes</h2>
<p>Now, let’s look at how two work with axes. There are two common operations that can be performed with axes: indexing and applying numpy operation.</p>
<h3 id="indexing">Indexing</h3>
<p>Using axes to index into numpy array is straightforward and is no different from indexing multidimensional arrays in other languages. Again, taking 2D matrix as an example, to choose an element we must specify a row (axis \(0\)) and a column (axis \(1\)). As an example let’s look at this matrix</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">a</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">3</span><span class="p">],[</span><span class="mi">3</span><span class="p">,</span><span class="mi">2</span><span class="p">,</span><span class="mi">1</span><span class="p">]])</span>
<span class="k">print</span><span class="p">(</span><span class="n">a</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"With shape"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
<span class="s">"""
[[1 2 3]
[3 2 1]]
With shape
(2,3)
"""</span>
</code></pre></div></div>
<p>and see what’s the element which is in the second row last column</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="n">a</span><span class="p">[</span><span class="mi">1</span><span class="p">][</span><span class="mi">2</span><span class="p">])</span>
<span class="s">"""
1
"""</span>
</code></pre></div></div>
<h3 id="operations">Operations</h3>
<p>There are tens on different <a href="https://numpy.org/doc/stable/reference/routines.html">operations</a> that can be applied to numpy arrays. Most of those operation accept <code class="language-plaintext highlighter-rouge">axis</code> as one of its parameters. This is the part that some numpy beginners find confusing.</p>
<p>Let’s look at the very common operator <a href="https://numpy.org/doc/stable/reference/generated/numpy.sum.html"><code class="language-plaintext highlighter-rouge">sum</code></a>, an aggregation operator which simply sums elements of a matrix. Let’s look at the 2D matrix from previous example and try to calculate sum of each row.</p>
<p>Usually the thought process goes like this: “I need to calculate sum of each row. When indexing axis \(0\) corresponds to rows, hence i need to call <code class="language-plaintext highlighter-rouge">sum</code> with <code class="language-plaintext highlighter-rouge">axis=0</code>”</p>
<p>Let’s do exactly that and see what we get!</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">row_sums</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">axis</span> <span class="o">=</span> <span class="mi">0</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">row_sums</span><span class="p">)</span>
<span class="s">"""
[4, 4, 4]
"""</span>
</code></pre></div></div>
<p>Well, this is definitely not what we wanted. Since there are two rows, we expect to get only 2 values, but instead we got 3, which corresponds to number of columns. So, when indexing axis \(0\) corresponds to rows, but when using it with operators it corresponds to columns?</p>
<p>Yes and no.</p>
<p>Let’s forget about aggregated values for a second and see what shape we get when we run <code class="language-plaintext highlighter-rouge">sum</code> with different axes as arguments.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="s">"Shape of original matrix"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Shape of sum with axis = 0"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">axis</span> <span class="o">=</span> <span class="mi">0</span><span class="p">).</span><span class="n">shape</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Shape of sum with axis = 1"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">axis</span> <span class="o">=</span> <span class="mi">1</span><span class="p">).</span><span class="n">shape</span><span class="p">)</span>
<span class="s">"""
Shape of original matrix
(2,3)
Shape of sum with axis = 0
(3,)
Shape of sum with axis = 1
(2,)
"""</span>
</code></pre></div></div>
<p>We should notice that It seems like we need to pass <code class="language-plaintext highlighter-rouge">1</code> as <code class="language-plaintext highlighter-rouge">axis</code> argument, because with this argument we get two values. Also, it looks as if axis specified by <code class="language-plaintext highlighter-rouge">axis</code> argument is collapsed.</p>
<p>This is much closer to the truth.</p>
<p>To get a full picture let’s look at another operation: <a href="https://numpy.org/doc/stable/reference/generated/numpy.repeat.html"><code class="language-plaintext highlighter-rouge">repeat</code></a>. It repeats elements of the arrays specified number of times. As <code class="language-plaintext highlighter-rouge">sum</code> it accepts <code class="language-plaintext highlighter-rouge">axis</code> as one of its parameters. Let’s setup <code class="language-plaintext highlighter-rouge">repeat</code> to double number of elements and run it with both axes again focusing on shape and not on the result itself.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">print</span><span class="p">(</span><span class="s">"Shape of original matrix"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">a</span><span class="p">.</span><span class="n">shape</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Shape of repeat with axis = 0"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">axis</span> <span class="o">=</span> <span class="mi">0</span><span class="p">).</span><span class="n">shape</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="s">"Shape of repeat with axis = 1"</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">np</span><span class="p">.</span><span class="n">repeat</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">axis</span> <span class="o">=</span> <span class="mi">1</span><span class="p">).</span><span class="n">shape</span><span class="p">)</span>
<span class="s">"""
Shape of original matrix
(2,3)
Shape of repeat with axis = 0
(4,3)
Shape of repeat with axis = 1
(2,6)
"""</span>
</code></pre></div></div>
<p>Look at this! With <code class="language-plaintext highlighter-rouge">repeat</code> <code class="language-plaintext highlighter-rouge">axis</code> parameter does control what we expect: with <code class="language-plaintext highlighter-rouge">axis = 0</code> there are twice as many rows and with <code class="language-plaintext highlighter-rouge">axis = 1</code> there are twice as many columns. Does it mean that <code class="language-plaintext highlighter-rouge">axis</code> corresponds to different notions in <code class="language-plaintext highlighter-rouge">sum</code> and <code class="language-plaintext highlighter-rouge">repeat</code>?</p>
<p><strong>No!</strong></p>
<p>In both cases <code class="language-plaintext highlighter-rouge">axis</code> controls <strong>direction along which</strong> an operation is applied and behaviours that are different at the first glance is nothing more than an artefact of types of operations: <code class="language-plaintext highlighter-rouge">sum</code> contracts and <code class="language-plaintext highlighter-rouge">repeat</code> expands matrix. If numpy would have treated dimensions a bit differently, for <code class="language-plaintext highlighter-rouge">np.sum(a, axis = 0).shape</code> we would have gotten <code class="language-plaintext highlighter-rouge">(1, 3)</code> and for <code class="language-plaintext highlighter-rouge">np.sum(a, axis = 1).shape</code> would have been <code class="language-plaintext highlighter-rouge">(2, 1)</code>. Now there is no discrepancy.</p>
<p>Actually, if we go back to the documentation of <a href="https://numpy.org/doc/stable/reference/generated/numpy.sum.html?highlight=axis"><code class="language-plaintext highlighter-rouge">sum</code></a> and <a href="https://numpy.org/doc/stable/reference/generated/numpy.repeat.html?highlight=axis"><code class="language-plaintext highlighter-rouge">repeat</code></a> and read what <code class="language-plaintext highlighter-rouge">axis</code> parameter means we would see in both cases phrase “axis along which” used. It is such an important concept that you even find it in a <a href="https://numpy.org/doc/stable/glossary.html#term-along-an-axis">glossary</a>!</p>
<p>With newfound understanding let’s go back to what we started with: finding a sum of each row. Now we know that the correct <code class="language-plaintext highlighter-rouge">axis</code> to use is 1.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">row_sums</span> <span class="o">=</span> <span class="n">np</span><span class="p">.</span><span class="nb">sum</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">axis</span> <span class="o">=</span> <span class="mi">1</span><span class="p">)</span>
<span class="k">print</span><span class="p">(</span><span class="n">row_sums</span><span class="p">)</span>
<span class="s">"""
[6 6]
"""</span>
</code></pre></div></div>
<p>which is exactly the correct answer.</p>
<p>And visually it looks like this, where <strong>left</strong> animation visualizes <code class="language-plaintext highlighter-rouge">np.sum(a, axis = 0)</code> and <strong>right</strong> animation visualizes <code class="language-plaintext highlighter-rouge">np.sum(a, axis = 1)</code></p>
<table>
<tbody>
<tr>
<td><img src="/assets/posts/numpy-axes/SumAnimationAxis0.gif" alt="SumAnimationAxis0" /></td>
<td><img src="/assets/posts/numpy-axes/SumAnimationAxis1.gif" alt="SumAnimationAxis1" /></td>
</tr>
</tbody>
</table>
<p>With all that the thing that you need to remember is</p>
<blockquote>
<p>axis argument controls axis along which operation is applied</p>
</blockquote>
<h4 id="several-axis">Several axis</h4>
<p>Sometimes <code class="language-plaintext highlighter-rouge">axis</code> accepts not only integers (single axis), but also tuples (several axis). Event though there is more than one axes the concept is the same: operation is applied along each axis in the order of axis present in the tuple</p>
<h3 id="default-value">Default value</h3>
<p><code class="language-plaintext highlighter-rouge">axis</code> is a named argument which means if you don’t supply a value a default one is gonna be used. For most operation the default value is <code class="language-plaintext highlighter-rouge">None</code> which usually (<a href="https://numpy.org/doc/stable/reference/generated/numpy.transpose.html">not always</a>) means operation will be applied across all axes. Sometimes this is what you want, but most of the time it is not. Therefore it is important to check what <code class="language-plaintext highlighter-rouge">axis</code> argument controls and what default value corresponds to.</p>
<h2 id="tldr">TLDR</h2>
<p>Axis is another term for dimension. Axes are referred by numbers, first access being \(0\). Axis \(0\) corresponds to rows — height, axis \(1\) to columns — width, axis \(2\) to depth, etc. When indexing axes behave as expected: axis \(0\) chooses row, axis \(1\) chooses column, etc. When applying operation to numpy array axis controls along which axis operation is applied.</p>Nikita EvdokimovAxes in numpy can be a little tricky for beginners. Usually there is no problem with axes when they are used for indexing. Trouble hits when we start working with numpy methods. After this post you should build an inuiation which will allow you to effectively use axes in numpy operations.