-
Notifications
You must be signed in to change notification settings - Fork 75
/
Copy pathevaluate.py
123 lines (103 loc) · 4.48 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import argparse
import numpy as np
from tqdm import tqdm
from pebble import ProcessPool
from concurrent.futures import TimeoutError
from eval.grader import *
from utils.parser import *
from utils.utils import load_jsonl
from utils.python_executor import PythonExecutor
def evaluate(data_name, prompt_type, samples: list=None, file_path: str=None, max_num_samples=None, execute=False):
assert samples or file_path, "samples or file_path must be provided"
if not samples:
samples = list(load_jsonl(file_path))
# dedup by idx
if 'idx' in samples[0]:
samples = {sample['idx']: sample for sample in samples}.values()
samples = sorted(samples, key=lambda x: x['idx'])
else:
samples = [dict(idx=idx, **sample) for idx, sample in enumerate(samples)]
if max_num_samples:
print(f"max_num_samples: {max_num_samples} / {len(samples)}")
samples = samples[:max_num_samples]
# parse gt
for sample in samples:
sample['gt_cot'], sample['gt'] = parse_ground_truth(sample, data_name)
# execute
if ('pred' not in samples[0]) or execute:
if "pal" in prompt_type:
executor = PythonExecutor(get_answer_expr="solution()")
else:
executor = PythonExecutor(get_answer_from_stdout=True)
for sample in tqdm(samples, desc="Execute"):
sample['pred'] = []
sample['report'] = []
for code in sample['code']:
pred, report = run_execute(executor, code, prompt_type, execute=True)
sample['pred'].append(pred)
sample['report'].append(report)
params = [(idx, pred, sample['gt']) for idx, sample in enumerate(samples) for pred in sample['pred']]
scores = []
timeout_cnt = 0
with ProcessPool() as pool:
future = pool.map(math_equal_process, params, timeout=3)
iterator = future.result()
with tqdm(total=len(samples), desc="Evaluate") as progress_bar:
while True:
try:
result = next(iterator)
scores.append(result)
except StopIteration:
break
except TimeoutError as error:
print(error)
scores.append(False)
timeout_cnt += 1
except Exception as error:
print(error.traceback)
exit()
progress_bar.update(1)
idx = 0
score_mat = []
for sample in samples:
sample['score'] = scores[idx: idx+len(sample['pred'])]
assert len(sample['score']) == len(sample['pred'])
score_mat.append(sample['score'])
idx += len(sample['pred'])
max_len = max([len(s) for s in score_mat])
for i, s in enumerate(score_mat):
if len(s) < max_len:
score_mat[i] = s + [s[-1]] * (max_len - len(s)) # pad
# output mean of each column of scores
col_means= np.array(score_mat).mean(axis=0)
mean_score = list(np.round(col_means * 100, decimals=1))
result_str = f"Num samples: {len(samples)}\n" \
f"Num scores: {len(scores)}\n" \
f"Timeout samples: {timeout_cnt}\n" \
f"Empty samples: {len([s for s in samples if not s['pred'][-1]])}\n" \
f"Mean score: {mean_score}\n"
# each type score
if "type" in samples[0]:
type_scores = {}
for sample in samples:
if sample['type'] not in type_scores:
type_scores[sample['type']] = []
type_scores[sample['type']].append(sample['score'][-1])
type_scores = {k: np.round(np.array(v).mean() * 100, decimals=1) for k, v in type_scores.items()}
type_scores = {k: v for k, v in sorted(type_scores.items(), key=lambda item: item[0])}
result_str += f"Type scores: {type_scores}\n"
print(result_str)
return result_str
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--data_name", type=str, default="math")
parser.add_argument("--prompt_type", type=str, default="tora")
parser.add_argument("--file_path", type=str, default=None, required=True)
parser.add_argument("--max_num_samples", type=int, default=None)
parser.add_argument("--execute", action="store_true")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
evaluate(data_name=args.data_name, prompt_type=args.prompt_type, file_path=args.file_path,
max_num_samples=args.max_num_samples, execute=args.execute)