diff --git a/src/spike_sort/ui/plotting.py b/src/spike_sort/ui/plotting.py index 1b00318..f85c184 100644 --- a/src/spike_sort/ui/plotting.py +++ b/src/spike_sort/ui/plotting.py @@ -139,11 +139,30 @@ def featuresgraph(features_dict, color='k', size=1, datarange=None, fig=None): _, n_feats = features.shape if fig is None: fig = plt.gcf() - axes = [[fig.add_subplot(n_feats, n_feats, i * n_feats + j + 1) - for i in range(n_feats)] for j in range(n_feats)] + axes = np.empty((n_feats, n_feats), dtype=object) + axes.fill(None) + + #fill first row + axes[0, 0] = fig.add_subplot(n_feats, n_feats, 1) + for i in range(1, n_feats): + axes[i, 0] = fig.add_subplot(n_feats, n_feats, i * n_feats + 1, sharex=axes[0, 0]) + #then first column + axes[0, 1] = fig.add_subplot(n_feats, n_feats, 2) + for i in range(2, n_feats): + axes[0, i] = fig.add_subplot(n_feats, n_feats, i + 1, sharey=axes[0, 1]) + #and now the rest + for i in range(1, n_feats): + for j in range(1, n_feats): + if i == j: + axes[i, j] = fig.add_subplot(n_feats, n_feats, i * (n_feats+1) + 1, + sharex=axes[0, j]) + else: + axes[i, j] = fig.add_subplot(n_feats, n_feats, i*n_feats + j + 1, + sharex=axes[0, j], sharey=axes[i, 0]) + for i in range(n_feats): for j in range(n_feats): - ax = axes[i][j] + ax = axes[j, i] if i != j: ax.plot(features[:, i], features[:, j], ".", @@ -168,10 +187,10 @@ def featuresgraph(features_dict, color='k', size=1, datarange=None, fig=None): ax.yaxis.label.set_visible(False) for i in range(n_feats): - ax = axes[i][0] - ax.xaxis.label.set_visible(True) - ax = axes[0][i] + ax = axes[i, 0] ax.yaxis.label.set_visible(True) + ax = axes[0, i] + ax.xaxis.label.set_visible(True) def legend(labels, colors=None, ax=None):