In [1]:
from typing import List, Optional, Any
import torch
import torch.utils.data
import numpy as np
import random

import datetime

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

from peft import LoraConfig
from trl import SFTTrainer
import datasets
import time

import json
import os
import cqlcmp

In [2]:
torch.manual_seed(21)
torch.cuda.manual_seed_all(21)
np.random.seed(21)
random.seed(21)
#torch.backends.cudnn.deterministic = True
#torch.backends.cudnn.benchmark = False

In [3]:
device = torch.device("cuda")

In [4]:
class DatasetNatural2CQL(torch.utils.data.Dataset):
    def __init__(self, path: Optional[str] = None) -> None:
        self.sentence_freq = []
        self.cql2nl = []
        self.nl2cql = []
        self.natural_language_rulebased = []
        self.cql = []
        self.natural_language = []
        self.enabled_natural_language = []

        if path is not None:
            self.load_tsv(path)

    def enable_cql(self, cqls):
        self.enabled_natural_language = []
        for cql in cqls:
            for p in cql:
                self.enabled_natural_language.append(p)

    def dump_json(self, filepath: str) -> None:
        with open(filepath, "w") as file:
            for i in range(len(self)):
                data = json.dumps(self[i])
                file.write(data)
                file.write("\n")

    def add_translation(self, freq: int, cql: str, natural_language_rulebased: str, natural_language: List[str]) -> None:
        cql_index = len(self.sentence_freq)
        self.sentence_freq.append(freq)
        self.cql.append(cql)
        self.natural_language_rulebased.append(natural_language_rulebased)
        self.cql2nl.append([])

        for sentence in natural_language:
            self.nl2cql.append(cql_index)
            self.cql2nl[-1].append(len(self.natural_language))
            self.natural_language.append(sentence)

    def load_tsv(self, path: str) -> None:
        with open(path, "r") as file_data:
            for line in file_data:
                line = line.strip()
                line = line.split("\t")
                texts_json = json.loads(line[4])
                texts_extracted = texts_json["data"][0]["content"][0]["text"]["value"].split("\n")
                self.add_translation(int(line[0]), line[2], line[3], texts_extracted)

    def __len__(self):
        return len(self.enabled_natural_language)

    def __getitem__(self, idx):
        if idx < len(self.nl2cql):
            return {"text": self.natural_language[self.enabled_natural_language[idx]], "cql": self.cql[self.nl2cql[self.enabled_natural_language[idx]]]}
        return None

In [5]:
dataset = DatasetNatural2CQL("expand_natural_texts_0004.res.tsv")

In [6]:
validation_cqls = []
with open("test_ids.json", "r") as file:
    validation_cqls = json.load(file)

In [7]:
dataset.enable_cql(validation_cqls)

In [8]:
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

base_model_name = "google/gemma-2-2b"
adapter_model_name = "outputs/checkpoint-102000"

model_orig = AutoModelForCausalLM.from_pretrained(base_model_name, quantization_config=bnb_config, device_map={"":0})
model = PeftModel.from_pretrained(model_orig, adapter_model_name)
model = model.to(device)

tokenizer = AutoTokenizer.from_pretrained(base_model_name)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [9]:
import os

path = 'outputs'

# List all folders (directories) in the given path
folders = [name for name in os.listdir(path) if os.path.isdir(os.path.join(path, name))]

In [10]:
def f_in(data):
    full_input = """Translate Natural Language into CQL Queries like:

Word dog followed by lemma run, and then followed by a noun. CQL:
```
[word="dog"][lemma="run"][pos="NN"]
```

Lemma "be" optionally followed by the word "not". CQL:
```
[lemma="be"][word="not"]? 
```

""" + data + " CQL:\n```\n"
    return full_input

In [11]:
# random.shuffle(dataset.enabled_natural_language)

i = 0
while i < len(dataset):
    dataset.natural_language[dataset.enabled_natural_language[i]] = f_in(dataset.natural_language[dataset.enabled_natural_language[i]])
    i += 1

In [12]:
tokenizer.pad_token = tokenizer.eos_token

In [13]:
def get_cql(text):
    text = text.split("<eos>")[0]
    text = text.split("CQL:")[3]
    text = text.split("```")[1]
    return text.strip()

In [14]:
def get_text(text):
    text = text = text.split("CQL:")[2]
    text = text = text.split("```")[-1]
    return text.strip()

