-
-
Notifications
You must be signed in to change notification settings - Fork 215
/
Copy pathodqa.py
304 lines (245 loc) · 12.5 KB
/
odqa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
.. module:: odqa
:platform: Unix
:synopsis: the top-level submodule of Dragonfire that contains the classes related to **ODQA**: Dragonfire's DeepPavlov SQuAD BERT model based Open-Domain Question Answering Engine.
.. moduleauthor:: Mehmet Mert Yıldıran <[email protected]>
"""
import collections # Imported to support ordered dictionaries in Python
from random import uniform # Generate pseudo-random numbers
from dragonfire.utilities import nostderr # With statement to suppress errors
import requests.exceptions # HTTP for Humans
import wikipedia # Provides and API-like functionality to search and access Wikipedia data
import wikipedia.exceptions # Exceptions of wikipedia library
from nltk.corpus import wordnet as wn # The WordNet corpus
from nltk.corpus.reader.wordnet import WordNetError # To catch the errors
from deeppavlov import build_model, configs
class ODQA():
"""Class to provide the factoid question answering ability.
"""
def __init__(self, nlp):
"""Initialization method of :class:`dragonfire.odqa.ODQA` class.
Args:
nlp: :mod:`spacy` model instance.
"""
self.nlp = nlp # Load en_core_web_sm, English, 50 MB, default model
self.model = build_model(configs.squad.squad, download=True)
def respond(self, com, tts_output=False, userin=None, user_prefix=None, is_server=False):
"""Method to respond the user's input/command using factoid question answering ability.
Args:
com (str): User's command.
Keyword Args:
tts_output (bool): Is text-to-speech output enabled?
userin: :class:`dragonfire.utilities.TextToAction` instance.
user_prefix (str): Prefix to address/call user when answering.
is_server (bool): Is Dragonfire running as an API server?
Returns:
str: Response.
.. note::
Entry function for :class:`ODQA` class. Dragonfire calls only this function. Unlike :func:`Learner.respond`, it executes TTS because of its late reponse nature.
"""
result = None
subject, subjects, focus, subject_with_objects = self.semantic_extractor(com) # Extract the subject, focus, objects etc.
if not subject:
return False
doc = self.nlp(com) # spaCy does all kinds of NLP analysis in one function
query = subject # Wikipedia search query (same as the subject)
# This is where the real search begins
if query: # If there is a Wikipedia query determined
if not tts_output and not is_server: print("Please wait...")
if tts_output and not is_server: userin.say("Please wait...", True, False) # Gain a few more seconds by saying Please wait...
wh_question = []
for word in doc: # Iterate over the words in the command(user's speech)
if word.tag_ in ['WDT', 'WP', 'WP$', 'WRB']: # if there is a wh word then
wh_question.append(word.text.upper()) # append it by converting to uppercase
if not wh_question:
return False
with nostderr():
try:
wikiresult = wikipedia.search(query) # run a Wikipedia search with the query
if len(wikiresult) == 0: # if there are no results
result = "Sorry, " + user_prefix + ". But I couldn't find anything about " + query + " in Wikipedia."
if not tts_output and not is_server: print(result)
if tts_output and not is_server: userin.say(result)
return result
wikipage = wikipedia.page(wikiresult[0])
return self.model([wikipage.content], [com])[0][0]
except requests.exceptions.ConnectionError: # if there is a connection error
result = "Sorry, " + user_prefix + ". But I'm unable to connect to Wikipedia servers."
if not is_server:
userin.execute([" "], "Wikipedia connection error.")
if not tts_output: print(result)
if tts_output: userin.say(result)
return result
except wikipedia.exceptions.DisambiguationError as disambiguation: # if there is a disambiguation
wikiresult = wikipedia.search(disambiguation.options[0]) # run Wikipedia search again with the most common option
except:
result = "Sorry, " + user_prefix + ". But something went horribly wrong while I'm searching Wikipedia."
if not tts_output and not is_server: print(result)
if tts_output and not is_server: userin.say(result)
return result
def phrase_cleaner(self, phrase):
"""Function to clean unnecessary words from the given phrase/string. (Punctuation mark, symbol, unknown, conjunction, determiner, subordinating or preposition and space)
Args:
phrase (str): Noun phrase.
Returns:
str: Cleaned noun phrase.
"""
clean_phrase = []
for word in self.nlp(phrase):
if word.pos_ not in ['PUNCT', 'SYM', 'X', 'CONJ', 'DET', 'ADP', 'SPACE']:
clean_phrase.append(word.text)
return ' '.join(clean_phrase)
def semantic_extractor(self, string):
"""Function to extract subject, subjects, focus, subject_with_objects from given string.
Args:
string (str): String.
Returns:
(list) of (str)s: List of subject, subjects, focus, subject_with_objects.
"""
doc = self.nlp(string) # spaCy does all kinds of NLP analysis in one function
the_subject = None # Wikipedia search query variable definition (the subject)
# Followings are lists because it could be multiple of them in a string. Multiple objects or subjects...
subjects = [] # subject list
pobjects = [] # object of a preposition list
dobjects = [] # direct object list
# https://nlp.stanford.edu/software/dependencies_manual.pdf - Hierarchy of typed dependencies
for np in doc.noun_chunks: # Iterate over the noun phrases(chunks)
# print(np.text, np.root.text, np.root.dep_, np.root.head.text)
if (np.root.dep_ == 'nsubj' or np.root.dep_ == 'nsubjpass') and np.root.tag_ != 'WP': # if it's a nsubj(nominal subject) or nsubjpass(passive nominal subject) then
subjects.append(np.text) # append it to subjects
if np.root.dep_ == 'pobj': # if it's an object of a preposition then
pobjects.append(np.text) # append it to pobjects
if np.root.dep_ == 'dobj': # if it's a direct object then
dobjects.append(np.text) # append it to direct objects
# This block determines the Wikipedia query (the subject) by relying on this priority: [Object of a preposition] > [Subject] > [Direct object]
pobjects = [x for x in pobjects]
subjects = [x for x in subjects]
dobjects = [x for x in dobjects]
if pobjects:
the_subject = ' '.join(pobjects)
elif subjects:
the_subject = ' '.join(subjects)
elif dobjects:
the_subject = ' '.join(dobjects)
else:
return None, None, None, None
# This block determines the focus(objective/goal) by relying on this priority: [Direct object] > [Subject] > [Object of a preposition]
focus = None
if dobjects:
focus = self.phrase_cleaner(' '.join(dobjects))
elif subjects:
focus = self.phrase_cleaner(' '.join(subjects))
elif pobjects:
focus = self.phrase_cleaner(' '.join(pobjects))
if focus in the_subject:
focus = None
# Full string of all subjects and objects concatenated
subject_with_objects = []
for dobject in dobjects:
subject_with_objects.append(dobject)
for subject in subjects:
subject_with_objects.append(subject)
for pobject in pobjects:
subject_with_objects.append(pobject)
subject_with_objects = ' '.join(subject_with_objects)
wh_found = False
for word in doc: # iterate over the each word in the given command(user's speech)
if word.tag_ in ['WDT', 'WP', 'WP$', 'WRB']: # check if there is a "wh-" question (we are determining that if it's a question or not, so only accepting questions with "wh-" form)
wh_found = True
if not wh_found:
return None, None, None, None
return the_subject, subjects, focus, subject_with_objects
def check_how_odqa_performs(self):
import json
import urllib.request
import random
import multiprocessing
import threading
from termcolor import colored
from dragonfire.utilities import split, s_print
HOTPOTQA_DATASET_URL = 'http://curtis.ml.cmu.edu/datasets/hotpot/hotpot_dev_fullwiki_v1.json'
SAMPLE_LENGTH = None
THREAD_MULTIPLIER = 1
CPU_COUNT = multiprocessing.cpu_count()
THREAD_COUNT = CPU_COUNT * THREAD_MULTIPLIER
correct_counter = 0
wrong_counter = 0
response = urllib.request.urlopen(HOTPOTQA_DATASET_URL)
dataset = response.read()
if SAMPLE_LENGTH is not None:
samples = random.sample(json.loads(dataset), SAMPLE_LENGTH)
else:
samples = json.loads(dataset)
question_number = 0
for sample in samples:
question_number += 1
sample['question_number'] = question_number
print('\nThread Count: {0}\n'.format(THREAD_COUNT))
print('\nStarting to test {0} questions'.format(len(samples)))
samples_split = list(split(samples, THREAD_COUNT))
results = []
threads = []
for j in range(THREAD_COUNT):
t = threading.Thread(
target=self.async_performer,
args=(
j,
samples_split[j],
results,
s_print
)
)
t.daemon = True
t.start()
threads.append(t)
for t in threads:
t.join()
print(colored('\n(Correct, Wrong) Pairs: {0}\n'.format(results), 'yellow'))
correct_total = sum([pair[0] for pair in results])
wrong_total = sum([pair[1] for pair in results])
success = correct_total / (correct_total + wrong_total)
print(colored('\nPerformance: {0}\n'.format(success), 'yellow'))
if success >= 0.05:
print(colored('SUCCESS!', 'green'))
exit(0)
else:
print(colored('FAILURE!', 'red'))
exit(1)
def async_performer(self, thread_number, samples, results, s_print):
s_print('\nThead {0} is started.\n'.format(thread_number + 1))
from termcolor import colored
correct_counter = 0
wrong_counter = 0
for sample in samples:
out = ''
out += '\n({0})'.format(sample['question_number'])
question = sample['question']
correct_answer = sample['answer']
out += '\nQuestion: {0}'.format(question.encode('ascii', 'ignore').decode('ascii'))
out += '\nCorrect Answer: {0}'.format(correct_answer.encode('ascii', 'ignore').decode('ascii'))
if not question or not correct_answer:
out += colored('\nDataset contains an empty question or answer, so it\'s skipped!', 'yellow')
continue
answer = self.respond(question, user_prefix="sir", is_server=True)
if isinstance(answer, str):
out += '\nOur Answer: {0}'.format(answer.encode('ascii', 'ignore').decode('ascii'))
else:
out += '\nOur Answer: {0}'.format(answer)
if not isinstance(answer, str):
wrong_counter += 1
out += colored('\nWRONG', 'red')
elif answer in correct_answer:
correct_counter += 1
out += colored('\nCORRECT', 'green')
else:
wrong_counter += 1
out += colored('\nWRONG', 'red')
s_print(out)
results.append((correct_counter, wrong_counter))
return True
if __name__ == '__main__':
import spacy
odqa = ODQA(spacy.load('en'))
odqa.check_how_odqa_performs()