Posts oneDNN
Post
Cancel

oneDNN

介绍

oneDNN(前身为mkl-dnn和dnnl),是intel开发的开源深度学习加速计算库,实现了部分常用神经网络算子,它是oneAPI的一部分。

开发oneDNN库的目的是为了提高intel处理器和显卡上开发深度学习应用的性能,因此该库主要针对intel的CPU和GPU进行优化,对AArch64和NVIDIA GPU实验性支持。目前使用了oneDNN的应用有TensorFlow、Pytorch、Matlab、Mindspore等。

oneDNN核心模块:

  • Primitives. 封装了算子执行时需要的所有信息。加上attributesprimitives可以表示更复杂的融合算子。
  • Engines. 是计算设备的抽象,主要为GPU设计,对CPU来说是个空壳。
  • Streams. 封装了一系列执行命令,和eigines绑定,对CPU来说也是个空壳。因为CPU算子的执行都是同步的。
  • Memory. 封装了一系列在eigines上的内存分配信息,包括维度,类型,格式等。

核心接口:

  • memory
  • desc
  • primitive_desc
  • primitive
  • execute

源码阅读

onednn中算子根据芯片架构(X64, ARM)、芯片指令集(asimd, sve, sse, avx, avx2, avx512)、数据类型(f32, bf16, s32, s8,u8)以及是否使用jit等有多种实现。每类算子都有cpu_xxx_list.cpp文件,里面impl_list_map变量则给出了各种实现的实例。通过get_xxx_v2_impl_list()接口获取对应的实现。

比如pooling算子,onednn给出了正向的19种实现,反向的12中实现。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
    static const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> the_map = REG_POOLING_P({
        {{forward}, {
            /* fp */
            CPU_INSTANCE_X64(jit_uni_pooling_fwd_t<avx512_core, bf16>)
            CPU_INSTANCE_X64(jit_uni_pooling_fwd_t<avx512_core, f32>)
            CPU_INSTANCE_X64(jit_uni_pooling_fwd_t<avx2, f32>)
            CPU_INSTANCE_X64(jit_uni_pooling_fwd_t<avx, f32>)
            CPU_INSTANCE_X64(jit_uni_pooling_fwd_t<sse41, f32>)
            CPU_INSTANCE_AARCH64(jit_uni_pooling_fwd_t<sve_512, f32>)
            CPU_INSTANCE(nchw_pooling_fwd_t<bf16>)
            CPU_INSTANCE(nchw_pooling_fwd_t<f32>)
            CPU_INSTANCE(nhwc_pooling_fwd_t<bf16>)
            CPU_INSTANCE(nhwc_pooling_fwd_t<f32>)
            CPU_INSTANCE(ref_pooling_fwd_t<f32>)
            CPU_INSTANCE(ref_pooling_fwd_t<bf16, f32>)
            /* int */
            CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t<avx512_core>)
            CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t<avx2>)
            CPU_INSTANCE_X64(jit_uni_i8i8_pooling_fwd_t<sse41>)
            CPU_INSTANCE_AARCH64(jit_uni_i8i8_pooling_fwd_t<sve_512>)
            CPU_INSTANCE(ref_pooling_fwd_t<s32>)
            CPU_INSTANCE(ref_pooling_fwd_t<s8, s32>)
            CPU_INSTANCE(ref_pooling_fwd_t<u8, s32>)
            nullptr,
        }},
        {{backward}, REG_BWD_PK({
            CPU_INSTANCE_X64(jit_uni_pooling_bwd_t<avx512_core, bf16>)
            CPU_INSTANCE_X64(jit_uni_pooling_bwd_t<avx512_core, f32>)
            CPU_INSTANCE_X64(jit_uni_pooling_bwd_t<avx2, f32>)
            CPU_INSTANCE_X64(jit_uni_pooling_bwd_t<avx, f32>)
            CPU_INSTANCE_X64(jit_uni_pooling_bwd_t<sse41, f32>)
            CPU_INSTANCE_AARCH64(jit_uni_pooling_bwd_t<sve_512, f32>)
            CPU_INSTANCE(nchw_pooling_bwd_t<bf16>)
            CPU_INSTANCE(nchw_pooling_bwd_t<f32>)
            CPU_INSTANCE(nhwc_pooling_bwd_t<bf16>)
            CPU_INSTANCE(nhwc_pooling_bwd_t<f32>)
            CPU_INSTANCE(ref_pooling_bwd_t<f32>)
            CPU_INSTANCE(ref_pooling_bwd_t<bf16>)
            nullptr,
        })},
    });
    return the_map;
}

在创建primitive_desc时:

1
2
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty = false)
    : dnnl::primitive_desc(&adesc.data, nullptr, aengine, nullptr, allow_empty) {}

会根据算子的输入情况以及芯片的情况,选择相应impl_list_map中的最佳实现。比如,对于pooling算子,如果数据类型是fp32,芯片是X86架构,且支持AVX512指令集,则会选择CPU_INSTANCE_X64(jit_uni_pooling_fwd_t<avx512_core, f32>)实现。然后根据具体的实现生成汇编代码(通过重写基类jit_generator中的generate()接口)。

获取支持的isa的接口get_supported_isa()@src/cpu/x64/prelu/jit_prelu_utils.cpp,在创建primitive_desc流程中。

根据impl_list_map创建dnnl_primitive_desc_iterator的流程的调用栈

