Spaces:
Runtime error
Runtime error
| # Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # This file is copied from https://github.com/NVlabs/VILA/tree/main/llava/wids | |
| import collections | |
| import fcntl | |
| import io | |
| import mmap | |
| import os | |
| import struct | |
| TarHeader = collections.namedtuple( | |
| "TarHeader", | |
| [ | |
| "name", | |
| "mode", | |
| "uid", | |
| "gid", | |
| "size", | |
| "mtime", | |
| "chksum", | |
| "typeflag", | |
| "linkname", | |
| "magic", | |
| "version", | |
| "uname", | |
| "gname", | |
| "devmajor", | |
| "devminor", | |
| "prefix", | |
| ], | |
| ) | |
| def parse_tar_header(header_bytes): | |
| header = struct.unpack("!100s8s8s8s12s12s8s1s100s6s2s32s32s8s8s155s", header_bytes) | |
| return TarHeader(*header) | |
| def next_header(offset, header): | |
| block_size = 512 | |
| size = header.size.decode("utf-8").strip("\x00") | |
| if size == "": | |
| return -1 | |
| size = int(size, 8) | |
| # compute the file size rounded up to the next block size if it is a partial block | |
| padded_file_size = (size + block_size - 1) // block_size * block_size | |
| return offset + block_size + padded_file_size | |
| # TODO(ligeng): support gzip stream | |
| class MMIndexedTar: | |
| def __init__(self, fname, index_file=None, verbose=True, cleanup_callback=None): | |
| self.verbose = verbose | |
| self.cleanup_callback = cleanup_callback | |
| if isinstance(fname, str): | |
| self.stream = open(fname, "rb") | |
| self.fname = fname | |
| elif isinstance(fname, io.IOBase): | |
| self.stream = fname | |
| self.fname = None | |
| self.mmapped_file = mmap.mmap(self.stream.fileno(), 0, access=mmap.ACCESS_READ) | |
| if cleanup_callback: | |
| cleanup_callback(fname, self.stream.fileno(), "start") | |
| self._build_index() | |
| def close(self, dispose=False): | |
| if self.cleanup_callback: | |
| self.cleanup_callback(self.fname, self.stream.fileno(), "end") | |
| self.mmapped_file.close() | |
| self.stream.close() | |
| def _build_index(self): | |
| self.by_name = {} | |
| self.by_index = [] | |
| offset = 0 | |
| while offset >= 0 and offset < len(self.mmapped_file): | |
| header = parse_tar_header(self.mmapped_file[offset : offset + 500]) | |
| name = header.name.decode("utf-8").strip("\x00") | |
| typeflag = header.typeflag.decode("utf-8").strip("\x00") | |
| if name != "" and name != "././@PaxHeader" and typeflag in ["0", ""]: | |
| try: | |
| size = int(header.size.decode("utf-8")[:-1], 8) | |
| except ValueError as exn: | |
| print(header) | |
| raise exn | |
| self.by_name[name] = offset | |
| self.by_index.append((name, offset, size)) | |
| offset = next_header(offset, header) | |
| def names(self): | |
| return self.by_name.keys() | |
| def get_at_offset(self, offset): | |
| header = parse_tar_header(self.mmapped_file[offset : offset + 500]) | |
| name = header.name.decode("utf-8").strip("\x00") | |
| start = offset + 512 | |
| end = start + int(header.size.decode("utf-8")[:-1], 8) | |
| return name, self.mmapped_file[start:end] | |
| def get_at_index(self, index): | |
| name, offset, size = self.by_index[index] | |
| return self.get_at_offset(offset) | |
| def get_by_name(self, name): | |
| offset = self.by_name[name] | |
| return self.get_at_offset(offset) | |
| def __iter__(self): | |
| for name, offset, size in self.by_index: | |
| yield name, self.mmapped_file[offset + 512 : offset + 512 + size] | |
| def __getitem__(self, key): | |
| if isinstance(key, int): | |
| return self.get_at_index(key) | |
| else: | |
| return self.get_by_name(key) | |
| def __len__(self): | |
| return len(self.by_index) | |
| def get_file(self, i): | |
| fname, data = self.get_at_index(i) | |
| return fname, io.BytesIO(data) | |
| def keep_while_reading(fname, fd, phase, delay=0.0): | |
| """This is a possible cleanup callback for cleanup_callback of MIndexedTar. | |
| It assumes that as long as there are some readers for a file, | |
| more readers may be trying to open it. | |
| Note that on Linux, unlinking the file doesn't matter after | |
| it has been mmapped. The contents will only be deleted when | |
| all readers close the file. The unlinking merely makes the file | |
| unavailable to new readers, since the downloader checks first | |
| whether the file exists. | |
| """ | |
| assert delay == 0.0, "delay not implemented" | |
| if fd < 0 or fname is None: | |
| return | |
| if phase == "start": | |
| fcntl.flock(fd, fcntl.LOCK_SH) | |
| elif phase == "end": | |
| try: | |
| fcntl.flock(fd, fcntl.LOCK_EX | fcntl.LOCK_NB) | |
| os.unlink(fname) | |
| except FileNotFoundError: | |
| # someone else deleted it already | |
| pass | |
| except BlockingIOError: | |
| # we couldn't get an exclusive lock, so someone else is still reading | |
| pass | |
| else: | |
| raise ValueError(f"Unknown phase {phase}") | |