我正在嘗試使用wandb庫,我運行wandb.watch,但這似乎對我的代碼不起作用。它不應該是什么復雜的東西,所以我不明白為什么它不作業。
代碼:
"" ""
https://docs.wandb.ai/guides/track/advanced/distributed-training
輸入wandb
# 1. 開始一個新的運行
wandb.init(project='playground', entity='brando')
# 2. 保存模型輸入和超引數
config = wandb.config
config.learning_rate = 0.01
# 3. 記錄梯度和模型引數
wandb.watch(model)
for batch_idx, (data, target) in enumerate(train_loader):
...
如果 batch_idx % args.log_interval == 0。
# 4. 日志指標來顯示性能
wandb.log({"loss": loss})
注意。
- 只在領導行程中呼叫wandb.init和wandb.log
"""
from argparse import Namespace
from pathlib import Path
from typing import Union
輸入 torch
從 torch 中匯入 nn
從 torch.nn.functional 中匯入 mse_loss
從 torch.opt.import Optimizer 中匯入優化器
輸入uutils
from uutils.torch_uu.com import r2_score_from_torch
from uutils.torch_uu.distributed import is_lead_worker
從 uutils.torch_uu.models 中匯入 get_simple_model
from uutils.torch_uu.tensorboard import log_2_tb_supervisedlearning
輸入 wandb
def log_2_wandb_nice(it, loss, inputs, outputs, captions):
wandb.log({"loss": loss, "epoch": it,
"inputs": wandb.Image(inputs),
"logits": wandb.Histogram(output),
"captions": wandb.HTML(captions)})
def log_2_wandb(**metrics):
""" 記錄到wandb ""
new_metrics:dict = {}。
for key, value in metrics.items():
key = str(key).strip('_')
new_metrics[key] = value
wandb.log(new_metrics)
def log_train_val_stats(args: 命名空間。
it: int,
train_loss: float,
train_acc: float。
有效。
log_freq: int = 10,
ckpt_freq: int = 50,
force_log: bool = False, # 例如,在最后的it/epoch處
save_val_ckpt: bool = False,
log_to_tb: bool = False,
log_to_wandb: bool = False
):
"""
記錄訓練和值的統計資訊。
注意:與save ckpt不同,這個確實需要它明確地被傳遞(所以它可以在統計收集器中保存)。
"""
from uutils.torch_uu.tensorboard import log_2_tb
from matplotlib import pyplot as plt
# - 它是epoch還是iteration
it_or_epoch: str = 'epoch_num' if args.training_mode == 'epochs' else 'it'.
# 如果它的
total_its: int = args.num_empochs if args.training_mode == 'epochs' else args.num_its
print(f'-- {it == total_its - 1}' )
print(f'-- {it}')
print(f'-- {total_its}')
如果(it % log_freq == 0 或 is_lead_worker(args.rank) 或 it == total_its - 1 或 force_log) 和 is_lead_worker(args.rank)。
print('inside log')
# - 獲得評估統計資訊
val_loss, val_acc = valid(args, args.mdl, save_val_ckpt=save_val_ckpt)
# - 列印
args.logger.log('
')
args.logger.log(f"{it_or_epoch}={it}: {train_loss=}, {train_acc=}" )
args.logger.log(f"{it_or_epoch}={it}: {val_loss=}, {val_acc=}")
# - 記錄到統計資訊收集器中
args.logger.record_train_stats_stats_collector(it, train_loss, train_acc)
args.logger.record_val_stats_stats_collector(it, val_loss, val_acc)
args.logger.save_experiment_stats_to_json_file()
fig = args.logger.save_current_plots_and_stats()
# - 記錄到wandb
如果log_to_wandb:
# 如果它==0。
# -- todo 為什么不作業了?
# wandb.watch(args.mdl)
# print('watching model')
# log_2_wandb(train_loss=train_loss, train_acc=train_acc)
print('inside wandb log')
wandb.log(data={'train loss': train_loss, 'train acc': train_acc, 'val loss': val_loss, 'val acc': val_acc}, step=it)
wandb.log(data={'it': it}, step=it)
如果它== total_its - 1:
print(f'logging fig at {it=}')
wandb.log(data={'fig': fig}, step=it)
plt.close('all')
# - 記錄到tensorboard
如果log_to_tb:
log_2_tb_supervisedlearning(args.tb, args, it, train_loss, train_acc, 'train')
log_2_tb_supervisedlearning(args.tb, args, it, train_loss, train_acc, 'val')
# log_2_tb(args, it, val_loss, val_acc, 'train')
# log_2_tb(args, it, val_loss, val_acc, 'val')
# - 記錄 ckpt
如果(it % ckpt_freq == 0 或 it == total_its - 1 或 force_log) 和 is_lead_worker(args.rank)。
save_ckpt(args, args.mdl, args.optimizer)
def save_ckpt(args: 命名空間, mdl: nn.Module, optimizer: torch.optimizer,
dirname: Union[None, Path] = None, ckpt_name: str = 'ckpt.pt')。)
"""
為任何作業者保存檢查點。
預期用途是保存得到估值損失改善的作業者。
"""
輸入 dill
dirname = args.log_root if (dirname is None) else dirname
# - pickle ckpt
assert uutils.xor(args.training_mode == 'epochs', args.training_mode == 'iterations')
pickable_args = uutils.make_args_pickable(args)
torch.save({'state_dict': mdl.state_dict(),
'epoch_num': args.epoch_num,
'it': args.it,
'optimizer': optimizer.state_dict(),
'args': pickable_args,
'mdl': mdl}。
pickle_module=dill,
f=dirname / ckpt_name) # f'mdl_{epoch_num:03}.pt' 。
def get_args() -> 名稱空間。
args = uutils.parse_args_synth_agent()
# 我們可以在這里放置模型...
args = uutils.setup_args_for_experiment(args)
回傳 args
def valid_for_test(args: Namespace, mdl: nn.Module, save_val_ckpt: bool = False) 。
輸入Torch
for t in range(1):
x = torch.randn(args.batch_size, 5)
y = (x ** 2 x 1).sum(dim=1)
y_pred = mdl(x).squeeze(dim=1)
val_loss, val_acc = mse_loss(y_pred, y), r2_score_from_torch(y_true=y, y_pred=y_pred)
if val_loss.item() < args.best_val_loss and save_val_ckpt:
args.best_val_loss = val_loss.item()
save_ckpt(args, args.mdl, args.optimizer, ckpt_name='ckpt_best_val.pt')
回傳 val_loss, val_acc
def train_for_test(args: Namespace, mdl: nn.Module, optimizer: Optimizer, scheduler=None) 。
# wandb.watch(args.mdl)
for it in range(args.num_its):
x = torch.randn(args.batch_size, 5)
y = (x ** 2 x 1).sum(dim=1)
y_pred = mdl(x).squeeze(dim=1)
train_loss, train_acc = mse_loss(y_pred, y), r2_score_from_torch(y_true=y, y_pred=y_pred)
優化器.zero_grad()
train_loss.backward() # 每個程序都在后向傳遞中同步它的梯度
optimizer.step() # 正確的更新已經完成,因為所有行程都有正確的同步梯度
scheduler.step()
log_train_val_stats(args, it, train_loss, train_acc, valid_for_test,
log_freq=2, ckpt_freq=10,
save_val_ckpt=True, log_to_tb=True, log_to_wandb=True)
回傳 train_loss, train_acc
def debug_test():
args: 名稱空間 = get_args()
args.num_its = 12
# - 獲得mdl、opt、調度器等
args.mdl = get_simple_model(in_features=5, hidden_features=20, out_features=1, num_layer=2)
wandb.watch(args.mdl)
args.optimizer = torch.optim.Adam(args.mdl.parameters(), lr=1e-1)
args.scheduler = torch.optimizer.lr_scheduler.ExponentialLR(args.optimizer, gamma=0.999, verbose=False)
# - 訓練
train_loss, train_acc = train_for_test(args, args.mdl, args.optimizer, args.scheduler)
print(f'{train_loss=}, {train_loss=}' )
# - 評估
val_loss, val_acc = valid_for_test(args, args.mdl)
print(f'{val_loss=}, {val_acc=}')
# - 確保wandb正常關閉
如果args.log_to_wandb:
wandb.finish()
如果 __name__ == '__main__':
輸入 os
# print(os.environ['WANDB_API_KEY'] )
輸入時間
start = time.time()
debug_test()
duration_secs = time.time() - start
print(f"
成功,時間已過:小時:{duration_secs / (60 ** 2)},分鐘={duration_secs / 60},秒={duration_secs}")
print('Done!a')
示例運行。https://wandb.ai/brando/playground/runs/wpupxvg1
uj5u.com熱心網友回復:
我不知道為什么,但這行代碼似乎是有效的:
我不知道為什么,但這行代碼似乎是有效的。
wandb.watch(args.mdl, mse_loss, log="all", log_freq=10)
也許它真的需要損失和日志全部,盡管它不在介紹/快速入門指南中:
import wandb
# 1. 啟動一個新的運行
wandb.init(project='playground', entity='brando')
# 2. 保存模型輸入和超引數
config = wandb.config
config.learning_rate = 0.01
# 3. 記錄梯度和模型引數
wandb.watch(model)
for batch_idx, (data, target) in enumerate(train_loader):
...
如果 batch_idx % args.log_interval == 0。
# 4. 日志指標來顯示性能
wandb.log({"loss": loss})
uj5u.com熱心網友回復:
:在這里你可能會遇到兩件事 -- 無法確認,因為你的代碼依賴于 ultimate-utils 包。
wandb.watch只有在你呼叫wandb.log 之后一個觸及被監視的Module(docs)的反向傳遞后,才會開始作業。log_freq引數控制。如果日志呼叫的數量少于log_freq的值,那么就不會有資訊被記錄下來。下面是一個簡短的colab,再現了這一行為。另外,如果你想要引數和梯度,你需要將log kwarg設定為"all"。默認情況下,我們只記錄梯度。
轉載請註明出處,本文鏈接:https://www.uj5u.com/qiye/316891.html
標籤:
上一篇:y應該是一個1d陣列,得到了一個shape()的陣列來代替
下一篇:撤銷前綴和
