题目链接

小红的能量校准

题目描述

小红和小紫在探险中遇到了“能量平衡门”。门上有一个包含变量 的线性方程。

  • 公式包含变量 、非负整数、运算符 +*()
  • 存在隐式乘法规则:数字、变量 或括号紧密相邻时表示乘法。例如:2(x+1)x3(x+1)22x
  • 方程中 的最高次幂为 1,且 在等号左侧恰好出现一次。
  • 等号右侧是一个常数(可能为负)。
  • 求解整数

输入:

  • 一行字符串 ,长度在 5 到 1000 之间。

输出:

  • 一个整数

解题思路

由于 在方程中只出现一次且最高次数为 1,等式左边的表达式可以抽象为一次函数形式 。我们的目标是求出系数 和常数项 ,然后解方程 ,即

  1. 预处理隐式乘法: 遍历字符串,在满足隐式乘法条件的位置插入显式的 * 运算符:

    • 数字后紧跟 x(
    • x 后紧跟数字或 (
    • ) 后紧跟数字、x(
  2. 表达式求值: 使用两个栈(操作数栈和操作符栈)进行中缀表达式求值。

    • 操作数栈存储结构体 Node {a, b},代表
    • 数字 转化为 Node {0, n}
    • 变量 转化为 Node {1, 0}
    • 加法
    • 乘法。由于 只出现一次,则 至少有一个为 0。
      • 若均为 0:
  3. 求解: 将左侧表达式化简为 后,代入右侧常数 ,计算

代码

#include <iostream>
#include <string>
#include <vector>
#include <stack>
#include <cctype>

using namespace std;

typedef long long LL;

struct Node {
    LL a, b; // ax + b
};

Node add(Node n1, Node n2) {
    return {n1.a + n2.a, n1.b + n2.b};
}

Node mul(Node n1, Node n2) {
    if (n1.a != 0) return {n1.a * n2.b, n1.b * n2.b};
    if (n2.a != 0) return {n2.a * n1.b, n1.b * n2.b};
    return {0, n1.b * n2.b};
}

int priority(char op) {
    if (op == '+') return 1;
    if (op == '*') return 2;
    return 0;
}

void compute(stack<Node>& nums, stack<char>& ops) {
    Node n2 = nums.top(); nums.pop();
    Node n1 = nums.top(); nums.pop();
    char op = ops.top(); ops.pop();
    if (op == '+') nums.push(add(n1, n2));
    else nums.push(mul(n1, n2));
}

string preprocess(string s) {
    string res = "";
    for (int i = 0; i < s.length(); ++i) {
        res += s[i];
        if (i + 1 < s.length()) {
            char curr = s[i], next = s[i + 1];
            bool c_d = isdigit(curr), n_d = isdigit(next);
            bool c_x = (curr == 'x'), n_x = (next == 'x');
            bool c_open = (curr == '('), n_open = (next == '(');
            bool c_close = (curr == ')'), n_close = (next == ')');
            
            if ((c_d && (n_x || n_open)) || (c_x && (n_d || n_open)) || (c_close && (n_d || n_x || n_open))) {
                res += '*';
            }
        }
    }
    return res;
}

int main() {
    string s;
    cin >> s;
    size_t eq_pos = s.find('=');
    string left = preprocess(s.substr(0, eq_pos));
    LL target = stoll(s.substr(eq_pos + 1));

    stack<Node> nums;
    stack<char> ops;

    for (int i = 0; i < left.length(); ++i) {
        if (isdigit(left[i])) {
            LL val = 0;
            while (i < left.length() && isdigit(left[i])) {
                val = val * 10 + (left[i++] - '0');
            }
            i--;
            nums.push({0, val});
        } else if (left[i] == 'x') {
            nums.push({1, 0});
        } else if (left[i] == '(') {
            ops.push('(');
        } else if (left[i] == ')') {
            while (ops.top() != '(') compute(nums, ops);
            ops.pop();
        } else {
            while (!ops.empty() && priority(ops.top()) >= priority(left[i])) compute(nums, ops);
            ops.push(left[i]);
        }
    }
    while (!ops.empty()) compute(nums, ops);

    Node res = nums.top();
    cout << (target - res.b) / res.a << endl;

    return 0;
}
import java.util.*;

public class Main {
    static class Node {
        long a, b;
        Node(long a, long b) { this.a = a; this.b = b; }
    }

    static int priority(char op) {
        if (op == '+') return 1;
        if (op == '*') return 2;
        return 0;
    }

    static void compute(Stack<Node> nums, Stack<Character> ops) {
        Node n2 = nums.pop();
        Node n1 = nums.pop();
        char op = ops.pop();
        if (op == '+') nums.push(new Node(n1.a + n2.a, n1.b + n2.b));
        else {
            if (n1.a != 0) nums.push(new Node(n1.a * n2.b, n1.b * n2.b));
            else if (n2.a != 0) nums.push(new Node(n2.a * n1.b, n1.b * n2.b));
            else nums.push(new Node(0, n1.b * n2.b));
        }
    }

    static String preprocess(String s) {
        StringBuilder res = new StringBuilder();
        for (int i = 0; i < s.length(); i++) {
            char curr = s.charAt(i);
            res.append(curr);
            if (i + 1 < s.length()) {
                char next = s.charAt(i + 1);
                boolean c_d = Character.isDigit(curr), n_d = Character.isDigit(next);
                boolean c_x = curr == 'x', n_x = next == 'x';
                boolean c_open = curr == '(', n_open = next == '(';
                boolean c_close = curr == ')', n_close = next == ')';
                if ((c_d && (n_x || n_open)) || (c_x && (n_d || n_open)) || (c_close && (n_d || n_x || n_open))) {
                    res.append('*');
                }
            }
        }
        return res.toString();
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        String s = sc.next();
        String[] parts = s.split("=");
        String left = preprocess(parts[0]);
        long target = Long.parseLong(parts[1]);

        Stack<Node> nums = new Stack<>();
        Stack<Character> ops = new Stack<>();

        for (int i = 0; i < left.length(); i++) {
            char c = left.charAt(i);
            if (Character.isDigit(c)) {
                long val = 0;
                while (i < left.length() && Character.isDigit(left.charAt(i))) {
                    val = val * 10 + (left.charAt(i++) - '0');
                }
                i--;
                nums.push(new Node(0, val));
            } else if (c == 'x') {
                nums.push(new Node(1, 0));
            } else if (c == '(') {
                ops.push('(');
            } else if (c == ')') {
                while (ops.peek() != '(') compute(nums, ops);
                ops.pop();
            } else {
                while (!ops.isEmpty() && priority(ops.peek()) >= priority(c)) compute(nums, ops);
                ops.push(c);
            }
        }
        while (!ops.isEmpty()) compute(nums, ops);
        Node res = nums.pop();
        System.out.println((target - res.b) / res.a);
    }
}
def solve():
    s = input().strip()
    left_raw, right_raw = s.split('=')
    target = int(right_raw)

    # 预处理隐式乘法
    left = ""
    for i in range(len(left_raw)):
        curr = left_raw[i]
        left += curr
        if i + 1 < len(left_raw):
            nxt = left_raw[i+1]
            if (curr.isdigit() and (nxt == 'x' or nxt == '(')) or \
               (curr == 'x' and (nxt.isdigit() or nxt == '(')) or \
               (curr == ')' and (nxt.isdigit() or nxt == 'x' or nxt == '(')):
                left += '*'

    def priority(op):
        return 2 if op == '*' else 1 if op == '+' else 0

    def compute(nums, ops):
        a2, b2 = nums.pop()
        a1, b1 = nums.pop()
        op = ops.pop()
        if op == '+':
            nums.append((a1 + a2, b1 + b2))
        else:
            if a1 != 0:
                nums.append((a1 * b2, b1 * b2))
            elif a2 != 0:
                nums.append((a2 * b1, b1 * b2))
            else:
                nums.append((0, b1 * b2))

    nums = []
    ops = []
    i = 0
    while i < len(left):
        if left[i].isdigit():
            val = 0
            while i < len(left) and left[i].isdigit():
                val = val * 10 + int(left[i])
                i += 1
            nums.append((0, val))
            continue
        elif left[i] == 'x':
            nums.append((1, 0))
        elif left[i] == '(':
            ops.append('(')
        elif left[i] == ')':
            while ops[-1] != '(':
                compute(nums, ops)
            ops.pop()
        else:
            while ops and priority(ops[-1]) >= priority(left[i]):
                compute(nums, ops)
            ops.append(left[i])
        i += 1

    while ops:
        compute(nums, ops)
    
    a, b = nums[0]
    print((target - b) // a)

if __name__ == "__main__":
    solve()

算法及复杂度

  • 算法:表达式解析(中缀表达式求值)。通过预处理将隐式乘法显式化,再利用双栈算法化简方程。
  • 时间复杂度:。其中 为字符串长度,每个字符仅被扫描和入栈出栈有限次数。
  • 空间复杂度:。用于存储预处理后的字符串及操作数、操作符栈。