# Copyright 2026 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Creates the datasets for FFF-NER model.""" import collections import json import math import os import sys import numpy as np import tensorflow as tf, tf_keras from tqdm import tqdm import transformers class NERDataset: """A Named Entity Recognition dataset for FFF-NER model.""" def __init__(self, words_path, labels_path, tokenizer, is_train, label_to_entity_type_index, ablation_not_mask, ablation_no_brackets, ablation_span_type_together): """Instantiates the class. Args: words_path: Path to the .words file that contains the text. labels_path: Path to the .ner file that contains NER labels for the text. tokenizer: A huggingface tokenizer. is_train: If creating a dataset for training, otherwise testing. label_to_entity_type_index: A mapping of NER labels to indices. ablation_not_mask: An ablation experiment that does not use mask tokens. ablation_no_brackets: An ablation experiment that does not use brackets. ablation_span_type_together: An ablation experiment that does span and type prediction together at a single token. """ self.words_path = words_path self.labels_path = labels_path self.tokenizer = tokenizer self.is_train = is_train self.label_to_entity_type_index = label_to_entity_type_index self.ablation_no_brackets = ablation_no_brackets self.ablation_span_type_together = ablation_span_type_together self.ablation_not_mask = ablation_not_mask self.left_bracket = self.tokenize_word(" [")[0] self.right_bracket = self.tokenize_word(" ]")[0] self.mask_id = self.tokenizer.mask_token_id self.cls_token_id = self.tokenizer.cls_token_id self.sep_token_id = self.tokenizer.sep_token_id self.data = [] self.id_to_sentence_infos = dict() self.id_counter = 0 self.all_tokens = [] self.all_labels = [] self.max_seq_len_in_data = 0 self.max_len = 128 def read_file(self): """Reads the input files from words_path and labels_paths.""" with open(self.words_path) as f1, open(self.labels_path) as f2: for _, (l1, l2) in enumerate(zip(f1, f2)): tokens = l1.strip().split(" ") labels = l2.strip().split(" ") # since we are use [ and ], we replace all [, ] in the text with (, ) tokens = ["(" if token == "[" else token for token in tokens] tokens = [")" if token == "]" else token for token in tokens] yield tokens, labels def tokenize_word(self, word): """Calls the tokenizer to produce word ids from text.""" result = self.tokenizer(word, add_special_tokens=False) return result["input_ids"] def tokenize_word_list(self, word_list): return [self.tokenize_word(word) for word in word_list] def process_to_input(self, input_ids, is_entity_token_pos, entity_type_token_pos, is_entity_label, entity_type_label, sid, span_start, span_end): """Process and store sentence and span id information.""" self.id_counter += 1 self.id_to_sentence_infos[self.id_counter] = { "sid": sid, # sentence id "span_start": span_start, "span_end": span_end, } seqlen = len(input_ids) self.max_seq_len_in_data = max(self.max_seq_len_in_data, seqlen) return { "input_ids": input_ids, "attention_mask": [1] * seqlen, "is_entity_token_pos": is_entity_token_pos, "entity_type_token_pos": entity_type_token_pos, "is_entity_label": 1 if is_entity_label else 0, "entity_type_label": entity_type_label, "sentence_id": sid, "span_start": span_start, "span_end": span_end, "id": self.id_counter, } def process_word_list_and_spans_to_inputs(self, sid, word_list, spans): """Constructs the fffner input with spans and types.""" tokenized_word_list = self.tokenize_word_list(word_list) final_len = sum(len(x) for x in tokenized_word_list) final_len = 2 + 3 + 2 + 3 + final_len # account for mask and brackets if final_len > self.max_len: print(f"final_len {final_len} too long, skipping") return for span_start, span_end, span_type, span_label in spans: assert span_type == "mask" input_ids = [] input_ids.append(self.cls_token_id) for ids in tokenized_word_list[:span_start]: input_ids.extend(ids) if not self.ablation_span_type_together: if not self.ablation_no_brackets: input_ids.append(self.left_bracket) is_entity_token_pos = len(input_ids) input_ids.append(self.mask_id if not self.ablation_not_mask else 8487) if not self.ablation_no_brackets: input_ids.append(self.right_bracket) if not self.ablation_no_brackets: input_ids.append(self.left_bracket) for ids in tokenized_word_list[span_start:span_end + 1]: input_ids.extend(ids) if not self.ablation_no_brackets: input_ids.append(self.right_bracket) if not self.ablation_no_brackets: input_ids.append(self.left_bracket) entity_type_token_pos = len(input_ids) if self.ablation_span_type_together: is_entity_token_pos = len(input_ids) input_ids.append(self.mask_id if not self.ablation_not_mask else 2828) if not self.ablation_no_brackets: input_ids.append(self.right_bracket) for ids in tokenized_word_list[span_end + 1:]: input_ids.extend(ids) input_ids.append(self.sep_token_id) is_entity_label = span_label in self.label_to_entity_type_index entity_type_label = self.label_to_entity_type_index.get(span_label, 0) yield self.process_to_input(input_ids, is_entity_token_pos, entity_type_token_pos, is_entity_label, entity_type_label, sid, span_start, span_end) def bio_labels_to_spans(self, bio_labels): """Gets labels to spans.""" spans = [] for i, label in enumerate(bio_labels): if label.startswith("B-"): spans.append([i, i, label[2:]]) elif label.startswith("I-"): if spans: print("Error... I-tag should not start a span") spans.append([i, i, label[2:]]) elif spans[-1][1] != i - 1 or spans[-1][2] != label[2:]: print("Error... I-tag not consistent with previous tag") spans.append([i, i, label[2:]]) else: spans[-1][1] = i elif label.startswith("O"): pass else: assert False, bio_labels spans = list( filter(lambda x: x[2] in self.label_to_entity_type_index.keys(), spans)) return spans def collate_fn(self, batch): batch = self.tokenizer.pad( batch, padding="max_length", max_length=self.max_len, ) return batch def prepare(self, negative_multiplier=3.): """Constructs negative sampling and handling train/test differences.""" desc = ("prepare data for training" if self.is_train else "prepare data for testing") total_missed_entities = 0 total_entities = 0 for sid, (tokens, labels) in tqdm(enumerate(self.read_file()), desc=desc): self.all_tokens.append(tokens) self.all_labels.append(labels) entity_spans = self.bio_labels_to_spans(labels) entity_spans_dict = { (start, end): ent_type for start, end, ent_type in entity_spans } num_entities = len(entity_spans_dict) num_negatives = int( (len(tokens) + num_entities * 10) * negative_multiplier) num_negatives = min(num_negatives, len(tokens) * (len(tokens) + 1) // 2) min_words = 1 max_words = len(tokens) total_entities += len(entity_spans) spans = [] if self.is_train: is_token_entity_prefix = [0] * (len(tokens) + 1) for start, end, _ in entity_spans: for i in range(start, end + 1): is_token_entity_prefix[i + 1] = 1 for i in range(len(tokens)): is_token_entity_prefix[i + 1] += is_token_entity_prefix[i] negative_spans = [] negative_spans_probs = [] for n_words in range(min_words, max_words + 1): for i in range(len(tokens) - n_words + 1): j = i + n_words - 1 ent_type = entity_spans_dict.get((i, j), "O") if not self.is_train or ent_type != "O": spans.append((i, j, "mask", ent_type)) else: negative_spans.append((i, j, "mask", ent_type)) intersection_size = (is_token_entity_prefix[j + 1] - is_token_entity_prefix[i] + 1) / ( j + 1 - i) negative_spans_probs.append(math.e**intersection_size) if negative_spans and num_negatives > 0: negative_spans_probs = np.array(negative_spans_probs) / np.sum( negative_spans_probs) negative_span_indices = np.random.choice( len(negative_spans), num_negatives, replace=True, p=negative_spans_probs) spans.extend([negative_spans[x] for x in negative_span_indices]) else: for n_words in range(min_words, max_words + 1): for i in range(len(tokens) - n_words + 1): j = i + n_words - 1 ent_type = entity_spans_dict.get((i, j), "O") spans.append((i, j, "mask", ent_type)) for instance in self.process_word_list_and_spans_to_inputs( sid, tokens, spans): self.data.append(instance) print(f"{total_missed_entities}/{total_entities} are ignored due to length") print(f"Total {self.__len__()} instances") def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] if __name__ == "__main__": path_to_data_folder = sys.argv[1] dataset_name = sys.argv[2] train_file = sys.argv[3] dataset = os.path.join(path_to_data_folder, dataset_name) test_file = "test" _tokenizer = transformers.AutoTokenizer.from_pretrained("bert-base-uncased") entity_map = json.load(open(os.path.join(dataset, "entity_map.json"))) _label_to_entity_type_index = { k: i for i, k in enumerate(list(entity_map.keys())) } train_ds = NERDataset( words_path=os.path.join(dataset, train_file + ".words"), labels_path=os.path.join(dataset, train_file + ".ner"), tokenizer=_tokenizer, is_train=True, ablation_not_mask=False, ablation_no_brackets=False, ablation_span_type_together=False, label_to_entity_type_index=_label_to_entity_type_index) eval_ds = NERDataset( words_path=os.path.join(dataset, test_file + ".words"), labels_path=os.path.join(dataset, test_file + ".ner"), tokenizer=_tokenizer, is_train=False, ablation_not_mask=False, ablation_no_brackets=False, ablation_span_type_together=False, label_to_entity_type_index=_label_to_entity_type_index) train_ds.prepare(negative_multiplier=3) train_data = train_ds.collate_fn(train_ds.data) eval_ds.prepare(negative_multiplier=3) eval_data = eval_ds.collate_fn(eval_ds.data) def file_based_convert_examples_to_features(examples, output_file): """Convert a set of `InputExample`s to a TFRecord file.""" tf.io.gfile.makedirs(os.path.dirname(output_file)) writer = tf.io.TFRecordWriter(output_file) for ex_index in range(len(examples["input_ids"])): if ex_index % 10000 == 0: print(f"Writing example {ex_index} of {len(examples['input_ids'])}") print(examples["input_ids"][ex_index]) def create_int_feature(values): f = tf.train.Feature(int64_list=tf.train.Int64List(value=list(values))) return f features = collections.OrderedDict() features["input_ids"] = create_int_feature( examples["input_ids"][ex_index]) features["input_mask"] = create_int_feature( examples["attention_mask"][ex_index]) features["segment_ids"] = create_int_feature( [0] * len(examples["attention_mask"][ex_index])) features["is_entity_token_pos"] = create_int_feature( [examples["is_entity_token_pos"][ex_index]]) features["entity_type_token_pos"] = create_int_feature( [examples["entity_type_token_pos"][ex_index]]) features["is_entity_label"] = create_int_feature( [examples["is_entity_label"][ex_index]]) features["entity_type_label"] = create_int_feature( [examples["entity_type_label"][ex_index]]) features["example_id"] = create_int_feature([examples["id"][ex_index]]) features["sentence_id"] = create_int_feature( [examples["sentence_id"][ex_index]]) features["span_start"] = create_int_feature( [examples["span_start"][ex_index]]) features["span_end"] = create_int_feature( [examples["span_end"][ex_index]]) tf_example = tf.train.Example( features=tf.train.Features(feature=features)) writer.write(tf_example.SerializeToString()) writer.close() file_based_convert_examples_to_features( train_data, f"{dataset_name}_{train_file}.tf_record") file_based_convert_examples_to_features( eval_data, f"{dataset_name}_{test_file}.tf_record")