Sonntag, 19. Januar 2020

Writeup: Out of the sbox - Insomnihack teaser 2020

After playing Insomnihack teaser 2020 CTF yesterday, here is my writeup on how i solved the
"Out of the sbox" crypto challenge.

Spoiler: No Sboxes were harmed during this challenge.

tldr:
Instead of attacking SBoxes (as probaly intended by the author), i reduced the keyspace from 6 Bytes to 4 Bytes and created a distinguisher with only 12 SBox lookups and a few XORs (instead of 28 Sbox lookups for the full cipher). Then bruteforced 4 Bytes.



Now let's explain how this works.


Challenge overview:

When unpacking the zip one will find 3 files:
  • main.py
  • des.py
  • params.py

main.py:

def welcome():
    # Welcome client
    print(WELCOME, flush=True)
    sleep(randint(0, 3))

def challenge():
    # Generate challenge
    key = int.from_bytes(urandom(6), 'big')
    keys = key_schedule(key)
    data = []
    for i in range(50000):
        plaintext = int.from_bytes(urandom(8), 'big')
        ciphertext = enc(plaintext, keys, sboxs, perms)
        data.append((plaintext, ciphertext))

    # Validate data
    inv_keys = reorder_keys(keys)
    inv_perms = reorder_perms(perms)
    inv_sboxs = reorder_sboxs(sboxs)
    for p, c in data:
        assert(p == enc(c, inv_keys, inv_sboxs, inv_perms))

    # Print challenge
    print(str(data) + '\n', flush=True)

    return key

def reward(start, key):

    # Wait for response
    print(TASK, flush=True)
    response = input()

    # Verify response
    if time() - start < 1337 and \
       len(response.strip()) == 12 and \
       search(r'([0-9A-F]{12})', response) and \
       int(response, 16) == key:

        reward = ''.join(open('/home/ctf/flag.txt', 'r').readlines())
        print(reward, flush=True)
    else:
        print(WRONG, flush=True)

if __name__ == '__main__':
    if pow():
        start = time()
        welcome()
        key = challenge()
        reward(start, key)


So after solving the proof_of_work, the start time is saved, a welcome screen is displayed, then a random key is generated.
Next 50000 random plaintexts are generated, encrypted and sent to us together with their ciphertext.
Now we have 1337 seconds (~20min) time to send back the encryption key.


Next up: des.py

Even though the name suggests the DataEncryptionStandard algorithm, DES isn't the encryption used here.



N = 256
ROUNDS = 7

def key_schedule(key):
    keys = []
    for i in range(ROUNDS):
        left = (key >> (16*((i+1) % 3))) & 0x0000FFFF;
        right = (key >> (16*(i%3))) & 0x0000FFFF;

        keys.append((left << 16) | right);

    return keys

def apply_perm32(perm, x):
    out = 0
    for i in range(32):
        out |= ((x & (0x00000001 << i)) >> i) << (perm[i])

    return out

def fbox(x, key, sboxs, perm):
    tmp = x ^ key

    tmp = int(sboxs[0, 0x000000FF & tmp]) | int(sboxs[1, (0x0000FF00 & tmp) >> 8]) << 8 | int(sboxs[2, (0x00FF0000 & tmp) >> 16]) << 16 | int(sboxs[3, (0xFF000000 & tmp) >> 24]) << 24
    out = apply_perm32(perm, tmp);

    return out;

def enc(plaintext, keys, sboxs, perms):
    A = plaintext & 0xFFFFFFFF
    B = plaintext >> 32

    for i in range(ROUNDS):
        y = fbox(A, keys[i], sboxs[i], perms[i])
        tmp = B ^ y
        B = A
        A = tmp

    return A << 32 | B

From a quick glance at enc() we can see that the encryption is a 7 round Feistel Cipher.

The function fbox is really simple:




 Basically it's just:
  • Key Addition
  • Sbox
  • Bytewise rotate left
One note here: There are 7 rounds with 4 SBox lookups per round, but every time a different SBox is used. There are 28 SBoxs in total every of them is used exacltly once during the encryption.


For the sake of completion let's quickly mention params.py:
Here all the 28 SBoxes are stored as well as an overcomplicated notation of the permutation in every round.


First attempts:

At this point it should be clear that the author wants us to do cryptanalysis attacks on the SBoxes.
The first things that come to my mind are differential and linear attacks.
And i indeed computed a DifferenceDistributionTable (DDT) for every SBox.
Every SBox, except for the last for SBoxes, had DDT entries of at lest 10 (some of them even 14).
This means that a certain input differences yields a certain output difference with a propability of 10/256.
With 50000 plain/cipher pairs it is probably now possible to construct a probabilistic attack by constructing a trail and a verifier and doing fancy crypto stuff.

You may have noticed that i say "probably". That is because at this point i realized that i was asleep most of the time during the lecture where i should have leared this and, except for a few buzzwords, i didn't remember much and have no clue how to actually do "fancy crypto stuff" to break the cipher and get the key.

So sorry to disappoint you if you expected fany crypto SBox attacks, but this is not the writeup you are looking for.


Key schedule:

Let's take another look at the cipher.
One of the first things that got my attention was the key_schedule() function.
You input a 6byte key and get 7 4byte round keys.
However taking a closer look, you really only get 3 unique round keys, which then repeat.

Setting the masterkey to 0x554433221100 you get the following roundkeys:
  • 0x33221100
  • 0x55443322
  • 0x11005544
  •  
  • 0x33221100
  • 0x55443322
  • 0x11005544
  •  
  • 0x33221100
Even worse that the keys repeating, is that always two bytes are used together in subsequent keys.
So one should really look at the round keys 2-byte-wise:
  • 3322 - 1100
  • 5544 - 3322
  • 1100 - 5544
  •  
  • 3322 - 1100
  • 5544 - 3322
  • 1100 - 5544
  •  
  • 3322 - 1100

The solution:

Since i'm an engineer and not a math person, my final solution was (a very optimized) bruteforce.

The general idea is that with only 4 key bytes it is possible to run the encryption and decryption such that in the middle round you know some certain state.
By combining the partial state you get through encryption with the partial state you get through decryption, it is possible to compute yet another byte of state, which is otherwise unubtainable using encryption or decryption alone.
With the help of this extra byte (further called "blue"), it is possible to make an equation for yet another keybyte (further called "k4"). This creates a linear dependency of four keybtes to a fifth keybyte.
By recovering k4 is is then possible to again compute more state bytes (further called "yellow").
Now at this point we have one certain byte computed through partial decryption alone as well as a combination of partial decryption and partial encryption (yellow).

So here we can compare the byte we got through partial decryption to yellow and we got a distinguisher!

For the correct keybytes yellow is always equal to the byte got from decryption alone, however if they are not equal, we know that the 4 byte keyguess was incorrect.

So after having verified with our distinguisher that our 4 byte keyguess is a good candidate (note there are a lot of false positives here, thus only "candidate"), we can use blue once again to create and solve an equation for yet another keybyte ("k3").

Finally this allows us to use partial encryption and partial decryption to create a linear dependency between the keybytes, allowing us to bruteforce 4 bytes and compute the matching other 2 keybytes.

Here is a simple graphic to illustrate the attack:




It is best if you open the graphic in another window (or sth) and look at it whilre reading along, as i will explain now what can be seen.

First ignore every color and just look at the general overview drawn with black.
This is the full unrolled feistel cipher.
You can see that every round consists of 2 parts (A and B).
Part B' is just A copied over to B in the next round.
Part A' is computed as A'=F(A) xor B.

Now lets bring in some color. Note the green "filled boxes".
This is what we get "for free" without having to do any computation, since we know the plaintext and the ciphertext.

Now we bring in some state.
Check the brown numbers below the input. We define the state as [7][6][5][4][3][2][1][0].
And also split the state into 8 bytes in every round (where neccessary).

Next up we add the key to every round.
They are written in red at the right side of the F function inbetween the states of each round.
These are the keybytes of the masterkey we need in this round.
Notice, some of them have a green line below them.
These are the bytes we are going to guess initially.
Namely these are K0 K1 K2 and K5.

Now wer can do partial state transitions.
These are the red dots in the round.
If the red dot "floats on the top" of the square indicating a byte in the state, we aquired this byte by encrypting the plaintext.
If the red dot is at the bottom of the square, we got the byte by decrypting the ciphertext.

Since the F function is very simple you can basically just follow along by looking at the sate.
Before i show you an example of how to place the red dots, lets first define how we address the sate.
If i say enc_state_1_0, then i mean the line in the picture prefixed by "1" and the byte index "0" counting from right to left. In case of enc_state_1_0, this is the rightmost byte in the state prefixed with "1" in the picture. To index the leftmost byte in the same line, i say enc_state_1_7.
The "enc" part in this notation is not relevant for the blogpost here, but if you decide to read the final code it will tell you where we got to the state by encrypting (going from top to bottom) or by decrypting (going from bottom to top).

Now an example of how to place the dots.
Lets take a look at the red dot at enc_state_0_1 which is the second byte from the right in the line prefixed with "0".
To set a byte in A we need to know when going from encrption:
  • The byte in the round before, rotated to the right by 1 (here enc_state_plaintext_0)
  • The keybyte below the byte we need in the previous bullet point (here K0)
    • Note: we can see that we know K0 because it has a green line under it
  • The byte at the same position in the (4-byte) word, but in B in the previous round
    (here enc_state_plain_5)
    • This is becase of the xor
To set a byte in B we just copy A from the previous round (this is for encryption).


For the decryption it is kinda similar:
To set a byte in A we copy B from the next round (now going from bottom to top, but still countin from top to bottom).

Here an example to get to dec_state_5_5
To set a byte in B we need to know:
  • The byte at the same position in the (4-byte) word, but in A in the next round
    (here dec_state_6_1)
  • The byte at the position you get when you rotate B to the right by one, but in A in the same round (here dec_state_5_0)
    • Note you would compute A first by copying it from B from the next round
  • The keybte below the byte in the previous bullet point (here K0 again)
Now we can set all the red dots we can reach from encryption and decryption with only the keybytes we know.


Finally it's time to get blue!

Having computed the red dots from both sides, we see that we have:
  • dec_state_2_6 //also called r2
    • Which is the same as dec_state_1_2
  • enc_state_1_7 //also called r1
This together with K0 is what we need to compute blue (which is state_2_3).


Time to get K4 

If we copy over blue (state_2_3) to state_3_7, we can compute K4.

Normally our formula is state_4_3 = SBox[ dec_state_3_2 xor K4 ] xor blue.
But because the SBox here is invertable, we can write
K4 = inv_SBox[ dec_state_4_3 xor blue ] xor dec_state_3_2

You can see the K4 in round 3 has a blue line below it, because we got it through blue.


Adding some orange
Now that we know K4 we can add an orange line below all K4 bytes and copute the orange dots the same way we computed the red dots.

In round 2 we see that we have 2 dots in one square (at state_2_1).
The orange dot we computed is called foo and the red dot we computed is called bar.
If foo != bar then we discard the key candidate.

Next we compute K3 similar to how we computed K4.
The equation is  
K3 = inv_SBox[dec_state_3_0 xor enc_state_2_4] xor blue


Last words

Finally we write a C tool to bruteforce 4 bytes and, if it is a valid key candidate, compute the missing 2 bytes of the key, then verify the key is correct by encrypting and verifying 100 plaintexts.
Add some communication with the gameserver and print the flag. Yay!

Attack takes just a few minutes with 4 threads running on a MacBook Pro 2015.

This challenge kept me busy all day, hope you enjoyed my solution.
I'm curious what the intended way to solve this would have been.

Cheers,
tihmstar



PS: Check out my exploit code
that is, if you manage to compile of course ;)



//
//  main.cpp
//  insomni-crypto
//
//  Created by tihmstar on 19.01.20.
//  Copyright © 2020 tihmstar. All rights reserved.
//

#include <iostream>
#include "sboxes.h"
#include <stdio.h>

#include <sys/socket.h>
#include <arpa/inet.h>
#include <unistd.h>
#include <string.h>
#include <vector>
#include <future>

using namespace std;

uint8_t inverted_sbox1[0x100];
uint8_t inverted_sbox2[0x100];

#define K(byte) ((candidate >> (8*byte)) & 0xff)
#define P(byte) ((plain >> (8*byte)) & 0xff)
#define C(byte) ((cipher >> (8*((byte+4)%8))) & 0xff)


uint64_t encrypt(uint64_t key, uint64_t plain){
    uint32_t A = plain & 0xFFFFFFFF;
    uint32_t B = plain >> 32;
  
    uint32_t keys[7];
    keys[6] = keys[3] = keys[0] = (key & 0xffffffff);
    keys[4] = keys[1] = ((key >> 16) & 0xffffffff);
    keys[5] = keys[2] = (uint32_t)(((key >> 32) & 0xffff) | ((key & 0xffff) << 16));

    for (int i=0; i<7; i++) {
        uint32_t y;
        {
            uint32_t tmp = A ^ keys[i];
            tmp = sboxs[i][0][0x000000FF & tmp] | sboxs[i][1][(0x0000FF00 & tmp) >> 8] << 8 | sboxs[i][2][(0x00FF0000 & tmp) >> 16] << 16 | sboxs[i][3][(0xFF000000 & tmp) >> 24] << 24;
            y = (tmp << 8) | (tmp >> 24);
        }
        uint32_t tmp = B ^ y;
        B = A;
        A = tmp;
    }

    return ((uint64_t)A << 32) | B;
}


