Source code for pythae.models.base.base_utils

import importlib
import io
import logging
from collections import OrderedDict
from typing import Any, Tuple

try:
    import pickle5 as pickle
except:
    import pickle
import torch

logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)

model_card_template = """---
language: en
tags:
- pythae
license: apache-2.0
---

### Downloading this model from the Hub
This model was trained with pythae. It can be downloaded or reloaded using the method `load_from_hf_hub`
```python
>>> from pythae.models import AutoModel
>>> model = AutoModel.load_from_hf_hub(hf_hub_path="your_hf_username/repo_name")
```
"""


def hf_hub_is_available():
    return importlib.util.find_spec("huggingface_hub") is not None


[docs]class ModelOutput(OrderedDict): """Base ModelOutput class fixing the output type from the models. This class is inspired from the ``ModelOutput`` class from hugginface transformers library""" def __getitem__(self, k): if isinstance(k, str): self_dict = {k: v for (k, v) in self.items()} return self_dict[k] else: return self.to_tuple()[k] def __setattr__(self, name, value): super().__setitem__(name, value) super().__setattr__(name, value) def __setitem__(self, key, value): super().__setitem__(key, value) super().__setattr__(key, value)
[docs] def to_tuple(self) -> Tuple[Any]: """ Convert self to a tuple containing all the attributes/keys that are not ``None``. """ return tuple(self[k] for k in self.keys())
class CPU_Unpickler(pickle.Unpickler): def find_class(self, module, name): if module == "torch.storage" and name == "_load_from_bytes": return lambda b: torch.load(io.BytesIO(b), map_location="cpu") else: return super().find_class(module, name)