Skip to content

8.5. Inject logging#

In this example, we use the BaseRewriter class to achieve something different: instead of removing nodes, we edit them to insert snippets of code. In particular, we insert in every function a call to a fictitious log function, using the name of the function to make it a bit more interesting.

In essence, the LoggingRewriter class below works as follow:

  1. When traversing a function definition, we collect the name and store it in a local variable. Then, we proceed to recurse, using the rewriteChildren function of BaseRewriter. Before returning, the name is cleared off.
  2. When traversing the list of statements of a function, if the name of the function is set, we start by parsing the new statement to inject: the function call to the log function. Then, we construct the new statements of the function by prepending the injected code (note the unshift) into the children of the function statements.
  3. Since the log function should come from somewhere, we also import it, appending its import at the end of the source members of the file.
logging-rewriter.mts
import { BaseRewriter, Edge, NonterminalKind, EdgeLabel, Node, NonterminalNode } from "@nomicfoundation/slang/cst";
import { Parser } from "@nomicfoundation/slang/parser";
import { LanguageFacts } from "@nomicfoundation/slang/utils";

export class LoggingRewriter extends BaseRewriter {
  functionName: string | undefined;
  injected = false;
  parser: Parser;

  constructor() {
    super();
    this.parser = Parser.create(LanguageFacts.latestVersion());
  }

  // collect the name of the function being travered
  public override rewriteFunctionDefinition(node: NonterminalNode): Node | undefined {
    const name = node.children().find((edge) => edge.label == EdgeLabel.Name);
    if (!name) {
      return node;
    }

    this.functionName = name.node.unparse().trim();
    // in the recursion is were the injection of code is actually performed
    const recurse = this.rewriteChildren(node);
    this.functionName = undefined;
    return recurse;
  }

  // once in the statements of a function, inject a call to the `log` function.
  public override rewriteStatements(node: NonterminalNode): Node | undefined {
    if (this.functionName) {
      this.injected = true;

      // the injected code
      const toInject = this.parser.parseNonterminal(
        NonterminalKind.ExpressionStatement,
        `    log("${this.functionName}");\n`,
      ).tree;

      // inject the node at the beginning of statements, and return the new node containing it
      const children = node.children();
      children.unshift(Edge.createWithNonterminal(EdgeLabel.Item, toInject));
      return NonterminalNode.create(NonterminalKind.Statements, children);
    }
    return node;
  }

  // at the end of the file, inject the import of the `log` function.
  public override rewriteSourceUnitMembers(node: NonterminalNode): Node | undefined {
    const newNode = this.rewriteChildren(node);

    if (!this.injected) {
      // No function was found, return
      return node;
    }

    const importMember = this.parser.parseNonterminal(
      NonterminalKind.SourceUnitMember,
      '\nimport { log } from "__logging.sol";\n',
    ).tree;
    const newChildren = newNode.children();
    newChildren.push(Edge.createWithNonterminal(EdgeLabel.Item, importMember));
    return NonterminalNode.create(NonterminalKind.SourceUnitMembers, newChildren);
  }
}

Again, we test the functionality on the ongoing Solidity example from Section 4. Note how the code was properly inserted in the relevant locations.

test-logging-rewriter.mts
import assert from "node:assert";
import { CONTRACT_VFS } from "../../04-find-unused-definitions/examples/test-find-unused-definitions.test.mjs";
import { buildCompilationUnit } from "../../common/compilation-builder.mjs";
import { LoggingRewriter } from "./logging-rewriter.mjs";

const EXPECTED_VFS = new Map<string, string>([
  [
    "contract.sol",
    `
abstract contract Ownable {
  address _owner;
  constructor() {
    _owner = msg.sender;
  }
  modifier onlyOwner() {
    require(_owner == msg.sender);
    _;
  }
  function checkOwner(address addr) internal returns (bool) {
    log("checkOwner");
    return _owner == addr;
  }
}

contract Counter is Ownable {
  uint _count;
  uint _unused;
  constructor(uint initialCount) {
    _count = initialCount;
  }
  function count() public view returns (uint) {
    log("count");
    return _count;
  }
  function increment(uint delta, uint multiplier) public onlyOwner returns (uint) {
    log("increment");
    require(delta > 0, "Delta must be positive");
    _count += delta;
    return _count;
  }
  function unusedDecrement() private {
    log("unusedDecrement");
    require(checkOwner(msg.sender));
    _count -= 1;
  }
}

import { log } from "__logging.sol";
    `,
  ],
]);

test("inject logging", async () => {
  const unit = await buildCompilationUnit(CONTRACT_VFS, "0.8.0", "contract.sol");

  const loggingRewriter = new LoggingRewriter();
  for (const file of unit.files()) {
    const newNode = loggingRewriter.rewriteNode(file.tree);
    assert.strictEqual(newNode?.unparse(), EXPECTED_VFS.get(file.id));
  }
});