src.wic.contextual_embedder#

Module Contents#

Classes#

LayerAggregator

str(object='') -> str

SubwordAggregator

str(object='') -> str

EmbeddingCache

ContextualEmbedder

Attributes#

log

T

TargetName

src.wic.contextual_embedder.log#
src.wic.contextual_embedder.T#
class src.wic.contextual_embedder.LayerAggregator#

Bases: str, enum.Enum

str(object=’’) -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to ‘strict’.

AVERAGE = 'average'#
CONCAT = 'concat'#
SUM = 'sum'#
__call__(tensor: torch.Tensor, layers: list[int]) torch.Tensor#
class src.wic.contextual_embedder.SubwordAggregator#

Bases: str, enum.Enum

str(object=’’) -> str str(bytes_or_buffer[, encoding[, errors]]) -> str

Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to ‘strict’.

AVERAGE = 'average'#
FIRST = 'first'#
LAST = 'last'#
SUM = 'sum'#
MAX = 'max'#
MIN = 'min'#
__call__(tensor: torch.Tensor) torch.Tensor#
src.wic.contextual_embedder.TargetName: TypeAlias#
class src.wic.contextual_embedder.EmbeddingCache(**data)#

Bases: pydantic.BaseModel

metadata: dict[Any, Any]#
_cache: dict[TargetName, dict[src.use.UseID, torch.Tensor]]#
_targets_with_new_uses: set[TargetName]#
_index: pandas.DataFrame#
_index_dir: pathlib.Path#
_index_path: pathlib.Path#
add_use(use: src.use.Use, embedding: torch.Tensor) None#

Add a use and its embedding to cache.

Parameters:
  • use (Use) – one data in the form of Use

  • embedding (torch.Tensor) – the embedding of the data

retrieve(use: src.use.Use) torch.Tensor | None#

If the target is not in cache yet, create one in the cache. Then, retrieve the embedding of the use by the specific identifier.

Parameters:

use (Use) – one data in the form of Use

Returns:

the embedding of the use

Return type:

torch.Tensor | None

load(target: str) dict[src.use.UseID, torch.Tensor] | None#

Load tha data of the target, its identifier and its embedding.

Parameters:

target (str) – the target term

Returns:

the identifier and the embedding

Return type:

dict[UseID, torch.Tensor] | None

_ids() set[src.use.UseID]#

Retrieve the identifier.

Returns:

the identifier

Return type:

set[UseID]

targets() set[TargetName]#

Get all the target from cache.

Returns:

_description_

Return type:

set[TargetName]

clean()#

Clean the duplicated row in the index dataframe.

persist(target: str) None#

Save the embedding of target. And remove the target term from the set of new uses.

Parameters:

target (str) – the target term

has_new_uses(target: str) bool#

To check if the target is in the set of new uses

Parameters:

target (str) – the target term

Returns:

if the target in the set of new uses

Return type:

bool

class src.wic.contextual_embedder.ContextualEmbedder(**data: Any)#

Bases: src.wic.model.WICModel

property device: torch.device#
property tokenizer: transformers.PreTrainedTokenizerBase#
property model: transformers.PreTrainedModel#
truncation_tokens_before_target: float#
similarity_metric: Callable[Ellipsis, float]#
normalization: None | Callable[[torch.Tensor], torch.Tensor]#
ckpt: str#
layers: conlist(item_type=conint(ge=0), unique_items=True)#
embedding_cache: EmbeddingCache | None#
gpu: int | None#
layer_aggregator: LayerAggregator#
subword_aggregator: SubwordAggregator#
encode_only: bool#
_embeddings: dict[src.use.Use, torch.Tensor]#
_device: torch.device#
_tokenizer: transformers.PreTrainedTokenizerBase#
_model: transformers.PreTrainedModel#
__enter__()#
__exit__(exc_type, exc_val, exc_tb)#
as_df() pandas.DataFrame#
truncation_indices(target_subword_indices: list[bool]) tuple[int, int]#
predict(use_pairs: Iterable[tuple[src.use.Use, src.use.Use]]) list[float]#
tokenize(use: src.use.Use) transformers.BatchEncoding#
aggregate(tensor: torch.Tensor, layers: list[int]) torch.Tensor#
encode_all(uses: list[src.use.Use], type: Type[T] = np.ndarray) list[T]#
encode(use: src.use.Use, type: Type[T] = np.ndarray) T#