昇思MindSpore动静结合中list和dict方法实现

01 概述

静态图和动态图是神经学习框架中的重要概念,昇思MindSpore同时支持动态图和静态图两种模式,在动态图与静态图的结合方面做了很多工作。本文以昇思MindSpore框架中图模式下list和dict的实现方式为例,介绍昇思MindSpore框架中的动静结合知识。

02 背景知识

2.1 动态图与静态图

目前主流的神经学习框架大概可以分为静态图和动态图两种。

在动态图中,每次编译都会重新构建一个新的计算图。这意味着计算图的构建和计算同时发生(define by run),这种机制由于能够实时得到中间结果的值,使得调试更加容易,同时我们将大脑中的想法转化为代码方案也变得更加容易,对于编程实现来说更友好。动态图以PyTorch框架为代表。

在静态图中,会事先了解和定义好整个运算流,在运行前可以对图结构进行优化,可以获得更快的前向速度,从性能上来说更加高效。但是只有运行起来之后才能看到变量的值,无法像动态图一样随时拿到中间计算结果。静态图以早期的TensorFlow框架为代表。

2.2 昇思MindSpore中的动静结合

昇思MindSpore支持动态图和静态图两种模式,动态图通过解释执行,具有动态语法亲和性,表达灵活;静态图使用JIT编译优化执行,偏静态语法,在语法上有较多限制。动态图模式是昇思MindSpore的默认模式,主要用于调试等用途,而静态图模式拥有更高效的执行性能,主要用于部署。

昇思MindSpore提供了静态图和动态图统一的编码方式,大大增加了静态图和动态图的可兼容性,用户无需开发多套代码,仅变更一行代码便可切换静态图/动态图模式。例如,在静态图模式下,使用 context.set_context(mode=context.PYNATIVE_MODE) 切换为动态图(PyNative)模式; 同理,昇思MindSpore处于动态图(PyNative)模式时,可以通过 context.set_context(mode=context.GRAPH_MODE) 切换为静态图(Graph)模式。

为了提高动态图模式下的前向计算任务执行速度,昇思MindSpore提供了jit装饰器,可以通过修饰Python函数或者Python类的成员函数使其被编译成计算图,通过图优化等技术提高运行速度。昇思MindSpore支持在动态图下使用静态编译的方式来进行混合执行,通过使用jit装饰符来修饰需要用静态图来执行的函数对象,即可实现动态图和静态图的混合执行。

import numpy as np
import mindspore.ops as ops
import mindspore as ms

#设置运行模式为动态图模式
ms.set_context(mode=ms.PYNATIVE_MODE)

#使用装饰器,指定静态图模式下执行
@ms.jit
def add_func(x,y):
    return ops.add(x,y)

x = ms.Tensor(np.array([1.0, 2.0, 3.0]).astype(np.float32))
y = ms.Tensor(np.array([4.0, 5.0, 6.0]).astype(np.float32))

out = add_func(x,y)
print(out)

03 实现方案和过程

为了实现静态图模式(Graph 模式)和动态图模式(PyNative 模式)的灵活切换,部分动态图的语法需要昇思MindSpore在静态图下单独进行开发。为了提高运算效率,昇思MindSpore引入了Pybind11库,使Python代码可以方便的调用C++代码,一些重要的功能使用C++代码实现。

3.1 图模式下 list.extend 方法支持

list.extend(obj) 方法会在原list后追加obj list内容。在昇思MindSpore图模式中,该方法的实现是使用ListGetItem算子取出原list和objlist中的所有元素,将所有元素按顺序添加到一个std::vector中,再使用MakeList算子构建新list,并将该list返回,替代原list。关键代码如下所示。

FuncGraphPtr ListExtend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
  abstract::CheckArgsSize("ListExtend", args_list, 2);

  FuncGraphPtr ret = std::make_shared<FuncGraph>();
  ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  ret->debug_info()->set_name("extend");

  std::vector<AnfNodePtr> elems;
  elems.push_back(NewValueNode(prim::kPrimMakeList));
  AddNodeToElems(args_list[0], ret, &elems);
  AddNodeToElems(args_list[1], ret, &elems);
  auto out = ret->NewCNode(elems);
  ret->set_output(out);
  return ret;
}

void ListExtend::AddNodeToElems(const AbstractBasePtr &arg, const FuncGraphPtr &ret, std::vector<AnfNodePtr> *elems) { 
  abstract::AbstractListPtr arg_list = dyn_cast<abstract::AbstractList>(arg);
  MS_EXCEPTION_IF_NULL(arg_list);
  int64_t len = SizeToLong(arg_list->size());
  AnfNodePtr arg_node = ret->add_parameter();
  for (int64_t i = 0; i < len; ++i) {
    auto value = ret->NewCNode({NewValueNode(prim::kPrimListGetItem), arg_node, NewValueNode(i)});
    elems->push_back(value);
  }
}

3.2 图模式下 list.extend 方法支持

