Source code for ffp.norms

"""
Norms module.
"""

import struct
from os import PathLike
from typing import BinaryIO, Union

import numpy as np

from ffp.io import Chunk, find_chunk, ChunkIdentifier, TypeId, _pad_float32, _read_binary, \
    FinalfusionFormatError, _write_binary


[docs]class Norms(np.ndarray, Chunk): """ Embedding Norms. Norms subclass `numpy.ndarray`, all typical numpy operations are available. The ith norm is expected to correspond to the l2 norm of the ith row in the storage before normalizing it. Therefore, Norms should have at most the same length as a given Storage and are expected to match the length of the Vocabulary. """ def __new__(cls, array: np.ndarray): if array.dtype != np.float32 or array.ndim != 1: raise TypeError("expected 1-d float32 array") return array.view(cls)
[docs] @staticmethod def chunk_identifier(): return ChunkIdentifier.NdNorms
[docs] @staticmethod def read_chunk(file: BinaryIO) -> 'Norms': n_norms, type_id = _read_binary(file, "<QI") if int(TypeId.f32) != type_id: raise FinalfusionFormatError( f"Invalid Type, expected {TypeId.f32}, got {type_id}") padding = _pad_float32(file.tell()) file.seek(padding, 1) array = np.fromfile(file=file, count=n_norms, dtype=np.float32) return Norms(array)
[docs] def write_chunk(self, file: BinaryIO): _write_binary(file, "<I", int(self.chunk_identifier())) padding = _pad_float32(file.tell()) chunk_len = struct.calcsize( "QI") + padding + self.size * struct.calcsize("f") _write_binary(file, f"<QQI{padding}x", chunk_len, self.size, int(TypeId.f32)) self.tofile(file)
def __getitem__(self, key): if isinstance(key, slice): return Norms(super().__getitem__(key)) return super().__getitem__(key)
[docs]def load_norms(file: Union[str, bytes, int, PathLike]) -> Norms: """ Load an Norms chunk from the given file. Parameters ---------- file : str, bytes, int, PathLike Finalfusion file with a norms chunk. Returns ------- storage : Norms The Norms from the file. Raises ------ ValueError If the file did not contain an Norms chunk. """ with open(file, "rb") as inf: chunk = find_chunk(inf, [ChunkIdentifier.NdNorms]) if chunk is None: raise ValueError("File did not contain a Norms chunk") if chunk == ChunkIdentifier.NdNorms: return Norms.read_chunk(inf) raise ValueError(f"unexpected chunk: {str(chunk)}")
__all__ = ['Norms', 'load_norms']