From bc88b746484257a6606ae9c30977132eef4cbf21 Mon Sep 17 00:00:00 2001 From: "Krzysztof R." Date: Sat, 6 Jan 2024 20:04:15 +0100 Subject: [PATCH] feat: split gpt chunks into smaller function and blacked it --- gpt_chunks.py | 359 ++++++++++++++++++++++++++++---------------------- 1 file changed, 202 insertions(+), 157 deletions(-) diff --git a/gpt_chunks.py b/gpt_chunks.py index d7f8c8ca..2ca41678 100644 --- a/gpt_chunks.py +++ b/gpt_chunks.py @@ -2,27 +2,30 @@ from openai import OpenAI import os import re + def sentence_loop(sentences): - # Process each sentence - chunked_sentences = [] - for sentence in sentences: - chunked_sentences = process_sentence(sentence, chunked_sentences) - if chunked_sentences == False: - return - return chunked_sentences + # Process each sentence + chunked_sentences = [] + for sentence in sentences: + chunked_sentences = process_sentence(sentence, chunked_sentences) + if chunked_sentences == False: + return + return chunked_sentences + def is_correct_format(s): # Regular expression to match sequences of '[ some words ]' - pattern = r'^(\[\s*[^]]+\s*\]\s*)+$' + pattern = r"^(\[\s*[^]]+\s*\]\s*)+$" # Check if the string matches the pattern return bool(re.match(pattern, s)) + def check_format(input_data): # Check if the input is a string, and if so, process it directly if isinstance(input_data, str): return is_correct_format(input_data) - + # If the input is a list, process each element in the list elif isinstance(input_data, list): return [is_correct_format(string) for string in input_data] @@ -31,283 +34,325 @@ def check_format(input_data): else: return None + +def remove_leading_numbering(string): + return re.sub(r"^\d+\.\s*", "", string) + + +def process_brackets(string): + parts = re.findall(r"\[[^\]]*\]", string) + if parts: + parts[-1] = handle_trailing_punctuation(string, parts[-1]) + return parts + + +def handle_trailing_punctuation(string, last_part): + trailing_punctuation = re.search(r"\]\s*([.!?])\s*$", string) + if trailing_punctuation: + return last_part[:-1] + trailing_punctuation.group(1) + "]" + return last_part + + +def process_non_bracketed_parts(string): + parts = re.split(r"\s*-\s*", string) + return [f"[ {part.strip()} ]" for part in parts if part.strip()] + + def reformat_list(strings): result = [] - for string in strings: - # Remove leading numbering like '1. ' - string = re.sub(r'^\d+\.\s*', '', string) - - # Check if the string contains partially formatted brackets - if '[' in string and ']' in string: - # Splitting the string into parts that are enclosed in square brackets - parts = re.findall(r'\[[^\]]*\]', string) - - # Check for trailing punctuation after the last bracket and add it to the last part - trailing_punctuation = re.search(r'\]\s*([.!?])\s*$', string) - if trailing_punctuation and parts: - parts[-1] = parts[-1][:-1] + trailing_punctuation.group(1) + ']' - - result.extend(parts) - - # Handle strings without square brackets + string = remove_leading_numbering(string) + if "[" in string and "]" in string: + result.extend(process_brackets(string)) else: - # Splitting based on the dashed pattern and removing the numbering - parts = re.split(r'\s*-\s*', string) - parts = [part.strip() for part in parts if part.strip()] + result.extend(process_non_bracketed_parts(string)) + return " ".join(result) - # Enclosing each part in square brackets - for part in parts: - result.append(f'[ {part} ]') - - return ' '.join(result) def reformat_square_not_full(input_string): # Split the string into parts that are either in brackets or not - parts = re.split(r'(\[[^\]]*\])', input_string) + parts = re.split(r"(\[[^\]]*\])", input_string) formatted_parts = [] for part in parts: if part: # Ignore empty strings # For parts not in brackets, remove trailing period, trim spaces, and enclose in brackets - if not part.startswith('['): - part = re.sub(r'\.\s*$', '', part).strip() + if not part.startswith("["): + part = re.sub(r"\.\s*$", "", part).strip() formatted_parts.append(f"[ {part} ]") else: # For parts already in brackets, add spaces inside - formatted_parts.append(re.sub(r'\[([^]]+)\]', r'[ \1 ]', part)) + formatted_parts.append(re.sub(r"\[([^]]+)\]", r"[ \1 ]", part)) + + return " ".join(formatted_parts) - return ' '.join(formatted_parts) def reformat_slash(input_string): # Split the string on slashes and strip spaces - parts = [part.strip() for part in input_string.split('/')] + parts = [part.strip() for part in input_string.split("/")] # Enclose each part in square brackets formatted_parts = [f"[ {part} ]" for part in parts] - return ' '.join(formatted_parts) + return " ".join(formatted_parts) + def reformat_pipe(input_string): # Split the string on ' | ' and strip spaces - parts = [part.strip() for part in input_string.split('|')] + parts = [part.strip() for part in input_string.split("|")] # Enclose each part in square brackets formatted_parts = [f"[ {part} ]" for part in parts] - return ' '.join(formatted_parts) + return " ".join(formatted_parts) + def is_pipe_format(s): # Regular expression to check for ' | ' pattern - pattern = r'^(?:[^|]+\|)+[^|]+$' + pattern = r"^(?:[^|]+\|)+[^|]+$" return bool(re.match(pattern, s)) + def is_numbered_line_format(s): # Regular expression to match lines starting with 'number. word/phrase \n' - pattern = r'^(?:\d+\.\s+.+\n?)+$' + pattern = r"^(?:\d+\.\s+.+\n?)+$" # Check if the input is a list of strings if isinstance(s, list): return [bool(re.match(pattern, string)) for string in s] - + # If the input is a single string, process it directly return bool(re.match(pattern, s)) + def reformat_ists(input_string): # Use regular expression to find all occurrences of the iSTS pattern and extract the text - parts = re.findall(r'\(iSTS \d+\)\s*([^)]+)', input_string) + parts = re.findall(r"\(iSTS \d+\)\s*([^)]+)", input_string) # Enclose each extracted part in square brackets formatted_parts = [f"[ {part.strip()} ]" for part in parts] - return ' '.join(formatted_parts) + return " ".join(formatted_parts) + def is_ists_format(s): # Regular expression to check for the iSTS pattern - pattern = r'^(?:\(iSTS \d+\)\s*[^)]+\s*)+$' + pattern = r"^(?:\(iSTS \d+\)\s*[^)]+\s*)+$" return bool(re.match(pattern, s)) + def reformat_with_sections(input_string): # Use regular expression to extract text following the section labels - parts = re.split(r'\s*\[\s*S\d+\s*\]\s*', input_string) + parts = re.split(r"\s*\[\s*S\d+\s*\]\s*", input_string) # Remove empty strings and enclose each part in square brackets formatted_parts = [f"[ {part.strip()} ]" for part in parts if part.strip()] - return ' '.join(formatted_parts) + return " ".join(formatted_parts) + def is_section_format(s): # Regular expression to check for the section label pattern - pattern = r'^(?:\s*\[\s*S\d+\s*\]\s*.+)+$' + pattern = r"^(?:\s*\[\s*S\d+\s*\]\s*.+)+$" return bool(re.match(pattern, s)) + def is_chunk_format(s): # Regular expression to check for both 'Chunk X: text' and '[Chunk X] text' patterns - pattern = r'^(Chunk \d+:\s*.+?|\[\s*Chunk \d+\s*\]\s*.+?)(\s+(Chunk \d+:\s*.+?|\[\s*Chunk \d+\s*\]\s*.+?))*$' + pattern = r"^(Chunk \d+:\s*.+?|\[\s*Chunk \d+\s*\]\s*.+?)(\s+(Chunk \d+:\s*.+?|\[\s*Chunk \d+\s*\]\s*.+?))*$" return bool(re.match(pattern, s)) + def reformat_chunks(input_string): # Find all occurrences of both chunk patterns and extract the text - parts = re.findall(r'Chunk \d+:\s*(.+?)(?=\s*Chunk \d+:|$)|\[\s*Chunk \d+\s*\]\s*(.+?)(?=\s*\[\s*Chunk \d+\s*\]|$)', input_string) - + parts = re.findall( + r"Chunk \d+:\s*(.+?)(?=\s*Chunk \d+:|$)|\[\s*Chunk \d+\s*\]\s*(.+?)(?=\s*\[\s*Chunk \d+\s*\]|$)", + input_string, + ) + # Flatten the list of tuples and remove empty strings flattened_parts = [item for sublist in parts for item in sublist if item] - + # Enclose each extracted part in square brackets formatted_parts = [f"[ {part.strip()} ]" for part in flattened_parts] - return ' '.join(formatted_parts) + return " ".join(formatted_parts) + def reformat_ists_markers(input_string): # Split the string at '[ iSTS ]', remove empty strings, and enclose each part in square brackets - parts = [part.strip() for part in re.split(r'\[\s*iSTS\s*\]', input_string) if part.strip()] + parts = [ + part.strip() + for part in re.split(r"\[\s*iSTS\s*\]", input_string) + if part.strip() + ] formatted_parts = [f"[ {part} ]" for part in parts] - return ' '.join(formatted_parts) + return " ".join(formatted_parts) + def is_ists_marker_format(s): # Regular expression to check for the pattern with '[ iSTS ]' - pattern = r'^(?:\[\s*iSTS\s*\].+)+$' + pattern = r"^(?:\[\s*iSTS\s*\].+)+$" return bool(re.match(pattern, s)) + def reformat_with_dashes(input_string): # Split the string on ' - ', remove empty strings, and enclose each part in square brackets - parts = [part.strip() for part in input_string.split('-') if part.strip()] + parts = [part.strip() for part in input_string.split("-") if part.strip()] formatted_parts = [f"[ {part} ]" for part in parts] - return ' '.join(formatted_parts) + return " ".join(formatted_parts) + def is_dash_format(s): # Regular expression to check for the pattern with ' - ' - pattern = r'^[^-]+(?:\s*-\s*[^-]+)+$' + pattern = r"^[^-]+(?:\s*-\s*[^-]+)+$" return bool(re.match(pattern, s)) + def preliminary_reformat(input_string): - input_string = input_string.strip("iSTS chunks:") - input_string = input_string.strip("Here are the iSTS chunks for the given sentence:") - return input_string + input_string = input_string.strip("iSTS chunks:") + input_string = input_string.strip( + "Here are the iSTS chunks for the given sentence:" + ) + return input_string + def reformat_brackets_and_text(input_string): # Remove the trailing dot if present - input_string = re.sub(r'\.$', '', input_string) + input_string = re.sub(r"\.$", "", input_string) # Split the string into parts that are either in brackets or not - parts = re.split(r'(\[[^\]]*\])', input_string) + parts = re.split(r"(\[[^\]]*\])", input_string) formatted_parts = [] for part in parts: if part: # Ignore empty strings # Remove commas and trim spaces for parts not in brackets - if not part.startswith('['): - part = re.sub(r',\s*', ' ', part).strip() + if not part.startswith("["): + part = re.sub(r",\s*", " ", part).strip() # Enclose the non-bracket part in brackets formatted_parts.append(f"[ {part} ]") else: # Directly append the part in brackets formatted_parts.append(part) - return ' '.join(formatted_parts) + return " ".join(formatted_parts) + def is_brackets_and_text_format(s): # Regular expression to check the pattern - pattern = r'^(\[[^\]]*\]|\s*[^[\]]+\s*)(,\s*|\s+|$)+$' + pattern = r"^(\[[^\]]*\]|\s*[^[\]]+\s*)(,\s*|\s+|$)+$" return bool(re.match(pattern, s)) + def reformat(input_string): - input_string = preliminary_reformat(input_string) - if check_format(input_string): - return input_string - # Check if the string is of the first type (numbered list) - if is_numbered_line_format(input_string): - print("formated by reformat_list", input_string) - return reformat_list(input_string) - elif is_brackets_and_text_format(input_string): - print("is_brackets_and_text_format") - return reformat_brackets_and_text(input_string) - # Check if the string is of the second type (already contains square brackets) - elif '[' in input_string and ']' in input_string: - print("formated by reformat_square_not_full") - return reformat_square_not_full(input_string) - # Check for the new type with slashes - elif '/' in input_string: - print("formated by reformat_slash") - return reformat_slash(input_string) - elif is_pipe_format(input_string): - print("formated by reformat_pipe") - return reformat_pipe(input_string) - elif is_ists_format(input_string): - print("formated by reformat_ists") - return reformat_ists(input_string) - elif is_section_format(input_string): - print("formated by is_section_format") - return reformat_with_sections(input_string) - elif is_chunk_format(input_string): - print("formated by reformat_chunks") - return reformat_chunks(input_string) - elif is_ists_marker_format(input_string): - print("formated by reformat ists markers") - return reformat_ists_markers(input_string) - # Return the original string if it does not match either type - elif is_dash_format(input_string): - print("formated by reformat with dashes") - return reformat_with_dashes(input_string) - print("ERROR! reformat did not recognize the format of string! ", input_string) - return False + input_string = preliminary_reformat(input_string) + if check_format(input_string): + return input_string + + format_checks = [ + (is_numbered_line_format, reformat_list), + (is_brackets_and_text_format, reformat_brackets_and_text), + (lambda s: "[" in s and "]" in s, reformat_square_not_full), + (lambda s: "/" in s, reformat_slash), + (is_pipe_format, reformat_pipe), + (is_ists_format, reformat_ists), + (is_section_format, reformat_with_sections), + (is_chunk_format, reformat_chunks), + (is_ists_marker_format, reformat_ists_markers), + (is_dash_format, reformat_with_dashes), + ] + + for check_func, format_func in format_checks: + result = check_and_reformat(input_string, check_func, format_func) + if result is not None: + return result + + print("ERROR! reformat did not recognize the format of string! ", input_string) + return False + def remove_empty_brackets(input_string): # Regular expression to match empty brackets (with any number of spaces) - pattern = r'\[\s*\]' - + pattern = r"\[\s*\]" + # Replace empty brackets with an empty string - cleaned_string = re.sub(pattern, '', input_string) + cleaned_string = re.sub(pattern, "", input_string) # Remove any extra spaces that might be left after removing brackets - cleaned_string = re.sub(r'\s{2,}', ' ', cleaned_string).strip() + cleaned_string = re.sub(r"\s{2,}", " ", cleaned_string).strip() return cleaned_string + def reformat_brackets(input_string): - formatted_string = re.sub(r'\[\s*(.*?)\s*\]', r'[ \1 ]', input_string) + formatted_string = re.sub(r"\[\s*(.*?)\s*\]", r"[ \1 ]", input_string) return formatted_string -def process_sentence(sentence, chunked_sentences): - # Define the instruction with line breaks to ensure each line is within 80 characters - instruction = ( - "Please divide the following sentence into iSTS chunks. Try to return chunks as a string in format [chunk 1] [chunk 2] [chunk 3] and so on... Here is the sentence: " + sentence - ) - # API call with the instruction variable - response = client.chat.completions.create( - messages=[{ - "role": "user", - "content": instruction, - }], - model="gpt-3.5-turbo", - ) +def generate_instruction(sentence): + return ( + "Please divide the following sentence into iSTS chunks. Try to return chunks as a string in format " + "[chunk 1] [chunk 2] [chunk 3] and so on... Here is the sentence: " + sentence + ) + + +def call_api(client, instruction): + return client.chat.completions.create( + messages=[ + { + "role": "user", + "content": instruction, + } + ], + model="gpt-3.5-turbo", + ) + + +def process_response(response): + chunked_sentence = response.choices[0].message.content.strip() + return chunked_sentence + + +def check_and_format_sentence(chunked_sentence): + reformated = reformat(chunked_sentence) + if reformated == False: + print("ERROR! failed to reformat! ", chunked_sentence, reformated) + return False + reformated = reformated.strip(",") + if not is_correct_format(reformated): + print("ERROR! wrong format ", chunked_sentence, reformated) + return False + return reformated + + +def finalize_formatting(reformated): + reformated = remove_empty_brackets(reformated) + reformated = reformat_brackets(reformated) + return reformated + + +def process_sentence(sentence, chunked_sentences, client): + instruction = generate_instruction(sentence) + response = call_api(client, instruction) + chunked_sentence = process_response(response) + reformated = check_and_format_sentence(chunked_sentence) + if not reformated: + return False + reformated = finalize_formatting(reformated) + chunked_sentences.append(reformated) + return chunked_sentences - chunked_sentence = response.choices[0].message.content - chunked_sentence = chunked_sentence.strip() - reformated = reformat(chunked_sentence) - if reformated == False: - print("ERROR! failed to reformat! ", chunked_sentence, reformated) - return False - reformated = reformated.strip(',') - if not is_correct_format(reformated): - print("ERROR! wrong format ", chunked_sentence, reformated) - return False - print("after reformation: ", reformated) - reformated = remove_empty_brackets(reformated) - reformated = reformat_brackets(reformated) - chunked_sentences.append(reformated) - return chunked_sentences def chunk_sentences(file_path, output_path): - # Read the sentences from the file - with open(file_path, 'r') as file: - sentences = file.readlines() + # Read the sentences from the file + with open(file_path, "r") as file: + sentences = file.readlines() - # Process each sentence - chunk_sentences = sentence_loop(sentences) - print(chunk_sentences) - # Write the chunked sentences to a new file - with open(output_path, 'w') as output_file: - for sentence in chunk_sentences: - output_file.write(sentence + '\n') + # Process each sentence + chunk_sentences = sentence_loop(sentences) + print(chunk_sentences) + # Write the chunked sentences to a new file + with open(output_path, "w") as output_file: + for sentence in chunk_sentences: + output_file.write(sentence + "\n") # Usage -file_path = 'test_goldStandard/student/STSint.testinput.answers-students.sent1.txt' -output_path = 'output.txt' +file_path = "test_goldStandard/student/STSint.testinput.answers-students.sent1.txt" +output_path = "output.txt" # Change me to os.environ['API_KEY'] -client = OpenAI(api_key='REDACTED_OPENAI_API_KEY') +client = OpenAI(api_key="REDACTED_OPENAI_API_KEY") chunk_sentences(file_path, output_path)