Matplotlib - TypeError: 'Axes' object is not subscriptable

ZhangZhihui's Blog / 2024-10-23 / 原文

Error code:

fig, axs = plt.subplots(n_filters, n_in_channels, figsize=figsize)
print(axs[0, 0])

 

This is because n_filters = 1 and n_in_channels = 1, and plt.subplots has default value True for parameter squeeze, so the axs is a single subplot rather than an array of subplots.

To get the expected result, add squeeze=False:

fig, axs = plt.subplots(n_filters, n_in_channels, figsize=figsize, squeeze=False)
print(axs[0, 0])

 

Or convert it to numpy array before using index or slice:

axs = np.atleast_2d(axs)
axs = axs.reshape(n_filters, n_in_channels)