list.count(item) 会统计item元素在原list中的数量。在图模式中,顺序遍历原list中的元素,逐个与obj进行比较,首先判断元素类型,Tensor类型单独处理,其余类型使用其各自对应的比较符号比较。关键代码如下所示。

FuncGraphPtr ListCount::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
  const size_t list_count_args_size = 2;
  abstract::CheckArgsSize("ListCount", args_list, list_count_args_size);
  auto &list_input = args_list[0];
  auto &element_value = args_list[1];

  auto arg_list = dyn_cast_ptr<abstract::AbstractList>(list_input);
  MS_EXCEPTION_IF_NULL(arg_list);
  FuncGraphPtr ret = std::make_shared<FuncGraph>();
  ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  ret->debug_info()->set_name("count");
  (void)ret->add_parameter();
  (void)ret->add_parameter();

  ValuePtr count_value = element_value->BuildValue();
  const auto &values = arg_list->elements();
  int64_t count = 0;
  for (auto value : values) {
    if (ComparesTwoValues(count_value, value->BuildValue())) {
      ++count;
    }
  }

  auto out = NewValueNode(MakeValue(count));
  ret->set_output(out);
  return ret;
}

3.3 图模式下 dict.fromkeys 方法支持

dict.fromkeys(seq[, value=None])根据给定的可迭代对象seq和value(默认为None),创建一个新的dict并返回。在图模式中,首先判断传入seq对象是否是支持的可迭代对象(目前支持list、tuple、dict、string),不支持将抛出异常。之后遍历该可迭代对象,判断每个元素是否为string类型,不通过将抛出异常。最后根据该key与传入的value值创建新的dict对象,并返回。关键代码如下所示。

FuncGraphPtr DictFromKeys::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
  constexpr size_t dict_fromkeys_args_size = 3; 
  abstract::CheckArgsSize("DictFromKeys", args_list, dict_fromkeys_args_size);
  const auto &values = ParseIterableObject(args_list[1]);
  auto value_node = args_list[2]->BuildValue();
  MS_EXCEPTION_IF_NULL(value_node);

  FuncGraphPtr ret = std::make_shared<FuncGraph>();
  ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  ret->debug_info()->set_name("fromkeys");
  (void)ret->add_parameter();
  (void)ret->add_parameter();
  (void)ret->add_parameter();

  std::vector<std::pair<std::string, ValuePtr>> key_values;
  for (auto &value : values) {
    auto key = value->BuildValue();
    if (!key->IsSameTypeId(StringImm::kTypeId)) {
      MS_LOG(EXCEPTION) << "The key should be string, but got " << key->type_name();
    }

    std::string key_node = GetValue<std::string>(key);
    (void)key_values.emplace_back(std::make_pair(key_node, value_node));
  }

  ret->set_output(NewValueNode(std::make_shared<ValueDictionary>(key_values)));
  return ret;
}

3.4 图模式下 dict.update 方法支持

dict.update(obj) 会在原dict后追加obj dict内容。在图模式中,首先创建一个AnfNodePtrList对象用来保存value值,并创建unordered_map对象记录dict key和value的index值,之后遍历原dict和obj dict,使用kPrimDictGetItem算子取出key对应的value值,并将value插入到AnfNodePtrList对象中,并将key和value index插入到unordered_map中,如存在重复key,则取出对应的index,替换AnfNodePtrList中对应的value。遍历完成后,使用kPrimMakeTuple算子和kPrimMakeDict算子创建新dict,使用该dict对象代替原dict对象。关键代码如下所示。

FuncGraphPtr DictUpdate::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) {
  constexpr size_t dict_update_args_size = 2;
  abstract::CheckArgsSize("DictUpdate", args_list, dict_update_args_size);

  FuncGraphPtr ret = std::make_shared<FuncGraph>();
  ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
  ret->debug_info()->set_name("update");

  AnfNodePtrList key_inputs;
  AnfNodePtrList value_inputs;
  (void)key_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
  (void)value_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));

  std::unordered_map<std::string, size_t> hash_map;
  AddNodeToLists(args_list[0], ret, &key_inputs, &value_inputs, &hash_map);
  AddNodeToLists(args_list[1], ret, &key_inputs, &value_inputs, &hash_map);

  ret->set_output(ret->NewCNode(
    {NewValueNode(prim::kPrimMakeDict), ret->NewCNode(std::move(key_inputs)), ret->NewCNode(std::move(value_inputs))}));
  return ret;
}

04 总结

首先感谢昇思MindSpore社区提供的这次机会,感谢在完成任务过程中导师的帮助。之前对开源的了解一直都是模糊的,通过这次机会,才让我了解到开源项目,并且加入到了开源贡献之中。在刚开始做任务的时候,读昇思MindSpore源码,不断地惊讶于它巧妙的设计,同时学习了深度框架的知识。在最后提交PR的过程中,使自己认识到代码规范和注释的细节,同时有很多专家会对代码的结构提出很多建议,在这个过程中学到了很多的知识,成长了很多。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值