import datetime
import os
import pickle
import platform
import shutil
import sys
import time
from contextlib import contextmanager
from typing import List, Optional
from dagster import (
Field,
InitResourceContext,
InputContext,
Int,
IOManager,
MetadataValue,
OutputContext,
String,
io_manager,
)
from dagster._core.storage.io_manager import dagster_maintained_io_manager
from wandb.sdk.data_types.base_types.wb_value import WBValue
from wandb.sdk.wandb_artifacts import Artifact
from .resources import WANDB_CLOUD_HOST
from .version import __version__
if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict
try:
import dill
has_dill = True
except ImportError:
has_dill = False
try:
import cloudpickle
has_cloudpickle = True
except ImportError:
has_cloudpickle = False
try:
import joblib
has_joblib = True
except ImportError:
has_joblib = False
PICKLE_FILENAME = "output.pickle"
DILL_FILENAME = "output.dill"
CLOUDPICKLE_FILENAME = "output.cloudpickle"
JOBLIB_FILENAME = "output.joblib"
ACCEPTED_SERIALIZATION_MODULES = [
"dill",
"cloudpickle",
"joblib",
"pickle",
]
class Config(TypedDict):
dagster_run_id: str
wandb_host: str
wandb_entity: str
wandb_project: str
wandb_run_name: Optional[str]
wandb_run_id: Optional[str]
wandb_run_tags: Optional[List[str]]
base_dir: str
cache_duration_in_minutes: Optional[int]
[docs]class WandbArtifactsIOManagerError(Exception):
"""Represents an execution error of the W&B Artifacts IO Manager."""
def __init__(self, message="A W&B Artifacts IO Manager error occurred."):
self.message = message
super().__init__(self.message)
class ArtifactsIOManager(IOManager):
"""IO Manager to handle Artifacts in Weights & Biases (W&B) .
It handles 3 different inputs:
- Pickable objects (the serialization module is configurable)
- W&B Objects (Audio, Table, Image, etc)
- W&B Artifacts
"""
def __init__(self, wandb_client, config: Config):
self.wandb = wandb_client
dagster_run_id = config["dagster_run_id"]
self.dagster_run_id = dagster_run_id
self.wandb_host = config["wandb_host"]
self.wandb_entity = config["wandb_entity"]
self.wandb_project = config["wandb_project"]
self.wandb_run_id = config.get("wandb_run_id") or dagster_run_id
self.wandb_run_name = config.get("wandb_run_name") or f"dagster-run-{dagster_run_id[0:8]}"
# augments the run tags
wandb_run_tags = config["wandb_run_tags"] or []
if "dagster_wandb" not in wandb_run_tags:
wandb_run_tags = [*wandb_run_tags, "dagster_wandb"]
self.wandb_run_tags = wandb_run_tags
self.base_dir = config["base_dir"]
cache_duration_in_minutes = config["cache_duration_in_minutes"]
default_cache_expiration_in_minutes = 60 * 24 * 30 # 60 minutes * 24 hours * 30 days
self.cache_duration_in_minutes = (
cache_duration_in_minutes
if cache_duration_in_minutes is not None
else default_cache_expiration_in_minutes
)
def _get_local_storage_path(self):
path = self.base_dir
if os.path.basename(path) != "storage":
path = os.path.join(path, "storage")
path = os.path.join(path, "wandb_artifacts_manager")
os.makedirs(path, exist_ok=True)
return path
def _get_artifacts_path(self, name, version):
local_storage_path = self._get_local_storage_path()
path = os.path.join(local_storage_path, "artifacts", f"{name}:{version}")
os.makedirs(path, exist_ok=True)
return path
def _get_wandb_logs_path(self):
local_storage_path = self._get_local_storage_path()
path = os.path.join(local_storage_path, "runs", self.dagster_run_id)
os.makedirs(path, exist_ok=True)
return path
def _clean_local_storage_path(self):
local_storage_path = self._get_local_storage_path()
cache_duration_in_minutes = self.cache_duration_in_minutes
current_timestamp = int(time.time())
expiration_timestamp = current_timestamp - (
cache_duration_in_minutes * 60 # convert to seconds
)
for root, dirs, files in os.walk(local_storage_path, topdown=False):
for name in files:
current_file_path = os.path.join(root, name)
most_recent_access = os.lstat(current_file_path).st_atime
if most_recent_access <= expiration_timestamp or cache_duration_in_minutes == 0:
os.remove(current_file_path)
for name in dirs:
current_dir_path = os.path.join(root, name)
if not os.path.islink(current_dir_path):
if len(os.listdir(current_dir_path)) == 0 or cache_duration_in_minutes == 0:
shutil.rmtree(current_dir_path)
@contextmanager
def wandb_run(self):
self.wandb.init(
id=self.wandb_run_id,
name=self.wandb_run_name,
project=self.wandb_project,
entity=self.wandb_entity,
dir=self._get_wandb_logs_path(),
tags=self.wandb_run_tags,
anonymous="never",
resume="allow",
)
try:
yield self.wandb.run
finally:
self.wandb.finish()
self._clean_local_storage_path()
def _upload_artifact(self, context: OutputContext, obj):
with self.wandb_run() as run:
parameters = context.metadata.get("wandb_artifact_configuration", {}) # type: ignore
serialization_module = parameters.get("serialization_module", {})
serialization_module_name = serialization_module.get("name", "pickle")
if serialization_module_name not in ACCEPTED_SERIALIZATION_MODULES:
raise WandbArtifactsIOManagerError(
f"The provided value '{serialization_module_name}' is not a supported"
f" serialization module. Supported: {ACCEPTED_SERIALIZATION_MODULES}."
)
serialization_module_parameters = serialization_module.get("parameters", {})
serialization_module_parameters_with_protocol = {
"protocol": pickle.HIGHEST_PROTOCOL, # we use the highest available protocol if we don't pass one
**serialization_module_parameters,
}
artifact_type = parameters.get("type", "artifact")
artifact_description = parameters.get("description")
artifact_metadata = {
"source_integration": "dagster_wandb",
"source_integration_version": __version__,
"source_dagster_run_id": self.dagster_run_id,
"source_created_at": datetime.datetime.now(datetime.timezone.utc).isoformat(),
"source_python_version": platform.python_version(),
}
if isinstance(obj, Artifact):
if parameters.get("name") is not None:
raise WandbArtifactsIOManagerError(
"A 'name' property was provided in the 'wandb_artifact_configuration'"
" metadata dictionary. A 'name' property can only be provided for output"
" that is not already an Artifact object."
)
if parameters.get("type") is not None:
raise WandbArtifactsIOManagerError(
"A 'type' property was provided in the 'wandb_artifact_configuration'"
" metadata dictionary. A 'type' property can only be provided for output"
" that is not already an Artifact object."
)
if context.has_partition_key:
raise WandbArtifactsIOManagerError(
"A partitioned job was detected for an output of type Artifact. This is not"
" currently supported. We would love to hear about your use case. Please"
" contact W&B Support."
)
if len(serialization_module) != 0: # not an empty dict
context.log.warning(
"A 'serialization_module' dictionary was provided in the"
" 'wandb_artifact_configuration' metadata dictionary. It has no effect on"
" an output that is already an Artifact object."
)
# The obj is already an Artifact we augment its metadata
artifact = obj
artifact.metadata = {**artifact.metadata, **artifact_metadata}
if artifact.description is not None and artifact_description is not None:
raise WandbArtifactsIOManagerError(
"A 'description' value was provided in the 'wandb_artifact_configuration'"
" metadata dictionary for an existing Artifact with a non-null description."
" Please, either set the description through 'wandb_artifact_argument' or"
" when constructing your Artifact."
)
if artifact_description is not None:
artifact.description = artifact_description
else:
if context.has_asset_key:
if parameters.get("name") is not None:
raise WandbArtifactsIOManagerError(
"A 'name' property was provided in the 'wandb_artifact_configuration'"
" metadata dictionary. A 'name' property is only required when no"
" 'AssetKey' is found. Artifacts created from an @asset use the asset"
" name as the Artifact name. Artifacts created from an @op with a"
" specified 'asset_key' for the output will use that value. Please"
" remove the 'name' property."
)
artifact_name = context.get_asset_identifier()[0] # name of asset
else:
if parameters.get("name") is None:
raise WandbArtifactsIOManagerError(
"Missing 'name' property in the 'wandb_artifact_configuration' metadata"
" dictionary. A 'name' property is required for Artifacts created from"
" an @op. Alternatively you can use an @asset."
)
artifact_name = parameters.get("name")
if context.has_partition_key:
artifact_name = f"{artifact_name}.{context.partition_key}"
# We replace the | character with - because it is not allowed in artifact names
# The | character is used in multi-dimensional partition keys
artifact_name = str(artifact_name).replace("|", "-")
# Creates an artifact to hold the obj
artifact = self.wandb.Artifact(
name=artifact_name,
type=artifact_type,
description=artifact_description,
metadata=artifact_metadata,
)
if isinstance(obj, WBValue):
if len(serialization_module) != 0: # not an empty dict
context.log.warning(
"A 'serialization_module' dictionary was provided in the"
" 'wandb_artifact_configuration' metadata dictionary. It has no effect"
" on when the output is a W&B object."
)
# Adds the WBValue object using the class name as the name for the file
artifact.add(obj, obj.__class__.__name__)
elif obj is not None:
# The output is not a native wandb Object, we serialize it
if serialization_module_name == "dill":
if not has_dill:
raise WandbArtifactsIOManagerError(
"No module named 'dill' found. Please, make sure that the module is"
" installed."
)
artifact.metadata = {
**artifact.metadata,
**{
"source_serialization_module": "dill",
"source_dill_version_used": dill.__version__,
"source_pickle_protocol_used": serialization_module_parameters_with_protocol[
"protocol"
],
},
}
with artifact.new_file(DILL_FILENAME, "wb") as file:
try:
dill.dump(
obj,
file,
**serialization_module_parameters_with_protocol,
)
context.log.info(
"Output serialized using dill with"
f" parameters={serialization_module_parameters_with_protocol}"
)
except Exception as exception:
raise WandbArtifactsIOManagerError(
"An error occurred in the dill serialization process. Please,"
" verify that the passed arguments are correct and your data is"
" compatible with the module."
) from exception
elif serialization_module_name == "cloudpickle":
if not has_cloudpickle:
raise WandbArtifactsIOManagerError(
"No module named 'cloudpickle' found. Please, make sure that the"
" module is installed."
)
artifact.metadata = {
**artifact.metadata,
**{
"source_serialization_module": "cloudpickle",
"source_cloudpickle_version_used": cloudpickle.__version__,
"source_pickle_protocol_used": serialization_module_parameters_with_protocol[
"protocol"
],
},
}
with artifact.new_file(CLOUDPICKLE_FILENAME, "wb") as file:
try:
cloudpickle.dump(
obj,
file,
**serialization_module_parameters_with_protocol,
)
context.log.info(
"Output serialized using cloudpickle with"
f" parameters={serialization_module_parameters_with_protocol}"
)
except Exception as exception:
raise WandbArtifactsIOManagerError(
"An error occurred in the cloudpickle serialization process."
" Please, verify that the passed arguments are correct and your"
" data is compatible with the module."
) from exception
elif serialization_module_name == "joblib":
if not has_joblib:
raise WandbArtifactsIOManagerError(
"No module named 'joblib' found. Please, make sure that the module"
" is installed."
)
artifact.metadata = {
**artifact.metadata,
**{
"source_serialization_module": "joblib",
"source_joblib_version_used": joblib.__version__,
"source_pickle_protocol_used": serialization_module_parameters_with_protocol[
"protocol"
],
},
}
with artifact.new_file(JOBLIB_FILENAME, "wb") as file:
try:
joblib.dump(
obj,
file,
**serialization_module_parameters_with_protocol,
)
context.log.info(
"Output serialized using joblib with"
f" parameters={serialization_module_parameters_with_protocol}"
)
except Exception as exception:
raise WandbArtifactsIOManagerError(
"An error occurred in the joblib serialization process. Please,"
" verify that the passed arguments are correct and your data is"
" compatible with the module."
) from exception
else:
artifact.metadata = {
**artifact.metadata,
**{
"source_serialization_module": "pickle",
"source_pickle_protocol_used": serialization_module_parameters_with_protocol[
"protocol"
],
},
}
with artifact.new_file(PICKLE_FILENAME, "wb") as file:
try:
pickle.dump(
obj,
file,
**serialization_module_parameters_with_protocol,
)
context.log.info(
"Output serialized using pickle with"
f" parameters={serialization_module_parameters_with_protocol}"
)
except Exception as exception:
raise WandbArtifactsIOManagerError(
"An error occurred in the pickle serialization process."
" Please, verify that the passed arguments are correct and"
" your data is compatible with pickle. Otherwise consider"
" using another module. Supported serialization:"
f" {ACCEPTED_SERIALIZATION_MODULES}."
) from exception
# Add any files: https://docs.wandb.ai/ref/python/artifact#add_file
add_files = parameters.get("add_files")
if add_files is not None and len(add_files) > 0:
for add_file in add_files:
artifact.add_file(**add_file)
# Add any dirs: https://docs.wandb.ai/ref/python/artifact#add_dir
add_dirs = parameters.get("add_dirs")
if add_dirs is not None and len(add_dirs) > 0:
for add_dir in add_dirs:
artifact.add_dir(**add_dir)
# Add any reference: https://docs.wandb.ai/ref/python/artifact#add_reference
add_references = parameters.get("add_references")
if add_references is not None and len(add_references) > 0:
for add_reference in add_references:
artifact.add_reference(**add_reference)
# Augments the aliases
aliases = parameters.get("aliases", [])
aliases.append(f"dagster-run-{self.dagster_run_id[0:8]}")
if "latest" not in aliases:
aliases.append("latest")
# Logs the artifact
self.wandb.log_artifact(artifact, aliases=aliases)
artifact.wait()
# Adds useful metadata to the output or Asset
artifacts_base_url = (
"https://wandb.ai"
if self.wandb_host == WANDB_CLOUD_HOST
else self.wandb_host.rstrip("/")
)
output_metadata = {
"dagster_run_id": MetadataValue.dagster_run(self.dagster_run_id),
"wandb_artifact_id": MetadataValue.text(artifact.id), # type: ignore
"wandb_artifact_type": MetadataValue.text(artifact.type),
"wandb_artifact_version": MetadataValue.text(artifact.version),
"wandb_artifact_size": MetadataValue.int(artifact.size),
"wandb_artifact_url": MetadataValue.url(
f"{artifacts_base_url}/{run.entity}/{run.project}/artifacts/{artifact.type}/{artifact.id}/{artifact.version}"
),
"wandb_entity": MetadataValue.text(run.entity),
"wandb_project": MetadataValue.text(run.project),
"wandb_run_id": MetadataValue.text(run.id),
"wandb_run_name": MetadataValue.text(run.name),
"wandb_run_path": MetadataValue.text(run.path),
"wandb_run_url": MetadataValue.url(run.url),
}
context.add_output_metadata(output_metadata)
def _download_artifact(self, context: InputContext):
with self.wandb_run() as run:
parameters = context.metadata.get("wandb_artifact_configuration", {}) # type: ignore
artifact_alias = parameters.get("alias")
artifact_version = parameters.get("version")
if artifact_alias is not None and artifact_version is not None:
raise WandbArtifactsIOManagerError(
"A value for 'version' and 'alias' have been provided. Only one property can be"
" used at the same time."
)
artifact_identifier = artifact_alias or artifact_version or "latest"
if context.has_asset_key:
artifact_name = context.get_asset_identifier()[0] # name of asset
else:
artifact_name = parameters.get("name")
if artifact_name is None:
raise WandbArtifactsIOManagerError(
"Missing 'name' property in the 'wandb_artifact_configuration' metadata"
" dictionary. A 'name' property is required for Artifacts used in an @op."
" Alternatively you can use an @asset."
)
if context.has_partition_key:
artifact_name = f"{artifact_name}.{context.partition_key}"
artifact = run.use_artifact(
f"{run.entity}/{run.project}/{artifact_name}:{artifact_identifier}"
)
name = parameters.get("get")
path = parameters.get("get_path")
if name is not None and path is not None:
raise WandbArtifactsIOManagerError(
"A value for 'get' and 'get_path' has been provided in the"
" 'wandb_artifact_configuration' metadata dictionary. Only one property can be"
" used. Alternatively you can use neither and the entire Artifact will be"
" dowloaded."
)
if name is not None:
return artifact.get(name)
artifacts_path = self._get_artifacts_path(artifact_name, artifact.version)
if path is not None:
path = artifact.get_path(path)
return path.download(root=artifacts_path)
artifact_dir = artifact.download(root=artifacts_path, recursive=True)
if os.path.exists(f"{artifact_dir}/{DILL_FILENAME}"):
if not has_dill:
raise WandbArtifactsIOManagerError(
"An object pickled with 'dill' was found in the Artifact. But the module"
" was not found. Please, make sure it's installed."
)
with open(f"{artifact_dir}/{DILL_FILENAME}", "rb") as file:
input_value = dill.load(file)
return input_value
elif os.path.exists(f"{artifact_dir}/{CLOUDPICKLE_FILENAME}"):
if not has_cloudpickle:
raise WandbArtifactsIOManagerError(
"An object pickled with 'cloudpickle' was found in the Artifact. But the"
" module was not found. Please, make sure it's installed."
)
with open(f"{artifact_dir}/{CLOUDPICKLE_FILENAME}", "rb") as file:
input_value = cloudpickle.load(file)
return input_value
elif os.path.exists(f"{artifact_dir}/{JOBLIB_FILENAME}"):
if not has_joblib:
raise WandbArtifactsIOManagerError(
"An object pickled with 'joblib' was found in the Artifact. But the module"
" was not found. Please, make sure it's installed."
)
with open(f"{artifact_dir}/{JOBLIB_FILENAME}", "rb") as file:
input_value = joblib.load(file)
return input_value
elif os.path.exists(f"{artifact_dir}/{PICKLE_FILENAME}"):
with open(f"{artifact_dir}/{PICKLE_FILENAME}", "rb") as file:
input_value = pickle.load(file)
return input_value
artifact.verify(root=artifacts_path)
return artifact
def handle_output(self, context: OutputContext, obj) -> None:
if obj is None:
context.log.warning(
"The output value passed to W&B IO Manager is empty. Ignore if expected."
)
else:
try:
self._upload_artifact(context, obj)
except WandbArtifactsIOManagerError as exception:
raise exception
except Exception as exception:
raise WandbArtifactsIOManagerError() from exception
def load_input(self, context: InputContext):
try:
return self._download_artifact(context)
except WandbArtifactsIOManagerError as exception:
raise exception
except Exception as exception:
raise WandbArtifactsIOManagerError() from exception
[docs]@dagster_maintained_io_manager
@io_manager(
required_resource_keys={"wandb_resource", "wandb_config"},
description="IO manager to read and write W&B Artifacts",
config_schema={
"run_name": Field(
String,
is_required=False,
description=(
"Short display name for this run, which is how you'll identify this run in the UI."
" By default, it`s set to a string with the following format dagster-run-[8 first"
" characters of the Dagster Run ID] e.g. dagster-run-7e4df022."
),
),
"run_id": Field(
String,
is_required=False,
description=(
"Unique ID for this run, used for resuming. It must be unique in the project, and"
" if you delete a run you can't reuse the ID. Use the name field for a short"
" descriptive name, or config for saving hyperparameters to compare across runs."
r" The ID cannot contain the following special characters: /\#?%:.. You need to set"
" the Run ID when you are doing experiment tracking inside Dagster to allow the IO"
" Manager to resume the run. By default it`s set to the Dagster Run ID e.g "
" 7e4df022-1bf2-44b5-a383-bb852df4077e."
),
),
"run_tags": Field(
[String],
is_required=False,
description=(
"A list of strings, which will populate the list of tags on this run in the UI."
" Tags are useful for organizing runs together, or applying temporary labels like"
" 'baseline' or 'production'. It's easy to add and remove tags in the UI, or filter"
" down to just runs with a specific tag. Any W&B Run used by the integration will"
" have the dagster_wandb tag."
),
),
"base_dir": Field(
String,
is_required=False,
description=(
"Base directory used for local storage and caching. W&B Artifacts and W&B Run logs"
" will be written and read from that directory. By default, it`s using the"
" DAGSTER_HOME directory."
),
),
"cache_duration_in_minutes": Field(
Int,
is_required=False,
description=(
"Defines the amount of time W&B Artifacts and W&B Run logs should be kept in the"
" local storage. Only files and directories that were not opened for that amount of"
" time are removed from the cache. Cache purging happens at the end of an IO"
" Manager execution. You can set it to 0, if you want to disable caching"
" completely. Caching improves speed when an Artifact is reused between jobs"
" running on the same machine. It defaults to 30 days."
),
),
},
)
def wandb_artifacts_io_manager(context: InitResourceContext):
"""Dagster IO Manager to create and consume W&B Artifacts.
It allows any Dagster @op or @asset to create and consume W&B Artifacts natively.
For a complete set of documentation, see `Dagster integration <https://docs.wandb.ai/guides/integrations/dagster>`_.
**Example:**
.. code-block:: python
@repository
def my_repository():
return [
*with_resources(
load_assets_from_current_module(),
resource_defs={
"wandb_config": make_values_resource(
entity=str,
project=str,
),
"wandb_resource": wandb_resource.configured(
{"api_key": {"env": "WANDB_API_KEY"}}
),
"wandb_artifacts_manager": wandb_artifacts_io_manager.configured(
{"cache_duration_in_minutes": 60} # only cache files for one hour
),
},
resource_config_by_key={
"wandb_config": {
"config": {
"entity": "my_entity",
"project": "my_project"
}
}
},
),
]
@asset(
name="my_artifact",
metadata={
"wandb_artifact_configuration": {
"type": "dataset",
}
},
io_manager_key="wandb_artifacts_manager",
)
def create_dataset():
return [1, 2, 3]
"""
wandb_client = context.resources.wandb_resource["sdk"]
wandb_host = context.resources.wandb_resource["host"]
wandb_entity = context.resources.wandb_config["entity"]
wandb_project = context.resources.wandb_config["project"]
wandb_run_name = None
wandb_run_id = None
wandb_run_tags = None
cache_duration_in_minutes = None
if context.resource_config is not None:
wandb_run_name = context.resource_config.get("run_name")
wandb_run_id = context.resource_config.get("run_id")
wandb_run_tags = context.resource_config.get("run_tags")
base_dir = context.resource_config.get(
"base_dir",
context.instance.storage_directory()
if context.instance
else os.environ["DAGSTER_HOME"],
)
cache_duration_in_minutes = context.resource_config.get("cache_duration_in_minutes")
if "PYTEST_CURRENT_TEST" in os.environ:
dagster_run_id = "unit-testing"
else:
dagster_run_id = context.run_id
config: Config = {
"dagster_run_id": dagster_run_id or "",
"wandb_host": wandb_host,
"wandb_entity": wandb_entity,
"wandb_project": wandb_project,
"wandb_run_name": wandb_run_name,
"wandb_run_id": wandb_run_id,
"wandb_run_tags": wandb_run_tags,
"base_dir": base_dir,
"cache_duration_in_minutes": cache_duration_in_minutes,
}
return ArtifactsIOManager(wandb_client, config)