chain.from_iterable
第一次见这个函数是在看torch.optim.Optimizer这个类的load_state_dict()
方法
源码
# Update the state
id_map = {old_id: p for old_id, p in zip(chain.from_iterable((g['params'] for g in saved_groups)),
chain.from_iterable((g['params'] for g in groups)))}
主要就是把加载进来的state_dict中的参数用来更新optimizer的state_dict。
再来看看chain.form_iterable()函数的功能,它属于终止迭代器类别(terminating iterators
)
把输入的可迭代对象的每个单个iterable作为参数进行迭代,具体
示例:
1
from itertools import chain
from_iterable = chain.from_iterable(['balabala','for', 'galagala'])
print(list(from_iterable))
#['b', 'a', 'l', 'a', 'b', 'a', 'l', 'a', 'f', 'o', 'r', 'g', 'a', 'l', 'a', 'g', 'a', 'l', 'a']
2
from itertools import chain
from_iterable = chain.from_iterable(['balabala','for', 'galagala',[(1,2),(2,3)]])
print(list(from_iterable))
#['b', 'a', 'l', 'a', 'b', 'a', 'l', 'a', 'f', 'o', 'r', 'g', 'a', 'l', 'a', 'g', 'a', 'l', 'a', (1, 2), (2, 3)]