PlaidCTF Compression
29 Apr 2013 in Security (Reading time: 6 minutes)PlaidCTF 2013 had a level called "Compression". Here's the provided code for this level:
#!/usr/bin/python import os import struct import SocketServer import zlib from Crypto.Cipher import AES from Crypto.Util import Counter # Not the real keys! ENCRYPT_KEY = '0000000000000000000000000000000000000000000000000000000000000000'.decode('hex') # Determine this key. # Character set: lowercase letters and underscore PROBLEM_KEY = 'XXXXXXXXXXXXXXXXXXXX' def encrypt(data, ctr): aes = AES.new(ENCRYPT_KEY, AES.MODE_CTR, counter=ctr) return aes.encrypt(zlib.compress(data)) class ProblemHandler(SocketServer.StreamRequestHandler): def handle(self): nonce = os.urandom(8) self.wfile.write(nonce) ctr = Counter.new(64, prefix=nonce) while True: data = self.rfile.read(4) if not data: break try: length = struct.unpack('I', data)[0] if length > (1<<20): break data = self.rfile.read(length) data += PROBLEM_KEY ciphertext = encrypt(data, ctr) self.wfile.write(struct.pack('I', len(ciphertext))) self.wfile.write(ciphertext) except: break class ReusableTCPServer(SocketServer.ForkingMixIn, SocketServer.TCPServer): allow_reuse_address = True if __name__ == '__main__': HOST = '0.0.0.0' PORT = 4433 SocketServer.TCPServer.allow_reuse_address = True server = ReusableTCPServer((HOST, PORT), ProblemHandler) server.serve_forever()
So there's a few interesting things of note here:
- They take user-supplied input, concatenate the flag, and then encrypt and return the value.
- Input is limited to 1MB (1
- We're compressing with gzip and then encrypting with AES in CTR mode.
- It's a 128 bit nonce: 8 bytes of urandom, followed by a 64-bit counter.
So I start thinking about the fact that AES (or any block cipher) in CTR mode is really a stream cipher -- they encrypt the counter with the key, to produce a keystream, then XOR with the plaintext to get the ciphertext. In particular, pycrypto guarantees that len(input) == len(output). Given the name of the level (Compression) I start thinking about approaches to get information out of the ciphertext length.
At this point, it's worth revisiting the design of the DEFLATE algorithm (used by the zlib.compress call in the compression.py program). DEFLATE uses a combination of Huffman coding and LZ77/LZ78-style duplicate string elimination. In this context, I believe the duplicate string elimination plays the bigger role -- this part takes repeated sections of the input and, for the 2nd and later instance, includes a pointer back to first instance that is shorter than the original string. For our purposes, that means if we provide input that contains substrings of the unknown key, we will get a shorter response than if our string is completely different than the flag. To test my theory, I fired off 27 tries to the server, each containing one of the valid ([a-z_]) characters repeated 3 times: all responses, save one, were the same length (30 bytes). Only the repeated 'c' string came back at 29 bytes. This led me to believe that 'c' probably started the flag. (If it only needed to be in the flag, more than one character would likely have returned a different length.)
I put together a script to go through character by character and look for lengths that were shorter from the rest. During my first couple of runs, it would frequently get stuck until I hit upon the idea of putting the test string multiple times, increasing the likelihood that duplicate string elimination would use the entire thing. Eventually, I had a few candidate flags, but from glancing at them, it was clear what the answer was...
import struct import socket import sys import collections HOST = 'ip.add.res.s' PORT = 4433 def try_val(val): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect((HOST, PORT)) nonce = sock.recv(8) sock.send(struct.pack('I', len(val))) sock.send(val) data = sock.recv(4) recv_len = struct.unpack('I', data)[0] data = sock.recv(recv_len) return (nonce, data) def get_candidate_len(can): nonce, data = try_val(can) return len(data) def try_layer(prefix): if len(prefix) == 20: print "Found candidate %s!" % prefix return candidates = 'abcdefghijklmnopqrstuvwxyz_' print "Trying %s" % prefix sys.stdout.flush() samples = {} for c in candidates: val = prefix + c samples[val] = get_candidate_len(val*2 if len(val)>9 else val*5) m = mode(samples.values()) for k in samples: if samples[k] == m: continue try_layer(k) def mode(l): c = collections.Counter(l) return c.most_common(1)[0][0] try_layer('')