diff --git a/stac_extension_genmeta/__init__.py b/stac_extension_genmeta/__init__.py index ea36f6d7014ea0e74efbd7139af770ba5a0c918c..b50774dce6f244552050784ea685b3ec44779936 100644 --- a/stac_extension_genmeta/__init__.py +++ b/stac_extension_genmeta/__init__.py @@ -1,2 +1,2 @@ from .core import create_extension_cls -__version__ = "0.0.21" \ No newline at end of file +__version__ = "0.0.22" \ No newline at end of file diff --git a/stac_extension_genmeta/core.py b/stac_extension_genmeta/core.py index ae7a88ac8e366f38246a6c63dfda3c94dd692c06..e4aba2b0af22a90553b801b904d5918975416b92 100644 --- a/stac_extension_genmeta/core.py +++ b/stac_extension_genmeta/core.py @@ -44,7 +44,7 @@ def create_extension_cls( def __init__(self, obj: T): if isinstance(obj, pystac.Item): self.properties = obj.properties - elif isinstance(obj, pystac.Asset): + elif isinstance(obj, (pystac.Asset, pystac.Collection)): self.properties = obj.extra_fields else: raise pystac.ExtensionTypeError( @@ -109,12 +109,12 @@ def create_extension_cls( ) -> model_cls.__name__: if isinstance(obj, pystac.Item): cls.ensure_has_extension(obj, add_if_missing) - return cast(CustomExtension[T], - ItemCustomExtension(obj)) + return cast(CustomExtension[T], ItemCustomExtension(obj)) elif isinstance(obj, pystac.Asset): cls.ensure_owner_has_extension(obj, add_if_missing) - return cast(CustomExtension[T], - AssetCustomExtension(obj)) + return cast(CustomExtension[T], AssetCustomExtension(obj)) + elif isinstance(obj, pystac.Collection): + return cast(CustomExtension[T], CollectionCustomExtension(obj)) raise pystac.ExtensionTypeError( f"{model_cls.__name__} does not apply to type " f"{type(obj).__name__}" @@ -134,5 +134,12 @@ def create_extension_cls( if asset.owner and isinstance(asset.owner, pystac.Item): self.additional_read_properties = [asset.owner.properties] + class CollectionCustomExtension(CustomExtension[pystac.Collection]): + properties: dict[str, Any] + additional_read_properties: Iterable[dict[str, Any]] | None = None + + def __init__(self, collection: pystac.Collection): + self.properties = collection.extra_fields + CustomExtension.__name__ = f"CustomExtensionFrom{model_cls.__name__}" return CustomExtension diff --git a/stac_extension_genmeta/testing.py b/stac_extension_genmeta/testing.py index 518e063c09e4663c8c790a8e57ba3adde799fa95..24ecd44ac93c94ff8f0d03b620cf10332af303d6 100644 --- a/stac_extension_genmeta/testing.py +++ b/stac_extension_genmeta/testing.py @@ -50,7 +50,7 @@ def create_dummy_item(date=None): ) col.add_item(item) - return item + return item, col def basic_test( @@ -58,6 +58,7 @@ def basic_test( ext_cls, item_test: bool = True, asset_test: bool = True, + collection_test: bool = True, validate: bool = True ): print( @@ -92,7 +93,7 @@ def basic_test( """ Test extension against item """ - item = create_dummy_item() + item, _ = create_dummy_item() apply(item) print_item(item) if validate: @@ -104,7 +105,7 @@ def basic_test( """ Test extension against asset """ - item = create_dummy_item() + item, _ = create_dummy_item() apply(item.assets["ndvi"]) print_item(item) if validate: @@ -112,12 +113,28 @@ def basic_test( # Check that we can retrieve the extension metadata from the asset comp(item.assets["ndvi"]) + def test_collection(): + """ + Test extension against collection + """ + item, col = create_dummy_item() + print_item(col) + apply(col) + print_item(col) + if validate: + col.validate() # <--- This will try to read the actual schema URI + # Check that we can retrieve the extension metadata from the asset + comp(col) + if item_test: print("Test item") test_item() if asset_test: print("Test asset") test_asset() + if collection_test: + print("Test collection") + test_collection() def is_schema_url_synced(cls):