In [1]:
from typing import List, Optional, Any
import torch
import torch.utils.data
import numpy as np
import random

import datetime

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import T5Tokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer

from peft import LoraConfig
from trl import SFTTrainer
import datasets
import time

import json
import os
import cqlcmp

In [2]:
torch.manual_seed(21)
torch.cuda.manual_seed_all(21)
np.random.seed(21)
random.seed(21)
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = False

In [3]:
device = torch.device("cuda")

In [4]:
class DatasetNatural2CQL(torch.utils.data.Dataset):
    def __init__(self, path: Optional[str] = None) -> None:
        self.sentence_freq = []
        self.cql2nl = []
        self.nl2cql = []
        self.natural_language_rulebased = []
        self.cql = []
        self.natural_language = []
        self.enabled_natural_language = []

        if path is not None:
            self.load_tsv(path)

    def enable_cql(self, cqls):
        self.enabled_natural_language = []
        for cql in cqls:
            for p in cql:
                self.enabled_natural_language.append(p)

    def dump_json(self, filepath: str) -> None:
        with open(filepath, "w") as file:
            for i in range(len(self)):
                data = json.dumps(self[i])
                file.write(data)
                file.write("\n")

    def add_translation(self, freq: int, cql: str, natural_language_rulebased: str, natural_language: List[str]) -> None:
        cql_index = len(self.sentence_freq)
        self.sentence_freq.append(freq)
        self.cql.append(cql)
        self.natural_language_rulebased.append(natural_language_rulebased)
        self.cql2nl.append([])

        for sentence in natural_language:
            self.nl2cql.append(cql_index)
            self.cql2nl[-1].append(len(self.natural_language))
            self.natural_language.append(sentence)

    def load_tsv(self, path: str) -> None:
        with open(path, "r") as file_data:
            for line in file_data:
                line = line.strip()
                line = line.split("\t")
                texts_json = json.loads(line[4])
                texts_extracted = texts_json["data"][0]["content"][0]["text"]["value"].split("\n")
                self.add_translation(int(line[0]), line[2], line[3], texts_extracted)

    def __len__(self):
        return len(self.enabled_natural_language)

    def __getitem__(self, idx):
        if idx < len(self.nl2cql):
            return {"text": self.natural_language[self.enabled_natural_language[idx]], "cql": self.cql[self.nl2cql[self.enabled_natural_language[idx]]]}
        return None

In [5]:
dataset = DatasetNatural2CQL("expand_natural_texts_0004.res.tsv")

In [6]:
validation_cqls = []
with open("test_ids.json", "r") as file:
    validation_cqls = json.load(file)

In [7]:
dataset.enable_cql(validation_cqls)

In [8]:
len(dataset)

4280

In [9]:
model_name = "google-t5/t5-base"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto")

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [10]:
import os

path = 'models/google-t5_t5-base'

# List all folders (directories) in the given path
folders = [name for name in os.listdir(path) if not os.path.isdir(os.path.join(path, name))]

In [11]:
def f_in(data):
    full_input = "translate: " + data
    return full_input

In [12]:
#random.shuffle(dataset.enabled_natural_language)

i = 0
while i < len(dataset):
    dataset.natural_language[dataset.enabled_natural_language[i]] = f_in(dataset.natural_language[dataset.enabled_natural_language[i]])
    i += 1

In [13]:
def get_text(text):
    text = text.split("translate: ", 1)[1]
    return text

In [14]:
with open("test_results.tsv", "w") as test_tsv:
    with open("test_log.txt", "w") as log:
        for folder in folders:
            start_time = int(time.time())
            adapter_model_name = 'models/google-t5_t5-base/' + folder

            if folder != "model_train_data_20_final.pt":
                continue
            
            model.load_state_dict(torch.load(adapter_model_name, weights_only=True))
            model.eval()
            dataloader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False)
            blue_sum = 0
            blue_size = 0
            batch_id = 0
            for batch in dataloader:
                batch_id += 1
                inputs = tokenizer(batch["text"], truncation=True, max_length=1024, return_tensors="pt", padding=True).to(device)
                outputs = model.generate(**inputs, max_new_tokens=100)
                for i, output in enumerate(outputs):
                    cql = tokenizer.decode(output, skip_special_tokens=True).replace("<extra_id_0>", "")
                    cql = "[" + cql
                    cql_ids = cqlcmp.cql_tokenizer(cql)
                    gold_cql_ids = cqlcmp.cql_tokenizer(batch["cql"][i])
                    bleu = cqlcmp.sentence_bleu([gold_cql_ids], cql_ids, weights=(0.25, 0.25, 0.25, 0.25))
                    blue_sum += bleu
                    blue_size += 1
                    test_result = [cql.replace("\n", " "), batch["cql"][i], get_text(batch["text"][i])]
                    print("\t".join(test_result))
                    print("\t".join(test_result), file=test_tsv)
                to_log = ""
                to_log += "Working on: " + adapter_model_name + " | "
                to_log += str(blue_size) + "/" + str(len(dataset)) + " | "
                to_log += "AVG Bleu: " + str(blue_sum/blue_size) + " | "
                to_log += "Time: " + str(int(time.time())-start_time) + " | "
                try:
                    to_log += "ETA: " + str(int((int(time.time())-start_time)/(blue_size/len(dataset)*(1-(blue_size/len(dataset)))))) + " sec" + " | "
                except:
                    to_log += "ETA: " + "0" + " sec" + " | "
                
                print(to_log, file=log)
                log.flush()
                test_tsv.flush()
                print(to_log)
            

