TSN
1.如何提幀
1.1資料集準備
下載網址:http://crcv.ucf.edu/data/UCF101/UCF101.rar
下載成功后的UCF檔案夾如下所示:
該檔案夾下是各種動作的視頻檔案,共有101種類別

下圖是UCF101在進行訓練和測驗時,分割的依據檔案

1.2原始碼準備
在實驗程序中,我們需要使用tsn-pytorch和mmaction的一些代碼檔案,所以我們提前從Git上獲得存盤在本地,
下載mmaction:
git clone --recursive https://github.com/open-mmlab/mmaction.git
下載tsn-pytorch:
git clone --recursive https://github.com/yjxiong/tsn-pytorch
1.3提幀
在我們下載好的UCF101資料集中,視頻大多是長時間的,很難對其進行動作識別,所以需要進行提幀操作,
首先在mmaction的data/ucf101中創建rawframes、videos、annotations檔案夾,
rawframes:視頻提幀后存放的檔案目錄
videos:拷貝ucf101資料集中的101個檔案目錄,放置其中
annotations:ucf101之后進行分割訓練集、測驗集的依據檔案

然后在mmaction/data_tools/build_rawframes.py的同級目錄下進行視頻提幀的代碼檔案,輸入命令如下所示:
python build_rawframes.py ../data/ucf101/videos ../data/ucf101/rawframes/ --level 2 --ext avi

生成的檔案目錄形式如下所示:


在這里插入圖片描述
運行完成后,將每一個視頻的每一幀提取出來,放在特定名稱的檔案夾中,
1.4生成file_list
在tsn-pytorch的readme檔案中可以看到,訓練程序中需要<ucf101_rgb_train_list>和<ucf101_rgb_val_list>,所以生成這兩個list檔案是必需的,使用mmaction/data_tools/buid_file_list.py即可對ucf101生成的幀進行訓練集和測驗集的劃分,輸入命令如下所示:
python data_tools/build_file_list.py ucf101 data/ucf101/rawframes/ --level 2 --format rawframes --shuffle
也可在mmaction/data_tools/ucf101/中輸入
bash generate_filelist.sh

生成的filelist在data/ucf101目錄下,形式如下:

file_list的內容如下所示:

file_list中有三列,第一列代表檔案的地址,第二列代表視頻的幀數,第三列代表視頻的類別,這里僅僅使用ucf101的3個檔案夾,所以類別只有0 1 2,
2.如何feed幀出特征
代碼修改部分參考 https://blog.csdn.net/qq_39862223/article/details/108461526
2.1IPO
下圖展示了,TSN如何將ucf101資料集提出的幀進行分類的程序,標明了每一個階段的tensor大小


3.如何save,以便load
定義的保存模型以及引數資訊的方法,該方法會在進行模型訓練的時候得到呼叫,
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
filename = '_'.join((args.snapshot_pref, args.modality.lower(), filename)) # 用于保存模型以及引數資訊的路徑以及檔案名
torch.save(state, filename) # 將模型以上述名稱保存在該路徑下
if is_best: # 如果準確率得到提高就進行模型的被備份
best_name = '_'.join((args.snapshot_pref, args.modality.lower(), 'model_best.pth.tar')) # 備份路徑以及檔案名稱
shutil.copyfile(filename, best_name) # 進行檔案復制
對該方法的呼叫,通過該方法保存模型,準確率,模型引數并判斷是否進行模型復制
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
}, is_best)
加載保存的引數
if args.resume: # args.resume是保存模型的路徑
if os.path.isfile(args.resume): # 判斷該絕對路徑下是否是檔案,也就是保存模型方法中的絕對路徑
print(("=> loading checkpoint '{}'".format(args.resume)))
checkpoint = torch.load(args.resume) # 進行加載checkpoint 字典的形式,里面包括epoch,arch,state_dict,best_prec1
args.start_epoch = checkpoint['epoch']
best_prec1 = checkpoint['best_prec1']
model.load_state_dict(checkpoint['state_dict'])
print(("=> loaded checkpoint '{}' (epoch {})"
.format(args.evaluate, checkpoint['epoch'])))
else:
print(("=> no checkpoint found at '{}'".format(args.resume)))
轉載請註明出處,本文鏈接:https://www.uj5u.com/qianduan/157792.html
標籤:其他
上一篇:CTF隱寫術之總結 讓你少走彎路