static uint64_t debug = 0;

uint64_t check_candidate(uint64_t candidate, uint64_t plain, uint64_t cipher){
//    printf("checking candidate=0x%012llx\n",candidate);
  
    if (candidate == debug) {
      
        printf("");
    }
  
    uint8_t enc_state_0_3 = sboxs[0][2][P(2) ^ K(2)] ^ P(7);
    uint8_t enc_state_1_7 = enc_state_0_3; //r1
    uint8_t enc_state_1_5 = sboxs[0][0][P(0) ^ K(0)] ^ P(5); //r6
    uint8_t enc_state_1_0 = sboxs[1][3][enc_state_0_3 ^ K(5)] ^ P(0); //r5
    uint8_t enc_state_2_4 = enc_state_1_0; //r7

  
  
    uint8_t dec_state_5_7 = sboxs[5][2][C(6) ^ K(2)] ^ C(3);
    uint8_t dec_state_4_3 = dec_state_5_7; //r3
  
    uint8_t dec_state_5_5 = sboxs[6][0][C(4) ^ K(0)] ^ C(1);
    uint8_t dec_state_4_1 = dec_state_5_5;
    uint8_t dec_state_4_5 = sboxs[5][1][dec_state_4_1 ^ K(5)] ^ C(6);
    uint8_t dec_state_3_2 = dec_state_4_5; //r4
  
  
    uint8_t dec_state_4_4 = sboxs[5][3][dec_state_4_3 ^ K(1)] ^ C(4);
    uint8_t dec_state_3_0 = dec_state_4_4;
    uint8_t dec_state_3_5 = sboxs[4][1][dec_state_3_0 ^ K(2)] ^ dec_state_4_1;
    uint8_t dec_state_2_1 = dec_state_3_5;
    uint8_t dec_state_2_6 = sboxs[3][1][dec_state_2_1 ^ K(1)] ^ dec_state_3_2;//r2
  
    uint8_t blau = sboxs[2][2][dec_state_2_6/*r2*/ ^ K(0)] ^ enc_state_1_7 /*r1*/;

    uint8_t k4 = inverted_sbox1[dec_state_4_3/*r3*/ ^ blau] ^ dec_state_3_2/*r4*/;
  
    uint8_t foo = sboxs[2][0][enc_state_1_0/*r5*/ ^ k4] ^ enc_state_1_5/*r6*/;
    uint8_t bar = dec_state_2_1;
  
    if (foo != bar) return 0;
  
    uint8_t k3 = inverted_sbox2[dec_state_3_0/*r8*/ ^ enc_state_2_4/*r7*/] ^ blau;

    return candidate | ((uint64_t)k4 << 32) | ((uint64_t)k3 << 24);
}


#define PORT 1337
//#define HOST "127.0.0.1"
#define HOST "34.65.198.15"

