TorchDynamo初探:Python ByteCode的動態(tài)修改-天天觀察

          來源:CSDN博客 | 2023-01-04 11:56:00 |

          作者|strint

          1背景

          深度學(xué)習(xí)框架編譯優(yōu)化時(shí),需要先根據(jù)計(jì)算邏輯形成一個(gè)邏輯計(jì)算圖,然后再改寫計(jì)算圖,最后執(zhí)行改寫后的計(jì)算圖。其中生成邏輯計(jì)算圖方式有兩種。


          【資料圖】

          一種計(jì)算圖生成是基于 trace tensor 的,跟蹤 tensor 的執(zhí)行路徑。tensor 執(zhí)行時(shí),基于函數(shù)重載,可以落到支持 tensor 計(jì)算的框架自定義函數(shù),該函數(shù)一般是 c++ 層的。c++ 層的自定義函數(shù)中,功能是用于生成一個(gè) Operation 的符號表達(dá)。比如一個(gè)對于加法運(yùn)算,trace 就是記錄一個(gè)符號化的加法算子。如此一連串的運(yùn)算就被轉(zhuǎn)換了符號化的計(jì)算圖。

          另外一種計(jì)算圖生成是基于 AST(抽象語法樹) 解析的。在代碼執(zhí)行前,直接根據(jù) Python 文本代碼得到 Python AST,然后根據(jù) AST 來翻譯成計(jì)算圖(也叫做中間代碼 IR)。

          Python(特指 CPython)解釋器執(zhí)行,第一階段會先把 Python 源碼解析成 AST,第二階段根據(jù) AST 生成和優(yōu)化 ByteCode(字節(jié)碼),第三階段在虛擬機(jī)中執(zhí)行 ByteCode。

          基于 AST 解析的計(jì)算圖生成,發(fā)生在這里的第一階段;基于 trace tensor 的計(jì)算圖生成,發(fā)生在第三階段之后。

          TorchDynamo 特別的地方在于其工作在第二階段,動態(tài)修改 Python ByteCode,這樣第三階段執(zhí)行的已經(jīng)是修改后的 ByteCode了。

          2

          TorchDynamo 概述

          TorchDynamo 是 PyTorch 新實(shí)驗(yàn)的 JIT 編譯接口,支持使用 Python 在運(yùn)行時(shí)修改動態(tài)執(zhí)行邏輯,修改的時(shí)機(jī)是 CPython 的 ByteCode 執(zhí)行前。這個(gè)思想類似?DynamoRIO(https://dynamorio.org)?項(xiàng)目,DynamoRIO 可以動態(tài)的修改 x86 機(jī)器碼。

          CPython 的每次函數(shù)調(diào)用會生成一個(gè) Frame(或者叫 Stack),F(xiàn)rame 中帶有的代碼部分就是 ByteCode。CPython 運(yùn)行時(shí)支持基于現(xiàn)有的 Frame 去設(shè)置一個(gè)自定義的 Frame,然后后面執(zhí)行的就是自定義的 Frame。

          TorchDynamo 的工作原理就是在運(yùn)行時(shí)設(shè)置一個(gè)自定義的 Frame,該 Frame 中的 ByteCode 支持 CallBack 到 Python 層去修改。其提供的典型的修改接口是 FX Graph,也就是說 TorchDynamo 會分析 ByteCode,生成對應(yīng)的 FX Graph,然后提供 FX Graph 的接口供用戶自定義計(jì)算圖。這種做法有如下優(yōu)點(diǎn):

          可以支持所有的 Python 語法,因?yàn)槿绻谧远x Frame 過程中的任何一點(diǎn)發(fā)現(xiàn)不支持,都可以選擇不修改 Frame 而回退到原 Frame;

          開銷少,劫持發(fā)生在 Python 執(zhí)行比較早的階段(ByteCode 生成和優(yōu)化階段),而非 Python ByteCode 執(zhí)行后的階段,有時(shí)可以減少 Python ByteCode 的執(zhí)行開銷(猜測如果很多次 ByteCode 層面的函數(shù)調(diào)用被融合層成一次函數(shù)調(diào)用,的確可以縮減開銷);

          可以做到不增加編譯帶來的延遲(之前的基于 tensor trace 或者 ast 解析的做法,一般都有先編譯執(zhí)行所以編譯開銷無法掩蓋,但是改寫 ByteCode 這個(gè)做法,猜測是可以在識別出熱點(diǎn)代碼后,單獨(dú)開一個(gè)線程去做編譯,而不影響主線程工作。Python ByteCode 改寫的 API 中有這種延遲編譯的樣例,peps.python.org/pep-052?)。

          之前計(jì)算圖生成機(jī)制(基于 trace tensor、基于 AST 解析的)中的幾個(gè)問題,得到了緩解:

          存在無法靜態(tài)化的操作,之前一般需要顯式的移除靜態(tài)化作用域,現(xiàn)在總是允許不做編譯,直接執(zhí)行原 Python 代碼,這樣使得靜態(tài)化標(biāo)注變得簡單;

          打開靜態(tài)圖編譯優(yōu)化,之前編譯時(shí)一般無法掩蓋,現(xiàn)在有辦法部分掩蓋;

          動態(tài) shape 問題,因?yàn)橛辛司幾g時(shí)和運(yùn)行時(shí)的掩蓋,也可以得到緩解。

          這種盡量優(yōu)化、動態(tài)優(yōu)化的設(shè)計(jì),最大程度了照顧了代碼開發(fā)的體驗(yàn),讓編譯優(yōu)化上手變得更簡單了。這是 TorchDynamo 帶來的最主要的好處。這種做法非常符合 PyTorch 的 Python First、Eager First、User Experience First的偏好。但是這個(gè)設(shè)計(jì)對于尋求最好的性能、最方便的靜態(tài)化部署這兩個(gè)目標(biāo)并沒有改善。

          3

          CPython 的標(biāo)準(zhǔn)執(zhí)行流程

          上文提到了 CPython 的執(zhí)行從 Python 文本代碼,到 AST,到 ByteCode。這里用一個(gè)示例展開看一下。Python 的標(biāo)準(zhǔn)組件非常易用,可以在 Python 層用 ast 組件來查看 AST,可以用 compile 內(nèi)置函數(shù)來編譯 ByteCode,可以用 exec 系統(tǒng)函數(shù)來執(zhí)行 ByteCode。我們先在代碼開頭導(dǎo)入相關(guān)組件:

          import?astimport disimport sys

          然后我們構(gòu)造一個(gè) python 代碼,可以看到 src_code 就是普通的字符串。其中包含了一段普通的 python 內(nèi)置的乘法,一段深度學(xué)習(xí)的 tensor scalar 加法,最后一段是當(dāng)前Python Frame 中的 ByteCode 關(guān)聯(lián)對象的打印(用于一個(gè)檢驗(yàn),后面會提到)。

          print("=== source code ===")src_code = """# normal python operationx = 1x = x * 2# tensor operationy = dl_framework.ones((1, 2))z = x + yprint(z)# print python framef = sys._getframe()# print the code objectprint(f.f_code)"""print(src_code)

          然后使用 ast 組件來生成這段代碼的 AST。

          print("===?source?code?to?ast?===")# 把源代碼解析成 ASTast_obj = ast.parse(src_code)# 打印 ASTprint(ast.dump(ast_obj))

          可以得到 AST,這里展示的結(jié)果額外做了格式化,另外刪減掉了和計(jì)算邏輯無關(guān)的打印 frame 的部分,代碼和其 AST 的對應(yīng)關(guān)系參見注釋。AST解析是純文本層面的,`dl_framework` 還沒有被 import 進(jìn)來,AST解析仍然可以正常工作。AST 基本是一個(gè)多叉樹的結(jié)構(gòu),每個(gè)節(jié)點(diǎn)對應(yīng)一個(gè)表達(dá)式,節(jié)點(diǎn)子節(jié)點(diǎn)代表子表達(dá)式。以 `x = x + 2` 為例,Assign 是一個(gè)節(jié)點(diǎn),是賦值運(yùn)算,被賦值的是 `x`,賦值的值是一個(gè)二元乘法運(yùn)算。

          Module(body=[ # x = 1 Assign(targets=[Name(id="x", ctx=Store())], value=Constant(value=1, kind=None), type_comment=None), # x = x * 2 Assign(targets=[Name(id="x", ctx=Store())], value=BinOp(left=Name(id="x", ctx=Load()), op=Mult(), right=Constant(value=2, kind=None)), type_comment=None), # y = dl_framework.ones((1, 2)) Assign(targets=[Name(id="y", ctx=Store())], # dl_framework.ones((1, 2)) value=Call(func=Attribute(value=Name(id="dl_framework", ctx=Load()), attr="ones", ctx=Load()), args=[Tuple(elts=[Constant(value=1, kind=None), Constant(value=2, kind=None)], ctx=Load())], keywords=[]), type_comment=None), # z = x + y Assign(targets=[Name(id="z", ctx=Store())], # x + y value=BinOp(left=Name(id="x", ctx=Load()), op=Add(), right=Name(id="y", ctx=Load())), type_comment=None), # print(z) Expr(value=Call(func=Name(id="print", ctx=Load()), args=[Name(id="z", ctx=Load())], keywords=[])), # 省略了打印 frame 的代碼],type_ignores=[])

          Python AST 生成后,可以利用系統(tǒng)函數(shù) `compile` 把它轉(zhuǎn)成 ByteCode 字節(jié)碼。解釋器執(zhí)行也存在編譯的環(huán)節(jié),只不過是編譯成字節(jié)碼。

          print("===?ast?to?bytecode?===")# 編譯成 ByteCodecode_obj = compile(ast_obj, filename="", mode="exec")print(code_obj)# 展示 ByteCode 的語法糖byte_obj = dis.Bytecode(code_obj)print(byte_obj.dis())

          `print(code_obj)`的結(jié)果是 ` at 0x7ff79bb5c660, file "", line 3>`,這里可以看到生成的 code object 對象的指針是 `0x7ff79bb5c660`,后面我們在執(zhí)行字節(jié)碼時(shí),會再次看到這個(gè)指針。

          `print(byte_obj.dis())` 的結(jié)果如下,每一行對應(yīng)一條字節(jié)碼,也即一條指令, 通過字面含義基本可以看出是在做什么:

          # x = 1 3 0 LOAD_CONST 0 (1) 2 STORE_NAME 0 (x) # x = x * 2 4 4 LOAD_NAME 0 (x) 6 LOAD_CONST 1 (2) 8 BINARY_MULTIPLY 10 STORE_NAME 0 (x) # y = dl_framework.ones((1, 2)) 7 12 LOAD_NAME 1 (dl_framework) 14 LOAD_METHOD 2 (ones) 16 LOAD_CONST 2 ((1, 2)) 18 CALL_METHOD 1 20 STORE_NAME 3 (y) # x = x + y 8 22 LOAD_NAME 0 (x) 24 LOAD_NAME 3 (y) 26 BINARY_ADD 28 STORE_NAME 4 (z) # print(z) 9 30 LOAD_NAME 5 (print) 32 LOAD_NAME 4 (z) 34 CALL_FUNCTION 1 36 POP_TOP # 省略了打印 frame 的代碼

          得到 ByteCode 之后,就可以傳遞給 Python VM 執(zhí)行了。在真正執(zhí)行前,先做了一下 ByteCode 中指令的打印,實(shí)際 Python VM 執(zhí)行時(shí),也基本是這樣遍歷每一行指令,然后執(zhí)行指令。可以想象,如果這些指令被修改,就可以讓 Python VM 執(zhí)行自定義的指令了。

          print("===?execute?bytecode?===")# print instructionfor instr in byte_obj: print(instr.opname, instr.opcode)# You can also do `import torch as dl_framework``import oneflow as dl_framework# execute bytecodeexec(code_obj)

          字節(jié)碼的執(zhí)行結(jié)果如下。只需要在真正執(zhí)行前,把 `dl_framework`導(dǎo)入就好,然后可以看到 tensor 計(jì)算的結(jié)果,是符合預(yù)期的。

          frame(或者叫 stack)是運(yùn)行時(shí)的對象,對應(yīng)一個(gè)函數(shù)調(diào)用的棧,在執(zhí)行時(shí)被創(chuàng)建。frame 中要執(zhí)行的指令就是之前創(chuàng)建的 ByteCode。

          在運(yùn)行時(shí)之前,像我們之前看到的,存在一個(gè)編譯時(shí)進(jìn)行 AST 和 ByteCode 的編譯,之前編譯時(shí)生成的 code object 對象的指針是 `0x7ff79bb5c660`。

          在運(yùn)行時(shí),可以獲取當(dāng)前的 frame,然后通過 `frame.f_code`拿到當(dāng)前 frame 里面包含的 ByteCode(即 code object),可以發(fā)現(xiàn)它的指針就是之前編譯時(shí)生成的那個(gè)。

          #?print(z)?的結(jié)果tensor([[3., 3.]], dtype=oneflow.float32)# 運(yùn)行時(shí)獲取當(dāng)前 frame ,然后打印 frame 中的 ByteCode 對象的結(jié)果# f = sys._getframe()# print(f.f_code) at 0x7f5cea7f1660, file "", line 3>

          到此,窺見了一下 Python 源碼到 AST, AST 到 ByteCode,ByteCode 到 Frame 執(zhí)行這個(gè)默認(rèn)的 Python 執(zhí)行流程。TorchDynamo 用下圖做了簡單的介紹:

          其中 foo 對應(yīng)一個(gè) Python 函數(shù),即上文介紹的 Python Source Code。PyCodeObject 是上文介紹的 code object (ByteCode)在 C 代碼層面對應(yīng)的類。PyFrameObject 是上文介紹的 Frame 在 C 代碼層面對應(yīng)的類,它包含了代碼段 PyCodeObject。_PyEval_EvalFrameDefault 對應(yīng)上文介紹的 exec,它執(zhí)行一個(gè) Frame,即運(yùn)行 Frame 帶有的 `PyCodeObject`。

          現(xiàn)在我們看一下 CPython 在 C 層面的執(zhí)行 Frame 的實(shí)現(xiàn),對應(yīng)?_PyEval_EvalFrameDefault(https://github.com/python/cpython/blob/d48ecebad5ac78a1783e09b0d32c211d9754edf4/Python/ceval.c#L757)。它的主邏輯就是取 ByteCode 指令和執(zhí)行指令(https://github.com/python/cpython/blob/d48ecebad5ac78a1783e09b0d32c211d9754edf4/Python/ceval.c#L1080):

          co?=?f->f_code;?//?從?PyFrameObject*?f?中取出?PyCodeObject*?,放到?co?中 names = co->co_names; consts = co->co_consts; fastlocals = f->f_localsplus; freevars = f->f_localsplus + co->co_nlocals; // 從 co 中取出第一條指令 first_instr = (_Py_CODEUNIT *) PyBytes_AS_STRING(co->co_code); next_instr = first_instr;#define NEXTOPARG() do { \ _Py_CODEUNIT word = *next_instr; \ opcode = _Py_OPCODE(word); \ oparg = _Py_OPARG(word); \ // 指向下一條指令 next_instr++; \ } while (0) // 循環(huán)執(zhí)行指令 for (;;) { // 從當(dāng)前的指令 next_instr 中獲取 opcode NEXTOPARG(); switch (opcode) { // 執(zhí)行 op code,參見下個(gè)部分 } }

          每個(gè)指令類型對應(yīng)一個(gè) opcode,它是一個(gè)數(shù)值,執(zhí)行 opcode(https://github.com/python/cpython/blob/d48ecebad5ac78a1783e09b0d32c211d9754edf4/Python/ceval.c#L1266),這里的 opcode 可以清晰的看到和之前我們打印的 ByteCode 的類型對應(yīng)關(guān)系:

          #define?TARGET(opcode)?\ case opcode: switch (opcode) { // TARGET 就是一個(gè) case // load TARGET(LOAD_FAST) { PyObject *value = GETLOCAL(oparg); if (value == NULL) { format_exc_check_arg(PyExc_UnboundLocalError, UNBOUNDLOCAL_ERROR_MSG, PyTuple_GetItem(co->co_varnames, oparg)); goto error; } Py_INCREF(value); PUSH(value); FAST_DISPATCH(); } // store TARGET(STORE_FAST) { PyObject *value = POP(); SETLOCAL(oparg, value); FAST_DISPATCH(); } // 二元加法 TARGET(BINARY_ADD) { PyObject *right = POP(); PyObject *left = TOP(); PyObject *sum; if (PyUnicode_CheckExact(left) && PyUnicode_CheckExact(right)) { sum = unicode_concatenate(left, right, f, next_instr); /* unicode_concatenate consumed the ref to left */ } else { sum = PyNumber_Add(left, right); Py_DECREF(left); } Py_DECREF(right); SET_TOP(sum); if (sum == NULL) goto error; DISPATCH(); } // 函數(shù)調(diào)用 TARGET(CALL_FUNCTION) { PyObject **sp, *res; PCALL(PCALL_ALL); sp = stack_pointer; res = call_function(&sp, oparg, NULL); stack_pointer = sp; PUSH(res); if (res == NULL) { goto error; } DISPATCH(); }????}

          以上總結(jié)了 Python的默認(rèn)執(zhí)行流程。

          4

          TorchDynamo 的工作流程

          TorchDynamo 在標(biāo)準(zhǔn)的 Python 執(zhí)行流程中做的主要改變就是支持修改 Frame 執(zhí)行前的 ByteCode。我們暫時(shí)不關(guān)注 AST 生成,看 Python 的執(zhí)行流程,是 Python Source Code -> ByteCode -> Evaluate. TorchDynamo 支持 Python Source Code -> ByteCode -> [ByteCode rewrite] -> Evaluate。

          ByteCode rewrite 的工作方式是把一段 ByteCode 轉(zhuǎn)成 FX Graph,然后調(diào)用用戶自定義的 FX Graph 改寫執(zhí)行邏輯,生成一個(gè)可以經(jīng)過編譯的執(zhí)行函數(shù)。然后把該段 ByteCode 替換成函數(shù)調(diào)用 ByteCode,而調(diào)用的函數(shù)就是經(jīng)過編譯的執(zhí)行函數(shù)。從而實(shí)現(xiàn)編譯優(yōu)化的功能。

          FX Graph 支持了在 Python 層做代碼改寫,提高了寫編譯 Pass 的便利性,這里不做深入,可以參考資料1(https://pytorch.org/docs/stable/fx.html)和2(https://zhuanlan.zhihu.com/p/416165157)。

          ByteCode rewrite 發(fā)生在 ByteCode 執(zhí)行前。同樣的 Source Code,每次執(zhí)行都會走到這個(gè)步驟,都可以選擇是否進(jìn)行 ByteCode rewrite,或者選擇進(jìn)行什么樣的 rewrite,還可以支持 rewrite 結(jié)果的緩存和復(fù)用。這體現(xiàn)了 Dynamo 的動態(tài)性。

          下面看一個(gè) TorchDynamo 下 fn() 函數(shù)編譯的的例子:

          #?一個(gè)普通的函數(shù)def fn(a, b): x = a + b x = x / 2.0 if x.sum() < 0: return x * -1.0 return x # torchdynamo 函數(shù)接口with torchdynamo.optimize(custom_compiler): fn(torch.randn(10), torch.randn(10))

          fn() 函數(shù)對應(yīng)的原始的 python ByteCode,和代碼對應(yīng)的關(guān)系參見其中的注釋:

          #?x?=?a?+?b 0 LOAD_FAST 0 (a) 2 LOAD_FAST 1 (b) 4 BINARY_ADD 6 STORE_FAST 2 (x) # x = x / 2.0 8 LOAD_FAST 2 (x) 10 LOAD_CONST 1 (2.0) 12 BINARY_TRUE_DIVIDE 14 STORE_FAST 2 (x) # if x.sum() < 0: 16 LOAD_FAST 2 (x) 18 LOAD_METHOD 0 (sum) 20 CALL_METHOD 0 22 LOAD_CONST 2 (0) 24 COMPARE_OP 0 (<) 26 POP_JUMP_IF_FALSE 36 # return x * -1.0 28 LOAD_FAST 2 (x) 30 LOAD_CONST 3 (-1.0) 32 BINARY_MULTIPLY 34 RETURN_VALUE # return x 36 LOAD_FAST 2 (x) 38 RETURN_VALUE

          經(jīng)過 TorchDynamo 動態(tài)改寫后的 ByteCode:

          #?x?=?a?+?b # x = x / 2.0 # x.sum() < 0 # 上面兩行被轉(zhuǎn)換成了 __compiled_fn_0 # __compiled_fn_0 會返回 x 和 x.sum() < 0 組成的 tuple 0 LOAD_GLOBAL 1 (__compiled_fn_0) 2 LOAD_FAST 0 (a) 4 LOAD_FAST 1 (b) 6 CALL_FUNCTION 2 8 UNPACK_SEQUENCE 2 10 STORE_FAST 2 (x) 12 POP_JUMP_IF_FALSE 22 # x * -1.0 被轉(zhuǎn)換成了 __compiled_fn_1 14 LOAD_GLOBAL 2 (__compiled_fn_1) 16 LOAD_FAST 2 (x) 18 CALL_FUNCTION 1 20 RETURN_VALUE # return x 22 LOAD_FAST 2 (x) 24 RETURN_VALUE

          可以看到新增了兩個(gè)函數(shù)調(diào)用, `__compiled_fn_0`?和 `__compiled_fn_1`?,這兩個(gè)函數(shù)對應(yīng)的代碼邏輯參見 bytecode 中的注釋。這兩個(gè)函數(shù)對應(yīng)的 fx graph 如下:

          __compiled_fn_0:opcode name target args kwargs------------- ------- --------------------------- ---------------- --------placeholder a_0 a_0 () {}placeholder b_1 b_1 () {}call_function add (a_0, b_1) {}call_function truediv (add, 2.0) {}call_method sum_1 sum (truediv,) {}call_function lt (sum_1, 0) {}output output output ((truediv, lt),) {}__compiled_fn_1:opcode name target args kwargs------------- ------ ----------------------- ----------- --------placeholder x_4 x_4 () {}call_function mul (x_4, -1.0) {}output output output (mul,) {}

          在 ByteCode rewrite 的最后,TorchDynamo 為這一段代碼的輸入創(chuàng)建兩個(gè) Guard:

          局部參數(shù) a 必須是一個(gè) Tensor

          局部參數(shù) b 必須是一個(gè) Tensor

          該 fn 函數(shù)被再次調(diào)用時(shí),如果符合這兩個(gè)條件,則可以命中緩存的 TrochDynamo 處理結(jié)果;否則下次 fn 執(zhí)行時(shí),會觸發(fā)新的 ByteCode 分析和變換。

          另外,對于和 tensor 無關(guān)的、比較特別的 python 代碼,其 ByteCode 會保持原狀。這樣就達(dá)到了不需要用戶標(biāo)注區(qū)域、自動尋找優(yōu)化機(jī)會的設(shè)計(jì)目標(biāo)。

          現(xiàn)在看下 TorchDynamo 執(zhí)行的流程總結(jié):

          可以看到它把原來的 PyFrameObject 替換成了 Patched PyFrameObject,這個(gè)是 CPython 支持的特性。這個(gè) Patched PyFrameObject 中最主要的改動就是 Frame 中的 ByteCode (即 PyCodeObject)被修改了,原來的 PyCodeObject 變成了 Transformed PyCodeObject。而這個(gè)被改寫的 PyCodeObject 如上文和上圖所示,主要是部分 ByteCode 被替換成了調(diào)用被編譯過函數(shù)。這個(gè)被編譯過的函數(shù),支持自定義編譯邏輯,當(dāng)前默認(rèn)的編譯接口是 FX Graph。

          這部分基本參考了Dynamo的官方介紹(https://dev-discuss.pytorch.org/t/torchdynamo-an-experiment-in-dynamic-python-bytecode-transformation/361)。

          5

          TorchDynamo 修改 Python ByteCode 的實(shí)現(xiàn)

          Python ByteCode 修改主要依賴?PEP 523(https://peps.python.org/pep-0523/)?提供的執(zhí)行自定義 Frame Evaluation API。默認(rèn)的 Eval Frame 邏輯入口函數(shù)是 _PyEval_EvalFrame,默認(rèn)情況,它會直接調(diào)用 _PyEval_EvalFrameDefault()?來處理沒被修改的 frame,但是如果發(fā)現(xiàn)存在一個(gè)自定義的 Eval Frame 函數(shù),就會執(zhí)行自動線的函數(shù)。

          CPython _PyEval_EvalFrame 函數(shù)實(shí)現(xiàn)(https://github.com/python/cpython/blob/76449350b3467b85bcb565f9e2bf945bd150a66e/Include/internal/pycore_ceval.h#L84),所以只要在 ByteCode 執(zhí)行前,設(shè)置一個(gè)自定義的 eval frame 函數(shù)即可:

          static?inline?PyObject*_PyEval_EvalFrame(PyThreadState *tstate, struct _PyInterpreterFrame *frame, int throwflag){ EVAL_CALL_STAT_INC(EVAL_CALL_TOTAL); if (tstate->interp->eval_frame == NULL) { // 這是默認(rèn)的 eval frame return _PyEval_EvalFrameDefault(tstate, frame, throwflag); } // 如果存在 eval_frame 就會被執(zhí)行 return tstate->interp->eval_frame(tstate, frame, throwflag);}

          可以看到 TorchDynamo 正是這么做的。第一步,在 Python 層基于 ContextManger 在進(jìn)入 Dynamo 作用域時(shí),就觸發(fā) eval_frame 的設(shè)置,實(shí)現(xiàn)(https://github.com/pytorch/pytorch/blob/4068c5467d496cd3c09a841f40adacedf3ab41a0/torch/_dynamo/eval_frame.py#L128):

          # torch._dynamo.optimize(...) 對應(yīng)的 context manager.class _TorchDynamoContext: def __init__( self, callback: DynamoCallback, ): super().__init__() assert callable(callback) or callback is False or callback is None self.callback: DynamoCallback = callback self.prior: Union[Unset, DynamoCallback] = unset def __enter__(self): # 設(shè)置 eval_frame,記錄之前的 eval frame self.prior = set_eval_frame(self.callback) def __exit__(self, exc_type, exc_val, exc_tb): assert self.prior is not unset # 恢復(fù)之前的 eval frame set_eval_frame(self.prior)

          這里先大致認(rèn)為設(shè)置的 DynamoCallback 對應(yīng)一個(gè)自定義的 eval frame 所需的參數(shù),通常是自定義的 eval frame 中所需的編譯邏輯。

          看下 set_eval_frame ,C 代碼層面的實(shí)現(xiàn)(https://github.com/pytorch/pytorch/blob/eaf4fe3d2b7096579b05b52d543756f74d0e91e7/torch/csrc/dynamo/eval_frame.c#L446),它有點(diǎn)繞但最終走到了這里(https://github.com/pytorch/pytorch/blob/eaf4fe3d2b7096579b05b52d543756f74d0e91e7/torch/csrc/dynamo/eval_frame.c#L121),也是設(shè)置 tstate->interp->eval_frame?,把 eval_frame 設(shè)置成自定義的 custom_eval_frame_shim:

          // custom_eval_frame_shim 是自定義的 frameinline static void enable_eval_frame_shim(PyThreadState* tstate) { if (tstate->interp->eval_frame != &custom_eval_frame_shim) { // First call // 設(shè)置自定義的 eval frame tstate->interp->eval_frame = &custom_eval_frame_shim; }}

          現(xiàn)在回頭看一下 PEP 523 提供的 Python JIT 編譯器的自定義 frame 執(zhí)行的樣例,它提供了一個(gè)比較標(biāo)準(zhǔn)的模版(注意筆者對例子做了微調(diào),原文有多余和不合理的地方)。在自定義 eval frame 之前,一般還需要自定義一個(gè)存放自定義 ByteCode 的數(shù)據(jù)結(jié)構(gòu),可以認(rèn)為是自定義編譯結(jié)果,比如樣例中自定義編譯結(jié)果包括3個(gè)字段:

          exec_count, 代表改 frame 被執(zhí)行的次數(shù);

          jit_failed, 代表之前 jit 編譯是否失敗過;

          jit_code,代表 jit 編譯過后的自定義 ByteCode;

          據(jù)此,來看下自定義 eval frame 的樣例:

          # 輸入原始的 framedef eval_frame(frame, throw_flag): # 獲取 frame 中的 code object 中的存放自定義編譯結(jié)果的字段 pyjion_code = frame.code.co_extra if not pyjion_code: # 不如不存在,就設(shè)置一個(gè)空的默認(rèn)值 frame.code.co_extra = PyjionJittedCode() elif not pyjion_code.jit_failed: # 如果之前 jit 執(zhí)行成功 if pyjion_code.jit_code: # 如果存在 jit 生成的 bytecode,就執(zhí)行它 return pyjion_code.eval(pyjion_code.jit_code, frame) elif pyjion_code.exec_count > 20000: # 沒有 jit 編譯過,且 frame 被執(zhí)行超過 20000 次,就嘗試進(jìn)行 jit 編譯 # 如果不存在 jit 生成的 bytecode,就 jit 編譯生成它 if jit_compile(frame): # 如果 jit 編譯成功,就執(zhí)行 jit 編譯的 bytecode return pyjion_code.eval(pyjion_code.jit_code, frame) else: # 如果 jit 編譯失敗,就記錄下,后面不再編譯 pyjion_code.jit_failed = True # 增加 frame 執(zhí)行次數(shù)計(jì)數(shù) pyjion_code.exec_count += 1 # 執(zhí)行默認(rèn)的 frame return _PyEval_EvalFrameDefault(frame, throw_flag)

          下面接著看 TorchDynamo 自定義 evale frame 的實(shí)現(xiàn)。在了解具體的自定義 frame 執(zhí)行邏輯前,有個(gè)前置知識是 PyFrameObject 中的 PyCodeObject 為了執(zhí)行自定義 frame 增加了一個(gè) co_extra 字段,用來讓用戶放置自定義的數(shù)據(jù),一般是存放自定義編譯結(jié)果(https://peps.python.org/pep-0523/#expanding-pycodeobject)。

          typedef struct { ... void *co_extra; /* 自定義的 frame 需要的自定義數(shù)據(jù) */} PyCodeObject;

          TorchDynamo 在自定義編譯結(jié)果的類型是 CacheEntry,其中最重要的字段是 code,是被編譯器修改后的 ByteCode:

          typedef struct cache_entry { // check the guards: lambda: : bool PyObject* check_fn; // modified user bytecode (protected by check_fn"s guards) PyCodeObject* code; // on a cache miss, linked list of next thing to try struct cache_entry* next;} CacheEntry;

          現(xiàn)在看下自定義的 eval frame 邏輯?custom_eval_frame_shim(https://github.com/pytorch/pytorch/blob/eaf4fe3d2b7096579b05b52d543756f74d0e91e7/torch/csrc/dynamo/eval_frame.c#L342):

          static PyObject* _custom_eval_frame(PyThreadState* tstate, PyFrameObject* frame, int throw_flag, PyObject* callback) { // 獲取當(dāng)前 frame 的 PyCodeObject 的 extra 字段用于后面設(shè)置 // 該字段用于放置自定義的編譯結(jié)果 CacheEntry* extra = get_extra(frame->f_code); // callback 即上文說的自定義編譯器 // 使用 callback 進(jìn)行 bytecode 的修改,即編譯 // 編譯結(jié)果寫在了 frame->f_code中的 extra 中 PyObject* result = call_callback(callback, (PyObject*)frame, cache_size(extra)); if (result != Py_None) { // 緩存編譯結(jié)果 extra = create_cache_entry(extra, result); Py_DECREF(result); // 執(zhí)行自定義的 frame // eval_custom_code 最終會調(diào)用 CPython 接口 _PyEval_EvalFrameDefault 來執(zhí)行計(jì)算 // 其中 extra->code 中存放的就自定義編譯器生成的 ByteCode // 所以最終 _PyEval_EvalFrameDefault 執(zhí)行的是編譯器生成的 ByteCode return eval_custom_code(tstate, frame, extra->code, throw_flag); }}inline static PyObject* eval_custom_code(PyThreadState* tstate, PyFrameObject* frame, PyCodeObject* custom_code, int throw_flag) { // 使用 custom_code 創(chuàng)建一個(gè)自定義的 frame PyFrameObject* shadow_frame = PyFrame_New(tstate, custom_code, frame->f_globals, NULL); // 調(diào)用 Python 的 frame 執(zhí)行自定義 frame return _PyEval_EvalFrameDefault(tstate, shadow_frame, throw_flag);}

          到這里,已經(jīng)清楚了修改 Python ByteCode 執(zhí)行的主線邏輯。

          6

          小結(jié)

          這里對 Python 的執(zhí)行和 TorchDynamo 的主要原理做了初探,主要是自定義 Eval Frame 的實(shí)現(xiàn)技巧。其它相關(guān)的 Python ByteCode 標(biāo)準(zhǔn),ByteCode 到 FX Graph 的轉(zhuǎn)換,ByteCode 的改寫等內(nèi)容還沒涉及。

          參考資料 ?

          tenthousandmeters.com/b?(https://tenthousandmeters.com/blog/python-behind-the-scenes-1-how-the-cpython-vm-works/)

          peps.python.org/pep-052?(https://peps.python.org/pep-0523/)

          dev-discuss.pytorch.org?(https://dev-discuss.pytorch.org/t/torchdynamo-an-experiment-in-dynamic-python-bytecode-transformation/361)

          (原文:https://zhuanlan.zhihu.com/p/589115427) 其他人都在看

          李白:你的模型權(quán)重很不錯(cuò),可惜被我沒收了

          單RTX 3090訓(xùn)練YOLOv5s,時(shí)間減少11小時(shí)

          OpenAI掌門Sam Altman:AI下一個(gè)發(fā)展階段

          32篇年度最佳AI論文;Python編譯器Codon開源

          對比四大深度學(xué)習(xí)框架,我發(fā)現(xiàn)都關(guān)注兩大問題

          比快更快,開源Stable Diffusion刷新作圖速度

          OneEmbedding:單卡訓(xùn)練TB級推薦模型不是夢

          歡迎Star、試用OneFlow最新版本:GitHub - Oneflow-Inc/oneflow: OneFlow is a deep learning framework designed to be user-friendly, scalable and efficient.OneFlow is a deep learning framework designed to be user-friendly, scalable and efficient. - GitHub - Oneflow-Inc/oneflow: OneFlow is a deep learning framework designed to be user-friendly, scalable and efficient.https://github.com/Oneflow-Inc/oneflow/

          關(guān)鍵詞: ByteCode

          久久精品国产亚洲5555| 亚洲视频免费播放| 国产成人综合亚洲AV第一页| 精品久久久久久久久亚洲偷窥女厕| 亚洲一区二区三区在线网站| 亚洲日本在线看片| 亚洲AV综合色区无码另类小说| 国产亚洲色视频在线| 亚洲欭美日韩颜射在线二| 亚洲综合日韩久久成人AV| 亚洲精品在线视频观看| 亚洲国产精品久久66| 亚洲色图黄色小说| 亚洲欧洲尹人香蕉综合| 亚洲精品在线播放视频| 亚洲激情电影在线| 亚洲综合色区中文字幕| 亚洲熟女综合色一区二区三区| 亚洲精品久久无码av片俺去也| 亚洲第一街区偷拍街拍| av无码东京热亚洲男人的天堂| 精品亚洲福利一区二区| 亚洲乱亚洲乱少妇无码| 亚洲一区二区三区偷拍女厕 | 亚洲一区无码精品色| 国产成人99久久亚洲综合精品| 国产亚洲自拍一区| 亚洲国产成人一区二区精品区| 亚洲Aⅴ无码专区在线观看q| 亚洲美女视频一区二区三区| 亚洲AV无码成人专区| 亚洲JIZZJIZZ妇女| 亚洲视频一区二区| 久久亚洲国产午夜精品理论片| 亚洲一区免费观看| 亚洲第一成人在线| 激情小说亚洲图片| 亚洲AV无码乱码在线观看性色扶 | 亚洲国产精品乱码在线观看97| 亚洲va在线va天堂成人| 亚洲1区2区3区精华液|