1
2
3
4
5
6
7
8
9
#0 dnnl::impl::cpu::get_prelu_impl_list at MKL-DNN/src/cpu/cpu_prelu_list.cpp:63
#1 dnnl::impl::cpu::cpu_engine_impl_list_t::get_implementation_list at MKL-DNN/src/cpu/cpu_engine.hpp:103
#2 dnnl::impl::cpu::cpu_engine_t::get_implementation_list at MKL-DNN/src/cpu/cpu_engine.hpp:148
#3 dnnl_primitive_desc_iterator::dnnl_primitive_desc_iterator at MKL-DNN/src/common/primitive_iterator.hpp:52
#4 dnnl_primitive_desc_iterator_create at MKL-DNN/src/common/primitive_iterator.cpp:46
#5 dnnl::primitive_desc::primitive_desc at MKL-DNN/include/oneapi/dnnl/dnnl.hpp:4569
#6 dnnl::prelu_forward::primitive_desc::primitive_desc at MKL-DNN/include/oneapi/dnnl/dnnl.hpp:12569
#7 prelu_example (engine_kind=dnnl::engine::kind::cpu) at MKL-DNN/examples/primitives/prelu.cpp:102
#8 main (argc=1, argv=0x7fffffffe288) at MKL-DNN/examples/primitives/prelu.cpp:138

根据创建的dnnl_primitive_desc_iterator选择最佳实现的流程的调用栈。

1
2
3
4
5
6
#0 dnnl_primitive_desc_iterator::operator++ at MKL-DNN/src/common/primitive_iterator.hpp:93
#1 dnnl_primitive_desc_iterator_create at MKL-DNN/src/common/primitive_iterator.cpp:53
#2 dnnl::primitive_desc::primitive_desc at MKL-DNN/include/oneapi/dnnl/dnnl.hpp:4569
#3 dnnl::prelu_forward::primitive_desc::primitive_desc at MKL-DNN/include/oneapi/dnnl/dnnl.hpp:12569
#4 prelu_example (engine_kind=dnnl::engine::kind::cpu) at MKL-DNN/examples/primitives/prelu.cpp:102
#5 main (argc=1, argv=0x7fffffffe288) at MKL-DNN/examples/primitives/prelu.cpp:138

balance211

oneDNN里的线程均衡函数(src/common/dnnl_thread.hpp)

一维

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
typedef int T;
typedef int U;

inline int div_up(int a, int b) {
  return (a + b - 1) / b;
}

void balance211(T n, U team, U tid, T& n_start, T& n_end) {
    T n_min = 1;
    T &n_my = n_end;
    if (team <= 1 || n == 0) {
        n_start = 0;
        n_my = n;
    } else if (n_min == 1) {
        // team = T1 + T2
        // n = T1*n1 + T2*n2  (n1 - n2 = 1)
        T n1 = div_up(n, (T)team);
        T n2 = n1 - 1;
        T T1 = n - n2 * (T)team;
        n_my = (T)tid < T1 ? n1 : n2;
        n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2;
    }

    n_end += n_start;  
}

int main() {
  int n = 153;
  int team = 16;
  for (int tid = 0; tid < 16; tid++) {
    int n_start = 0;
    int n_end = 0;
    balance211(n, team, tid, n_start, n_end);
    std::cout << "[" << n_start << ", " << n_end << ")" << std::endl;
  }
}

输出结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
[0, 10)
[10, 20)
[20, 30)
[30, 40)
[40, 50)
[50, 60)
[60, 70)
[70, 80)
[80, 90)
[90, 99)
[99, 108)
[108, 117)
[117, 126)
[126, 135)
[135, 144)
[144, 153)

二维

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
void balance2D(U nthr, U ithr, T ny, T &ny_start, T &ny_end, 
               T nx, T &nx_start, T &nx_end, T nx_divider) {
    const T grp_count = std::min(nx_divider, static_cast<T>(nthr));
    const int grp_size_big = nthr / static_cast<int>(grp_count) + 1;
    const int grp_size_small = nthr / static_cast<int>(grp_count);
    const int n_grp_big = nthr % static_cast<int>(grp_count);
    const int threads_in_big_groups = n_grp_big * grp_size_big;

    const int ithr_bound_distance = ithr - threads_in_big_groups;
    T grp, grp_ithr, grp_nthr;
    if (ithr_bound_distance < 0) { // ithr in first groups
        grp = ithr / grp_size_big;
        grp_ithr = ithr % grp_size_big;
        grp_nthr = grp_size_big;
    } else { // ithr in last groups
        grp = n_grp_big + ithr_bound_distance / grp_size_small;
        grp_ithr = ithr_bound_distance % grp_size_small;
        grp_nthr = grp_size_small;
    }

    balance211(nx, grp_count, grp, nx_start, nx_end);
    balance211(ny, grp_nthr, grp_ithr, ny_start, ny_end);
}

int main() {
  int nthr = 16;
  int nx = 153;
  int ny = 147;
  int nx_divider = 2;
  for (int ithr = 0; ithr < nthr; ithr++) {
    int ny_start = 0;
    int ny_end = 0;
    int nx_start = 0;
    int nx_end = 0;
    balance2D(nthr, ithr, ny, ny_start, ny_end, nx, nx_start, nx_end, nx_divider);
    std::cout << "[" << nx_start << ", " << nx_end << ")" << ","; 
    std::cout << "[" << ny_start << ", " << ny_end << ")" << std::endl;
  }
}

输出结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
[0, 77),[0, 19)
[0, 77),[19, 38)
[0, 77),[38, 57)
[0, 77),[57, 75)
[0, 77),[75, 93)
[0, 77),[93, 111)
[0, 77),[111, 129)
[0, 77),[129, 147)
[77, 153),[0, 19)
[77, 153),[19, 38)
[77, 153),[38, 57)
[77, 153),[57, 75)
[77, 153),[75, 93)
[77, 153),[93, 111)
[77, 153),[111, 129)
[77, 153),[129, 147)
This post is licensed under CC BY 4.0 by the author.