[!tag="DT"&!tag="J.*"][tag="CD"]?[tag="N.*"][lemma="man"]	[tag !="DT" & tag !="JJ.*" & tag !="CD" & tag !="N.*"] [word="man"]	All tokens that are not determiners, adjectives, cardinal numbers, or nouns followed by the word "man."
[!tag="DT"][tag="J.*"][tag!="CD"][tag="N.*"]	[tag !="DT" & tag !="JJ.*" & tag !="CD" & tag !="N.*"] [word="man"]	Any token that does not match the tags DT, JJ, CD, or is not a noun, followed by the word "man."
[!tag="DT"&!tag="J.*"][!tag="N.*"][word="man"]	[tag !="DT" & tag !="JJ.*" & tag !="CD" & tag !="N.*"] [word="man"]	Non-determiner, non-adjective, non-number, and non-noun tokens that precede the word "man."
[!tag="DT"&!tag="J.*"|!tag="N.*"][word="man"]	[tag !="DT" & tag !="JJ.*" & tag !="CD" & tag !="N.*"] [word="man"]	Tokens excluding determiners, adjectives, numbers, and nouns followed by "man."
[!tag="DT"&!tag="J.*"][!tag="CD"][word="man"]	[tag !="DT" & tag !="JJ.*" & tag !="CD" & tag !="N.*"] [word="man"]	Any token that is not tagged as DT, JJ, CD, o

The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


[word="Want|want"][word="to"][word="be"]	[word="want|Want"][word="to"][word="be"]	The sequence where "want" or "Want" is followed by "to," and then "be."  
[word="Want|want"][word="to"][word="be"]	[word="want|Want"][word="to"][word="be"]	Occurrences of "want" or "Want" leading to "to" and culminating in "be."  
[word=","]	[word=","]	Token containing a comma.
[word=","]	[word=","]	Only the comma as a word.
[word=","]	[word=","]	Word that matches a comma.
[word=","]	[word=","]	All instances of the comma.
[word=","]	[word=","]	Examples of the word comma.
[word=","]	[word=","]	Any token that is a comma.
[word=","]	[word=","]	All words that are commas.
[tag="IN"]	[tag=="IN"]	Token with the tag IN.  
[tag = "IN"]	[tag=="IN"]	A token having the tag equal to IN.  
[tag="IN"]	[tag=="IN"]	Instances of the tag IN.  
[tag = "IN"]	[tag=="IN"]	All tokens tagged as IN.  
[tag="IN"]	[tag=="IN"]	Tokens categorized under the IN tag.  
[tag="IN"]	[tag=="IN"]	Only tokens that have the tag IN.  
[tag="IN"]

The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


[lemma="get"] [word="away"]	[lemma="get"] [word="away"]	Any use of the lemma "get" followed by the word "away."
[lemma="get"] [word="away"]	[lemma="get"] [word="away"]	Lemmas that include "get" coming before the word "away."
[lemma="get"] [lemma="away"]	[lemma="get"] [word="away"]	Expressions that involve the lemma "get" followed by "away."
[lemma="get"] [lemma="away"]	[lemma="get"] [word="away"]	Getting away.
[lemma="get"] [lemma="away"]	[lemma="get"] [word="away"]	Tokens that show "get" and then "away." 
[lemma="get"] [lemma="away"]	[lemma="get"] [word="away"]	The token "get" followed by "away." 
[lemma="get"] [lemma="away"]	[lemma="get"] [word="away"]	Example phrases with "get" leading to "away."
[word="hedging"]	[word="hedging"][tag="RB.?"]	The word hedging.
[word="hedging"]	[word="hedging"][tag="RB.?"]	A token that contains the word hedging.
[lemma="hedging"] [tag="RB.?"]	[word="hedging"][tag="RB.?"]	Token hedging followed by an adverb.
[lemma="hedging"] [tag="RB.?"]	[word="hedgin