其中测试对象:
特别注意:
api_res = self.cache_api_wrapper.get_api(self.network_api_svc.network_whole_features, company_name)
与以前的直接通过调用api获取网络图特征不同,此处外面还加了多进程get_api函数;
直接调用它会报错!
def get_api(self, function, *params):
"""
多web进程之间的api(调度、缓存)管理
"""
key = self.__cache_api.func_and_params_2_key(function, params)
if self.__redis_helper.exists(key): # 有缓存读取缓存
return self.__get_redis_value(key)
else: # redis无缓存处理
if self.__token_cache_db.try_add(key): # 成功获取到调用api令牌
api_data = self.__cache_api.get_api(function, *params)
self.__token_cache_db.delete_by_token_string(key)
return api_data
else: # 未获取到api令牌等阻塞读取到redis 缓存为止
second_count = 0
while True:
second_count += 1
if self.__redis_helper.exists(key): # 缓存已经成功
self.__token_cache_db.delete_by_token_string(key)
return self.__get_redis_value(key)
if second_count >= 100: # 超过等待时间重新api获取
self.__token_cache_db.delete_by_token_string(key)
return self.__cache_api.get_api(function, *params)
sleep(1)
测试对象,mock对象:
class NetworkFeaturesSingle(object):
"""
网络图特征数据收集封装 -- 单一节点
"""
network_api_svc = NetworkApiSvc()
cache_api_wrapper = CacheApiWrapper()
def get_features(self, company_name):
"""
获取网络图特征
"""
api_res = self.cache_api_wrapper.get_api(self.network_api_svc.network_whole_features, company_name)
network_dto = NetworkFeatureDto()
network_dto.network_share_cancel_cnt = api_res.get('shareOrPosRevokedCnt', 0)
network_dto.cancel_cnt = api_res.get('frRevokedCnt', 0)
network_dto.fr_zhi_xing_cnt = api_res.get('frZhixingCnt', 0)
network_dto.network_share_zhixing_cnt = api_res.get('shareOrPosZhixingCnt', 0)
network_dto.network_share_judge_doc_cnt = api_res.get('shareOrPosJudgeDocCnt', 0)
network_dto.net_judgedoc_defendant_cnt = api_res.get('allLinkJudgedocCnt', 0)
network_dto.judge_doc_cnt = api_res.get('frJudgedocCnt', 0)
return network_dto
测试脚本,写法1:
报错:AttributeError: _name_
from unittest import TestCase
import mock
from api.network_api_svc import NetworkApiSvc
from biz.biz_utils.cache_api import CacheApi
from biz.integration_api.common.network_features_single import NetworkFeaturesSingle
from common.helper.test_helper import TestHelper
from test.test_clean.test_common.test_utils import TestUtils
class TestNetworkFeaturesSingle(TestCase):
t_h = TestHelper()
t_u = TestUtils()
def tearDown(self):
CacheApi.REFRESH_CACHE = False
@mock.patch.object(NetworkApiSvc, 'network_whole_features')
def test_get_features(self, network_whole_features):
# given
network_whole_features.return_value = self.__get_api_value()
n_f_w = NetworkFeaturesSingle()
company_name = u'测试公司'
# when
network_dto = n_f_w.get_features(company_name)
# then
self.assertEqual(1, network_dto.network_share_judge_doc_cnt)
self.assertEqual(2, network_dto.judge_doc_cnt)
self.assertEqual(3, network_dto.network_share_cancel_cnt)
self.assertEqual(4, network_dto.cancel_cnt)
self.assertEqual(5, network_dto.network_share_zhixing_cnt)
self.assertEqual(6, network_dto.net_judgedoc_defendant_cnt)
self.assertEqual(7, network_dto.fr_zhi_xing_cnt)
def test_get_features_no_mock(self):
# given
n_f_w = NetworkFeaturesSingle()
company_name = u'小米科技有限责任公司'
# when
network_dto = n_f_w.get_features(company_name)
# then
self.t_u.print_domain(network_dto)
assert network_dto
@staticmethod
def __get_api_value():
return {
u'shareOrPosJudgeDocCnt': 1,
u'frJudgedocCnt': 2,
u'shareOrPosRevokedCnt': 3,
u'frRevokedCnt': 4,
u'shareOrPosZhixingCnt': 5,
u'allLinkJudgedocCnt': 6,
u'frZhixingCnt': 7
}
测试脚本,写法2:
from unittest import TestCase
import mock
from api.network_api_svc import NetworkApiSvc
from biz.biz_utils.cache_api import CacheApi
from biz.integration_api.common.network_features_single import NetworkFeaturesSingle
from common.helper.test_helper import TestHelper
from test.test_clean.test_common.test_utils import TestUtils
class TestNetworkFeaturesSingle(TestCase):
t_h = TestHelper()
t_u = TestUtils()
def tearDown(self):
CacheApi.REFRESH_CACHE = False
@mock.patch.object(NetworkApiSvc, 'network_whole_features')
def test_get_features(self, network_whole_features):
# given
company_name = u'测试公司'
self.t_h.set_api_return_value(locals(),network_whole_features, company_name, self.__get_api_value())
n_f_w = NetworkFeaturesSingle()
#注释:locals()返回在它之前的所有局部变量,mock的network_whole_features以及company_name
# when
network_dto = n_f_w.get_features(company_name)
# then
self.assertEqual(1, network_dto.network_share_judge_doc_cnt)
self.assertEqual(2, network_dto.judge_doc_cnt)
self.assertEqual(3, network_dto.network_share_cancel_cnt)
self.assertEqual(4, network_dto.cancel_cnt)
self.assertEqual(5, network_dto.network_share_zhixing_cnt)
self.assertEqual(6, network_dto.net_judgedoc_defendant_cnt)
self.assertEqual(7, network_dto.fr_zhi_xing_cnt)
**local()用法:**
**locals() 函数会以字典类型返回当前位置的全部局部变量。**
对于函数, 方法, lambda 函式, 类, 以及实现了 __call__ 方法的类实例, 它都返回 True
>>>def runoob(arg): # 两个局部变量:arg、z
... z = 1
... print (locals())
...
>>> runoob(4)
{'z': 1, 'arg': 4} # 返回一个名字/值对的字典
>>>
当前locals()返回的是:
{'network_whole_features': <MagicMock name='network_whole_features' id='140258690673808'>, 'company_name': u'\u6d4b\u8bd5\u516c\u53f8', 'self': <test_networkFeaturesSingle.TestNetworkFeaturesSingle testMethod=test_get_features>}
//
def test_get_features_no_mock(self):
# given
n_f_w = NetworkFeaturesSingle()
company_name = u'小米科技有限责任公司'
# when
network_dto = n_f_w.get_features(company_name)
# then
self.t_u.print_domain(network_dto)
assert network_dto
@staticmethod
def __get_api_value():
return {
u'shareOrPosJudgeDocCnt': 1,
u'frJudgedocCnt': 2,
u'shareOrPosRevokedCnt': 3,
u'frRevokedCnt': 4,
u'shareOrPosZhixingCnt': 5,
u'allLinkJudgedocCnt': 6,
u'frZhixingCnt': 7
}
增加的类函数:
def set_api_return_value(self, local_vars, api, company_name, return_value):
"""
用于 mock 测试的 api 调用 -- 支持 Redis缓存
:param local_vars: 通过 locals() 取得
:param api: 要调用的api
:param company_name: 公司名
:param return_value: 需要作为存根的数据
"""
api.return_value = return_value
redis_helper = RedisHelper()
api_name = self.__get_variable_name(local_vars, api)
redis_helper.delete(api_name + u'__' + company_name)
api.__name__ = api_name
@staticmethod
def __get_variable_name(local_vars, x):
"""
获取变量的字串名
:param local_vars: 通过 locals() 取得
:param x: 需要获取字串名的变量
"""
for k, v in local_vars.items():
if v is x:
return k