Skip to main content

auto download datasets

リモートサーバーからデータセットをwgetすると、403 Forbiddenが出ました。 (wget <URL> -dで確認すると、You don't have permission to access <URL> といわれていました)。 いつもはsshでデータを送っていたのですが、今回は600GBを超えていました. (ローカルはあと5GBしかない...)

データセットを自動setupしてくれるscriptが、最近試したnvidiaのサンプルコードにあったので、 参考にして、Pythonで何とかしようとしました。 (メモリがあふれていたのとデータが大きすぎて途中で止めていたことに気づかず結構はまりました。)

requests.Sesssionの.iter_contentでメモリを分けてダウンロードし、 tqdmでプログレスバーを表示させるとうまくいきました。

REF

次のような関数を定義しておきます。

import requests, zipfile, os, sys, subprocess
from tqdm import tqdm
def download_file(url, dir='./'):
session = requests.Session()
response = session.get(url, stream=True)
destination = os.path.join(dir, os.path.basename(url))
content_size = int(response.headers["content-length"])
try:
print('download %s'%destination)
CHUNK_SIZE = 32768
pbar = tqdm(total=content_size, unit="B", unit_scale=True)
with open(destination, "wb") as f:
for chunk in [c for c in response.iter_content(CHUNK_SIZE) if c]:
pbar.update(len(chunk))
f.write(chunk)
pbar.close()
except:
import traceback
traceback.print_exc()

torchnlpのコードを変えてunzipします。

def unzip_file(url, dir='./'):
destination = os.path.join(dir, os.path.basename(url))
extension = extension = os.path.basename(url).split('.', 1)[1]
if 'zip' in extension:
with zipfile.ZipFile(destination, "r") as f:
f.extractall(dir)
elif 'tar.gz' in extension or 'tgz' in extension:
subprocess.call(['tar', '-C', dir, '-zxvf', destination])
elif 'tar' in extension:
subprocess.call(['tar', '-C', dir, '-xvf', destination])
os.remove(destination)

今回はrgb_urlにダウンロードできるurlの一覧があったので、 リストで取得して各ダウンロードします。

def main():
rgb_url = "URL"
chpt_path = "./datasets"
rgb_dir = os.path.join(chpt_path, "train_images")
dirs = [chpt_path, rgb_dir]
_ = [os.makedirs(dir) for dir in dirs if not os.path.isdir(dir)]
for url in [u for u in requests.get(rgb_url).text.split("") if u][:1]:
download_file(url, rgb_dir)
unzip_file(url, rgb_dir)
if __name__=="__main__":
main()

追記:dataloader

本来のnvidiaのコードとは異なるディレクトリ構成なので、 dataloaderの構成を変えようと思ったら、 そもそも画像データが入ったpathかで判別していました。

IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM',
'.bmp', '.BMP', '.tiff', '.webp',
'.txt', '.json']

def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

def make_grouped_dataset(dir):
for fname in sorted(sorted(os.walk(dir))):
paths = []; root = fname[0]
for f in sorted(fname[2]):
if is_image_file(f):
paths.append(os.path.join(root, f))
if len(paths) > 0:
images.append(paths)
return images

importlibによって、dataset_nameからimportするclassを選択できます。

dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)

importしたlibの中から、BaseDatasetを継承したカスタムデータセットのclassを見つけます。

dataset = None
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls