:py:mod:`src.wic.contextual_embedder`
=====================================

.. py:module:: src.wic.contextual_embedder


Module Contents
---------------

Classes
~~~~~~~

.. autoapisummary::

   src.wic.contextual_embedder.LayerAggregator
   src.wic.contextual_embedder.SubwordAggregator
   src.wic.contextual_embedder.EmbeddingCache
   src.wic.contextual_embedder.ContextualEmbedder




Attributes
~~~~~~~~~~

.. autoapisummary::

   src.wic.contextual_embedder.log
   src.wic.contextual_embedder.T
   src.wic.contextual_embedder.TargetName


.. py:data:: log

   

.. py:data:: T

   

.. py:class:: LayerAggregator


   Bases: :py:obj:`str`, :py:obj:`enum.Enum`

   str(object='') -> str
   str(bytes_or_buffer[, encoding[, errors]]) -> str

   Create a new string object from the given object. If encoding or
   errors is specified, then the object must expose a data buffer
   that will be decoded using the given encoding and error handler.
   Otherwise, returns the result of object.__str__() (if defined)
   or repr(object).
   encoding defaults to sys.getdefaultencoding().
   errors defaults to 'strict'.

   .. py:attribute:: AVERAGE
      :value: 'average'

      

   .. py:attribute:: CONCAT
      :value: 'concat'

      

   .. py:attribute:: SUM
      :value: 'sum'

      

   .. py:method:: __call__(tensor: torch.Tensor, layers: list[int]) -> torch.Tensor



.. py:class:: SubwordAggregator


   Bases: :py:obj:`str`, :py:obj:`enum.Enum`

   str(object='') -> str
   str(bytes_or_buffer[, encoding[, errors]]) -> str

   Create a new string object from the given object. If encoding or
   errors is specified, then the object must expose a data buffer
   that will be decoded using the given encoding and error handler.
   Otherwise, returns the result of object.__str__() (if defined)
   or repr(object).
   encoding defaults to sys.getdefaultencoding().
   errors defaults to 'strict'.

   .. py:attribute:: AVERAGE
      :value: 'average'

      

   .. py:attribute:: FIRST
      :value: 'first'

      

   .. py:attribute:: LAST
      :value: 'last'

      

   .. py:attribute:: SUM
      :value: 'sum'

      

   .. py:attribute:: MAX
      :value: 'max'

      

   .. py:attribute:: MIN
      :value: 'min'

      

   .. py:method:: __call__(tensor: torch.Tensor) -> torch.Tensor



.. py:data:: TargetName
   :type: TypeAlias

   

.. py:class:: EmbeddingCache(**data)


   Bases: :py:obj:`pydantic.BaseModel`

   .. py:attribute:: metadata
      :type: dict[Any, Any]

      

   .. py:attribute:: _cache
      :type: dict[TargetName, dict[src.use.UseID, torch.Tensor]]

      

   .. py:attribute:: _targets_with_new_uses
      :type: set[TargetName]

      

   .. py:attribute:: _index
      :type: pandas.DataFrame

      

   .. py:attribute:: _index_dir
      :type: pathlib.Path

      

   .. py:attribute:: _index_path
      :type: pathlib.Path

      

   .. py:method:: add_use(use: src.use.Use, embedding: torch.Tensor) -> None

      Add a use and its embedding to cache.

      :param use: one data in the form of Use
      :type use: Use
      :param embedding: the embedding of the data
      :type embedding: torch.Tensor


   .. py:method:: retrieve(use: src.use.Use) -> torch.Tensor | None

      If the target is not in cache yet, create one in the cache. Then, retrieve the 
      embedding of the use by the specific identifier.

      :param use: one data in the form of Use
      :type use: Use
      :return: the embedding of the use
      :rtype: torch.Tensor | None


   .. py:method:: load(target: str) -> dict[src.use.UseID, torch.Tensor] | None

      Load tha data of the target, its identifier and its embedding. 

      :param target: the target term
      :type target: str
      :return: the identifier and the embedding
      :rtype: dict[UseID, torch.Tensor] | None


   .. py:method:: _ids() -> set[src.use.UseID]

      Retrieve the identifier.

      :return: the identifier
      :rtype: set[UseID]


   .. py:method:: targets() -> set[TargetName]

      Get all the target from cache.

      :return: _description_
      :rtype: set[TargetName]


   .. py:method:: clean()

      Clean the duplicated row in the index dataframe. 
              


   .. py:method:: persist(target: str) -> None

      Save the embedding of target. And remove the target term from the set of new uses.

      :param target: the target term
      :type target: str


   .. py:method:: has_new_uses(target: str) -> bool

      To check if the target is in the set of new uses

      :param target: the target term
      :type target: str
      :return: if the target in the set of new uses
      :rtype: bool



.. py:class:: ContextualEmbedder(**data: Any)


   Bases: :py:obj:`src.wic.model.WICModel`

   .. py:property:: device
      :type: torch.device


   .. py:property:: tokenizer
      :type: transformers.PreTrainedTokenizerBase


   .. py:property:: model
      :type: transformers.PreTrainedModel


   .. py:attribute:: truncation_tokens_before_target
      :type: float

      

   .. py:attribute:: similarity_metric
      :type: Callable[Ellipsis, float]

      

   .. py:attribute:: normalization
      :type: None | Callable[[torch.Tensor], torch.Tensor]

      

   .. py:attribute:: ckpt
      :type: str

      

   .. py:attribute:: layers
      :type: conlist(item_type=conint(ge=0), unique_items=True)

      

   .. py:attribute:: embedding_cache
      :type: EmbeddingCache | None

      

   .. py:attribute:: gpu
      :type: int | None

      

   .. py:attribute:: layer_aggregator
      :type: LayerAggregator

      

   .. py:attribute:: subword_aggregator
      :type: SubwordAggregator

      

   .. py:attribute:: encode_only
      :type: bool

      

   .. py:attribute:: _embeddings
      :type: dict[src.use.Use, torch.Tensor]

      

   .. py:attribute:: _device
      :type: torch.device

      

   .. py:attribute:: _tokenizer
      :type: transformers.PreTrainedTokenizerBase

      

   .. py:attribute:: _model
      :type: transformers.PreTrainedModel

      

   .. py:method:: __enter__()


   .. py:method:: __exit__(exc_type, exc_val, exc_tb)


   .. py:method:: as_df() -> pandas.DataFrame


   .. py:method:: truncation_indices(target_subword_indices: list[bool]) -> tuple[int, int]


   .. py:method:: predict(use_pairs: Iterable[tuple[src.use.Use, src.use.Use]]) -> list[float]


   .. py:method:: tokenize(use: src.use.Use) -> transformers.BatchEncoding


   .. py:method:: aggregate(tensor: torch.Tensor, layers: list[int]) -> torch.Tensor


   .. py:method:: encode_all(uses: list[src.use.Use], type: Type[T] = np.ndarray) -> list[T]


   .. py:method:: encode(use: src.use.Use, type: Type[T] = np.ndarray) -> T



