box_embeddings.parameterizations.box_tensor
¶
Base class for creating a wrapper around the torch.Tensor to represent boxes
A BoxTensor contains single tensor which represents single or multiple boxes.
- note:
Have to use composition instead of inheritance because currently it is not safe to interit from
torch.Tensor
because creating an instance of such a class will always make it a leaf node. This works fortorch.nn.Parameter
but won’t work for a general BoxTensor. This most likely will change in the future as pytorch starts offical support for inheriting from a Tensor. Give this point some thought when this happens.
Module Contents¶
- logger¶
- TBoxTensor¶
- class BoxTensor(data: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], *args: Any, **kwargs: Any)¶
Bases:
object
Base class defining the interface for BoxTensor.
- w2z_ratio :int = 2¶
- reinit(self, data: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]) None ¶
Constructor.
- Parameters
data – Tensor of shape (…, zZ, num_dims). Here, zZ=2, where the 0th dim is for bottom left corner and 1st dim is for top right corner of the box
Returns: None
- Raises
ValueError – If new shape is different than old shape
- property kwargs(self) Dict ¶
Configuration attribute values
- Returns
Dict
- property args(self) Tuple ¶
Configuration attribute as Tuple
- Returns
Tuple
- property z(self) torch.Tensor ¶
Lower left coordinate as Tensor
- Returns
lower left corner
- Return type
Tensor
- property Z(self) torch.Tensor ¶
Top right coordinate as Tensor
- Returns
top right corner
- Return type
Tensor
- property centre(self) torch.Tensor ¶
Centre coordinate as Tensor
- Returns
Center
- Return type
Tensor
- classmethod check_if_valid_zZ(cls: Type[TBoxTensor], z: torch.Tensor, Z: torch.Tensor) None ¶
Check of (z,Z) form a valid box.
If your child class parameterization bounds the boxes to some universe box then this is the right place to check that.
- Parameters
z – Lower left coordinate of shape (…, hidden_dims)
Z – Top right coordinate of shape (…, hidden_dims)
- Raises
ValueError – If z and Z do not have the same shape
ValueError – If Z < z
- classmethod W(cls: Type[TBoxTensor], z: torch.Tensor, Z: torch.Tensor, *args: Any, **kwargs: Any) torch.Tensor ¶
Given (z,Z), it returns one set of valid box weights W, such that Box(W) = (z,Z).
For the base BoxTensor class, we just return z and Z stacked together. If you implement any new parameterization for boxes. You most likely need to override this method.
- Parameters
z – Lower left coordinate of shape (…, hidden_dims)
Z – Top right coordinate of shape (…, hidden_dims)
*args – TODO
**kwargs – TODO
- Returns
- Parameters of the box. In base class implementation, this
will have shape (…, 2, hidden_dims).
- Return type
Tensor
- classmethod zZ_to_embedding(cls, z: torch.Tensor, Z: torch.Tensor, *args: Any, **kwargs: Any) torch.Tensor ¶
collapse the last two dimensions
- Parameters
z – Lower left coordinate of shape (…, hidden_dims)
Z – Top right coordinate of shape (…, hidden_dims)
*args – TODO
**kwargs – TODO
- Returns
A Box tensor with the last two dimensions z, Z collapsed
- classmethod from_zZ(cls: Type[TBoxTensor], z: torch.Tensor, Z: torch.Tensor, *args: Any, **kwargs: Any) TBoxTensor ¶
Creates a box for the given min-max coordinates (z,Z).
In the this base implementation we do this by stacking z and Z along -2 dim to form W.
- Parameters
z – lower left
Z – top right
*args – extra arguments for child class
**kwargs – extra arguments for child class
- Returns
A BoxTensor
- like_this_from_zZ(self, z: torch.Tensor, Z: torch.Tensor) BoxTensor ¶
Creates a box for the given min-max coordinates (z,Z). This is similar to the class method :method:`from_zZ`, but uses the attributes on self and not external args, kwargs.
For the base class, since we do not have extra attributes, we simply call from_zZ.
- Parameters
z – lower left
Z – top right
- Returns
A BoxTensor
- classmethod from_vector(cls, vector: torch.Tensor, *args: Any, **kwargs: Any) TBoxTensor ¶
Creates a box for a vector. In this base implementation the vector is split into two pieces and these are used as z,Z.
- Parameters
vector – tensor
*args – extra arguments for child class
**kwargs – extra arguments for child class
- Returns
A BoxTensor
- Raises
ValueError – if last dimension is not even
- property box_shape(self) Tuple ¶
Shape of z, Z and center.
- Returns
Shape of z, Z and center.
Note
This is not the shape of the data attribute.
- broadcast(self, target_shape: Tuple) None ¶
Broadcasts the internal data member in-place such that z and Z return tensors that can be automatically broadcasted to perform arithmetic operations with shape target_shape.
- Ex:
target_shape = (4,5,10)
self.box_shape = (10,) => (1,1,10)
self.box_shape = (3,) => ValueError
self.box_shape = (4,10) => (4,1,10)
self.box_shape = (4,2,10) => ValueError
self.box_shape = (5,10) => (1,5,10)
Note
This operation will not result in self.z, self.Z and self.center returning tensor of shape target_shape but it will result in return a tensor which is arithmetic compatible with target_shape.
- Parameters
target_shape – Shape of the broadcast target. Usually will be the shape of the tensor you wish to use z, Z with. For instance, if you wish to add self box’s center [shape=(batch, hidden_dim)] with other box whose center’s shape is (batch, extra_dim, hidden_dim), then this function will reshape the data such that the resulting center has shape (batch, 1, hidden_dim).
- Raises
ValueError – If bad target
- ..todo::
Add an extra argument repeat which tell the function to repeat values till target is satisfied. This is needed for gumbel_intersection, where the broadcasted tensors need to be stacked.
- box_reshape(self, target_shape: Tuple) BoxTensor ¶
Reshape the z,Z and center.
- Ex:
self.box_shape = (5,10), target_shape = (-1,10), creates box_shape (5,10)
2. self.box_shape = (5,4,10), target_shape = (-1,10), creates box_shape (20,10) 4. self.box_shape = (20,10), target_shape = (10,2,10), creates box_shape (10,2,10) 3. self.box_shape = (5,), target_shape = (-1,10), raises RuntimeError 5. self.box_shape = (5,10), target_shape = (2,10), raises RuntimeError
- Parameters
target_shape – TODO
- Returns
TBoxTensor
- Raises
RuntimeError – If space dimensions, ie. the last dimensions do not match.
RuntimeError – If cannot reshape the extra dimensions and torch.reshape raises.
- class BoxFactory(name: str, kwargs_dict: Dict = None)¶
Bases:
box_embeddings.common.registrable.Registrable
A factory class which will be subclassed(one for each box type).
- box_registry :Dict[str, Tuple[Type[BoxTensor], Optional[str]]]¶
- classmethod register_box_class(cls, name: str, constructor: str = None, exist_ok: bool = False) Callable[[Type[BoxTensor]], Type[BoxTensor]] ¶
This is different from allennlp registrable because what this class registers is not subclasses but subclasses of BoxTensor
- Parameters
name – TODO
constructor – TODO
exist_ok – TODO
- Returns
()
- Raises
RuntimeError –