Skip to content

Commit d1f16cf

Browse files
authored
Accelerate get_infos by caching the DataseInfoDicts (#778)
* accelerate `get_infos` by caching the `DataseInfoDict`s * quality * consistency
1 parent f5c3977 commit d1f16cf

File tree

2 files changed

+29
-8
lines changed

2 files changed

+29
-8
lines changed

Diff for: promptsource/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
DEFAULT_PROMPTSOURCE_CACHE_HOME = "~/.cache/promptsource"
1+
from pathlib import Path
2+
3+
4+
DEFAULT_PROMPTSOURCE_CACHE_HOME = str(Path("~/.cache/promptsource").expanduser())

Diff for: promptsource/app.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
import argparse
22
import functools
33
import multiprocessing
4+
import os
45
import textwrap
6+
from hashlib import sha256
57
from multiprocessing import Manager, Pool
68

79
import pandas as pd
810
import plotly.express as px
911
import streamlit as st
1012
from datasets import get_dataset_infos
13+
from datasets.info import DatasetInfosDict
1114
from pygments import highlight
1215
from pygments.formatters import HtmlFormatter
1316
from pygments.lexers import DjangoLexer
14-
from templates import INCLUDED_USERS
1517

18+
from promptsource import DEFAULT_PROMPTSOURCE_CACHE_HOME
1619
from promptsource.session import _get_state
17-
from promptsource.templates import DatasetTemplates, Template, TemplateCollection
20+
from promptsource.templates import INCLUDED_USERS, DatasetTemplates, Template, TemplateCollection
1821
from promptsource.utils import (
1922
get_dataset,
2023
get_dataset_confs,
@@ -25,6 +28,9 @@
2528
)
2629

2730

31+
DATASET_INFOS_CACHE_DIR = os.path.join(DEFAULT_PROMPTSOURCE_CACHE_HOME, "DATASET_INFOS")
32+
os.makedirs(DATASET_INFOS_CACHE_DIR, exist_ok=True)
33+
2834
# Python 3.8 switched the default start method from fork to spawn. OS X also has
2935
# some issues related to fork, eee, e.g., https://github.com./bigscience-workshop/promptsource/issues/572
3036
# so we make sure we always use spawn for consistency
@@ -38,7 +44,17 @@ def get_infos(all_infos, d_name):
3844
:param all_infos: multiprocess-safe dictionary
3945
:param d_name: dataset name
4046
"""
41-
all_infos[d_name] = get_dataset_infos(d_name)
47+
d_name_bytes = d_name.encode("utf-8")
48+
d_name_hash = sha256(d_name_bytes)
49+
foldername = os.path.join(DATASET_INFOS_CACHE_DIR, d_name_hash.hexdigest())
50+
if os.path.isdir(foldername):
51+
infos_dict = DatasetInfosDict.from_directory(foldername)
52+
else:
53+
infos = get_dataset_infos(d_name)
54+
infos_dict = DatasetInfosDict(infos)
55+
os.makedirs(foldername)
56+
infos_dict.write_to_directory(foldername)
57+
all_infos[d_name] = infos_dict
4258

4359

4460
# add an argument for read-only
@@ -181,11 +197,13 @@ def show_text(t, width=WIDTH, with_markdown=False):
181197
else:
182198
subset_infos = infos[subset_name]
183199

184-
split_sizes = {k: v.num_examples for k, v in subset_infos.splits.items()}
200+
try:
201+
split_sizes = {k: v.num_examples for k, v in subset_infos.splits.items()}
202+
except Exception:
203+
# Fixing bug in some community datasets.
204+
# For simplicity, just filling `split_sizes` with nothing, so the displayed split sizes will be 0.
205+
split_sizes = {}
185206
else:
186-
# Zaid/coqa_expanded and Zaid/quac_expanded don't have dataset_infos.json
187-
# so infos is an empty dic, and `infos[list(infos.keys())[0]]` raises an error
188-
# For simplicity, just filling `split_sizes` with nothing, so the displayed split sizes will be 0.
189207
split_sizes = {}
190208

191209
# Collect template counts, original task counts and names

0 commit comments

Comments
 (0)