Reducing Clojure Lambda Cold Starts Part 7 - More Realistic Workloads Java & JS

Reducing Clojure Lambda Cold Starts Part 7 - More Realistic Workloads Java & JS

To summarize my previous posts, it's looking like my suspicions that ClojureScript Lambdas are much faster to initialize, but slower after initialization than Clojure could be correct for more computation-heavy workloads, or at least with my particular workload. This got me wondering if the same was true for Java vs. JavaScript. Let's investigate.

Since neither Java or JavaScript have persistent (immutable) data structures, we'll need to change the Clojure/Script code to just sort the items, rather than enqueuing them in a priority queue. The Clojure/Script sort-by function does not do sorting in place, but hopefully, the performance isn't too much worse than the Java and JavaScript in place sorting:

src/cljc/tax/calcs.cljc:

(ns tax.calcs)

(defn calculate-aux [items]
  ;; realizing the items with mapv to print calc time
  (mapv
   (fn [{:keys [a b c d] :as item}]
     (let [x (+ a b c d)
           y (/ x c)
           z (* y a b c d)]
       {:x x :y y :z z}))
   items))

(defn calculate [items]
  (prn "SORTING")
  (let [sorted (time (sort-by :a items))]
    (prn "CALCULATING")
    (time (calculate-aux sorted))))

Java

Now we'll create the roughly equivalent Java. First we need our data POJOs:

src/java/tax/Data.java:

package tax;

public class Data {

    private double a;
    private double b;
    private double c;
    private double d;

    //... getters/setters omitted
}

src/java/tax/Calc.java:

package tax;

public class Calc {

    private double x;
    private double y;
    private double z;

    public Calc(double x, double y, double z) {
        this.x = x;
        this.y = y;
        this.z = z;
    }

    //... getters/setters omitted
}

Then we need src/java/tax/Calcs.java:

package tax;


import java.util.Comparator;
import java.util.List;
import java.util.function.Supplier;

import java.util.stream.Collectors;

public class Calcs {

    private static Calc calculateItem(Data data) {
        double a = data.getA();
        double b = data.getB();
        double c = data.getC();
        double d = data.getD();
        double x = a + b + c + d;
        double y = x / c;
        double z = y * a * b * c * d;

        return new Calc(x, y, z);
    }

    public static <T> T time(Supplier<T> supplier) {
        long start = System.currentTimeMillis();
        T result = supplier.get();
        long end = System.currentTimeMillis();
        System.out.println(String.format("Elapsed time: %s msecs", end - start));
        return result;
    }

    public static List<Calc> calculate(List<Data> data) {
        System.out.println("SORTING");
        time(() -> { data.sort(Comparator.comparing(Data::getA)); return data; });

        System.out.println("CALCULATING");
        return time(() ->  data.stream().map(Calcs::calculateItem).collect(Collectors.toList()));
    }
}

And update our src/java/tax/core.clj:

package tax;

import java.io.InputStream;
import java.io.IOException;

import java.nio.charset.StandardCharsets;

import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.core.JsonProcessingException;

import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.ListBucketsRequest;

import software.amazon.awssdk.services.s3.model.GetObjectRequest;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.PutObjectResponse;

import software.amazon.awssdk.core.sync.RequestBody;

public class core {

    private static S3Client client = S3Client.builder().build();

    private static String outputBucket = System.getenv("CALCULATIONS_BUCKET");

    private static ObjectMapper mapper = new ObjectMapper();

    private PutObjectResponse putObject(String bucketName, String objectKey, String body) {
        PutObjectRequest req = PutObjectRequest.builder().bucket(bucketName).key(objectKey).build();
        return client.putObject(req, RequestBody.fromString(body));
    }

    private String getObjectAsString(String bucketName, String objectKey) {
        try {
            GetObjectRequest req = GetObjectRequest.builder().bucket(bucketName).key(objectKey).build();
            InputStream stream = client.getObjectAsBytes(req).asInputStream();
            return new String(stream.readAllBytes(), StandardCharsets.UTF_8);
        } catch (IOException e) {
            e.printStackTrace();
            return null;
        }
    }

    private static class Props {
        public String bucket;
        public String key;
    }

    private static Data readLine(String line) {
        try {
            return mapper.readValue(line, Data.class);
        } catch (JsonProcessingException e) {
            e.printStackTrace();
            return null;
        }
    }

    private static String writeItem(Calc item) {
        try {
            return mapper.writeValueAsString(item);
        } catch (JsonProcessingException e) {
            e.printStackTrace();
            return null;
        }
    }

    private static List<Data> toItems(String input) {
        return Stream.of(input.split("\n")).map(core::readLine).collect(Collectors.toList());           
    }

    private static String toJsonOutput(List<Calc> input) {
        return input.stream().map(core::writeItem).collect(Collectors.joining("\n"));
    }

    public Object calculationsHandler(Map<String, Object> event) throws JsonProcessingException {
        List records = (List)event.get("Records");
        Map<String, Object> record = (Map<String, Object>)records.get(0);

        String messageBody = (String)record.get("body");
        Props props = mapper.readValue(messageBody, Props.class);

        System.out.println("GETTING OBJECT");
        String input = Calcs.time(() -> getObjectAsString(props.bucket, props.key));

        System.out.println("PARSING INPUT");
        List<Data> inputLines = Calcs.time(() -> toItems(input));

        List<Calc> calculatedItems = Calcs.calculate(inputLines);

        System.out.println("CONVERTING OT OUTPUT");

        String outputString = Calcs.time(() -> toJsonOutput(calculatedItems));

        System.out.println("PUTTING TO OUTPUT");

        Calcs.time(() -> putObject(outputBucket, props.key, outputString));
        return event;
    }
}

