Spaces:
Runtime error
Runtime error
| # Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids | |
| import io | |
| import json | |
| import os | |
| import tempfile | |
| from urllib.parse import urlparse, urlunparse | |
| from .wids_dl import download_and_open | |
| def urldir(url): | |
| """Return the directory part of a url.""" | |
| parsed_url = urlparse(url) | |
| path = parsed_url.path | |
| directory = os.path.dirname(path) | |
| return parsed_url._replace(path=directory).geturl() | |
| def urlmerge(base, url): | |
| """Merge a base URL and a relative URL. | |
| The function fills in any missing part of the url from the base, | |
| except for params, query, and fragment, which are taken only from the 'url'. | |
| For the pathname component, it merges the paths like os.path.join: | |
| an absolute path in 'url' overrides the base path, otherwise the paths are merged. | |
| Parameters: | |
| base (str): The base URL. | |
| url (str): The URL to merge with the base. | |
| Returns: | |
| str: The merged URL. | |
| """ | |
| # Parse the base and the relative URL | |
| parsed_base = urlparse(base) | |
| parsed_url = urlparse(url) | |
| # Merge paths using os.path.join | |
| # If the url path is absolute, it overrides the base path | |
| if parsed_url.path.startswith("/"): | |
| merged_path = parsed_url.path | |
| else: | |
| merged_path = os.path.normpath(os.path.join(parsed_base.path, parsed_url.path)) | |
| # Construct the merged URL | |
| merged_url = urlunparse( | |
| ( | |
| parsed_url.scheme or parsed_base.scheme, | |
| parsed_url.netloc or parsed_base.netloc, | |
| merged_path, | |
| parsed_url.params, # Use params from the url only | |
| parsed_url.query, # Use query from the url only | |
| parsed_url.fragment, # Use fragment from the url only | |
| ) | |
| ) | |
| return merged_url | |
| def check_shards(l): | |
| """Check that a list of shards is well-formed. | |
| This checks that the list is a list of dictionaries, and that | |
| each dictionary has a "url" and a "nsamples" key. | |
| """ | |
| assert isinstance(l, list) | |
| for shard in l: | |
| assert isinstance(shard, dict) | |
| assert "url" in shard | |
| assert "nsamples" in shard | |
| return l | |
| def set_all(l, k, v): | |
| """Set a key to a value in a list of dictionaries.""" | |
| if v is None: | |
| return | |
| for x in l: | |
| if k not in x: | |
| x[k] = v | |
| def load_remote_dsdesc_raw(source): | |
| """Load a remote or local dataset description in JSON format.""" | |
| if isinstance(source, str): | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| dlname = os.path.join(tmpdir, "dataset.json") | |
| with download_and_open(source, dlname) as f: | |
| dsdesc = json.load(f) | |
| elif isinstance(source, io.IOBase): | |
| dsdesc = json.load(source) | |
| else: | |
| # FIXME: use gopen | |
| import requests | |
| jsondata = requests.get(source).text | |
| dsdesc = json.loads(jsondata) | |
| return dsdesc | |
| def rebase_shardlist(shardlist, base): | |
| """Rebase the URLs in a shardlist.""" | |
| if base is None: | |
| return shardlist | |
| for shard in shardlist: | |
| shard["url"] = urlmerge(base, shard["url"]) | |
| return shardlist | |
| def resolve_dsdesc(dsdesc, *, options=None, base=None): | |
| """Resolve a dataset description. | |
| This rebases the shards as necessary and loads any remote references. | |
| Dataset descriptions are JSON files. They must have the following format; | |
| { | |
| "wids_version": 1, | |
| # optional immediate shardlist | |
| "shardlist": [ | |
| {"url": "http://example.com/file.tar", "nsamples": 1000}, | |
| ... | |
| ], | |
| # sub-datasets | |
| "datasets": [ | |
| {"source_url": "http://example.com/dataset.json"}, | |
| {"shardlist": [ | |
| {"url": "http://example.com/file.tar", "nsamples": 1000}, | |
| ... | |
| ]} | |
| ... | |
| ] | |
| } | |
| """ | |
| if options is None: | |
| options = {} | |
| assert isinstance(dsdesc, dict) | |
| dsdesc = dict(dsdesc, **options) | |
| shardlist = rebase_shardlist(dsdesc.get("shardlist", []), base) | |
| assert shardlist is not None | |
| set_all(shardlist, "weight", dsdesc.get("weight")) | |
| set_all(shardlist, "name", dsdesc.get("name")) | |
| check_shards(shardlist) | |
| assert "wids_version" in dsdesc, "No wids_version in dataset description" | |
| assert dsdesc["wids_version"] == 1, "Unknown wids_version" | |
| for component in dsdesc.get("datasets", []): | |
| # we use the weight from the reference to the dataset, | |
| # regardless of remote loading | |
| weight = component.get("weight") | |
| # follow any source_url dsdescs through remote loading | |
| source_url = None | |
| if "source_url" in component: | |
| source_url = component["source_url"] | |
| component = load_remote_dsdesc_raw(source_url) | |
| assert "source_url" not in component, "double indirection in dataset description" | |
| assert "shardlist" in component, "no shardlist in dataset description" | |
| # if the component has a base, use it to rebase the shardlist | |
| # otherwise use the base from the source_url, if any | |
| subbase = component.get("base", urldir(source_url) if source_url else None) | |
| if subbase is not None: | |
| rebase_shardlist(component["shardlist"], subbase) | |
| l = check_shards(component["shardlist"]) | |
| set_all(l, "weight", weight) | |
| set_all(l, "source_url", source_url) | |
| set_all(l, "dataset", component.get("name")) | |
| shardlist.extend(l) | |
| assert len(shardlist) > 0, "No shards found" | |
| dsdesc["shardlist"] = shardlist | |
| return dsdesc | |
| def load_dsdesc_and_resolve(source, *, options=None, base=None): | |
| if options is None: | |
| options = {} | |
| dsdesc = load_remote_dsdesc_raw(source) | |
| return resolve_dsdesc(dsdesc, base=base, options=options) | |