Skip to content

Commit ab294bd

Browse files
iftenneyLIT team
authored and
LIT team
committed
Utils and helpers for sequence salience, most notably token grouping code.
PiperOrigin-RevId: 606346156
1 parent 27e6901 commit ab294bd

File tree

8 files changed

+243
-4
lines changed

8 files changed

+243
-4
lines changed

lit_nlp/client/elements/tooltip.css

+8
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
* with tooltip positioning.
1111
*/
1212
--anchor-display-mode: inline-block;
13+
--tooltip-position-left: unset;
14+
--tooltip-position-right: unset;
15+
--tooltip-position-top: unset;
16+
--tooltip-position-bottom: unset;
1317
}
1418

1519
/* Tooltip */
@@ -34,6 +38,10 @@
3438
font-size: 12px;
3539
font-weight: normal;
3640
line-height: 16px;
41+
left: var(--tooltip-position-left);
42+
right: var(--tooltip-position-right);
43+
top: var(--tooltip-position-top);
44+
bottom: var(--tooltip-position-bottom);
3745

3846
display: -webkit-box;
3947
-webkit-line-clamp: 6;

lit_nlp/client/elements/tooltip.ts

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ export class LitTooltip extends ReactiveElement {
7171
'disabled': this.disabled,
7272
});
7373

74+
// prettier-ignore
7475
return html`<div class='lit-tooltip'>
7576
<slot name="tooltip-anchor">
7677
${this.content === '' ? '' : html`