A few notes on the Java code here:

  • In real code, I would probably use the published Java SQS event type rather than doing all the casting, but the dependencies added a bunch of transient dependencies, and dependencies matter a lot when it comes to Lambda, so I decided to just stick with the same dependencies as the Clojure version.
  • While Java is much, much more verbose than Clojure and took like 5 times as long to type out, there is somewhat of a more cozy feeling I get with the strict static typing.
  • My exception handling here is as bad as it is with the Clojure code I've been using, but production-grade error handling would just muddy the waters here.

JavaScript

Updating template.yml:

  RunCalculationsJS:
    Type: AWS::Serverless::Function
    Properties:
      FunctionName: !Sub "${AWS::StackName}-run-calcs-js"
      Handler: index.handler
      Runtime: nodejs14.x
      Timeout: 900
      MemorySize: 128
      Policies:
        - AWSLambdaBasicExecutionRole
        - S3ReadPolicy:
            BucketName: !Ref TransactionsBucket
        - S3WritePolicy:
            BucketName: !Ref CalculationsBucket
        - Version: '2012-10-17' 
          Statement:
            - Effect: Allow
              Action:
                - s3:ListAllMyBuckets
              Resource: 'arn:aws:s3:::*'
      Environment:
        Variables:
          TRANSACTIONS_BUCKET: !Ref TransactionsBucket
          CALCULATIONS_BUCKET: !Ref CalculationsBucket
      Events:
        SQSEvent:
          Type: SQS
          Properties:
            Queue: !GetAtt RunJavaScriptCalculationsQueue.Arn
            BatchSize: 1
      InlineCode: |
        const AWS = require('aws-sdk');
        const client = new AWS.S3();

        const outputBucket = process.env.CALCULATIONS_BUCKET;

        const getObjectAsString = async (Bucket, Key) => {
          const {Body} = await client.getObject({Bucket, Key}).promise();
          return Body.toString("utf-8");
        };

        const putObject = async (Bucket, Key, Body) => {
          return await client.putObject({Bucket, Key, Body}).promise();
        };

        const time = async(func) => {
          const start = Date.now();
          const result = await func();
          const end = Date.now();
          console.log(`Elapsed time: ${end - start} msecs`);
          return result;
        };

        const toItems = async(input) => {
          return input.split("\n").map(JSON.parse);
        };

        const compareItems = (a, b) => a.a - b.a;

        const calculateItem = ({a, b, c, d}) => {
          const x = a + b + c + d;
          const y = x / c;
          const z = y * a * b * c * d;
          return {x, y, z};
        };

        const calculate = async(inputLines) => {
          console.log("SORTING")
          const sortedLines = await time(() => inputLines.sort(compareItems));
          console.log("CALCULATING");
          return await time(() => inputLines.map(calculateItem));
        };

        const toJsonOutput = async(items) => {
          return items.map(JSON.stringify).join("\n");
        };

        exports.handler = async function(event) {
          const {body} = event.Records[0];
          const {bucket, key} = JSON.parse(body);

          console.log("GETTING OBJECT");
          const input = await time(() => getObjectAsString(bucket, key));

          console.log("PARSING INPUT");
          const inputLines = await time(() => toItems(input));

          const calculatedItems =  await calculate(inputLines);

          console.log("CONVERTING TO OUTPUT");
          const outputString = await time(() => toJsonOutput(calculatedItems));

          console.log("PUTTING TO OUTPUT");
          return await time(() => putObject(outputBucket, key, outputString));
        }

Running the SQS blaster to cause each Lambda to be invoked 1000 times, we get:

LangAvg InitAvg DurationAvg Warm DurationInit CountInvoke Count
CLJ5423.44051586.36981057.5869221002
CLJS497.26937218.83686985.2084891003
JAVA2532.00781090.0055755.8679181001
JS467.3094961.8722923.3941321007

Some things that stick out to me in these results:

  • The Java version was about 30% faster for warmed durations, which is actually better than I would have expected, I would have thought there would be a larger overhead, especially since the Clojure version uses persistent data structures and sorts into a new sequence while the Java version just sorts in place. There is also overhead in using Clojure persistent maps for each data item vs. lightweight objects. I might investigate using Clojure records and other optimizations to see how close to parity I can get the Clojure to the Java.
  • It's really surprising that the JavaScript version is so much faster than the ClojureScript version. I need to investigate what is going on there and perhaps try some optimizations.
  • I was surprised that the JavaScript version was so close to the Java one when it comes to the warmed duration, roughly 20% slower, but averaging in the init durations, it was actually slightly faster. I would have assumed that sorting and computation would be a lot worse in JavaScript than Java.

Conclusion

From these results, it seems that, with Clojure/Script Lambdas, we might be stuck with the tradeoff of either bad init times and decent run times, or good init times and bad run times. I, of course, need to gather more detailed metrics about where the computation time is spent in each version and see how much I can optimize each one. I'll investigate that in my next post.