[源码解析] 深度学习流水线并行 PipeDream(2)--- 计算分区

时间:2024-02-17 17:53:58

[源码解析] 深度学习流水线并行 PipeDream(2)--- 计算分区

0x00 摘要

在前文中,我们介绍了PipeDream的总体架构和Profile阶段,本文我们继续介绍计算分区阶段。其功能是:依据profile结果确定所有层的运行时间,然后使用动态规划对模型进行划分,将模型划分为不同的stage,以及得到每个stage的replication数。计算结果具体如下图所示:

流水线并行其他文章链接如下:

[源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现

[源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积

[源码解析] 深度学习流水线并行之PipeDream(1)--- Profile阶段

0x01 前言

1.1 Profile文件

我们首先看看profile文件 profiler/translation/profiles/gnmt/graph.txt 内容,这里只是做摘录。

node1 -- Input0 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
node5 -- EmuBidirLSTM(  (bidir): LSTM(1024, 1024, bidirectional=True)  (layer1): LSTM(1024, 1024)  (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
node2 -- Input1 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
node6 -- Dropout(p=0.2) -- forward_compute_time=0.077, backward_compute_time=0.196, activation_size=12582912.0, parameter_size=0.000
node7 -- LSTM(2048, 1024) -- forward_compute_time=3.190, backward_compute_time=5.348, activation_size=[6291456.0; 131072.0; 131072.0], parameter_size=50364416.000
node8 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000
node9 -- __getitem__(1) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=131072.0, parameter_size=0.000
node10 -- Dropout(p=0.2) -- forward_compute_time=0.064, backward_compute_time=0.128, activation_size=6291456.0, parameter_size=0.000
node11 -- LSTM(1024, 1024) -- forward_compute_time=2.491, backward_compute_time=4.203, activation_size=[6291456.0; 131072.0; 131072.0], parameter_size=33587200.000
node12 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000
node13 -- __getitem__(1) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=131072.0, parameter_size=0.000
node14 -- Add -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000
node15 -- Dropout(p=0.2) -- forward_compute_time=0.059, backward_compute_time=0.121, activation_size=6291456.0, parameter_size=0.000
node16 -- LSTM(1024, 1024) -- forward_compute_time=2.492, backward_compute_time=4.201, activation_size=[6291456.0; 131072.0; 131072.0], parameter_size=33587200.000
node17 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000
......
	node1 -- node4
	node4 -- node5
	node2 -- node5
	node5 -- node6
	node6 -- node7
	node7 -- node8
	node7 -- node9
	node8 -- node10
	node10 -- node11
	node11 -- node12
	node11 -- node13
	node12 -- node14
	node8 -- node14
	node14 -- node15
	node15 -- node16
	node16 -- node17
	node16 -- node18
	node17 -- node19
	node14 -- node19
......

1.2 总体思路

在前文我们也提到了几个挑战,其中有:

  • 如何高效划分流水线
    • 模型特质和硬件拓扑会降低效率。分配算法也必须考虑模型特质和硬件拓扑
    • 机器间的过度通信会降低硬件效率。
  • 如何防止流水线瓶颈
    • 由木桶原理我们可以知道,一个流水线管道的吞吐量由这个流水线上最慢环节的吞吐量决定。所以需要确保流水线中所有阶段都大致花费相同的计算时间,否则最慢的阶段将会成为整个流水线的瓶颈

因此当跨机器将层划分为不同的阶段时,PipeDream的自动划分算法必须确保每个阶段大致执行相同的总工作量。同时还必须确保各阶段之间通信的数据量尽可能小,以避免通信中断。

PipeDream的自动划分算法总体目标是输出一个平衡的管道,算法如下:

  • 将DNN层划分为多个阶段,以便每个阶段以大致相同的速率完成,即花费大致相同的计算时间。
  • 尝试以拓扑感知的方式尽量减少worker之间的通信(例如,如果可能,向更高带宽的链路发送较大的输出)。
  • 因为DNN并不总可以在可用的workers做平均分配,为了进一步改进负载平衡,PipeDream允许复制一个stage,即在这个stage上使用多个worker进行数据并行。

这个划分问题等价于最小化流水线的最慢阶段所花费的时间,并且具有最优子问题属性:在给定worker工作量前提下,吞吐量最大化的流水线由一系列子流水线构成,其中每一个子流水线针对较小worker工作量来最大化自己的输出。因此PipeDream使用动态规划来寻找最优解。

这里给出对应的架构图如下:

我们下面先看看计算分区之前的准备工作:图相关工作和构建反链。

0x02 图相关

图的定义位于 graph/graph.py 文件之中,主要数据结构有两个:Graph 和 Node。

2.1 Graph

Graph就是图的数据结构,其主要成员包括:

  • nodes :图内节点;
  • edges :图内每个节点的输出边;
  • in_edges :图的每个节点的输入边;
  • _predecessors :每个节点的前序节点;
  • _successors :每个节点的后序节点;
  • _antichain_dag :反链DAG;
class Graph(object):
    def __init__(self, node=None):
        self.nodes = {} # 节点
        if node is not None:
            self.nodes[node.node_id] = node
        self.edges = {} # 出边
        self.in_edges = {} # 入边

        self._predecessors = {} #每个节点的前序节点 
        self._successors = {} # 每个节点的后序节点
        self._augmented_antichains = {}
        self._deaugmented_augmented_antichains = {}
        self._next_antichains = {}
        self._antichain_dag = None # 反链DAG

        if node is not None:
            self.in_edges[node.node_id] = list()

节点定义如下,里面就是从profile获取到的结构,比如:

  • forward_compute_time : 前向传播时间;
  • backward_compute_time :反向传播时间;
  • activation_size : 激活值大小;
  • parameter_size : 参数大小;
class Node(object):
    def __init__(self, node_id, node_desc="", forward_compute_time=0.0,
                 backward_compute_time=0.0, activation_size=0.0, parameter_size=0.0,
                 stage_id=None):
        self.node_id = node_id
        self.node_desc = node_desc
        self.forward_compute_time = forward_compute_time
        self.backward_compute_time = backward_compute_time
        self.activation_size = activation_size
        self.parameter_size = parameter_size
        self.stage_id = stage_id
        self.depth = None
        self.height = None

我们打印出运行时看看,可以发现 Graph 的具体情况。

gr = {Graph} 
 # 边
 edges = {dict: 39}
  'node1' = {list: 1} 
   0 = {Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
  'node4' = {list: 1} 
   0 = {Node} node5 -- EmuBidirLSTM(  (bidir): LSTM(1024, 1024, bidirectional=True)  (layer1): LSTM(1024, 1024)  (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
   ......

 # 输入边 
 in_edges = {dict: 44} 
  'node4' = {list: 1} 
   0 = {Node} node1 -- Input0 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
  'node5' = {list: 2} 
   0 = {Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
   1 = {Node} node2 -- Input1 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
   ......
  
 # 节点 
 nodes = {dict: 48}
  'node1' = {Node} node1 -- Input0 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
  'node4' = {Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
  'node5' = {Node} node5 -- EmuBidirLSTM(  (bidir): LSTM(1024, 1024, bidirectional=True)  (layer1): LSTM(1024, 1024)  (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
 ......

# 前置节点
_predecessors = {dict: 36} 
 'node4' = {set: 0} set()
  __len__ = {int} 0
 'node5' = {set: 1} {<graph.graph.Node object at 0x7fb055e4bf28>}
  {Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
  __len__ = {int} 1
 'node6' = {set: 2} {<graph.graph.Node object at 0x7fb055e4bf98>, <graph.graph.Node object at 0x7fb055e4bf28>}
  {Node} node5 -- EmuBidirLSTM(  (bidir): LSTM(1024, 1024, bidirectional=True)  (layer1): LSTM(1024, 1024)  (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
  {Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
  __len__ = {int} 2
 'node7' = {set: 3} {<graph.graph.Node object at 0x7fb055e4bf98>, <graph.graph.Node object at 0x7fb055e4bf28>, <graph.graph.Node object at 0x7fb055e670f0>}
  {Node} node5 -- EmuBidirLSTM(  (bidir): LSTM(1024, 1024, bidirectional=True)  (layer1): LSTM(1024, 1024)  (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
  {Node} node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
  {Node} node6 -- Dropout(p=0.2) -- forward_compute_time=0.077, backward_compute_time=0.196, activation_size=12582912.0, parameter_size=0.000
  __len__ = {int} 3

 # 其他变量
  _antichain_dag = {NoneType} None
  _augmented_antichains = {dict: 0} {}
  _deaugmented_augmented_antichains = {dict: 0} {}
  _next_antichains = {dict: 0} {}
  _successors = {dict: 0} {}

2.2 构建图

图是由profile文件的字符串构建出来。找出来profile文件内容我们就可以知道,具体是针对每行进行不同处理。

node1 -- Input0 -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=0.0, parameter_size=0.000
node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000
node5 -- EmuBidirLSTM(  (bidir): LSTM(1024, 1024, bidirectional=True)  (layer1): LSTM(1024, 1024)  (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000
	node1 -- node4
	node4 -- node5
	node2 -- node5

构建图具体代码如下:

@staticmethod
def from_str(graph_str):
    gr = Graph()
    graph_str_lines = graph_str.strip().split('\n') 
    for graph_str_line in graph_str_lines: # 逐行处理
        if not graph_str_line.startswith('\t'):
            node = Node.from_str(graph_str_line.strip()) # 构建节点
            gr.nodes[node.node_id] = node
        else:
            # 构建边
            [in_node_id, node_id] = graph_str_line.strip().split(" -- ")
            if node_id not in gr.in_edges: # 每个节点的输入边
                gr.in_edges[node_id] = [gr.nodes[in_node_id]]
            else:
                gr.in_edges[node_id].append(gr.nodes[in_node_id])
            if in_node_id not in gr.edges: # 每个节点的输出边
                gr.edges[in_node_id] = [gr.nodes[node_id]]
            else:
                gr.edges[in_node_id].append(gr.nodes[node_id])
    return gr

构建节点具体代码如下:

    @staticmethod
    def from_str(node_str):
        node_str_tokens = node_str.strip().split(" -- ")
        node_id = node_str_tokens[0] # 节点名字
        node_desc = node_str_tokens[1] # 节点描述
        node_metadata = node_str_tokens[2] # 元数据
        stage_id = None
        if len(node_str_tokens) > 3:
            stage_id = int(node_str_tokens[3].split("=")[1]) # 阶段信息
        [forward_compute_time, backward_compute_time, activation_size, parameter_size] = node_metadata.split(", ")
        forward_compute_time = float(forward_compute_time.split("=")[1]) # 前向传播计算时间
        backward_compute_time = float(backward_compute_time.split("=")[1]) # 后向传播计算时间
        if "[" in activation_size:
            activation_size = activation_size.split("=")[1] # 激活值大小
            activation_size = sum([float(x) for x in activation_size.lstrip("[").rstrip("]").split("; ")])
        else:
            activation_size = float(activation_size.split("=")[1])
        parameter_size = float(parameter_size.split("=")[1]) # 参数大小
        # 构建节点
        return Node(node_id, node_desc, forward_compute_time=forward_compute_time,
                    backward_compute_time=backward_compute_time, activation_size=activation_size,
                    parameter_size=parameter_size, stage_id=stage_id)

2.3 反链

在有向无环图中,有如下的一些概念:

  • 链 :一条链是一些点的集合,在此链上的任意两个点x, y,满足以下条件:或者 x 能到达 y ,或者 y 能到达 x 。也可以认为是某一个偏序集S的全序子集(所谓全序是指其中任意两个元素可以比较)

  • 反链 :一条反链也是一些点的集合,在此链上任意两个点x, y,满足如下条件: x 不能到达 y,且 y 也不能到达 x。也可以认为是某一个偏序集S的子集,其中任意两个元素不可比较。

在PipeDream的图数据结构之中,也有反链的概念。反链节点定义如下:

class AntichainNode(Node):
    def __init__(self, node_id, antichain, node_desc=""):
        self.antichain = antichain
        self.output_activation_size = 0.0
        super(AntichainNode, self).__init__(node_id, node_desc)

因为此处过于复杂,所以我们会在下面用一节专门分析。

0x03 构建反链

因为本节概念比较绕,所以我们先提前剧透。

寻找某节点后续反链的目的就是找到下一个图分割点 A(可能是若干node的组合),为了确定 A 的运行时间(或者其他信息),我们需要找到 A 的增强反链

此处具体代码位于optimizer_graph_hierarchical.py 文件。

我们利用如下逻辑来演示:

+-------+       +-------+
| node1 |       | node2 |
+---+---+       +---+---+
    |               |
    |               |
    |               |
    v               v
+---+---+       +---+---+        +-------+        +-------+
| node4 +-----> | node5 +------> | node6 +------->+ node7 |
+-------+       +-------+        +-------+        +-+-+---+
                                                    | |
                                                    | |
                                      +-------------+ |
                                      |               |
                                      v               v
                                 +----+--+        +---+---+
                                 | node9 |        | node8 +-----+
                                 +-------+        +---+---+     |
                                                      |         |
                    +---------------------------------+         |
                    |                                           |
                    v                                           |
               +----+---+       +--------+        +--------+    |
               | node10 +-----> | node11 +------> | node12 |    |
               +--------+       +---+----+        +----+---+    |
                                    |                  |        |
                                    |                  |        |
                                    v                  v        |
                                +---+----+        +----+---+    |
                                | node13 |        | node14 +<---+
                                +--------+        +-+----+-+
                                                    |    |
                                             +------+    +---+
                                             |               |
                                             v               v
                                        +----+---+        +--+-----+
                                        | node15 |        | node19 |
                                        +--------+        +--------+

3.1 main函数入口

我们首先从 main 函数看起。main函数第一部分是构建反链和拓扑排序,具体如下:

  • 从图中移除source节点。目的是排除干扰,因为input必然在第一层,没必要让优化器再来选择把输入放在哪里,所以先去除,后续转换模型时候会再加上。
  • 对图的输出进行处理,移除没有用到的输出。
  • 得到反链DAG。
  • 对反链DAG进行拓扑排序,得到一个排序好的节点列表。

具体代码如下:

def main(all_num_machines, profile_filename, network_bandwidths, memory_size,
         straight_pipeline, use_memory_constraint, use_fewer_machines,
         activation_compression_ratio, output_directory,
         print_configuration=True, verbose=False):
    gr = graph.Graph.from_str(open(profile_filename, 'r').read())

    # Zero out all metadata associated with inputs in graph, since the optimizer
    # shouldn't really get a choice with where to place the input (should always
    # be in the first stage).
    # 排除干扰,因为input必然在第一层,没必要让优化器再来选择把输入放在哪里,所以先去除,后续会再加上。
    sources = gr.sources() # 对图的输入进行处理
    nodes_to_remove = OrderedDict()
    for source in sources:
        if source.node_desc.startswith("Input"): # 只处理input
            source.forward_compute_time = 0.0
            source.backward_compute_time = 0.0
            source.activation_size = 0.0
            source.parameter_size = 0.0
            nodes_to_remove[source] = []
            for out_node in gr.edges[source.node_id]:
                nodes_to_remove[source].append(out_node) # 记录这些删除source对应了哪些out节点,因为后续还要处理
            gr.remove_node(source) # 在图中移除这些input source

    # Remove all unneeded sinks that are not used, makes code generation and
    # optimization easier.
    sinks = gr.sinks() # 对图的输出进行处理,移除没有用到的输出
    for sink in sinks:
        if sink.node_desc.startswith("__getitem__"):
            gr.remove_node(sink)

    antichain_gr = gr.antichain_dag() # 得到反链DAG
    states = antichain_gr.topological_sort() # 拓扑排序,得到一个排序好的节点列表

    # 后续代码暂时省略

这里再取出反链节点定义如下,可以看出来和代码对应关系。

class AntichainNode(Node):
    def __init__(self, node_id, antichain, node_desc=""):
        self.antichain = antichain
        self.output_activation_size = 0.0
        super(AntichainNode, self).__init__(node_id, node_desc)

3.2 增强反链

首先要介绍先增强反链概念。每个节点的增强反链包括:本身节点 + 部分前序节点

这个前序节点的选取算法是:

  1. 获取本节点的全部前序节点列表;
  2. 如果一个前序节点的"出边目的节点"不在全部前序节点列表,且"出边目的节点"不为本身,则选取此前序节点为增强反链的一部分。

从下面图例中可以看出来,如果某一个节点 A,其前置节点中有一个分叉节点 Z,且这个分叉之中,有一个分叉绕过了节点 A,则对于节点 A,他的增强反链就是 [A, Z]。

对于增强反链概念,可以理解为:对于节点 A,他只有把节点 Z 一起考虑,才能唯一确定自己节点的运行时间。因为如果思考节点 A 的运行时间,我理解的大致思路是:

  • 因为各个阶段可以流水线并行,所以 A 的运行时间应该是以下三个时间的最大值:A的计算时间,A的输入时间,A的输出时间。
  • A 的输入时间是以下两个时间的最大值: X --> A 节点输出时间,Z --> A 节点的输出时间。
  • 但是因为不清楚 Z 的内部运行机制,所以不能确定 Z 的两个输出之间是否有依赖关系,比如 "必须先完成 Z--> D,才能输出 Z--> A", 所以,也需要考虑 Z --> D 的传输时间。

所以,需要把 [ A,Z ] 放在一起作为一个状态考虑,事实上 PipeDream 就是这么处理的,用 [ A,Z ] 这个状态来统一计算

因为作为一个状态考虑,所以给节点 A 计算输出激活值大小,具体是通过遍历其反链(增强反链)来计算,就是把其增强反链的前序节点给自己的输出都叠加起来

    +-----+            +-----+
    |  X  |            |  Z  |
    +--+--+            +--+-++
       |                  | |
       |                  | |
       +------+   +-------+ |
              |   |         |
              v   v         |
             ++---++        |
             |  A  |        |
             ++-+--+        |
              | |           |
    +---------+ |           |
    |           |           |
    v           v           v
+---+-+      +--+--+      +-+---+
|  B  |      |  C  |      |  D  |
+-----+      +-----+      +-----+

在代码之中,_augmented_antichains 是增强反链,也是一个字典类,key是节点名字,value是 key 节点的增强反链,比如:

augment_antichain函数作用就是对每个节点,找到其增强反链。

def augment_antichain(self, antichain):
    # 参数 antichain 是一个节点列表
    antichain_key = tuple(sorted(antichain))
    # 如果key已经在扩大反链之中,就直接返回对应key的增强反链
    if antichain_key in self._augmented_antichains:
        return self._augmented_antichains[antichain_key]
    extra_nodes = set()
    all_predecessors = set()
    # 遍历参数list之中的反链节点,获取每个节点的前置节点,归并在all_predecessors之中。
    for antichain_node in antichain:
        predecessors = self.predecessors(antichain_node)
        all_predecessors = all_predecessors.union(predecessors)
    # 遍历参数list之中的反链节点
    for antichain_node in antichain:
        # 获取每个反链节点的前置节点列表
        predecessors = self.predecessors(antichain_node)
        # 遍历每个前置节点
        for predecessor in predecessors:
            # 看每个前置节点的出边,如果出边不在前置节点列表之中,且 出边节点不等于本反链节点
            for out_node in self.edges[predecessor.node_id]:
                if out_node not in predecessors and out_node.node_id != antichain_node:
                    # 把这个前置节点插入到附加节点列表中
                    extra_nodes.add(predecessor.node_id)
    # 最终把个附加节点列表插入到增强节点之中
    self._augmented_antichains[antichain_key] = list(extra_nodes) + antichain
    return self._augmented_antichains[antichain_key]

比如对应下图中的逻辑,初始化之后,_augmented_antichains 就是

_augmented_antichains = {dict: 1} 
 ('node4',) = {list: 1} ['node4']

后续迭代node 5之后,_augmented_antichains 就是

_augmented_antichains = {dict: 2} 
 ('node4',) = {list: 1} ['node4']
 ('node5',) = {list: 1} ['node5']
 __len__ = {int} 2

继续迭代,增强反链为:

_augmented_antichains = {dict: 7} 
('node4',) = {list: 1} ['node4'] # node4的增强反链只有自己
('node5',) = {list: 1} ['node5'] # node5的增强反链只有自己
('node6',) = {list: 1} ['node6']
('node7',) = {list: 1} ['node7']
('node8',) = {list: 1} ['node8']
('node10',) = {list: 2} ['node8', 'node10'] # node10的增强反链是'node8', 'node10'
('node14',) = {list: 1} ['node14']
('node11',) = {list: 2} ['node8', 'node11'] # node11的增强反链是'node8', 'node11'
('node15',) = {list: 2} ['node14', 'node15']
('node19',) = {list: 1} ['node19']
('node12',) = {list: 2} ['node8', 'node12']
('node16',) = {list: 2} ['node14', 'node16']
('node23',) = {list: 2} ['node20', 'node23']
('node17',) = {list: 2} ['node14', 'node17']  

图例中可以看出来,因为有 node 8的出边 [node 8,node 14] 存在,对于 node 10, node 11, node 12 来说,他们必须把 node 8 加入自己的增强反链之中。

对于 node 10,我们可以认为,必须结合 node 8之后,node 10 才能确定 node 10 的运行时间。下面图上标记出来了 node 10 的 augmented 反链(本身节点 + 部分前序节点)。

+-------+       +-------+
| node1 |       | node2 |
+---+---+       +---+---+
    |               |
    |               |
    |               |
    v               v
+---+---+       +---+---+        +-------+        +-------+
| node4 +-----> | node5 +------> | node6 +------->+ node7 |
+-------+       +-------+        +-------+        +-+-+---+
                                                    | |
                                                    | |
                                      +-------------+ |
                                      |               |
                                      v               v  augmented
                                 +----+--+        +---+---+
                                 | node9 |        | node8 +-----+
                                 +-------+        +---+---+     |
                                                      |         |
                    +---------------------------------+         |
                    |                                           |
                    v                                           |
               +----+---+       +--------+        +--------+    |
     antichain | node10 +-----> | node11 +------> | node12 |    |
               +--------+       +---+----+        +----+---+    |
             augmented              |                  |        |
                                    |                  |        |
                                    v                  v        |
                                +---+----+        +----+---+    |
                                | node13 |        | node14 +<---+
                                +--------+        +-+----+-+
                                                    |    |
                                             +------+    +---+
                                             |               |
                                             v               v
                                        +----+---+        +--+-----+
                                        | node15 |        | node19 |
                                        +--------+        +--------+

3.3 后续反链

在代码之中,_next_antichains 是一个字典类,key是节点名字,value是 key 节点的后续反链。

比如,对于 node A 来说,下一个反链是 [ node B, node C ],其中 node B 和 node C 彼此之间无法排序。寻找反链的目的就是找到下一个图分割点

    +-----+            +-----+
    |  X  |            |  Z  |
    +--+--+            +--+-++
       |                  | |
       |                  | |
       +------+   +-------+ |
              |   |         |
              v   v         |
             ++---++        |
             |  A  |        |
             ++-+--+        |
              | |           |
    +---------+ |           |
    |           |           |
    v           v           v
+---+-+      +--+--+      +-+---+
|  B  |      |  C  |      |  D  |
+-----+      +-----+      +-----+

对于每个节点 antichain ,next_antichains 函数获取其后续反链。

    def next_antichains(self, antichain):
        # 构建antichain的反链key,其实就是 antichain 自己作为key
        antichain_key = tuple(sorted(antichain))
        # 如果key已经在后续反链之中,则返回这个后续反链
        if antichain_key in self._next_antichains:
            return self._next_antichains[antichain_key]

        next_antichains = []
        antichain_set = set(antichain)
        # 获取 antichain 的增强反链
        augmented_antichain = self.augment_antichain(antichain)
        # 遍历增强反链
        for augmented_antichain_node in augmented_antichain:
            # 遍历增强反链某节点的出边
            next_nodes = self.edges[augmented_antichain_node] if augmented_antichain_node in self.edges else []
            # 遍历增强反链某节点的出边
            for next_node in next_nodes:
                # 如果出边节点已经在反链集合之中,跳过,进入下一循环
                if next_node.node_id in antichain_set:
                    continue
                # 如果出边节点是后续反链,则假如到反链列表   
                if self.is_next_antichain(augmented_antichain, next_node.node_id):
                    next_antichain = self.construct_antichain(augmented_antichain,
                                                              augmented_antichain_node,
                                                              next_node.node_id)
                    next_antichains.append(next_antichain)
        # 最终把反链列表设置为key对应的反链            
        self._next_antichains[antichain_key] = next_antichains
        return self._next_antichains[antichain_key]

is_next_antichain 方法用来判断某新节点是否为后续反链。

def is_next_antichain(self, augmented_antichain, new_node):
    successors = self.successors(new_node)
    augmented_antichain_set = set(augmented_antichain)
    # 遍历新节点的后续节点
    for successor in successors:
        # 如果后续节点有一个在增强节点之中,就返回false,说明不是后续反链
        if successor.node_id in augmented_antichain_set:
            return False
    # 否则就是后续反链      
    return True

_next_antichains举例如下,大家可以结合之前的增强反链对比看看。

  • 以 node 10 为例,其增强节点为:[ node 8,node 10 ],
  • 遍历这些增强节点,看每一个增强节点的出边。8 的出边 [ node 10,node 14 ],10 的出边是 [ node 11]。
  • 所以有三个点 node 10,node 11,node 14 可以继续看。其中node 10 已经在[ node 8,node 10 ]之中,所以不考虑。
  • 用 14 调用 is_next_antichain。
    • is_next_antichain 之中,augmented_antichain 为 [ node 8, node 10],new_node 是 node 14。
    • 得到 successors 集合为 [ node31,node16,node23,node44,node48 ....] 等22个节点,这些节点都不在 [ node 8, node 10] 之中,所以 is_next_antichain 为 true,14 是后续反链节点之一。
  • 用 11 调用 is_next_antichain。
    • is_next_antichain 之中,augmented_antichain 为 [ node 8, node 10],new_node 是 node 11。
    • 得到 successors 集合为 [ node16,node40,node23,....] 等节点,这些节点都不在 [ node 8, node 10] 之中,所以 is_next_antichain 为 true,11 是后续反链节点之一。

所以 node 10 的后续反链是 [ ['node14'] ,[ 'node11'] ]。

对比 看看,node 10 的增强反链是 ['node8', 'node10'],

_next_antichains = {dict: 99} 
 ('node4',) = {list: 1} [['node5']]
 ('node5',) = {list: 1} [['node6']]
 ('node6',) = {list: 1} [['node7']]
 ('node7',) = {list: 1} [['node8']]
 ('node8',) = {list: 2} [['node10'], ['node14']]
 ('node10',) = {list: 2} [['node14'], ['node11']] # 这里
 ('node14',) = {list: 2} [['node15'], ['node19']]
 ('node11',) = {list: 2} [['node14'], ['node12']]
 ('node15',) = {list: 2} [['node19'], ['node16']]
 ('node19',) = {list: 1} [['node23']]
 ('node12',) = {list: 2} [['node14'], ['node14']]
 ('node16',) = {list: 2} [['node19'], ['node17']]

具体如下图,可以看出来,node 11和 node 14确实是 node 10的后续反链,就是在这两个节点上可以对于图进行分割。

可以这么理解:对于 node 10 来说,下一个反链是 [ node 11, node 14],其中 node 11 和 node 14 彼此之间无法排序寻找后续反链的目的就是找到下一个图分割点

+-------+       +-------+
| node1 |       | node2 |
+---+---+       +---+---+
    |               |
    |               |
    |               |
    v               v
+---+---+       +---+---+        +-------+        +-------+
| node4 +-----> | node5 +------> | node6 +------->+ node7 |
+-------+       +-------+        +-------+        +-+-+---+
                                                    | |
                                                    | |
                                      +-------------+ |
                                      |               |
                                      v               v  augmented
                                 +----+--+        +---+---+
                                 | node9 |        | node8 +-----+
                                 +-------+        +---+---+     |
                                                      |         |
                    +---------------------------------+         |
                    |                                           |
                    v              next                         |
               +----+---+       +--------+        +--------+    |
     antichain | node10 +-----> | node11 +------> | node12 |    |
               +--------+       +---+----+        +----+---+    |
             augmented              |                  |        |
                                    |                  |        |
                                    v             next v        |
                                +---+----+        +----+---+    |
                                | node13 |        | node14 +<---+
                                +--------+        +-+----+-+
                                                    |    |
                                             +------+    +---+
                                             |               |
                                             v               v
                                        +----+---+        +--+-----+
                                        | node15 |        | node19 |
                                        +--------+        +--------+

3.4 总体构建

antichain_dag 的目的是依据 增强反链列表后续反链列表来构建一个反链 DAG。

我们以上面的图例进行讲解,以 node 8 为例。

def antichain_dag(self):
    if self._antichain_dag is not None:
        return self._antichain_dag

    antichain_dag = Graph()
    antichain_id = 0
    antichain = [self.sources()[0].node_id] # 获取source第一个节点。
    # 构建首节点,同时利用 augment_antichain 来往_augmented_antichains 之中添加首节点。
    source_node = AntichainNode("antichain_%d" % antichain_id, self.augment_antichain(antichain))
    antichain_dag.source = source_node
    antichain_queue = [antichain] # 把第一个节点插入queue
    antichain_mapping = {tuple(sorted(antichain)): source_node}

    # 如果queue之中还有节点
    while len(antichain_queue) > 0:
        antichain = antichain_queue.pop(0) # 弹出第一个节点,赋值为 antichain,这里为 node 8
        # key就是由 antichain 节点名字构建,比如 antichain_key = {tuple: 1} node8
        antichain_key = tuple(sorted(antichain)) 
        # 如果 antichain_key 已经位于self._next_antichains之中,即 antichain_key 的后续反链已经被记录,就跳过去
        if antichain_key in self._next_antichains:  
            continue
        # 获取 antichain 的后续反链,对于8,这里是[[10],[14]]
        next_antichains = self.next_antichains(antichain)
        # 遍历后续反链[10,14]
        for next_antichain in next_antichains:
            # 下一个反链节点的key 10
            next_antichain_key = tuple(sorted(next_antichain))
            if next_antichain_key not in antichain_mapping: # 如果存在,就跳过
                antichain_id += 1
                # 下一反链节点 10 被设置为其增强节点 [ 8, 10 ]
                next_antichain_node = AntichainNode("antichain_%d" % antichain_id, self.augment_antichain(next_antichain))
                # 设置 antichain_mapping
                antichain_mapping[next_antichain_key] = next_antichain_node
            # 向 反链DAG 插入边:    
            antichain_dag.add_edge(antichain_mapping[antichain_key],
                                   antichain_mapping[next_antichain_key])
            # 把最新反链节点插入queue,下次迭代使用
            antichain_queue.append(next_antichain)

    self._antichain_dag = antichain_dag
    return antichain_dag

这里其实目的是设置 antichain_mapping。

流程是:

  • 从 antichain_queue 弹出第一个节点,赋值为 antichain,这里为 node 8。
  • 获取 antichain 的后续反链,对于8,这里是[[10],[14]]。
  • 遍历后续反链 [10,14]。
  • 以 10 为例,设置下一个反链节点的key 为 10。
  • 下一反链节点 10 被设置为其增强节点 [ 8, 10 ],即 ('node10',) = {AntichainNode} antichain_5 -- ['node8', 'node10']。

可以看到,寻找某节点后续反链的目的就是找到下一个图分割点 A,然后为了确定 A 的运行时间(或者其他信息),需要找到 A 的增强反链(一些增强反链就是一些状态),A 的 antichain_mapping 就是其增强反链

antichain_mapping 示例如下:

antichain_mapping = {dict: 99} 
 ('node4',) = {AntichainNode} antichain_0 -- ['node4']
 ('node5',) = {AntichainNode} antichain_1 -- ['node5']
 ('node6',) = {AntichainNode} antichain_2 -- ['node6']
 ('node7',) = {AntichainNode} antichain_3 -- ['node7']
 ('node8',) = {AntichainNode} antichain_4 -- ['node8']
 ('node10',) = {AntichainNode} antichain_5 -- ['node8', 'node10'] # 最新设置
 ('node14',) = {AntichainNode} antichain_6 -- ['node14']
 ('node11',) = {AntichainNode} antichain_7 -- ['node8', 'node11']
 ('node15',) = {AntichainNode} antichain_8 -- ['node14', 'node15']
 ('node19',) = {AntichainNode} antichain_9 -- ['node19']
 ('node12',) = {AntichainNode} antichain_10 -- ['node8', 'node12']
 ('node16',) = {AntichainNode} antichain_11 -- ['node14', 'node16']
 ('node23',) = {AntichainNode} antichain_12 -- ['node20', 'node23']
 ('node17',) = {AntichainNode} antichain_13 -- ['node14', 'node17']

antichain_dag 示例如下,可以认为就是增强反链DAG

antichain_dag = {Graph}
	nodes = {dict: 99} 
   'antichain_0' = {AntichainNode} antichain_0 -- ['node4']
   'antichain_1' = {AntichainNode} antichain_1 -- ['node5']
   'antichain_2' = {AntichainNode} antichain_2 -- ['node6']
   'antichain_3' = {AntichainNode} antichain_3 -- ['node7']
   'antichain_4' = {AntichainNode} antichain_4 -- ['node8']
   'antichain_5' = {AntichainNode} antichain_5 -- ['node8', 'node10']
   'antichain_6' = {AntichainNode} antichain_6 -- ['node14']
   'antichain_7' = {AntichainNode} antichain_7 -- ['node8', 'node11']
   'antichain_8' = {AntichainNode} antichain_8 -- ['node14', 'node15']
   'antichain_9' = {AntichainNode} antichain_9 -- ['node19']
   'antichain_10' = {AntichainNode} antichain_10 -- ['node8', 'node12']
   'antichain_11' = {AntichainNode} antichain_11 -- ['node14', 'node16']
   'antichain_12' = {AntichainNode} antichain_12 -- ['node20', 'node23']
   'antichain_13' = {AntichainNode} antichain_13 -- ['node14', 'node17']
   'antichain_14' = {AntichainNode} antichain_14 -- ['node20', 'node30', 'node23']
   'antichain_15' = {AntichainNode} antichain_15 -- ['node20', 'node36', 'node23']
   'antichain_16' = {AntichainNode} antichain_16 -- ['node20', 'node43', 'node23']
   'antichain_17' = {AntichainNode} antichain_17 -- ['node20', 'node23', 'node24']

3.5 拓扑排序

得到了增强反链之后,需要进行拓扑排序之后才能使用

antichain_gr = gr.antichain_dag()
states = antichain_gr.topological_sort()

得出拓扑排序的目的是:如果按照拓扑序列的顶点次序,在到达某节点之前,可以保证它的所有前序活动都已经完成,从而整个工程顺序执行,不会冲突

在图论中,拓扑排序(Topological Sorting)是一个有向无环图(DAG, Directed Acyclic Graph)的所有顶点的线性序列。且该序列必须满足下面两个条件:

  1. 每个顶点出现且只出现一次。
  2. 若存在一条从顶点 A 到顶点 B 的路径,那么在序列中顶点 A 出现在顶点 B 的前面。

有向无环图(DAG)才有拓扑排序,非DAG图没有拓扑排序一说。一个有向无环图可以有一个或多个拓扑排序序列。

例如,下面这个图:

+--------+                  +--------+
|        +----------------> |        |
|   1    |                  |   4    +------------+
|        |    +-----------> |        |            |
+-----+--+    |             +---+----+            |
      |       |                 |                 v
      |       |                 |              +--+--+
      |       |                 |        +---> |  5  |
      |       |                 |        |     +-----+
      v       |                 |        |
              |                 v        |
+--------+    |             +---+-----+  |
|        +----+             |         |  |
|    2   +----------------->+    3    +--+
|        |                  |         |
+--------+                  +---------+

得到拓扑排序后的结果是 { 1, 2, 4, 3, 5 }。

这里的拓扑排序算法使用的是深度优先排序。

    def topological_sort(self):
        # Algorithm from https://en.wikipedia.org/wiki/Topological_sorting
        self.sorted_nodes = []
        self.marked_nodes = set()
        self.temporarily_marked_nodes = set()
        nodes = list(self.nodes.values())
        nodes.sort(key=lambda x: x.node_desc)
        for node in nodes:
            if node.node_id in self.marked_nodes:
                continue
            self.topological_sort_helper(node.node_id)
        return [self.nodes[node_id] for node_id in self.sorted_nodes]

    def topological_sort_helper(self, node_id):
        if node_id in self.marked_nodes:
            return
        if node_id in self.temporarily_marked_nodes:
            raise Exception("Graph has a cycle")
        self.temporarily_marked_nodes.add(node_id)
        if node_id in self.edges:
            out_nodes = list(self.edges[node_id])
            out_nodes.sort(key=lambda x: (x.node_desc, x.height))
            for out_node in out_nodes:
                self.topological_sort_helper(out_node.node_id)
        self.marked_nodes.add(node_id)
        self.temporarily_marked_nodes.remove(node_id)
        self.sorted_nodes.insert(0, node_id)

最终结果举例如下,可以和上面的反链DAG antichain_dag 比对,看看异同:

states = {list: 99} 
 00 = {AntichainNode} antichain_0 -- ['node4']
 01 = {AntichainNode} antichain_1 -- ['node5']
 02 = {AntichainNode} antichain_2 -- ['node6']
 03 = {AntichainNode} antichain_3 -- ['node7']
 04 = {AntichainNode} antichain_4 -- ['node8']
 05 = {AntichainNode} antichain_5 -- ['node8', 'node10']
 06 = {AntichainNode} antichain_7 -- ['node8', 'node11']
 07 = {AntichainNode} antichain_10 -- ['node8', 'node12']
 08 = {AntichainNode} antichain_6 -- ['node14']
 09 = {AntichainNode} antichain_8 -- ['node14', 'node15']
 10 = {AntichainNode} antichain_11 -- ['node14', 'node16']
 11 = {AntichainNode} antichain_13 -- ['node14', 'node17']
 12 = {AntichainNode} antichain_9 -- ['node19']
 13 = {AntichainNode} antichain_12 -- ['node20', 'node23']
 14 = {AntichainNode} antichain_18 -- ['node23', 'node20', 'node26']
 15 = {AntichainNode} antichain_17 -- ['node23', 'node20', 'node24']
 16 = {AntichainNode} antichain_32 -- ['node23', 'node20', 'node28']
 17 = {AntichainNode} antichain_31 -- ['node23', 'node20', 'node26', 'node24']
 18 = {AntichainNode} antichain_63 -- ['node23', 'node20', 'node26', 'node28']
 19 = {AntichainNode} antichain_33 -- ['node20', 'node26', 'node29']
 20 = {AntichainNode} antichain_16 -- ['node20', 'node43', 'node23']
 21 = {AntichainNode} antichain_30 -- ['node23', 'node20', 'node43', 'node26']
 22 = {AntichainNode} antichain_29 -- ['node23', 'node20', 'node43', 'node24']
 23 = {AntichainNode} antichain_59 -- ['node23', 'node20', 'node43', 'node28']

我们 也可以和如下增强反链比对,看到 states 就是对增强反链DAG进行拓扑排序之后的结果,按照这个顺序进行训练是符合逻辑的。

_augmented_antichains = {dict: 99} 
 ('node4',) = {list: 1} ['node4']
 ('node5',) = {list: 1} ['node5']
 ('node6',) = {list: 1} ['node6']
 ('node7',) = {list: 1} ['node7']
 ('node8',) = {list: 1} ['node8']
 ('node10',) = {list: 2} ['node8', 'node10']
 ('node14',) = {list: 1} ['node14']
 ('node11',) = {list: 2} ['node8', 'node11']
 ('node15',) = {list: 2} ['node14', 'node15']
 ('node19',) = {list: 1} ['node19']
 ('node12',) = {list: 2} ['node8', 'node12']
 ('node16',) = {list: 2} ['node14', 'node16']
 ('node23',) = {list: 2} ['node20', 'node23']
 ('node17',) = {list: 2} ['node14', 'node17']
 ('node23', 'node30') = {list: 3} ['node20', 'node30', 'node23']
 ('node23', 'node36') = {list: 3} ['node20', 'node36', 'node23']
 ('node23', 'node43') = {list: 3} ['node20', 'node43', 'node23']
 ('node24',) = {list: 3} ['node23', 'node20', 'node24']
 ('node26',) = {list: 3} ['node23', 'node20', 'node26']
 ('node23', 'node30', 'node36') = {list: 4} ['node20', 'node36', 'node30', 'node23']
 ('node23', 'node30', 'node43') = {list: 4} ['node20', 'node43', 'node30', 'node23']
 ('node31',) = {list: 3} ['node20', 'node26', 'node31']
 ('node24', 'node30') = {list: 4} ['node23', 'node20', 'node30', 'node24']
 ('node26', 'node30') = {list: 4} ['node23', 'node20', 'node30', 'node26']
 ('node23', 'node36', 'node43') = {list: 4} ['node20', 'node43', 'node36', 'node23']
 ('node37',) = {list: 4} ['node32', 'node20', 'node26', 'node37']
 ('node24', 'node36') = {list: 4} ['node23', 'node20', 'node36', 'node24']
 ('node26', 'node36') = {list: 4} ['node23', 'node20', 'node36', 'node26']
 ('node44',) = {list: 2} ['node40', 'node44']
 ('node24', 'node43') = {list: 4} ['node23', 'node20', 'node43', 'node24']
 ('node26', 'node43') = {list: 4} ['node23', 'node20', 'node43', 'node26']
 ('node24', 'node26') = {list: 4} ['node23', 'node20', 'node26', 'node24']

3.6 总结

因为目前的算法比较复杂,所以我们暂时总结一下目前为止的工作:

  • 计算出了每个节点的增强反链,最终得到增强反链组合 _augmented_antichains
  • 计算出了每个节点的后续反链。寻找某节点后续反链的目的就是找到下一个图分割点 A,然后为了确定 A 的运行时间(或者其他信息),需要找到 A 的增强反链(一些增强反链就是一些状态)。_next_antichains 是后续反链组合。
  • antichain_dag 函数依据 _next_antichains_augmented_antichains 进行处理,构建一个反链 DAG,就是变量 antichain_dag。
  • 得到了增强反链DAG之后,需要进行拓扑排序之后才能使用。得出拓扑排序的目的是:如果按照拓扑序列的顶点次序,在到达某节点之前,可以保证它的所有前序活动都已经完成,从而整个工程顺序执行,不会冲突
  • states 就是对增强反链DAG进行拓扑排序之后的结果,按照这个顺序进行训练是符合逻辑的。所以后续工作就是在 states 基础上运行。

0x04 计算分区

至此,图已经依据后续反链被分割成若干状态(states),每个状态很重要的一个属性是其增强反链。states 就是对增强反链进行拓扑排序之后的结果,按照这个顺序进行训练是符合逻辑的。

自动分区算法具体分为两部分。

  • compute_partitioning 是使用动态规划算法对于这些状态得出一个最优化结果,但是没有做具体分区。
  • analyze_partitioning 是利用最优化结果来做具体分区,排序后得到了一个偏序结果。

下面我们逐一分析。

4.1 main函数的逻辑

main函数接下来与计算分区相关的逻辑如下:

  • 为每个状态设置index。
  • 给每个状态计算出输出激活值大小,具体是通过遍历其反链(增强反链),可以认为就是其必要前序节点给自己的输出。
  • 给每个状态计算其信息,比如计算时间,激活大小,参数大小等等,都是通过前置节点完成的 。
  • 得到总体输出大小 output_activation_sizes & 所有前置节点id,后面计算分区时候需要。
  • 依据profile估计出系统内部的计算时间,compute_times_row 是 i 节点到 后续节点(i+1, i+2, ...)的计算时间,下面类似。
  • 依据profile估计出系统内部的激活值大小。
  • 依据profile估计出系统内部的参数大小。
  • 遍历机器集&网络带宽组合。流水线可以是straight(数目为1)或者并行(数目为num_machines),依据目前的信息,以及机器数量,网络带宽等,使用动态规划算法计算分区。假如机器集&网络带宽组合有两个,则会用每个组合进行一次动态规划算法,最后 all_As.append(A) 这里就是两个动态规划的结果,就是考虑到各种必要因素之后的最优结果

具体代码如下:

def main(all_num_machines, profile_filename, network_bandwidths, memory_size,
         straight_pipeline, use_memory_constraint, use_fewer_machines,
         activation_compression_ratio, output_directory,
         print_configuration=True, verbose=False):
    gr = graph.Graph.from_str(open(profile_filename, 'r').read())

    # Zero out all metadata associated with inputs in graph, since the optimizer
    # shouldn't really get a choice with where to place the input (should always
    # be in the first stage).
    # 排除干扰,因为input必然在第一层,没必要让优化器再来选择把输入放在哪里,所以先去除,后续会再加上。
    sources = gr.sources() # 对图的输入进行处理
    nodes_to_remove = OrderedDict()
    for source in sources:
        if source.node_desc.startswith("Input"): # 只处理input
            source.forward_compute_time = 0.0
            source.backward_compute_time = 0.0
            source.activation_size = 0.0
            source.parameter_size = 0.0
            nodes_to_remove[source] = []
            for out_node in gr.edges[source.node_id]:
                nodes_to_remove[source].append(out_node) # 记录这些删除source对应了哪些out节点,因为后续还要处理
            gr.remove_node(source) # 在图中移除这些input source

    # Remove all unneeded sinks that are not used, makes code generation and
    # optimization easier.
    sinks = gr.sinks() # 对图的输出进行处理,移除没有用到的输出
    for sink in sinks:
        if sink.node_desc.startswith("__getitem__"):
            gr.remove_node(sink)

    antichain_gr = gr.antichain_dag() # 得到反链DAG
    states = antichain_gr.topological_sort() # 拓扑排序,得到一个排序好的节点列表

    ###########################################################################
    # 之前代码在上节分析过,我们本节从这里继续分析
    ###########################################################################
    
    states_indices = {} # 为每个状态设置index
    for i in range(len(states)):
        states_indices[states[i]] = i
        
##################################### 运行时如下        
#states_indices = {dict: 99} 
# antichain_0 -- ['node4'] = {int} 0
# antichain_1 -- ['node5'] = {int} 1
# antichain_2 -- ['node6'] = {int} 2
# antichain_3 -- ['node7'] = {int} 3
# antichain_4 -- ['node8'] = {int} 4
# ......
         
    # 给每个状态计算出输出激活值大小,具体是通过遍历其反链(增强反链),可以认为就是其必要前序节点给自己的输出
    for i in range(len(states)):
        for antichain_node in states[i].antichain:
            states[i].output_activation_size += gr.nodes[antichain_node].activation_size
       
    # 给每个状态计算其信息,比如计算时间,激活大小,参数大小等等,都是通过前置节点完成的      
    for i in range(len(states)):
        antichain = states[i].antichain
        all_predecessors = gr.all_predecessors(antichain)
        states[i].compute_time = 0.0
        states[i].activation_size = 0.0
        states[i].parameter_size = 0.0
        for predecessor in all_predecessors: # 计算所有前置节点的信息
            states[i].compute_time += ((predecessor.forward_compute_time +
                                        predecessor.backward_compute_time) / 1000.0)
            states[i].activation_size += predecessor.activation_size
            states[i].parameter_size += predecessor.parameter_size
    gr.reset()

    # 得到总体输出大小 & 所有前置节点id,后面计算分区时候需要
    output_activation_sizes = [state.output_activation_size for state in states]
    all_predecessor_ids = [[states_indices[predecessor] for predecessor in
                            antichain_gr.predecessors(states[i].node_id)]
                           for i in range(len(states))]

##################################### 运行时如下      
# output_activation_sizes = {list: 99} 
# 00 = {float} 6291456.0
# 01 = {float} 12582912.0
# 02 = {float} 12582912.0
# 03 = {float} 6553600.0    
# .....
# all_predecessor_ids = {list: 99} 
#  00 = {list: 0} []
#  01 = {list: 1} [0]
#  02 = {list: 2} [0, 1]
#  03 = {list: 3} [0, 1, 2]
#  04 = {list: 4} [0, 1, 2, 3]
#  05 = {list: 5} [2, 3, 4, 0, 1]
#  06 = {list: 6} [2, 3, 4, 0, 1, 5]
#  07 = {list: 7} [6, 2, 3, 4, 0, 1, 5]
# ......
    
    compute_times = [] # 初始化计算时间
    activation_sizes = [] # 初始化激活值大小
    parameter_sizes = [] # 初始化参数值大小
    for i in range(len(states)+1): # 具体计算每一个节点的信息,去除他之前节点的影响
        compute_times_row = []
        activation_sizes_row = []
        parameter_sizes_row = []
        for j in range(len(states)): # 去除之前的节点
            if i == 0: # 列表中第一个节点
                compute_times_row.append(states[j].compute_time) # i 到 j 的计算时间
                activation_sizes_row.append(states[j].activation_size)
                parameter_sizes_row.append(states[j].parameter_size)
            else: # 列表中后续节点
                if j > (i-1):
                    compute_times_row.append(states[j].compute_time -
                        states[i-1].compute_time) # i 到 j 的计算时间
                    activation_sizes_row.append(states[j].activation_size -
                        states[i-1].activation_size)
                    parameter_sizes_row.append(states[j].parameter_size -
                        states[i-1].parameter_size)
                else:
                    compute_times_row.append(None)
                    activation_sizes_row.append(None)
                    parameter_sizes_row.append(None)
        compute_times.append(compute_times_row) # 依据profile估计出系统内部的计算时间,compute_times_row 是 i 节点到 后续节点(i+1, i+2, ...)的计算时间,下面类似
        activation_sizes.append(activation_sizes_row) # 依据profile估计出系统内部的激活值大小
        parameter_sizes.append(parameter_sizes_row) # 依据profile估计出系统内部的参数大小

##################################### 运行时如下  
# compute_times = {list: 100} 
# 000 = {list: 99} [0.0070220000000000005, 0.012285, 0.012558, 0.021096000000,...
# 001 = {list: 99} [None, 0.005263, 0.005535999999999999, 0.014074000000000003, ...
# 002 = {list: 99} [None, None, 0.00027299999999999894, 0.008811000000000003, ...
# 003 = {list: 99} [None, None, None, 0.008538000000000004, 0.008538, ...
# 004 = {list: 99} [None, None, None, None, -3.469446951953614e-18, 0.000191999999...

    counter = 1
    all_As = []
    num_machines_in_machine = 1 #第一个节点就是1
    # all_num_machines, network_bandwidths 是用户在输入中指定
    # 遍历机器集&网络带宽组合。流水线可以是straight(数目为1)或者并行(数目为num_machines)
    for num_machines, network_bandwidth in zip(all_num_machines, network_bandwidths):
        print("Solving optimization problem with %d machines with inter-machine bandwidth of %.2f GB/s" % (num_machines, network_bandwidth / 10**9))
        import numpy as np
        print(np.array(compute_times))
        # 依据目前的信息,以及机器数量,网络带宽等计算分区
        A = compute_partitioning(compute_times, activation_sizes, parameter_sizes,
                                 output_activation_sizes, all_predecessor_ids,
                                 num_machines, num_machines_in_machine,
                                 network_bandwidth,
                                 final_level=(counter==len(network_bandwidths)))
        num_machines_in_machine = num_machines # 因为计算完了,所以设置为本阶段的机器数目
        for i in range(len(compute_times)): # 遍历机器
            for j in range(len(compute_times[0])): # 后续机器
                compute_times[i][j] = A[i][j][-1][0] # 记录计算时间(本阶段最后一个机器的计算时间)
        counter += 1
        all_As.append(A) # 添加逻辑关系,就是里面包括了不同阶段的优化逻辑
    print(np.array(compute_times))
    
    # 省略后续代码

其中compute_times 是一个计算时间的二维数组,也可以认为是矩阵,具体举例如下。

[w12,w13,w14,w15], // 第一个节点到后续节点的计算时间

[None, w23,w24,w25], // 第二个节点到后续节点的计算时间

[None, None, w34, w35], // 第三个节点到后续节点的计算时间

[None, None, None, w45], // 第四个节点到后续节点的计算时间

activation_sizes 和 parameter_sizes 与之类似。

4.2 动态规划

4.2.1 总体思路

这里有一些动态规划的算法需要分析。

分割算法试图减少模型的整体训练时间。对于流水线系统,这个问题等价于最小化流水线最慢阶段所花费的时间。该问题具有最优化子问题性质;在给定机器计数的情况下,使吞吐量最大化的管道由子管道组成,这些子管道分别使自己这个子管道的吞吐量最大化。因此,我们可以用动态规划来寻找这个问题的最优解。

分区算法获取profiling步骤的输出,并计算:

1)将层划分为多个阶段,

2)每个阶段的复制因子(worker数),

3)保持训练管道繁忙的最佳动态小批量数。

PipeDream的优化器假设机器拓扑是分层的,并且可以被组织成多个级别,如下图所示。一个级别内的带宽是相同的,而跨级别的带宽是不同的。我们假设 k 级由 mk 个 k-1层组件构成 ,这些组件通过带宽为Bk的链路连接。在下图中,m2=2,m1=4。此外,我们定义m0为1。即 4 个 m0 构成一个 m1, 2个 m1 构成一个 m2。

层 0 就是绿色矩形,代表最底层的计算设备,比如GPU,4个GPU构成了一个层1(虚线矩形,代表一个服务器),2个层1构成了一个层2(就是下图全部模块)。

PipeDream的优化器从最低层到最高层逐步解决动态规划问题。直观地说,这个过程在服务器中找到最佳分区,然后使用这些分区在服务器之间最优地分割模型

4.2.2 具体分析

假设 A(j, m) 表示使用m台机器在第1层和第j层之间的最佳管道中,最慢阶段所用的时间

我们算法的目标是找到 A(N,M) 和相应的划分。让T( i → j,m) 表示跨越层 i 到 j 的单级所用的时间,此时间在m台机器上复制。

其中:

  • max中的左项是在此阶段中所有层的总计算时间,右项是此阶段中所有层的总通信时间。

  • 因为计算和通信可以重叠,所以不需要相加,直接取最大数值。

  • 由1到j的由m个机器组成的最佳流水线可以是单个阶段复制m次,也可以由多个阶段组成。

当最佳管道包含多个阶段时,它可以被分解成一个最优的子管道(由从1到 i 的 由m − m′ 个机器组成)和后续的一个单独阶段(由i+1到j 的被 m' 个机器复制组成)。因此,利用最优子问题的性质,我们得到

其中,max中:

  • 第一项是第1层和第i层之间的最优子管道(由m-m'个机器组成)的最慢阶段所用的时间。

  • 第二项是在层 i 和 i + 1 之间传递激活和梯度所用的时间。

  • 第三项是最后单个阶段的时间(由 m' 个数据并行的机器组成)。

我们具体看看如何计算,假设一个图逻辑如下:

                       +----------------+
+-----+                |                +--------+
|     +------------->  |  k[m_prime]    |        |          +-----+
|  i  |                |                |        +--------->+     |
|     +----+           +----------------+                   |  j  |
+-----+    |                                      +-------->+     |
           |           +----------------+         |         +-----+
           |           |                |         |
           +-------->  |  k[m-m_prime]  +---------+
                       |                |
                       +----------------+

在 (A [i] [k] [m-m_prime] [0], last_stage_time, output_transfer_time, input_transfer_time )之中选一个最大的:

  • A [i] [k] [m-m_prime] [0] :i 到 k 之间的计算时间,是已经计算好的子问题
  • last_stage_time :last_stage_time 是 (k 到 j 的计算时间) + 传输时间
    • 其中compute_times[k + 1] [j] 是k 到 j 的计算时间,compute_times[k + 1] 就对应了k的输出
    • 传输时间是依据k 到 j 的下一阶段参数大小(parameter_sizes[k + 1 ] [j])计算得出。
    • 即:last_stage_time = compute_times[k + 1] +(parameter_sizes[k + 1 ] [j])
  • input_transfer_time :使用 k 的输出激活大小计算出来的传输时间(就是 j 的输入)。
  • output_transfer_time :使用 j 的输出激活大小计算出来的传输时间。

因为传输和计算是可以重叠的,所以可以这样取最大数值

最后得到的 A 就是动态规划优化的结果,其中每一个元素 A[i][j][m] 是个三元组 (min_pipeline_time, optimal_split, optimal_num_machines)A[i][j][m] 表示节点 i 到 节点 j 之间的计算结果。三元组就是 (最小流水线时间,i 到 j 之间那个最佳分割点,最优机器数目)。

大致阶段如下图所示:

                                                       +----------------+
                                                       | i              |
                                                       |                |
                                                       |                |
                                                       +--+------+------+
                                                          |      |
                                                          |      +----------+
                                  A[i][k][m+m_prime][0]   |                 |
                                                          |                 |
                                                          v                 v
                                        +-----------------+-------+    +----+--------+
                                        | k[m-m_prime]            |    | k[m_prime]  |
                                        |                         |    |             |
last_stage_time = compute_times[k+1][j] |                         |    |             |
            + (parameter_sizes[k+1][j]) | output_activation_sizes |    |             |
                                        |                         |    |             |
                                        |                         |    |             |
                                        +-----------------+-------+    +-----+-------+
                                     input_transfer_time  |                  |
                                                          |      +-----------+
                                                          |      |
                                                          |      |
                                                          v      v
                                             +------------+------+------+
                                             | j                        |
                                             |                          |
                                             |                          |
                                             |                          |
                                             |  output_activation_sizes |
                                             |                          |
                                             +------------------+-------+
                                          output_transfer_time  |
                                                                |
                                                                |
                                                                v

具体代码如下:

def compute_partitioning(compute_times, activation_sizes, parameter_sizes,
                         output_activation_sizes, all_predecessor_ids,
                         num_machines, num_machines_within_machine,
                         bandwidth, final_level=True):
    # 初始化
    A = []
    for i in range(len(compute_times)): # 遍历所有节点
        row_A = []
        for j in range(len(compute_times[0])): # 所有后续节点(即第一个节点的所有后续节点)
            row_row_A = []
            for m in range(num_machines): # 机器数目
                row_row_A.append((None, None, None))
            row_A.append(row_row_A)
        A.append(row_A)

    # 得到计算时间
    for i in range(len(compute_times)): # 遍历所有节点
        for j in range(i, len(compute_times[0])): # 所有后续节点
            cum_compute_time = compute_times[i][j] # i --> j 的计算时间
            cum_activation_size = activation_sizes[i][j] # i --> j 的激活大小
            cum_parameter_size = parameter_sizes[i][j] # i --> j 的参数大小
            max_m = 1 if straight_pipeline else num_machines # 线性还是并行流水线
            for m in range(max_m): # 遍历流水线下一阶段的机器
                # 存储的数据大小
                stashed_data_size = math.ceil((num_machines - (m+1)) / (m+1)) * \
                                              (cum_activation_size + cum_parameter_size)
                # memory_size 是用户传进来的参数,就是每个机器有效的内存  
                # use_memory_constraint 也是用户传进来的参数,就是使用的内存限制
                if use_memory_constraint and stashed_data_size > memory_size:
                    continue
                # 数据并行通讯时间依据参数尺寸,带宽,下一阶段机器数量计算    
                data_parallel_communication_time = (4 * m * cum_parameter_size) / (bandwidth * (m+1))
                # 除以本阶段机器数量,如果本阶段机器多,当然就是分开计算了
                data_parallel_communication_time /= num_machines_within_machine

                if cum_compute_time is None:
                    # 需要计算下一阶段中,每个机器的计算时间,所以还要除以(m+1)
                    A[i][j][m] = (None, None, None) # 直接赋值
                else:
                    # 三元组,分别是[(计算时间 + 通信时间), None,(m+1)],对应的意义是 min_pipeline_time, optimal_split, optimal_num_machines,就对应了前面的公式 2
                    A[i][j][m] = (sum([cum_compute_time,
                                       data_parallel_communication_time]) / (m+1), None, (m+1))

    # 需要得到最小计算时间                
    min_machines = 1
    max_i = len(compute_times) if not final_level else 1
    for i in range(max_i): # 遍历节点
        for m in range(min_machines, num_machines): # 遍历下一阶段机器的可能选择
            for j in range(i+1, len(compute_times[0])): # 遍历 i 的后续节点
                (min_pipeline_time, optimal_split, optimal_num_machines) = A[i][j][m]
                if use_fewer_machines and m > 0 and ( # 如果设置了用尽量少的机器,则如果小于min_pipeline_time,就设置新的 min_pipeline_time
                    min_pipeline_time is None or A[i][j][m-1][0] < min_pipeline_time):
                    (min_pipeline_time, optimal_split, optimal_num_machines) = A[i][j][m-1]
                # 遍历 j 节点的前置机器 k,注意,j 是 i 的后续节点之一
                # 就是在 i --> k --> j 之间找到一个计算时间最小的,其中A[i][k][m-m_prime][0]已经是一个最优子问题了
                for k in all_predecessor_ids[j]:
                    # 如果k已经在之前计算过了,就跳过
                    if i > 0 and k in all_predecessor_ids[i-1]:
                        continue
                    # 设置质数    
                    max_m_prime = 2 if straight_pipeline else (m+1)
                    for m_prime in range(1, max_m_prime): # prime就是看看如何分割
                        # 输入传输时间 input_transfer_time 使用 k 的输出激活尺寸计算
                        input_transfer_time = (2.0 * output_activation_sizes[k]) / \
                            (bandwidth * m_prime)
                        # 输出传输时间 output_transfer_time 使用 j 的输出激活尺寸计算
                        output_transfer_time = None
                        if j < len(output_activation_sizes) -1:
                            output_transfer_time = (2.0 *
                                output_activation_sizes[j]) / (bandwidth * m_prime)
                        # last_stage_time 设置为 k 到 j 的计算时间, compute_times[k+1] 就对应了k的输出
                        last_stage_time = compute_times[k+1][j]
                        if last_stage_time is None:
                            continue
                        # 设置为 k 到 j 的下一阶段参数尺寸
                        last_stage_parameter_size = parameter_sizes[k+1][j]
                        # 设置为 k 到 j 的存储数据尺寸
                        stashed_data_size = (activation_sizes[k+1][j]) + last_stage_parameter_size
                        # 依据机器数据计算
                        stashed_data_size *= math.ceil((num_machines - (m+1)) / m_prime)
                        # 超过机器内存就跳过
                        if use_memory_constraint and stashed_data_size > memory_size:
                            continue
                        # 加上传输时间,所以 last_stage_time 是 (k 到 j 的计算时间) + 传输时间
                        last_stage_time = sum([last_stage_time,
                                               ((4 * (m_prime - 1) *
                                                last_stage_parameter_size) / (bandwidth * m_prime))])
                        last_stage_time /= m_prime

                        # 如果从i到k没有边,则跳过
                        if A[i][k][m-m_prime][0] is None:
                            continue
                        # 如果i到k已经有计算时间,则选一个较大的    
                        pipeline_time = max(A[i][k][m-m_prime][0], last_stage_time)
                        if activation_compression_ratio is not None: # 如果压缩
                            # 在(A[i][k][m-m_prime][0], last_stage_time, output_transfer_time, input_transfer_time 之中选一个最大的)
                            input_transfer_time /= activation_compression_ratio
                            # output_transfer_time 也压缩
                            if output_transfer_time is not None:
                                output_transfer_time /= activation_compression_ratio
                            # 选一个大的    
                            pipeline_time = max(pipeline_time, input_transfer_time)
                            if output_transfer_time is not None:
                                pipeline_time = max(pipeline_time, output_transfer_time)
                                
                        # 如果比min_pipeline_time小,则设定 min_pipeline_time,为了下一次循环
                        if min_pipeline_time is None or min_pipeline_time > pipeline_time:
                            optimal_split = (k, m-m_prime) # 选一个优化分割点
                            optimal_num_machines = m_prime
                            min_pipeline_time = pipeline_time
                # 设置            
                A[i][j][m] = (min_pipeline_time, optimal_split, optimal_num_machines)

    return A

all_As 就是动态规划的结果,示例如下:

all_As = {list: 2}  
 0 = {list: 100} 
  000 = {list: 99} 
   00 = {list: 5} [(0.0070220000000000005, None, 1), (0.1689894, None, 2), (0.14943257777777777, None, 3), (0.1258643, None, 4), (0.107310576, None, 5)]
   01 = {list: 5} [(0.012285, None, 1), (0.0070220000000000005, (0, 0), 1), (0.0865995, (0, 0), 2), (0.07639255555555556, (0, 0), 3), (0.06429175000000001, (0, 0), 4)]
   02 = {list: 5} [(0.012558, None, 1), (0.0070220000000000005, (0, 0), 1), (0.0070220000000000005, (1, 1), 1), (0.0070220000000000005, (1, 1), 2), (0.0070220000000000005, (1, 1), 3)]
   03 = {list: 5} [(0.021096, None, 1), (0.012285, (1, 0), 1), (0.008538, (2, 1), 1), (0.008538, (2, 2), 1), (0.008538, (2, 3), 1)]
   ......
  __len__ = {int} 100
  
1 = {list: 100} 
 000 = {list: 99} 
  00 = {list: 5} [(0.107310576, None, 1), (0.080131832, None, 2), (0.05930489777777778, None, 3), (0.046685052000000005, None, 4), (0.03840710336000001, None, 5)]
  01 = {list: 5} [(0.06429175000000001, None, 1), (0.072057299, None, 2), (0.05690740466666667, None, 3), (0.0460065055, None, 4), (0.03840166136, None, 5)]
  02 = {list: 5} [(0.0070220000000000005, None, 1), (0.043422424, None, 2), (0.037817488, None, 3), (0.031689068, None, 4), (0.026947711359999998, None, 5)]
  03 = {list: 5} [(0.008538, None, 1), (0.0419991328, (2, 0), 1), (0.043422424, (2, 1), 1), (0.0396227304, None, 4), (0.033697556608, None, 5)]
 ......
  __len__ = {int} 100
 __len__ = {int} 2

4.2.3 区别

我们接下来要分析代码作者两个相似名字变量之间的区别。

activation_sizes :某个节点所有前置节点的activation_size 之和。

for predecessor in all_predecessors:
    states[i].compute_time += ((predecessor.forward_compute_time +
                                predecessor.backward_compute_time) / 1000.0)
    states[i].activation_size += predecessor.activation_size
    states[i].parameter_size += predecessor.parameter_size

用来计算stashed数据大小,用来看看是否超过了节点配置的内存额度

stashed_data_size = (activation_sizes[k+1][j]) + last_stage_parameter_size
stashed_data_size *= math.ceil((num_machines - (m+1)) / m_prime)
if use_memory_constraint and stashed_data_size > memory_size:
		continue

output_activation_sizes : 某个节点所有增强反链的activation_size之和。

for i in range(len(states)):
    for antichain_node in states[i].antichain:
        states[i].output_activation_size += gr.nodes[antichain_node].activation_size

用来计算输出传播时间和输入传播时间

input_transfer_time = (2.0 * output_activation_sizes[k]) / \
    (bandwidth * m_prime)
output_transfer_time = None
if j < len(output_activation_sizes) -1:
    output_transfer_time = (2.0 *
        output_activation_sizes[j]) / (bandwidth * m_prime)

0x05 分析分区

5.1 main函数逻辑

前面计算分区只是得到了一个动态规划优化结果,需要在analyze_partitioning之中进行分析划分之后,赋予到各个层(stage)

main函数接下来与计算分区相关的逻辑如下:

  • states是反链DAG的结果,all_As 就是动态规划得到的优化结果,可能是多个
  • splits 初始化时候就只有一个二元组元素:最初的划分 (0, len(states))。
  • 遍历all_As的动态优化结果,对于每个动态优化结果,遍历其各个逻辑关系,调用 analyze_partitioning 对分区进行分析,在splits分割中遍历,splits会逐步更新(分割点逐步逐阶段细化),analyze_partitioning 返回一个 partial_splits。
  • 遍历 partial_splits,对于每一个分割点,获取其增强反链(states)的所有前置节点,给这些节点打上stage_id。这里是从前往后遍历,所以stage_id数值是逐步增加。
  • 把图写到文件之中。后续 convert_graph_to_model.py 会把这个文件转换成模型。
  • 做分析对比。

具体代码如下:

def main(all_num_machines, profile_filename, network_bandwidths, memory_size,
         straight_pipeline, use_memory_constraint, use_fewer_machines,
         activation_compression_ratio, output_directory,
         print_configuration=True, verbose=False):
    gr = graph.Graph.from_str(open(profile_filename, 'r').read())

    # Zero out all metadata associated with inputs in graph, since the optimizer
    # shouldn't really get a choice with where to place the input (should always
    # be in the first stage).
    # 排除干扰,因为input必然在第一层,没必要让优化器再来选择把输入放在哪里,所以先去除,后续会再加上。
    sources = gr.sources() # 对图的输入进行处理
    nodes_to_remove = OrderedDict()
    for source in sources:
        if source.node_desc.startswith("Input"): # 只处理input
            source.forward_compute_time = 0.0
            source.backward_compute_time = 0.0
            source.activation_size = 0.0
            source.parameter_size = 0.0
            nodes_to_remove[source] = []
            for out_node in gr.edges[source.node_id]:
                nodes_to_remove[source].append(out_node) # 记录这些删除source对应了哪些out节点,因为后续还要处理
            gr.remove_node(source) # 在图中移除这些input source

    # Remove all unneeded sinks that are not used, makes code generation and
    # optimization easier.
    sinks = gr.sinks() # 对图的输出进行处理,移除没有用到的输出
    for sink in sinks:
        if sink.node_desc.startswith("__getitem__"):
            gr.remove_node(sink)

    antichain_gr = gr.antichain_dag() # 得到反链DAG
    states = antichain_gr.topological_sort() # 拓扑排序,得到一个排序好的节点列表

    ###########################################################################
    # 计算阶段
    ###########################################################################
    states_indices = {} # 为每个状态设置index
    for i in range(len(states)):
        states_indices[states[i]] = i
        
##################################### 运行时如下        
#states_indices = {dict: 99} 
# antichain_0 -- ['node4'] = {int} 0
# antichain_1 -- ['node5'] = {int} 1
# antichain_2 -- ['node6'] = {int} 2
# antichain_3 -- ['node7'] = {int} 3
# antichain_4 -- ['node8'] = {int} 4
# ......
         
    # 给每个状态计算出输出激活值大小,具体是通过遍历其反链(增强反链),可以认为就是其必要前序节点给自己的输出
    for i in range(len(states)):
        for antichain_node in states[i].antichain:
            states[i].output_activation_size += gr.nodes[antichain_node].activation_size
       
    # 给每个状态计算其信息,比如计算时间,激活大小,参数大小等等,都是通过前置节点完成的      
    for i in range(len(states)):
        antichain = states[i].antichain
        all_predecessors = gr.all_predecessors(antichain)
        states[i].compute_time = 0.0
        states[i].activation_size = 0.0
        states[i].parameter_size = 0.0
        for predecessor in all_predecessors: # 计算所有前置节点的信息
            states[i].compute_time += ((predecessor.forward_compute_time +
                                        predecessor.backward_compute_time) / 1000.0)
            states[i].activation_size += predecessor.activation_size
            states[i].parameter_size += predecessor.parameter_size
    gr.reset()

    # 得到总体输出大小 & 所有前置节点id,后面计算分区时候需要
    output_activation_sizes = [state.output_activation_size for state in states]
    all_predecessor_ids = [[states_indices[predecessor] for predecessor in
                            antichain_gr.predecessors(states[i].node_id)]
                           for i in range(len(states))]

##################################### 运行时如下      
# output_activation_sizes = {list: 99} 
# 00 = {float} 6291456.0
# 01 = {float} 12582912.0
# 02 = {float} 12582912.0
# 03 = {float} 6553600.0    
# .....
# all_predecessor_ids = {list: 99} 
#  00 = {list: 0} []
#  01 = {list: 1} [0]
#  02 = {list: 2} [0, 1]
#  03 = {list: 3} [0, 1, 2]
#  04 = {list: 4} [0, 1, 2, 3]
#  05 = {list: 5} [2, 3, 4, 0, 1]
#  06 = {list: 6} [2, 3, 4, 0, 1, 5]
#  07 = {list: 7} [6, 2, 3, 4, 0, 1, 5]
# ......
    
    compute_times = [] # 初始化计算时间
    activation_sizes = [] # 初始化激活值大小
    parameter_sizes = [] # 初始化参数值大小
    for i in range(len(states)+1): # 具体计算每一个节点的信息,去除他之前节点的影响
        compute_times_row = []
        activation_sizes_row = []
        parameter_sizes_row = []
        for j in range(len(states)): # 去除之前的节点
            if i == 0: # 列表中第一个节点
                compute_times_row.append(states[j].compute_time) # i 到 j 的计算时间
                activation_sizes_row.append(states[j].activation_size)
                parameter_sizes_row.append(states[j].parameter_size)
            else: # 列表中后续节点
                if j > (i-1):
                    compute_times_row.append(states[j].compute_time -
                        states[i-1].compute_time) # i 到 j 的计算时间
                    activation_sizes_row.append(states[j].activation_size -
                        states[i-1].activation_size)
                    parameter_sizes_row.append(states[j].parameter_size -
                        states[i-1].parameter_size)
                else:
                    compute_times_row.append(None)
                    activation_sizes_row.append(None)
                    parameter_sizes_row.append(None)
        compute_times.append(compute_times_row) # 依据profile估计出系统内部的计算时间,compute_times_row 是 i 节点到 后续节点(i+1, i+2, ...)的计算时间,下面类似
        activation_sizes.append(activation_sizes_row) # 依据profile估计出系统内部的激活值大小
        parameter_sizes.append(parameter_sizes_row) # 依据profile估计出系统内部的参数大小

##################################### 运行时如下  
# compute_times = {list: 100} 
# 000 = {list: 99} [0.0070220000000000005, 0.012285, 0.012558, 0.021096000000,...
# 001 = {list: 99} [None, 0.005263, 0.005535999999999999, 0.014074000000000003, ...
# 002 = {list: 99} [None, None, 0.00027299999999999894, 0.008811000000000003, ...
# 003 = {list: 99} [None, None, None, 0.008538000000000004, 0.008538, ...
# 004 = {list: 99} [None, None, None, None, -3.469446951953614e-18, 0.000191999999...

    counter = 1
    all_As = []
    num_machines_in_machine = 1 #第一个节点就是1
    # all_num_machines, network_bandwidths 是用户在输入中指定
    # 遍历机器集&网络带宽组合。流水线可以是straight(数目为1)或者并行(数目为num_machines)
    for num_machines, network_bandwidth in zip(all_num_machines, network_bandwidths):
        print("Solving optimization problem with %d machines with inter-machine bandwidth of %.2f GB/s" % (num_machines, network_bandwidth / 10**9))
        import numpy as np
        print(np.array(compute_times))
        # 依据目前的信息,以及机器数量,网络带宽等计算分区
        A = compute_partitioning(compute_times, activation_sizes, parameter_sizes,
                                 output_activation_sizes, all_predecessor_ids,
                                 num_machines, num_machines_in_machine,
                                 network_bandwidth,
                                 final_level=(counter==len(network_bandwidths)))
        num_machines_in_machine = num_machines # 因为计算完了,所以设置为本阶段的机器数目
        for i in range(len(compute_times)): # 遍历机器
            for j in range(len(compute_times[0])): # 后续机器
                compute_times[i][j] = A[i][j][-1][0] # 记录计算时间(本阶段最后一个机器的计算时间)
        counter += 1
        all_As.append(A) # 添加逻辑关系,就是里面包括了不同阶段的优化逻辑
    print(np.array(compute_times))
    
    ###########################################################################
    # 我们从这里继续分析
    ###########################################################################
    
    # 分析阶段
    # 在 analyze_partitioning 内部做了具体分析
    # 这里最重要的是对 gr.all_predecessors 做设置,就是设置 gr 之中每个node的stage_id,这样就是利用stage_id把初始流水线重新划分
    splits = [(0, len(states))] # 如何分割,states是反链DAG的结果,所以 splits 初始化时候就只有一个二元组元素:最初的划分 (0, len(states))
    i = len(all_As) - 1 # all_As 就是动态规划得到的优化结果
    while i >= 0: # 遍历优化的出来的各个逻辑关系
        print("======================================")
        print("Level %d" % (i+1))
        print("======================================")
        new_splits = []
        stage_id = 0 # 在后续的convert_graph_to_model.py 之中会使用到
        for (start, end) in splits: # 在分割中遍历,splits会逐步更新
            # 依据新的splits中的二元组重新计算
            partial_splits = \
                analyze_partitioning(all_As[i], states, start, end,
                                     network_bandwidths[i], all_num_machines[i],
                                     activation_compression_ratio,
                                     print_configuration, verbose)
            start_point = start # 起始点
            for split in partial_splits: # 遍历分析得出的节点
                new_splits.append((start_point, split)) # 添加一个新的二元祖
                if i == 0:
                    predecessors = gr.all_predecessors(states[split-1].antichain)
                    for predecessor in predecessors:
                        if predecessor.stage_id is None:
                            predecessor.set_stage_id(stage_id) # 设置所在阶段
                start_point = split # 下一个阶段
                stage_id += 1 # 增加所在阶段
            new_splits.append((start_point, end)) # 添加一个新的二元祖
            if i == 0:                
                predecessors = gr.all_predecessors(states[end-1].antichain)
                for predecessor in predecessors:
                    if predecessor.stage_id is None:
                        predecessor.set_stage_id(stage_id) # 设置所在阶段
            stage_id += 1 # 增加所在阶段
        
        print("Total number of stages: %d" % stage_id)
        splits = new_splits # 加入新的分割
        i -= 1

    # 以下是为了把图写到文件之中。后续convert_graph_to_model.py会把这个文件转换成模型 
    for source in nodes_to_remove: # 之前移除了input节点,现在需要加回到图中
        for out_node in nodes_to_remove[source]: # input对应的哪些输出
            source.stage_id = 0
            gr.add_edge(source, out_node)

    if output_directory is not None:
        total_num_machines = 1
        for num_machines in all_num_machines:
            total_num_machines *= num_machines
        gr.to_dot(os.path.join(output_directory, "gpus=%d" % total_num_machines))
        gr_str = str(gr)
        with open(os.path.join(output_directory, "gpus=%d.txt" % total_num_machines), 'w') as f:
            f.write(gr_str)

    # 以下是为了做分析对比        
    # 计算数据并行需要的时间,以便接下来做比较,这个时间要比动态规划时间长。        
    total_time = states[-1].compute_time # 最后一个阶段的计算时间,是没有经过优化的最初计算时间
    total_parameter_size = states[-1].parameter_size
    data_parallel_total_time = total_time # 先赋值为最后一阶段的计算时间
    num_machines_in_machine = 1 # 本阶段的机器数目
    # 遍历流水线上各个阶段,因为没有优化,所以就是严格按照用户原始配置的流水线阶段来逐一计算
    for (num_machines, network_bandwidth) in zip(all_num_machines, network_bandwidths):
        # 计算传输时间。num_machines是下一阶段流水线机器数目,所以带宽需要乘以这个数字
        data_parallel_communication_time = (
            (4 * (num_machines - 1) * total_parameter_size) /
            (network_bandwidth * num_machines)) / num_machines_in_machine
        # 总时间需要加上传输时间
        data_parallel_total_time = sum(
            [data_parallel_total_time, data_parallel_communication_time]) / num_machines
        # 下个迭代中,本阶段的机器数目需要设置为num_machines
        num_machines_in_machine = num_machines

    # 这个是用动态规划算法得出来的优化时间    
    pipeline_parallel_total_time = A[0][len(states)-1][num_machines-1][0]

    # 可以看到用户需要注意哪些数据
    if verbose:
        print()
        print("Time taken by single-stage pipeline:", total_time)
        print("Time per stage in pipeline:", pipeline_parallel_total_time)
        print("Throughput increase (compared to single machine):",
              total_time / pipeline_parallel_total_time)
        dp_str = ",".join([str(elem) for elem in all_num_machines])
        print(("[Note that single-machine and (%s)-machine DP might not fit "
               "given memory constraints]") % dp_str)
        print("Throughput increase of (%s)-machine DP compared to single "
              "machine:" % dp_str, total_time / data_parallel_total_time)
        print("Throughput increase (compared to (%s)-machine DP):" % dp_str,
              data_parallel_total_time / pipeline_parallel_total_time)
    return pipeline_parallel_total_time, data_parallel_total_time    

5.2 分析阶段

分析阶段具体可以参见下面注释。

def analyze_partitioning(A, states, start, end, network_bandwidth, num_machines,
                         activation_compression_ratio, print_configuration, verbose):
    # start,end 是本组节点的起始点,终止点
    metadata = A[start][end-1][num_machines-1] # 这是个三元组  (min_pipeline_time, optimal_split, optimal_num_machines)
    next_split = metadata[1] # metadata[1] 是 optimal_split,即 (k, m-m_prime)
    remaining_machines_left = num_machines
    splits = []
    replication_factors = []
    prev_split = end - 1 # 前一个分割点
    
    while next_split is not None: #是否继续分割
        num_machines_used = metadata[2] # optimal_num_machines
        if verbose:
            print("-------------------------------------")
            print("Number of machines used: %d..." % num_machines_used)
            print("Split between layers %d and %d..." % (next_split[0], next_split[0] + 1))
            print("Split before antichain %s..." % (states[next_split[0]+1].antichain))
        splits.append(next_split[0]+1) # 得到了 k + 1,这是关键点,因为最后返回的是splits
        compute_time = states[prev_split-1].compute_time - \
            states[next_split[0]].compute_time
        parameter_size = states[prev_split-1].parameter_size - \
            states[next_split[0]].parameter_size

        dp_communication_time = (4 * (num_machines_used - 1) * parameter_size) \
            / (network_bandwidth * num_machines_used)
        pp_communication_time_input = ( # 下个阶段的数据输入时间
            2.0 * states[next_split[0]].output_activation_size *
            (1.0 / float(num_machines_used))) / network_bandwidth
        pp_communication_time_output = ( # 上个阶段的数据输出时间
            2.0 * states[prev_split-1].output_activation_size *
            (1.0 / float(num_machines_used))) / network_bandwidth
        # 如果需要压缩,就进行压缩
        if activation_compression_ratio is not None:
            pp_communication_time_input /= activation_compression_ratio
            pp_communication_time_output /= activation_compression_ratio
        if activation_compression_ratio is None:
            pp_communication_time_input = 0.0
            pp_communication_time_output = 0.0

        compute_time /= num_machines_used # 本阶段计算时间
        dp_communication_time /= num_machines_used # 数据并行时间

        if verbose:
            print(("Compute time = %f, Data-parallel communication time = %f, "
                   "Pipeline-parallel communication time = %f...") % (
                compute_time, dp_communication_time,
                max(pp_communication_time_input, pp_communication_time_output)))
        prev_split = splits[-1] # 设定新的前一分割点
        # next_split 格式是 (k, m-m_prime),就是 optimal_split 的格式
        # A[i][j][m] 格式是 (min_pipeline_time, optimal_split, optimal_num_machines)
        metadata = A[start][next_split[0]][next_split[1]]
        next_split = metadata[1] # 设定新的下一次分割点,就是 optimal_split
        replication_factors.append(num_machines_used) # 每个阶段的 replication factor
        remaining_machines_left -= num_machines_used # 剩余机器
    if verbose:
        print("-------------------------------------")
        print("Number of machines used: %d..." % metadata[2])

    #     
    num_machines_used = metadata[2]
    remaining_machines_left -= num_machines_used # 剩余的机器
    compute_time = states[prev_split-1].compute_time 
    parameter_size = states[prev_split-1].parameter_size
    dp_communication_time = ((4 * (num_machines_used - 1) * parameter_size) /
                             (network_bandwidth * num_machines_used)) 
    compute_time /= num_machines_used # 计算时间
    dp_communication_time /= num_machines_used # 数据并行通信时间

    if verbose:
        print("Compute time = %f, Data-parallel communication time = %f..." %
              (compute_time, dp_communication_time))
        print("-------------------------------------")
    if print_configuration:
        print("Number of machines in budget not used: %d..." %
              remaining_machines_left)
        print()
        print("(Split start, split end) / compute time taken per stage "
              "/ replication factor per stage:")
    # 下面就是打印 (Split start, split end) / compute time taken per stage / replication factor per stage    
    prev_split = start
    splits.reverse() # 
    splits.append(end)
    replication_factors.append(num_machines_used)
    replication_factors.reverse()
    for i in range(len(splits)):
        time = 0.0
        if prev_split > 0:
            time = states[splits[i]-1].compute_time - states[prev_split-1].compute_time
        else:
            time = states[splits[i]-1].compute_time
        if print_configuration:
            print((prev_split, splits[i]), time, replication_factors[i])
        prev_split = splits[i]
    if print_configuration:
        print()
    return splits[:-1] # 最后一个不返回

我们还是用样例进行说明。

这里是从后面进行分割,举例分析如下,这里设定了总机器数目为10:

回忆在计算分区之中,A[i][j][m] = (min_pipeline_time, optimal_split, optimal_num_machines),optimal_split = (k, m-m_prime) 是一个本阶段优化点。

所以在本函数之中,start = 0, end = 99,所以 metadata 为A[0][99][10],即 (0.01903199999999998, (95, 8), 1),next_split = (95, 8),prev_split = end - 1 = 98。

next_split 就是下一个分割点,splits 是目前的分割序列。

第一轮while循环

因为next_split = (95, 8),所以 splits = append(next_split[0]+1) = [96],因此计算 states[prev_split-1] - states[next_split[0]] = state[97] - state[95]。这样把0~99分成了 0 ~95 和 96 ~ 99。

然后 prev_split = 96,去找A[ 0 ] [ 95] [8] 得到 meta = (0.019031999999999993, (78, 7), 1),next_split = (78, 7)。

所以下一轮从78这个分割点开始分割。

第二轮while循环

因为next_split = (78, 7),所以 splits = [96, 79],这就是新的分割序列。,因此计算 states[96-1] - states[next_split[0]] = state[96] - state[78]。这样就使用 splits = [96, 79] 把0~99分成了 0 ~78,79 ~ 95 和 96 ~ 99。

然后 prev_split =79,去找A[ 0 ] [ 78 ] [ 7 ] 得到 meta = (0.011081, (48, 6), 1),next_split = (48, 6)。

所以下一轮从 48 这个分割点开始分割,以此类推。

while循环之后,得到 splits = [96, 79, 49, 15, 12, 7, 5, 3, 1]。

于是下面代码需要把顺序调整过来。

prev_split = start
splits.reverse()
splits.append(end)
replication_factors.append(num_machines_used)
replication_factors.reverse()

得到:splits = { 1,3,5,7,12,15,49,79,96 }。然后加上 end = 99。

最后返回 splits[:-1],即返回 { 1,3,5,7,12,15,49,79,96 },去掉刚刚添加的end。

而依据 { 1,3,5,7,12,15,49,79,96 } 得到的最终分割序列 是 [(0, 1), (1, 3), (3, 5), (5, 7), (7, 12), (12, 15), (15, 49), (49, 79), (79, 96), (96, 99)],这个列表会在后续"设定stage"之中会用到。

5.3 设定stage

目前我们得到了一个理想分割序列,但是事情没有结束,我们回忆一下分区算法的目的:依据profile结果确定所有层的运行时间,然后使用动态规划对模型进行划分,将模型划分为不同的stage,以及得到每个stage的replication数。

所以,分析的最终目的是给模型的每一个子层分配一个stage,如果某些子层属于同一个stage,这些子层最终就被分配到同一个worker(节点)上执行

因为这里涉及到多个子网,所以我们依然用实例来分析。

如果分成了两个子网,假设:

all_num_machines = [5,5]
network_bandwidths = [800000000, 1000000000]

初始化 splits = [0,99]。

第一轮 while 中,i = 1,

对于 splits 结果[(0, 99)] 遍历,每一段应用analyze_partitioning,得到 partial_splits 为 [3, 6, 30, 75, 99]。

最后,splits 更新为:[(0, 3), (3, 6), (6, 30), (30, 75), (75, 99)]。

此时不会设置stage_id

第二轮 while 中,i = 0,

对于第一轮的 splits 结果 [(0, 3), (3, 6), (6, 30), (30, 75), (75, 99)] 进行遍历,对于这里的每一段也应用 analyze_partitioning,比如对 (0,3) 应用analyze_partitioning,对 (3,6) 应用 analyze_partitioning,对(6,30) 也应用 analyze_partitioning,......,最后得到新的 partial_splits 为 [1, 2, 3, 4, 5, 6, 8, 10, 13, 28, 30, 45, 49, 51, 75, 79, 96, 99]。

最后,splits 更新为:[(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 8), (8, 10), (10, 13), (13, 28), (28, 30), (30, 45), (45, 49), (49, 51), (51, 75), (75, 79), (79, 96), (96, 99)]。

这个列表就是理想分割序列

在此轮中,得到了partial_splits之后,会遍历 for split in partial_splits: 然后对于每一个 split,利用

states[split-1].antichain 获取其增强反链的所有前置节点,给这些节点打上 split 对应的 stage_id

回忆一下增强反链的意义:

  • 每个节点的增强反链包括:本身节点 + 部分前序节点
  • 对于增强反链概念,可以理解为:对于节点 A,他只有把节点 Z 一起考虑,才能唯一确定自己节点的运行时间

所以,对于 split = 1,1 - 1 = 0,于是就得到 states[0].antichain ,就是 'node4',那么 'node4' 自己被打上了一个stage_id=0,说明 'node4' 被分到了一个 "与stage_id=0 所对应" 的 worker 节点上训练。

如果有疑问,我们回忆一下state如何构建,就是有序的 "节点组合"。

antichain_gr = gr.antichain_dag()
states = antichain_gr.topological_sort()

具体如下。

states = {list: 99} 
 00 = {AntichainNode} antichain_0 -- ['node4'] # states[0].antichain
 01 = {AntichainNode} antichain_1 -- ['node5']
 02 = {AntichainNode} antichain_2 -- ['node6']
 03 = {AntichainNode} antichain_3 -- ['node7']
 04 = {AntichainNode} antichain_4 -- ['node8']
 05 = {AntichainNode} antichain_5 -- ['node8', 'node10']
 06 = {AntichainNode} antichain_7 -- ['node8', 'node11']
 07 = {AntichainNode} antichain_10 -- ['node8', 'node12']
 08 = {AntichainNode} antichain_6 -- ['node14']
 09 = {AntichainNode} antichain_8 -- ['node14', 'node15']
 10 = {AntichainNode} antichain_11 -- ['node14', 'node16']
 11 = {AntichainNode} antichain_13 -- ['node14', 'node17']
 12 = {AntichainNode} antichain_9 -- ['node19']
 13 = {AntichainNode} antichain_12 -- ['node20', 'node23']
 14 = {AntichainNode} antichain_18 -- ['node23', 'node20', 'node26']
 15 = {AntichainNode} antichain_17 -- ['node23', 'node20', 'node24']
 16 = {AntichainNode} antichain_32 -- ['node23', 'node20', 'node28']
 17 = {AntichainNode} antichain_31 -- ['node23', 'node20', 'node26', 'node24']
 18 = {AntichainNode} antichain_63 -- ['node23', 'node20', 'node26', 'node28']
 19 = {AntichainNode} antichain_33 -- ['node20', 'node26', 'node29']
 20 = {AntichainNode} antichain_16 -- ['node20', 'node43', 'node23']
 21 = {AntichainNode} antichain_30 -- ['node23', 'node20', 'node43', 'node26']
 22 = {AntichainNode} antichain_29 -- ['node23', 'node20', 'node43', 'node24']
 23 = {AntichainNode} antichain_59 -- ['node23', 'node20', 'node43', 'node28']

设定stage 具体代码如下:

splits = [(0, len(states))]
i = len(all_As) - 1
while i >= 0:
    new_splits = []
    stage_id = 0
    for (start, end) in splits:
        partial_splits = \
            analyze_partitioning(all_As[i], states, start, end,
                                 network_bandwidths[i], all_num_machines[i],
                                 activation_compression_ratio,
                                 print_configuration, verbose)
        start_point = start
        for split in partial_splits: # 遍历这个偏序列表
            new_splits.append((start_point, split))
            if i == 0: # 最终的while
                # 针对每个节点,找到每个节点的所有反链
                predecessors = gr.all_predecessors(states[split-1].antichain)
                for predecessor in predecessors:
                    if predecessor.stage_id is None:
                        predecessor.set_stage_id(stage_id) # 打上stage id
            start_point = split
            stage_id += 1
        new_splits.append((start_point, end))
        if i == 0: # 最终的while
            predecessors = gr.all_predecessors(states[end-1].antichain)
            for predecessor in predecessors:
                if predecessor.stage_id is None:
                    predecessor.set_stage_id(stage_id) # 打上stage id
        stage_id += 1
    splits = new_splits
    i -= 1

5.4 总结

我们总结一下计算分区和分析分区所做的工作:

  • 反链DAG图已经被分割成若干状态(states),每个状态很重要的一个属性是其增强反链。states 就是对增强反链进行拓扑排序之后的结果,按照这个顺序进行训练是符合逻辑的。

  • compute_partitioning 是使用动态规划算法对于这些 states 状态得出一个最优化结果但是这个计算分区只是得到了一个动态规划优化结果,需要在analyze_partitioning之中进行分析划分之后,赋予到各个层(stage)

  • analyze_partitioning 是利用动态规划算法的最优化结果来做具体分区,排序后得到了一个偏序结果,就是理想分割序列。

  • 依据 analyze_partitioning 的结果,给模型的每一个子层分配一个stage,如果某些子层属于同一个stage,这些子层最终就被分配到同一个worker(节点)上执行

0x06 输出

输出文件如下(摘录部分),可以看到,关键之处在于给每一个节点加上了stage,具体如何使用我们将在下一篇进行分析。比如:

stage_id=0 对应的是 node4。

stage_id=1 对应的是 node5,node6。

stage_id=2 对应的是 node7。

stage_id=3 对应的是 node8,node10,node11,node12。

......

具体如下:

node4 -- Embedding(32320, 1024, padding_idx=0) -- forward_compute_time=0.073, backward_compute_time=6.949, activation_size=6291456.0, parameter_size=132382720.000 -- stage_id=0
node5 -- EmuBidirLSTM(  (bidir): LSTM(1024, 1024, bidirectional=True)  (layer1): LSTM(1024, 1024)  (layer2): LSTM(1024, 1024)) -- forward_compute_time=5.247, backward_compute_time=0.016, activation_size=12582912.0, parameter_size=67174400.000 -- stage_id=1
node6 -- Dropout(p=0.2) -- forward_compute_time=0.077, backward_compute_time=0.196, activation_size=12582912.0, parameter_size=0.000 -- stage_id=1
node7 -- LSTM(2048, 1024) -- forward_compute_time=3.190, backward_compute_time=5.348, activation_size=6553600.0, parameter_size=50364416.000 -- stage_id=2
node8 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000 -- stage_id=3
node10 -- Dropout(p=0.2) -- forward_compute_time=0.064, backward_compute_time=0.128, activation_size=6291456.0, parameter_size=0.000 -- stage_id=3
node11 -- LSTM(1024, 1024) -- forward_compute_time=2.491, backward_compute_time=4.203, activation_size=6553600.0, parameter_size=33587200.000 -- stage_id=3
node12 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000 -- stage_id=3
node14 -- Add -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000 -- stage_id=4
node15 -- Dropout(p=0.2) -- forward_compute_time=0.059, backward_compute_time=0.121, activation_size=6291456.0, parameter_size=0.000 -- stage_id=4
node16 -- LSTM(1024, 1024) -- forward_compute_time=2.492, backward_compute_time=4.201, activation_size=6553600.0, parameter_size=33587200.000 -- stage_id=4
node17 -- __getitem__(0) -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000 -- stage_id=5
node19 -- Add -- forward_compute_time=0.000, backward_compute_time=0.000, activation_size=6291456.0, parameter_size=0.000 -- stage_id=5
	node1 -- node4
	node4 -- node5
	node2 -- node5
	node5 -- node6
	node6 -- node7
	node7 -- node8
	node8 -- node10
	node10 -- node11
	node11 -- node12
	node12 -- node14
	node8 -- node14
	node14 -- node15
	node15 -- node16
	node16 -- node17
	node17 -- node19

0xFF 参考

[源码解析] 深度学习流水线并行之PipeDream(1)--- Profile阶段