Torch 和 C++互相呼叫 pybind11
- torch 和 C++互相呼叫
- 代碼例子
torch 和 C++互相呼叫
需要安裝torch即可,在linux環境下實驗通過,
torch.utils.cpp_extension 通過pybind11實作C++和python互相通信,
在ninja框架下,構建即時代碼(JIT),只需要第一次編譯C++
代碼例子
展示了python如何和C++后端相互呼叫和傳遞串列和torchTensor
CPP.cpp
#include <torch/extension.h>
#include <iostream>
#include <string>
#include <iterator>
// 定義類
//https://pybind11.readthedocs.io/en/latest/advanced/classes.html
struct Pet {
Pet(const std::string &name) : name(name) { }
void setName(const std::string &name_) { name = name_; }
const std::string &getName() const { return name; }
std::string name;
};
using PetList = std::vector<Pet>;
// 定義函式,并呼叫python,以參考的方式傳參
// https://pybind11.readthedocs.io/en/latest/advanced/cast/stl.html?highlight=STL#making-opaque-types
PYBIND11_MAKE_OPAQUE(std::vector<Pet>)
void addAndprintPet()
{
PetList petlist;
petlist.push_back(Pet("CatCpp"));
py::object addPet=py::module::import("PY").attr("addPet");
addPet(&petlist);
for (auto pet:petlist)
{
std::cout<<"from CPP "<<pet.getName()<<std::endl;
}
}
//通過回傳值傳遞串列注意一切python回傳皆為object,需要強轉
//https://pybind11.readthedocs.io/en/latest/advanced/pycpp/object.html#instantiating-compound-python-types-from-c
void printList()
{
py::list a;
a.append(123); //python 有的基本都能用,包括模塊
// py::module sys=py::module::import("sys");
// py::print(sys.attr("path"));
py::object addNumer=py::module::import("PY").attr("addNumer");
py::list b = addNumer(a);
for (auto number:b)
{
std::cout<<"from CPP "<<number.cast<int>()<<std::endl;
}
}
torch::Tensor TensorAdd(const torch::Tensor &a,const torch::Tensor &b)
{
return a+b;
}
void mainFun()
{
addAndprintPet();
printList();
}
// 系結
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<PetList>(m, "PetList")
.def(py::init<>())
.def("pop_back", &PetList::pop_back)
.def("push_back", (void (PetList::*)(const Pet&)) &PetList::push_back)
.def("__len__", [](const PetList &v) { return v.size(); })
.def("__iter__", [](PetList &v) {
return py::make_iterator(v.begin(), v.end());
}, py::keep_alive<0, 1>());
py::class_<Pet>(m, "Pet")
.def(py::init<const std::string &>())
.def("setName", &Pet::setName)
.def("getName", &Pet::getName)
.def("__repr__", [](const Pet& u) { return u.getName(); }); //python print方法呼叫
m.def("mainFun", &mainFun, "mainFun");
m.def("TensorAdd", &TensorAdd, "TensorAdd");
}
PY.py
import os
import torch
from torch.utils.cpp_extension import load
dir = os.path.dirname(os.path.realpath(__file__))
CPP = load(
name="CPP",
sources=[os.path.join(dir, "CPP.cpp")],
verbose=False)
def addPet(petlist):
# petlist.pop_back()
for p in petlist:
print('from PY',p) # petlist 是PetList型別
petlist.push_back(CPP.Pet('CatPy'))
def addNumer(numlist):
print('from PY',numlist)
return numlist+[1234]
if __name__=='__main__':
# 呼叫CPP的函式addAndprintPet,printList
CPP.mainFun()
# 定義類
p = CPP.Pet("Molly")
print(p)
print(p.getName())
p.setName("Charly")
print(p.getName())
# Tensor加法
print(CPP.TensorAdd(torch.zeros((3,3)),torch.ones(3,3)))
運行結果:(verbose=True)
![Emitting ninja build file /home/fuchy/.cache/torch_extensions/CPP/build.ninja...Building extension module CPP...Using envvar MAX_JOBS (64) as the number of workers...ninja: no work to do.Loading extension module CPP...Using /home/fuchy/.cache/torch_extensions as PyTorch extensions root...No modifications detected for re-loaded extension module CPP, skipping build step...Loading extension module CPP...from PY CatCppfrom CPP CatCppfrom CPP CatPyfrom PY [123]from CPP 123from CPP 1234MollyMollyCharlytensor([[1., 1., 1.],[1., 1., 1.],[1., 1., 1.]])](https://img.uj5u.com/2021/10/24/277235241205301.png)
轉載請註明出處,本文鏈接:https://www.uj5u.com/houduan/333963.html
標籤:python
