嗨,我正在使用現有的C 代碼,我通常使用VB.NET,我所看到的許多東西對我來說是混亂和矛盾的。
現有的代碼從一個檔案中加載神經網路的權重,該檔案的編碼如下:
。
2
模型.0.conv.conv. 重量 5 3e17c000 3e9be000 3e844000 bc2f8000 3d676000
模型.0.conv.bn. 重量7 4006a000 3f664000 3fc980003fa6a000 3ff2e000 3f5dc000 3fc94000
第一行給出了后續行的數量。每一行都有一個描述,一個數字代表后面有多少個值,然后是十六進制的權重值。在真實的檔案中,有數百行,每行可能有數十萬個權重。權重檔案的大小為400MB。這些值被轉換為浮點數,以便在NN中使用。
對該檔案進行解碼需要 3 分鐘以上。我希望通過消除從十六進制編碼到二進制的轉換來提高性能,而只是將數值原生地存盤為浮點數。問題是我不明白代碼在做什么,也不明白我應該如何以二進制存盤這些值。對行進行解碼的相關部分在這里:
while (count--)
{
權重 wt{ DataType::kFLOAT, nullptr, 0 }。
uint32_t size。
//讀取blob的名稱和型別。
std::string name;
輸入 >> name >> std::dec >> size。
wt.type = DataType::kFLOAT。
//加載blob。
uint32_t* val = reinterpret_cast< uint32_t*>(malloc(sizeof(val) * size))
for (uint32_t x = 0, y = size; x < y; x)
{
輸入 >> std::hex >> val[x]。
}
wt.values = val;
wt.count = size;
weightMap[name] = wt;
}
這里描述了Weights類。 DataType::kFLOAT是一個32位浮點。
我希望在內回圈中添加一行input >> std::hex >> val[x];,這樣我就可以將浮點數值從十六進制轉換為二進制檔案,但我不明白到底發生了什么。它看起來像是被分配了記憶體來保存這些值,但是sizeof(val)是8位元組,而uint32_t是4位元組。此外,看起來值被存盤在wt.values中,但val包含整數而不是浮點數。我真的不明白這里的意圖是什么。
我能否得到一些建議,如何存盤和加載二進制值以消除十六進制的轉換。如果有任何建議,我將不勝感激。非常感謝。
uj5u.com熱心網友回復:
這里有一個示例程式,它可以將顯示的文本格式轉換成二進制格式,然后再轉換回來。 我從問題中提取了資料,并成功地轉換為二進制并回傳。我的感覺是,在用實際的應用程式消耗資料之前,最好先用一個單獨的程式來處理資料,這樣應用程式的讀取代碼才是單一目的。
最后還有一個關于如何將二進制檔案讀入Weights類的例子。我沒有使用TensorRT,所以我從檔案中復制了所使用的兩個類,所以這個例子可以編譯。請確保你不要將這些添加到你的實際代碼中。
如果你有任何問題,請告訴我。希望這對你有幫助,并使你的加載速度更快。
如果你有任何問題,請告訴我。
#include <fstream>
#include <iostream>
#include <unordered_map>
#include <vector>
void usage()
{
std::cerr << "用法:轉換<操作> <輸入檔案> <輸出檔案>。
"。
std::cerr << " convert b in.txt out.bin - 將文本轉換成二進制檔案
"。
std::cerr << " convert t in.bin out.txt - 轉換二進制為文本
"。
}
bool text_to_binary(const char *infilename。const char *outfilename)。
{
std::ifstream in(infilename)。
if (! in)
{
std::cerr << "錯誤。無法打開輸入檔案'" << infilename << "'
"。
return false。
}
std::ofstream out(outfilename, std::ios::binary)。
if (! out)
{
std::cerr << "錯誤。無法打開輸出檔案'" << outfilename << "'
"。
return false。
}
uint32_t line_count。
if (! (in >> line_count))
{
return false;
}
if (!out. write(reinterpret_cast< const char *>(&line_count), sizeof(line_count)))
{
return false;
}
for (uint32_t l = 0; l < line_count; l)
{
std::string name;
uint32_t num_values;
if (! (in >> name >> std::dec >> num_values)
{
return false。
}
std::vector<uint32_t> values(num_values)。
for (uint32_t i = 0; i < num_values; i)
{
if (! (in >> std::hex >> values[i])
{
return false。
}
}
uint32_t name_size = static_cast<uint32_t> (name.size())。
bool result = out. write(reinterpret_cast< const char *>(&name_size), sizeof(name_size))&&。
out.write(name.data(), name.size() &&
出來。 write(reinterpret_cast< const char *>(&num_values), sizeof(num_values))&&
out.write(reinterpret_cast<const char *>(values. data()), values.size() * sizeof(values[0] )。
if (! result)
{
return false;
}
}
return true;
}
bool binary_to_text(const char *infilename。const char *outfilename)。
{
std::ifstream in(infilename, std::ios::binary)。
if (! in)
{
std::cerr << "錯誤。無法打開輸入檔案'" << infilename << "'
"。
return false。
}
std::ofstream out(outfilename)。
if (! out)
{
std::cerr << "錯誤。無法打開輸出檔案'" << outfilename << "'
"。
return false。
}
uint32_t line_count。
if (!in. read(reinterpret_cast<char *>(&line_count), sizeof(line_count))
{
return false。
}
if (! (out << line_count << "
"))
{
return false;
}
for (uint32_t l = 0; l < line_count; l)
{
uint32_t name_size。
if (!in. read(reinterpret_cast<char *>(&name_size), sizeof(name_size))
{
return false。
}
std::string name(name_size, 0)。
if (!in.read(name.data(), name_size)
{
return false。
}
uint32_t num_values;
if (!in. read(reinterpret_cast<char *>(&num_values), sizeof(num_values)))
{
return false。
}
std::vector<float> values(num_values)>。
if (!in.read(reinterpret_cast<char *> (values. data()), num_values * sizeof(values[0]) ) )
{
return false。
}
if (! (out << name << " " << std::dec << num_values))
{
return false。
}
for (float &f : values)
{
uint32_t i;
memcpy(&i, &f, sizeof(i))。
if (! (out << " << std::hex << i))
{
return false。
}
}
if (! (out << "
"))
{
return false;
}
}
return true;
}
int main(int argc, const char *argv[])
{
if (argc != 4)
{
usage()。
return EXIT_FAILURE。
}
char op = argv[1][0] 。
bool result = false;
switch (op)
{
case 'b'/span>:
case 'B':
result = text_to_binary(argv[2], argv[3] 。)
break。
case 't':
case 'T':
result = binary_to_text(argv[2], argv[3] 。)
break。
default:
usage()。
break;
}
return result ? exit_success : exit_failure。
}
//可能實作原問題中的代碼片段來讀取權重。
//START 復制自TensorRT檔案 - 不要包含在你的代碼中。
enum class DataType : int32_t
{
kFLOAT = 0,
kHALF = 1,
kINT8 = 2,
kINT32 = 3,
kBOOL = 4.
};
classWeights
{
public:
資料型別型別。
const void *values;
int64_t count;
};
//END 復制自TensorRT檔案 - 不要包含在你的代碼中。
bool read_weights(const char *infilename)
{
std::unordered_map<std::string, Weights> weightMap;
std::ifstream in(infilename, std::ios::binary)。
if (! in)
{
std::cerr << "錯誤。無法打開輸入檔案'" << infilename << "'
"。
return false。
}
uint32_t line_count。
if (!in. read(reinterpret_cast<char *>(&line_count), sizeof(line_count))
{
return false。
}
for (uint32_t l = 0; l < line_count; l)
{
uint32_t name_size。
if (!in. read(reinterpret_cast<char *>(&name_size), sizeof(name_size))
{
return false。
}
std::string name(name_size, 0)。
if (!in.read(name.data(), name_size)
{
return false。
}
uint32_t num_values;
if (!in. read(reinterpret_cast<char *>(&num_values), sizeof(num_values)))
{
return false。
}
//通常我會使用float* values = new float[num_values]; 這里會
//需要delete [] ptr; 來釋放記憶體。
//我使用malloc來匹配原例,因為我不知道誰是
//負責以后的清理作業,而TensorRT可能會使用free(ptr)。
//只要new/delete ro malloc/free是匹配的,就沒有真正的區別。
float *values = reinterpret_cast<float *> (malloc(num_values * sizeof(*values))。
if (!in. read(reinterpret_cast<char *>(values), num_values * sizeof(*values)))
{
return false。
}
weightMap[name] = Weights { DataType::kFLOAT, values, num_values };
}
return true;
}
轉載請註明出處,本文鏈接:https://www.uj5u.com/caozuo/329428.html
標籤:
上一篇:一次性檢查大量變數的真實性
