我正在嘗試生成 16 個子圖,我的最終目標是最終圖的大小為 8 x 2,我的代碼如下所示:
def visualize_t2t(token_dict, scores):
fig = plt.figure(figsize=(50, 50))
for idx, scores in enumerate(scores):
scores_np = np.array(scores)
ax = fig.add_subplot(12, 12, idx 1)
# append the attention weights
im = ax.imshow(scores, cmap='viridis')
fontdict = {'fontsize': 3}
ax.set_xticks(range(len(all_tokens)))
ax.set_yticks(range(len(all_tokens)))
ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
ax.set_yticklabels(all_tokens, fontdict=fontdict)
ax.set_xlabel('{} {}'.format('label_name', idx 1))
fig.colorbar(im, fraction=0.046, pad=0.04)
plt.tight_layout()
name_f = str(uuid.uuid4())
plt.savefig(f'{name_f}.pdf',
bbox_inches='tight',
dpi=350)
輸入資料
all_tokens = ['[CLS]',
'what',
'type',
'of',
'heart',
'issue',
'does',
'the',
'Person',
'have',
'[CLS]']
dummy_input = np.random.uniform(-1, 1, [16, len(all_tokens), len(all_tokens)])
visualize_t2t(all_tokens, dummy_input)
但結果如下所示:

如何在此處設定 rows 和 col 以使一行中有 8 個子圖并保留在另一行中?
uj5u.com熱心網友回復:
只需替換ax = fig.add_subplot(12, 12, idx 1)為ax = fig.add_subplot(2, 8, idx 1).
轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/530162.html
標籤:Pythonpython-3.xmatplotlib阴谋
上一篇:C中的printf:Speicherzugriffsfehler(Speicherabzuggeschrieben)核心轉儲