In [15]:
with open("test_results.tsv", "w") as test_tsv:
    with open("test_log.txt", "w") as log:
        for folder in folders:
            start_time = int(time.time())
            adapter_model_name = "outputs/" + folder
    
            if folder not in ["checkpoint-202000"]:
                continue
            
            model = PeftModel.from_pretrained(model_orig, adapter_model_name)
            dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False)
            blue_sum = 0
            blue_size = 0
            batch_id = 0
            for batch in dataloader:
                batch_id += 1
                inputs = tokenizer(batch["text"], truncation=True, max_length=1024, return_tensors="pt", padding=True).to(device)
                outputs = model.generate(**inputs, max_new_tokens=100)
                for i, output in enumerate(outputs):
                    cql = get_cql(tokenizer.decode(output, skip_special_tokens=True))
                    cql_ids = cqlcmp.cql_tokenizer(cql)
                    gold_cql_ids = cqlcmp.cql_tokenizer(batch["cql"][i])
                    bleu = cqlcmp.sentence_bleu([gold_cql_ids], cql_ids, weights=(0.25, 0.25, 0.25, 0.25))
                    blue_sum += bleu
                    blue_size += 1
                    test_result = [cql, batch["cql"][i], get_text(batch["text"][i])]
                    print("\t".join(test_result))
                    print("\t".join(test_result), file=test_tsv)
                to_log = ""
                to_log += "Working on: " + adapter_model_name + " | "
                to_log += str(blue_size) + "/" + str(len(dataset)) + " | "
                to_log += "AVG Bleu: " + str(blue_sum/blue_size) + " | "
                to_log += "Time: " + str(int(time.time())-start_time) + " | "
                to_log += "ETA: " + str(int((int(time.time())-start_time)/(blue_size/len(dataset)*(1-(blue_size/len(dataset)))))) + " sec" + " | "
                
                print(to_log, file=log)
                log.flush()
                test_tsv.flush()
                print(to_log)
            



[tag!="DT|JJ|CD|N.*"][word="man"]	[tag !="DT" & tag !="JJ.*" & tag !="CD" & tag !="N.*"] [word="man"]	All tokens that are not determiners, adjectives, cardinal numbers, or nouns followed by the word "man."
[tag!="DT|JJ|CD|N.*"] [word="man"]	[tag !="DT" & tag !="JJ.*" & tag !="CD" & tag !="N.*"] [word="man"]	Any token that does not match the tags DT, JJ, CD, or is not a noun, followed by the word "man."
[tag!="DT"& tag!="JJ.*" & tag!="N.*"] [word="man"]	[tag !="DT" & tag !="JJ.*" & tag !="CD" & tag !="N.*"] [word="man"]	Non-determiner, non-adjective, non-number, and non-noun tokens that precede the word "man."
[tag!="DT"& tag!="J.*" & tag!="CD" & tag!="N.*"] [word="man"]	[tag !="DT" & tag !="JJ.*" & tag !="CD" & tag !="N.*"] [word="man"]	Tokens excluding determiners, adjectives, numbers, and nouns followed by "man."
[tag!="DT"& tag!="JJ.*" & tag!="CD" & tag="N.*"][word="man"]	[tag !="DT" & tag !="JJ.*" & tag !="CD" & tag !="N.*"] [word="man"]	Any token that is not tagged as DT, JJ, CD, 

The hypothesis contains 0 counts of 4-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()
The hypothesis contains 0 counts of 3-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


[lempos="determine-v"]	[lempos="determine-v"]	Forms of the verb determine.
[lempos="determine-v"]	[lempos="determine-v"]	Only verb forms determine.
[lempos="determine-v"]	[lempos="determine-v"]	All lexemes for the verb 'determine'.
[lempos="determine-v"]	[lempos="determine-v"]	Any verb forms of determine.
[lempos="determine-v"]	[lempos="determine-v"]	Lexical items with the verb lemma determine.
[lempos="determine-v"]	[lempos="determine-v"]	All occurrences of the verb lemma as determine.
[lempos="determine-v"]	[lempos="determine-v"]	Lexical forms labeled as determine in verb category.
[word="s?pr?l?ay"]	[word= "s?pr?l?ay"]	Any word that matches the pattern s?pr?l?ay.
[word= "s?pr?ay"]	[word= "s?pr?l?ay"]	Words that contain variations of "spray."
<s>[word=".*"]	[word= "s?pr?l?ay"]	Words before or after a possible 's', 'pr', 'l', or 'ay'.
[word= "s?pr?ay"]	[word= "s?pr?l?ay"]	Words resembling the spelling of "spray".
[word= "s?pr?ay"]	[word= "s?pr?l?ay"]	All words that include 'spray' wit

The hypothesis contains 0 counts of 2-gram overlaps.
Therefore the BLEU score evaluates to 0, independently of
how many N-gram overlaps of lower order it contains.
Consider using lower n-gram order or use SmoothingFunction()


[word="good"]	[word="good"]	Word "good."
[word="good"]	[word="good"]	Example of the word good.
[word="good"]	[word="good"]	Token containing the word good.
[word="good"]	[word="good"]	Only the word good.
[word="good"]	[word="good"]	Any occurrence of good.
[word="good"]	[word="good"]	Single token with the word good.
[word="good"]	[word="good"]	Instances of the word good.
[word="good"]	[word="good"]	Only good as a word.
[word="good"]	[word="good"]	All occurrences of good.
[word="good"]	[word="good"]	A match for the word good.
[word="a.l{1,3}"]	[word="a.?l{1,3}"]	Words starting with 'a' followed by any character and then 'l' appearing one to three times.
[word="a.l{1,3}"]	[word="a.?l{1,3}"]	Any word commencing with 'a' followed by a character and then one to three instances of 'l'.
[word="a.l{1,3}"]	[word="a.?l{1,3}"]	All words beginning with the letter 'a', having a character and one to three l's following.
[word="a.l{1,3}"]	[word="a.?l{1,3}"]	Words that start with 'a', contain one charac

ZeroDivisionError: float division by zero