src.wic.contextual_embedder#
Module Contents#
Classes#
str(object='') -> str |
|
str(object='') -> str |
|
Attributes#
- src.wic.contextual_embedder.log#
- src.wic.contextual_embedder.T#
- class src.wic.contextual_embedder.LayerAggregator#
Bases:
str,enum.Enumstr(object=’’) -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to ‘strict’.
- AVERAGE = 'average'#
- CONCAT = 'concat'#
- SUM = 'sum'#
- __call__(tensor: torch.Tensor, layers: list[int]) torch.Tensor#
- class src.wic.contextual_embedder.SubwordAggregator#
Bases:
str,enum.Enumstr(object=’’) -> str str(bytes_or_buffer[, encoding[, errors]]) -> str
Create a new string object from the given object. If encoding or errors is specified, then the object must expose a data buffer that will be decoded using the given encoding and error handler. Otherwise, returns the result of object.__str__() (if defined) or repr(object). encoding defaults to sys.getdefaultencoding(). errors defaults to ‘strict’.
- AVERAGE = 'average'#
- FIRST = 'first'#
- LAST = 'last'#
- SUM = 'sum'#
- MAX = 'max'#
- MIN = 'min'#
- __call__(tensor: torch.Tensor) torch.Tensor#
- src.wic.contextual_embedder.TargetName: TypeAlias#
- class src.wic.contextual_embedder.EmbeddingCache(**data)#
Bases:
pydantic.BaseModel- metadata: dict[Any, Any]#
- _cache: dict[TargetName, dict[src.use.UseID, torch.Tensor]]#
- _targets_with_new_uses: set[TargetName]#
- _index: pandas.DataFrame#
- _index_dir: pathlib.Path#
- _index_path: pathlib.Path#
- add_use(use: src.use.Use, embedding: torch.Tensor) None#
Add a use and its embedding to cache.
- Parameters:
use (Use) – one data in the form of Use
embedding (torch.Tensor) – the embedding of the data
- retrieve(use: src.use.Use) torch.Tensor | None#
If the target is not in cache yet, create one in the cache. Then, retrieve the embedding of the use by the specific identifier.
- Parameters:
use (Use) – one data in the form of Use
- Returns:
the embedding of the use
- Return type:
torch.Tensor | None
- load(target: str) dict[src.use.UseID, torch.Tensor] | None#
Load tha data of the target, its identifier and its embedding.
- Parameters:
target (str) – the target term
- Returns:
the identifier and the embedding
- Return type:
dict[UseID, torch.Tensor] | None
- _ids() set[src.use.UseID]#
Retrieve the identifier.
- Returns:
the identifier
- Return type:
set[UseID]
- targets() set[TargetName]#
Get all the target from cache.
- Returns:
_description_
- Return type:
set[TargetName]
- clean()#
Clean the duplicated row in the index dataframe.
- persist(target: str) None#
Save the embedding of target. And remove the target term from the set of new uses.
- Parameters:
target (str) – the target term
- has_new_uses(target: str) bool#
To check if the target is in the set of new uses
- Parameters:
target (str) – the target term
- Returns:
if the target in the set of new uses
- Return type:
bool
- class src.wic.contextual_embedder.ContextualEmbedder(**data: Any)#
Bases:
src.wic.model.WICModel- property device: torch.device#
- property tokenizer: transformers.PreTrainedTokenizerBase#
- property model: transformers.PreTrainedModel#
- truncation_tokens_before_target: float#
- similarity_metric: Callable[Ellipsis, float]#
- normalization: None | Callable[[torch.Tensor], torch.Tensor]#
- ckpt: str#
- layers: conlist(item_type=conint(ge=0), unique_items=True)#
- embedding_cache: EmbeddingCache | None#
- gpu: int | None#
- layer_aggregator: LayerAggregator#
- subword_aggregator: SubwordAggregator#
- encode_only: bool#
- _embeddings: dict[src.use.Use, torch.Tensor]#
- _device: torch.device#
- _tokenizer: transformers.PreTrainedTokenizerBase#
- _model: transformers.PreTrainedModel#
- __enter__()#
- __exit__(exc_type, exc_val, exc_tb)#
- as_df() pandas.DataFrame#
- truncation_indices(target_subword_indices: list[bool]) tuple[int, int]#
- predict(use_pairs: Iterable[tuple[src.use.Use, src.use.Use]]) list[float]#
- tokenize(use: src.use.Use) transformers.BatchEncoding#
- aggregate(tensor: torch.Tensor, layers: list[int]) torch.Tensor#
- encode_all(uses: list[src.use.Use], type: Type[T] = np.ndarray) list[T]#
- encode(use: src.use.Use, type: Type[T] = np.ndarray) T#