Skip to content

zamba.models.utils

Attributes

S3_BUCKET = 's3://drivendata-public-assets' module-attribute

Classes

RegionEnum

Bases: str, Enum

Source code in zamba/models/utils.py
17
18
19
20
class RegionEnum(str, Enum):
    us = "us"
    eu = "eu"
    asia = "asia"

Attributes

asia = 'asia' class-attribute
eu = 'eu' class-attribute
us = 'us' class-attribute

Functions

download_weights(filename: str, destination_dir: Union[os.PathLike, str], weight_region: RegionEnum = RegionEnum('us')) -> Path

Source code in zamba/models/utils.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def download_weights(
    filename: str,
    destination_dir: Union[os.PathLike, str],
    weight_region: RegionEnum = RegionEnum("us"),
) -> Path:
    # get s3 bucket based on region
    if weight_region != "us":
        region_bucket = f"{S3_BUCKET}-{weight_region}"
    else:
        region_bucket = S3_BUCKET

    s3p = S3Path(
        f"{region_bucket}/zamba_official_models/{filename}",
        client=S3Client(local_cache_dir=destination_dir, no_sign_request=True),
    )

    s3p.download_to(destination_dir)
    return str(Path(destination_dir) / s3p.name)

get_checkpoint_hparams(checkpoint)

Source code in zamba/models/utils.py
62
63
def get_checkpoint_hparams(checkpoint):
    return copy.deepcopy(_cached_hparams(checkpoint))

get_default_hparams(model)

Source code in zamba/models/utils.py
53
54
55
56
57
58
59
def get_default_hparams(model):
    if isinstance(model, Enum):
        model = model.value

    hparams_file = MODELS_DIRECTORY / model / "hparams.yaml"
    with hparams_file.open() as f:
        return yaml.safe_load(f)

get_model_checkpoint_filename(model_name)

Source code in zamba/models/utils.py
43
44
45
46
47
48
49
50
def get_model_checkpoint_filename(model_name):
    if isinstance(model_name, Enum):
        model_name = model_name.value

    config_file = MODELS_DIRECTORY / model_name / "config.yaml"
    with config_file.open() as f:
        config_dict = yaml.safe_load(f)
    return Path(config_dict["public_checkpoint"])

get_model_species(checkpoint, model_name)

Source code in zamba/models/utils.py
71
72
73
74
75
76
77
def get_model_species(checkpoint, model_name):
    # hparams on checkpoint supersede base model
    if checkpoint is not None:
        model_species = get_checkpoint_hparams(checkpoint)["species"]
    else:
        model_species = get_default_hparams(model_name)["species"]
    return model_species