8.5. Inject logging#
In this example, we use the BaseRewriter
class to achieve something different: instead of removing nodes, we edit them to insert snippets of code. In particular, we insert in every function a call to a fictitious log
function, using the name of the function to make it a bit more interesting.
In essence, the LoggingRewriter
class below works as follow:
- When traversing a function definition, we collect the name and store it in a local variable. Then, we proceed to recurse, using the
rewriteChildren
function ofBaseRewriter
. Before returning, the name is cleared off. - When traversing the list of statements of a function, if the name of the function is set, we start by parsing the new statement to inject: the function call to the
log
function. Then, we construct the new statements of the function by prepending the injected code (note theunshift
) into the children of the function statements. - Since the
log
function should come from somewhere, we also import it, appending itsimport
at the end of the source members of the file.
logging-rewriter.mts
import { BaseRewriter, Edge, NonterminalKind, EdgeLabel, Node, NonterminalNode } from "@nomicfoundation/slang/cst";
import { Parser } from "@nomicfoundation/slang/parser";
import { LanguageFacts } from "@nomicfoundation/slang/utils";
export class LoggingRewriter extends BaseRewriter {
functionName: string | undefined;
injected = false;
parser: Parser;
constructor() {
super();
this.parser = Parser.create(LanguageFacts.latestVersion());
}
// collect the name of the function being travered
public override rewriteFunctionDefinition(node: NonterminalNode): Node | undefined {
const name = node.children().find((edge) => edge.label == EdgeLabel.Name);
if (!name) {
return node;
}
this.functionName = name.node.unparse().trim();
// in the recursion is were the injection of code is actually performed
const recurse = this.rewriteChildren(node);
this.functionName = undefined;
return recurse;
}
// once in the statements of a function, inject a call to the `log` function.
public override rewriteStatements(node: NonterminalNode): Node | undefined {
if (this.functionName) {
this.injected = true;
// the injected code
const toInject = this.parser.parseNonterminal(
NonterminalKind.ExpressionStatement,
` log("${this.functionName}");\n`,
).tree;
// inject the node at the beginning of statements, and return the new node containing it
const children = node.children();
children.unshift(Edge.createWithNonterminal(EdgeLabel.Item, toInject));
return NonterminalNode.create(NonterminalKind.Statements, children);
}
return node;
}
// at the end of the file, inject the import of the `log` function.
public override rewriteSourceUnitMembers(node: NonterminalNode): Node | undefined {
const newNode = this.rewriteChildren(node);
if (!this.injected) {
// No function was found, return
return node;
}
const importMember = this.parser.parseNonterminal(
NonterminalKind.SourceUnitMember,
'\nimport { log } from "__logging.sol";\n',
).tree;
const newChildren = newNode.children();
newChildren.push(Edge.createWithNonterminal(EdgeLabel.Item, importMember));
return NonterminalNode.create(NonterminalKind.SourceUnitMembers, newChildren);
}
}
Again, we test the functionality on the ongoing Solidity example from Section 4. Note how the code was properly inserted in the relevant locations.
test-logging-rewriter.mts
import assert from "node:assert";
import { CONTRACT_VFS } from "../../04-find-unused-definitions/examples/test-find-unused-definitions.test.mjs";
import { buildCompilationUnit } from "../../common/compilation-builder.mjs";
import { LoggingRewriter } from "./logging-rewriter.mjs";
const EXPECTED_VFS = new Map<string, string>([
[
"contract.sol",
`
abstract contract Ownable {
address _owner;
constructor() {
_owner = msg.sender;
}
modifier onlyOwner() {
require(_owner == msg.sender);
_;
}
function checkOwner(address addr) internal returns (bool) {
log("checkOwner");
return _owner == addr;
}
}
contract Counter is Ownable {
uint _count;
uint _unused;
constructor(uint initialCount) {
_count = initialCount;
}
function count() public view returns (uint) {
log("count");
return _count;
}
function increment(uint delta, uint multiplier) public onlyOwner returns (uint) {
log("increment");
require(delta > 0, "Delta must be positive");
_count += delta;
return _count;
}
function unusedDecrement() private {
log("unusedDecrement");
require(checkOwner(msg.sender));
_count -= 1;
}
}
import { log } from "__logging.sol";
`,
],
]);
test("inject logging", async () => {
const unit = await buildCompilationUnit(CONTRACT_VFS, "0.8.0", "contract.sol");
const loggingRewriter = new LoggingRewriter();
for (const file of unit.files()) {
const newNode = loggingRewriter.rewriteNode(file.tree);
assert.strictEqual(newNode?.unparse(), EXPECTED_VFS.get(file.id));
}
});