int main(int argc, const char * argv[]) {
    printf("start\n");

    for (int i=0; i<0x100; i++) {
        inverted_sbox1[sboxs[4][2][i]] = i;
        inverted_sbox2[sboxs[3][3][i]] = i;
    }
  
    int sock = 0;
    size_t bufferSize = 0x100 * 50000;
    char *buffer = (char *)malloc(bufferSize);
    struct sockaddr_in serv_addr;
    if ((sock = socket(AF_INET, SOCK_STREAM, 0)) < 0){
        printf("\n Socket creation error \n");
        return -1;
    }
    serv_addr.sin_family = AF_INET;
    serv_addr.sin_port = htons(PORT);
      
    if(inet_pton(AF_INET, HOST, &serv_addr.sin_addr)<=0){
        printf("\nInvalid address/ Address not supported \n");
        return -1;
    }
  
    if (connect(sock, (struct sockaddr *)&serv_addr, sizeof(serv_addr)) < 0){
        printf("\nConnection Failed \n");
        return -1;
    }
  
    read(sock, buffer, bufferSize-1);
  
    string chal = buffer;
    chal = chal.substr(chal.find("\"")+1);
    chal = chal.substr(0,chal.rfind("\""));
    printf("got chal='%s'\n",chal.c_str());
  
//#warning DEBUG
  
    std::string userinput;
    cin >> userinput;
  
    userinput += "\n";
  
    printf("got userinput: %s\n",userinput.c_str());
  
    write(sock, userinput.c_str(), userinput.size());
//    write(sock, "DEBUG\n", sizeof("DEBUG\n")-1);

    read(sock, buffer, bufferSize-1); //"Brace yourself, your favorite oracle has a challenge for you...\n"
  
    printf("reading input...\n");

    ssize_t didRead = 0;
    while(true){
        ssize_t r = read(sock, buffer+didRead, bufferSize-1-didRead);
        didRead += r;
        if (strstr(buffer, "Can you recover the key?")) {
            break;
        }
        printf("reading more...\n");
    };
  
    if (didRead >= bufferSize) {
        printf("we read too much!\n");
        exit(1);
    }
  
    std::string input = buffer;

  
    printf("got input!\n");
  
  
  
    std::vector<pair<uint64_t, uint64_t>> pairs;
  
    ssize_t cpos = -1;
    while (pairs.size() < 100) {
      
        cpos = input.find("(",cpos+1);
        ssize_t ccomma = input.find(",",cpos);
        std::string plain = input.substr(cpos+1, ccomma-cpos-1);
        ssize_t closing = input.find(")",ccomma);
        std::string cipher = input.substr(ccomma+1, closing-ccomma-1);
      
        uint64_t pp = 0;
        uint64_t cc = 0;
      
        char *z = NULL;
      
        pp = strtoull(plain.c_str(),&z,10);
        cc = strtoull(cipher.c_str(),&z,10);

        pairs.push_back({pp,cc});
    }

    ssize_t newline = input.find("Can you recover the key?");
    input = input.substr(newline);
    printf("input=%s\n",input.c_str());
  
    printf("got 100 pairs!\n");
  
    bool keepRunning = true;
  
    int threadsCnt = 4;
  
    std::vector<future<uint64_t>> threads;
  
        for (int i=0; i<threadsCnt; i++){
            threads.push_back(std::async(std::launch::async,[threadsCnt, &pairs, &keepRunning](int worker)->uint64_t{
                uint64_t iter = 0x100000000 / threadsCnt;
                uint64_t witer = iter*worker;
              
              
                uint64_t plain = pairs.at(0).first;
                uint64_t cipher = pairs.at(0).second;
              
              
                for (uint64_t z = 0; keepRunning && z<iter; z++) {
                    uint64_t i = z + witer;
                  
                    if (z % 0x1000000 == 0) {
                        printf("[%d] cand=0x%012llx\n",worker,i);
                    }
                  
                    uint64_t cand = (i >> 8) | ((i & 0xff) << 40);
                    if (uint64_t candidate = check_candidate(cand, plain, cipher)){
//                        printf("key candidate=0x%012llx\n\n",candidate);
                      
                        bool bad_key = false;
                      
                        for (auto &p: pairs) {
                            if (encrypt(candidate, p.first) != p.second) {
                                bad_key = true;
                                break;
                            }
                        }
                      
                        if (!bad_key) {
                            printf("GOT KEY 0x%012llx\n\n",candidate);
                            keepRunning = false;
                            return candidate;
                        }
                    }
                }
              
              
                return 0;
            },i));
        }
  
    printf("waiting for workers to finish...\n");
    uint64_t key = 0;
    for (int i=0; i<threadsCnt; i++){
        threads[i].wait();
        uint64_t lk = threads[i].get();
        printf("lk[%d]=0x%012llx\n",i,lk);
        if (!key){
            if (lk) {
                key = lk;
                printf("FINALLY GOT KEY 0x%012llx\n\n",key);
            }
        }
    }
    printf("all workers finished!\n");

    memset(buffer, 0, bufferSize);
  
    snprintf(buffer, bufferSize, "%012llX\n",key);
    printf("sending: %s",buffer);
  
    write(sock, buffer, strlen(buffer));
  
    printf("response:\n");
    for (int i=0; i< 5; i++) {
        ssize_t s = read(sock, buffer, bufferSize-1);
        if (s <=0) {
            break;
        }
        printf("%s",buffer);
    }
  
      
//    /////// --- --- -- - -- - -
//    uint64_t key = 0x5a4142022e09;
//    uint64_t plain = 0x85d8888f97c18284;
//    uint64_t cipher = 0xf50f82a8f699904e;
//    printf("key=0x%012llx\n\n",key);
//
//
//    //begin
//    uint64_t i = ((key & 0xffffff) << 8) | ((key >> 40) & 0xff);
//    uint64_t cand = (i >> 8) | ((i & 0xff) << 40);
//
//    if (uint64_t candidate = check_candidate(cand, plain, cipher)){
//        printf("key candidate=0x%012llx\n\n",candidate);
//        uint64_t e = encrypt(candidate, plain);
//    }
  
    printf("done!\n");
    return 0;
}