出自:/tensorflow_object_detection_api/models_installed/research/slim/deployment/model_deploy.py
class DeploymentConfig(object):
"""Configuration for deploying a model with `deploy()`.
You can pass an instance of this class to `deploy()` to specify exactly
how to deploy the model to build. If you do not pass one, an instance built
from the default deployment_hparams will be used.
"""
def __init__(self,
num_clones=1,
clone_on_cpu=False,
replica_id=0,
num_replicas=1,
num_ps_tasks=0,
worker_job_name='worker',
ps_job_name='ps'):
"""Create a DeploymentConfig.
The config describes how to deploy a model across multiple clones and
replicas. The model will be replicated `num_clones` times in each replica.
If `clone_on_cpu` is True, each clone will placed on CPU.
If `num_replicas` is 1, the model is deployed via a single process. In that
case `worker_device`, `num_ps_tasks`, and `ps_device` are ignored.
If `num_replicas` is greater than 1, then `worker_device` and `ps_device`
must specify TensorFlow devices for the `worker` and `ps` jobs and
`num_ps_tasks` must be positive.
Args:
num_clones: Number of model clones to deploy in each replica.
clone_on_cpu: If True clones would be placed on CPU.
replica_id: Integer. Index of the replica for which the model is
deployed. Usually 0 for the chief replica.
num_replicas: Number of replicas to use.
num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas.
worker_job_name: A name for the worker job.
ps_job_name: A name for the parameter server job.
Raises:
ValueError: If the arguments are invalid.
"""
if num_replicas > 1:
if num_ps_tasks < 1:
raise ValueError('When using replicas num_ps_tasks must be positive')
if num_replicas > 1 or num_ps_tasks > 0:
if not worker_job_name:
raise ValueError('Must specify worker_job_name when using replicas')
if not ps_job_name:
raise ValueError('Must specify ps_job_name when using parameter server')
if replica_id >= num_replicas:
raise ValueError('replica_id must be less than num_replicas')
self._num_clones = num_clones
self._clone_on_cpu = clone_on_cpu
self._replica_id = replica_id
self._num_replicas = num_replicas
self._num_ps_tasks = num_ps_tasks
self._ps_device = '/job:' + ps_job_name if num_ps_tasks > 0 else ''
self._worker_device = '/job:' + worker_job_name if num_ps_tasks > 0 else ''
@property
def num_clones(self):
return self._num_clones
@property
def clone_on_cpu(self):
return self._clone_on_cpu
@property
def replica_id(self):
return self._replica_id
@property
def num_replicas(self):
return self._num_replicas
@property
def num_ps_tasks(self):
return self._num_ps_tasks
@property
def ps_device(self):
return self._ps_device
@property
def worker_device(self):
return self._worker_device
def caching_device(self):
"""Returns the device to use for caching variables.
Variables are cached on the worker CPU when using replicas.
Returns:
A device string or None if the variables do not need to be cached.
"""
if self._num_ps_tasks > 0:
return lambda op: op.device
else:
return None
def clone_device(self, clone_index):
"""Device used to create the clone and all the ops inside the clone.
Args:
clone_index: Int, representing the clone_index.
Returns:
A value suitable for `tf.device()`.
Raises:
ValueError: if `clone_index` is greater or equal to the number of clones".
"""
if clone_index >= self._num_clones:
raise ValueError('clone_index must be less than num_clones')
device = ''
if self._num_ps_tasks > 0:
device += self._worker_device
if self._clone_on_cpu:
device += '/device:CPU:0'
else:
device += '/device:GPU:%d' % clone_index
return device
def clone_scope(self, clone_index):
"""Name scope to create the clone.
Args:
clone_index: Int, representing the clone_index.
Returns:
A name_scope suitable for `tf.name_scope()`.
Raises:
ValueError: if `clone_index` is greater or equal to the number of clones".
"""
if clone_index >= self._num_clones:
raise ValueError('clone_index must be less than num_clones')
scope = ''
if self._num_clones > 1:
scope = 'clone_%d' % clone_index
return scope
def optimizer_device(self):
"""Device to use with the optimizer.
Returns:
A value suitable for `tf.device()`.
"""
if self._num_ps_tasks > 0 or self._num_clones > 0:
return self._worker_device + '/device:CPU:0'
else:
return ''
def inputs_device(self):
"""Device to use to build the inputs.
Returns:
A value suitable for `tf.device()`.
"""
device = ''
if self._num_ps_tasks > 0:
device += self._worker_device
device += '/device:CPU:0'
return device
def variables_device(self):
"""Returns the device to use for variables created inside the clone.
Returns:
A value suitable for `tf.device()`.
"""
device = ''
if self._num_ps_tasks > 0:
device += self._ps_device
device += '/device:CPU:0'
class _PSDeviceChooser(object):
"""Slim device chooser for variables when using PS."""
def __init__(self, device, tasks):
self._device = device
self._tasks = tasks
self._task = 0
def choose(self, op):
if op.device:
return op.device
node_def = op if isinstance(op, tf.NodeDef) else op.node_def
if node_def.op.startswith('Variable'):
t = self._task
self._task = (self._task + 1) % self._tasks
d = '%s/task:%d' % (self._device, t)
return d
else:
return op.device
if not self._num_ps_tasks:
return device
else:
chooser = _PSDeviceChooser(device, self._num_ps_tasks)
return chooser.choose