lit_nlp/client/lib/token_utils.ts

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/**
2+
* @fileoverview Utils for working with tokenized text.
3+
*/
4+
5+
/**
6+
* Evil underscore used by sentencepiece to replace spaces.
7+
*/
8+
export const SPM_SPACE_SENTINEL = '▁';
9+
10+
/**
11+
* Clean SPM text to make it more human-readable.
12+
*/
13+
export function cleanSpmText(text: string): string {
14+
return text.replaceAll(SPM_SPACE_SENTINEL, ' ');
15+
}
16+
17+
/**
18+
* Use a regex to match segment prefixes. The prefix and anything
19+
* following it (until the next match) are treated as one segment.
20+
*/
21+
export function groupTokensByRegexPrefix(
22+
tokens: string[],
23+
matcher: RegExp,
24+
): string[][] {
25+
const text = tokens.join('');
26+
const matches = [...text.matchAll(matcher)];
27+
28+
let textCharOffset = 0; // chars into text
29+
let matchIdx = 0; // indices into matches
30+
const groups: string[][] = [];
31+
let acc: string[] = [];
32+
for (let i = 0; i < tokens.length; i++) {
33+
const token = tokens[i];
34+
const nextMatch = matches[matchIdx];
35+
36+
// Look ahead to see if this token intrudes on a match.
37+
// If so, start a new segment before pushing the token.
38+
if (nextMatch !== undefined &&
39+
textCharOffset + token.length > nextMatch.index!) {
40+
// Don't push an empty group if the first token is part of a match.
41+
if (acc.length > 0 || groups.length > 0) groups.push(acc);
42+
acc = [];
43+
matchIdx += 1;
44+
}
45+
46+
// Push the token.
47+
acc.push(token);
48+
textCharOffset += token.length;
49+
}
50+
// Finally, push any open group.
51+
if (acc.length > 0) groups.push(acc);
52+
return groups;
53+
}
+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/**
2+
* Testing for token_utils.ts
3+
*/
4+
5+
import 'jasmine';
6+
7+
import * as tokenUtils from './token_utils';
8+
9+
describe('cleanSpmText test', () => {
10+
it('cleans magic underscores from SPM output', () => {
11+
const text = 'Summarize▁this▁sentence:\n\nOnce▁upon▁a▁time';
12+
expect(tokenUtils.cleanSpmText(text))
13+
.toEqual('Summarize this sentence:\n\nOnce upon a time');
14+
});
15+
});
16+
17+
describe('groupTokensByRegexPrefix test', () => {
18+
[{
19+
testcaseName: 'groups tokens by word',
20+
tokens: ['Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':'],
21+
regex: /[\s]+/g,
22+
expectedGroups: [['Sum', 'mar', 'ize'], ['▁this'], ['▁sent', 'ence', ':']],
23+
},
24+
{
25+
testcaseName: 'groups tokens by word, handling newlines',
26+
tokens: [
27+
'Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':', '\n', '\n', 'Once',
28+
'▁upon', '▁a', '▁time'
29+
],
30+
// Consecutive newlines should be their own segment.
31+
// Start a new word on the first non-\n afterwards.
32+
regex: /([\s]+)|(?<=\n)[^\n]/g,
33+
expectedGroups: [
34+
['Sum', 'mar', 'ize'], ['▁this'], ['▁sent', 'ence', ':'], ['\n', '\n'],
35+
['Once'], ['▁upon'], ['▁a'], ['▁time']
36+
],
37+
},
38+
{
39+
testcaseName: 'groups tokens by sentence, simple version',
40+
tokens: [
41+
'Sent', 'ence', '▁one', '.', '▁Sent', 'ence', '▁two', '!', '▁Sent',
42+
'ence', '▁three', '?'
43+
],
44+
regex: /(?<=[.?!])[\s]+/g,
45+
expectedGroups: [
46+
['Sent', 'ence', '▁one', '.'],
47+
['▁Sent', 'ence', '▁two', '!'],
48+
['▁Sent', 'ence', '▁three', '?'],
49+
],
50+
},
51+
{
52+
testcaseName: 'groups tokens by sentence, handling newlines',
53+
tokens: [
54+
'Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':', '\n', '\n', 'Once',
55+
'▁upon', '▁a', '▁time'
56+
],
57+
// Sentence start is one of:
58+
// - a run of consecutive \n as its own segment
59+
// - any non-\n following \n
60+
// - whitespace or magic underscore following punctuation [.?!]
61+
regex: /(\n+)|((?<=\n)[^\n])|((?<=[.?!])([\s]+))/g,
62+
expectedGroups: [
63+
['Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':'], ['\n', '\n'],
64+
['Once', '▁upon', '▁a', '▁time']
65+
],
66+
},
67+
{
68+
testcaseName: 'groups tokens by line',
69+
tokens: [
70+
'Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':', '\n', '\n', 'Once',
71+
'▁upon', '▁a', '▁time'
72+
],
73+
// Line start is either:
74+
// - a run of consecutive \n as its own segment
75+
// - any non-\n following \n
76+
regex: /(\n+)|([^\n]+)/g,
77+
expectedGroups: [
78+
['Sum', 'mar', 'ize', '▁this', '▁sent', 'ence', ':'], ['\n', '\n'],
79+
['Once', '▁upon', '▁a', '▁time']
80+
],
81+
},
82+
].forEach(({testcaseName, tokens, regex, expectedGroups}) => {
83+
it(testcaseName, () => {
84+
const groups = tokenUtils.groupTokensByRegexPrefix(tokens, regex);
85+
expect(groups).toEqual(expectedGroups);
86+
});
87+
});
88+
});

lit_nlp/client/lib/utils.ts

+21
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,27 @@ export function cumSumArray(array: number[]) {
302302
return newArray;
303303
}
304304

305+
/**
306+
* Group elements of one list to match the partitions of another.
307+
*
308+
* Example:
309+
* groupAlike([0, 1, 2, 3, 4, 5], [['a', 'b'], ['c'], ['d', 'e', 'f']])
310+
*
311+
* Should return: [[0, 1], [2], [3, 4, 5]]
312+
*/
313+
export function groupAlike<T>(items: T[], groups: unknown[][]): T[][] {
314+
const offsets = [0, ...cumSumArray(groups.map(g => g.length))];
315+
if (offsets.at(-1) !== items.length) {
316+
throw new Error(`Total length of groups (${
317+
offsets.at(-1)}) !== number of items (${items.length}).`);
318+
}
319+
const ret = [];
320+
for (let i = 0; i < groups.length; i++) {
321+
ret.push(items.slice(offsets[i], offsets[i + 1]));
322+
}
323+
return ret;
324+
}
325+
305326
/**
306327
* Python-style array comparison.
307328
* Compare on first element, then second, and so on until a mismatch is found.

lit_nlp/client/lib/utils_test.ts

+13
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,19 @@ describe('cumSumArray test', () => {
436436
});
437437
});
438438

439+
describe('groupAlike test', () => {
440+
it('groups items', () => {
441+
const result = utils.groupAlike(
442+
[0, 1, 2, 3, 4, 5], [['a', 'b'], ['c'], ['d', 'e', 'f']]);
443+
expect(result).toEqual([[0, 1], [2], [3, 4, 5]]);
444+
});
445+
it('raises an error if lengths do not match', () => {
446+
expect(() => utils.groupAlike([0, 1, 2, 3, 4, 5], [['a', 'b'], ['c']]))
447+
.toThrow(
448+
new Error('Total length of groups (3) !== number of items (6).'));
449+
});
450+
});
451+
439452
describe('compareArrays test', () => {
440453
it('Correctly tests normal comparison', () => {
441454
// Shorter arrays.

lit_nlp/lib/utils.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,20 @@ def unbatch_preds(
198198
yield {key: value[i] for key, value in preds.items()}
199199

200200

201-
def pad1d(arr: list[T], min_len: int, pad_val: T) -> list[T]:
201+
def pad1d(
202+
arr: list[T],
203+
min_len: int,
204+
pad_val: T,
205+
pad_left: bool = False,
206+
max_len: int | None = None,
207+
) -> list[T]:
202208
"""Pad a list to the target length."""
203-
return arr + [pad_val] * max(0, min_len - len(arr))
209+
if pad_left:
210+
padded = [pad_val] * max(0, min_len - len(arr)) + arr
211+
return padded[-max_len:] if max_len is not None else padded
212+
else:
213+
padded = arr + [pad_val] * max(0, min_len - len(arr))
214+
return padded[:max_len] if max_len is not None else padded
204215

205216

206217
def find_all_combinations(

lit_nlp/lib/utils_test.py

+46-2
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,55 @@ def test_batch_inputs_raises(
252252
pad_val="",
253253
expected=["one", "two", "three", "", ""],
254254
),
255+
dict(
256+
testcase_name="truncate_max_len",
257+
inputs=[1, 2, 3, 4, 5],
258+
min_len=3,
259+
pad_val=0,
260+
max_len=3,
261+
expected=[1, 2, 3],
262+
),
263+
dict(
264+
testcase_name="pad_left",
265+
inputs=[1, 2, 3],
266+
min_len=5,
267+
pad_val=0,
268+
pad_left=True,
269+
expected=[0, 0, 1, 2, 3],
270+
),
271+
dict(
272+
testcase_name="truncate_max_len_left",
273+
inputs=[1, 2, 3, 4, 5],
274+
min_len=3,
275+
pad_val=0,
276+
pad_left=True,
277+
max_len=3,
278+
expected=[3, 4, 5],
279+
),
280+
dict(
281+
testcase_name="pad_left_with_strings",
282+
inputs=["one", "two", "three"],
283+
min_len=5,
284+
pad_val="",
285+
pad_left=True,
286+
expected=["", "", "one", "two", "three"],
287+
),
255288
)
256289
def test_pad1d(
257-
self, inputs: list[T], min_len: T, pad_val: T, expected: list[T]
290+
self,
291+
inputs: list[T],
292+
min_len: T,
293+
pad_val: T,
294+
expected: list[T],
295+
pad_left: bool = False,
296+
max_len: int | None = None,
258297
):
259-
self.assertEqual(utils.pad1d(inputs, min_len, pad_val), expected)
298+
self.assertEqual(
299+
utils.pad1d(
300+
inputs, min_len, pad_val, pad_left=pad_left, max_len=max_len
301+
),
302+
expected,
303+
)
260304

261305
@parameterized.named_parameters(
262306
dict(

0 commit comments

Comments
 (0)