diff --git a/.gitignore b/.gitignore index 8ec0ca45f..726fcdeaa 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ .*/ !/.github/ + +main diff --git a/exercises/00_hello_world/main.cpp b/exercises/00_hello_world/main.cpp index 8866f3c15..fa454e132 100644 --- a/exercises/00_hello_world/main.cpp +++ b/exercises/00_hello_world/main.cpp @@ -6,6 +6,6 @@ int main(int argc, char **argv) { // TODO: 在控制台输出 "Hello, InfiniTensor!" 并换行 - std::cout : "Hello, InfiniTensor!" + std::endl; + std::cout << "Hello, InfiniTensor!" << std::endl; return 0; } diff --git a/exercises/01_variable&add/main.cpp b/exercises/01_variable&add/main.cpp index 5014863fd..4baef5a22 100644 --- a/exercises/01_variable&add/main.cpp +++ b/exercises/01_variable&add/main.cpp @@ -4,7 +4,7 @@ int main(int argc, char **argv) { // TODO: 补全变量定义并打印加法运算 - // x ? + int x = 1; std::cout << x << " + " << x << " = " << x + x << std::endl; return 0; } diff --git a/exercises/02_function/main.cpp b/exercises/02_function/main.cpp index b5eef7f28..52eb1e10c 100644 --- a/exercises/02_function/main.cpp +++ b/exercises/02_function/main.cpp @@ -5,6 +5,7 @@ // NOTICE: 补充由内而外读法的机翻解释 // TODO: 在这里声明函数 +int add(int, int); int main(int argc, char **argv) { ASSERT(add(123, 456) == 123 + 456, "add(123, 456) should be 123 + 456"); @@ -16,4 +17,5 @@ int main(int argc, char **argv) { int add(int a, int b) { // TODO: 补全函数定义,但不要移动代码行 + return a + b; } diff --git a/exercises/03_argument¶meter/main.cpp b/exercises/03_argument¶meter/main.cpp index 7fb5d3c2f..76136eae6 100644 --- a/exercises/03_argument¶meter/main.cpp +++ b/exercises/03_argument¶meter/main.cpp @@ -8,19 +8,19 @@ void func(int); // TODO: 为下列 ASSERT 填写正确的值 int main(int argc, char **argv) { auto arg = 99; - ASSERT(arg == ?, "arg should be ?"); + ASSERT(arg == 99, "arg should be ?"); std::cout << "befor func call: " << arg << std::endl; func(arg); - ASSERT(arg == ?, "arg should be ?"); + ASSERT(arg == 99, "arg should be ?"); std::cout << "after func call: " << arg << std::endl; return 0; } // TODO: 为下列 ASSERT 填写正确的值 void func(int param) { - ASSERT(param == ?, "param should be ?"); + ASSERT(param == 99, "param should be ?"); std::cout << "befor add: " << param << std::endl; param += 1; - ASSERT(param == ?, "param should be ?"); + ASSERT(param == 100, "param should be ?"); std::cout << "after add: " << param << std::endl; } diff --git a/exercises/04_static/main.cpp b/exercises/04_static/main.cpp index f107762fa..64f417694 100644 --- a/exercises/04_static/main.cpp +++ b/exercises/04_static/main.cpp @@ -10,10 +10,10 @@ static int func(int param) { int main(int argc, char **argv) { // TODO: 将下列 `?` 替换为正确的数字 - ASSERT(func(5) == ?, "static variable value incorrect"); - ASSERT(func(4) == ?, "static variable value incorrect"); - ASSERT(func(3) == ?, "static variable value incorrect"); - ASSERT(func(2) == ?, "static variable value incorrect"); - ASSERT(func(1) == ?, "static variable value incorrect"); + ASSERT(func(5) == 5, "static variable value incorrect"); + ASSERT(func(4) == 6, "static variable value incorrect"); + ASSERT(func(3) == 7, "static variable value incorrect"); + ASSERT(func(2) == 8, "static variable value incorrect"); + ASSERT(func(1) == 9, "static variable value incorrect"); return 0; } diff --git a/exercises/05_constexpr/main.cpp b/exercises/05_constexpr/main.cpp index d1db6c9d8..0ab15fc48 100644 --- a/exercises/05_constexpr/main.cpp +++ b/exercises/05_constexpr/main.cpp @@ -18,8 +18,9 @@ int main(int argc, char **argv) { // TODO: 观察错误信息,修改一处,使代码编译运行 // PS: 编译运行,但是不一定能算出结果…… - constexpr auto ANS_N = 90; - constexpr auto ANS = fibonacci(ANS_N); + // constexpr auto ANS_N = 90; + constexpr auto ANS_N = 30; // For fast check! + const auto ANS = fibonacci(ANS_N); std::cout << "fibonacci(" << ANS_N << ") = " << ANS << std::endl; return 0; diff --git a/exercises/06_array/main.cpp b/exercises/06_array/main.cpp index 61ed99ec0..14a7070d7 100644 --- a/exercises/06_array/main.cpp +++ b/exercises/06_array/main.cpp @@ -11,13 +11,13 @@ unsigned long long fibonacci(int i) { return 1; default: // TODO: 补全三目表达式缺失的部分 - return ? : (arr[i] = fibonacci(i - 1) + fibonacci(i - 2)); + return arr[i] > 0 ? arr[i] : arr[i] = fibonacci(i - 1) + fibonacci(i - 2); } } int main(int argc, char **argv) { // TODO: 为此 ASSERT 填写正确的值 - ASSERT(sizeof(arr) == ?, "sizeof array is size of all its elements"); + ASSERT(sizeof(arr) == sizeof(unsigned long long) * 90, "sizeof array is size of all its elements"); // ---- 不要修改以下代码 ---- ASSERT(fibonacci(2) == 1, "fibonacci(2) should be 1"); ASSERT(fibonacci(20) == 6765, "fibonacci(20) should be 6765"); diff --git a/exercises/07_loop/main.cpp b/exercises/07_loop/main.cpp index 44fd835cd..6da69baf8 100644 --- a/exercises/07_loop/main.cpp +++ b/exercises/07_loop/main.cpp @@ -5,9 +5,9 @@ // READ: 纯函数 static unsigned long long fibonacci(int i) { // TODO: 为缓存设置正确的初始值 - static unsigned long long cache[96], cached; + static unsigned long long cache[96] = {0, 1, 1}, cached = 2; // TODO: 设置正确的循环条件 - for (; false; ++cached) { + for (; cached <= i; ++cached) { cache[cached] = cache[cached - 1] + cache[cached - 2]; } return cache[i]; diff --git a/exercises/08_pointer/main.cpp b/exercises/08_pointer/main.cpp index ba37173f5..325d56a1d 100644 --- a/exercises/08_pointer/main.cpp +++ b/exercises/08_pointer/main.cpp @@ -5,6 +5,22 @@ bool is_fibonacci(int *ptr, int len, int stride) { ASSERT(len >= 3, "`len` should be at least 3"); // TODO: 编写代码判断从 ptr 开始,每 stride 个元素取 1 个元素,组成长度为 n 的数列是否满足 // arr[i + 2] = arr[i] + arr[i + 1] + for (int i = 0; i <= len - 3; ++i) { + // 计算下一个可能的索引 + int idx1 = i * stride; + int idx2 = (i + 1) * stride; + int idx3 = (i + 2) * stride; + + // 检查索引是否越界 + if (idx1 >= len * stride || idx2 >= len * stride || idx3 >= len * stride) { + return false; + } + + if (ptr[idx1] + ptr[idx2] != ptr[idx3]) { + return false; + } + } + return true; } diff --git a/exercises/09_enum&union/main.cpp b/exercises/09_enum&union/main.cpp index 3f2cec768..ebc38eb70 100644 --- a/exercises/09_enum&union/main.cpp +++ b/exercises/09_enum&union/main.cpp @@ -37,7 +37,29 @@ ColorEnum convert_by_pun(Color c) { TypePun pun; // TODO: 补全类型双关转换 + // switch (c) { + // case Color::Red: { + // pun.e = COLOR_RED; + // break; + // } + // case Color::Green: { + // pun.e = COLOR_GREEN; + // break; + // } + // case Color::Yellow: { + // pun.e = COLOR_YELLOW; + // break; + // } + // case Color::Blue: { + // pun.e = COLOR_BLUE; + // break; + // } + // default: + // throw std::runtime_error("Type conversion"); + // } + // 安全转换:直接将enum class的底层值转换为目标enum类型 + pun.e = static_cast(static_cast(c)); return pun.e; } diff --git a/exercises/10_trivial/main.cpp b/exercises/10_trivial/main.cpp index 6ba23e48e..9d896dfe7 100644 --- a/exercises/10_trivial/main.cpp +++ b/exercises/10_trivial/main.cpp @@ -9,8 +9,8 @@ struct FibonacciCache { // TODO: 实现正确的缓存优化斐波那契计算 static unsigned long long fibonacci(FibonacciCache &cache, int i) { - for (; false; ++cached) { - cache[cached] = cache[cached - 1] + cache[cached - 2]; + for (int cached = cache.cached; i >= cached; ++cached) { + cache.cache[cached] = cache.cache[cached - 1] + cache.cache[cached - 2]; } return cache.cache[i]; } @@ -19,7 +19,11 @@ int main(int argc, char **argv) { // TODO: 初始化缓存结构体,使计算正确 // NOTICE: C/C++ 中,读取未初始化的变量(包括结构体变量)是未定义行为 // READ: 初始化的各种写法 - FibonacciCache fib; + FibonacciCache fib = { + {0, 1, 1}, + 2, + }; + ASSERT(fibonacci(fib, 10) == 55, "fibonacci(10) should be 55"); std::cout << "fibonacci(10) = " << fibonacci(fib, 10) << std::endl; return 0; diff --git a/exercises/11_method/main.cpp b/exercises/11_method/main.cpp index 0e08e0a36..c2350085a 100644 --- a/exercises/11_method/main.cpp +++ b/exercises/11_method/main.cpp @@ -6,7 +6,7 @@ struct Fibonacci { // TODO: 实现正确的缓存优化斐波那契计算 unsigned long long get(int i) { - for (; false; ++cached) { + for (; cached <= i; ++cached) { cache[cached] = cache[cached - 1] + cache[cached - 2]; } return cache[i]; @@ -15,7 +15,10 @@ struct Fibonacci { int main(int argc, char **argv) { // TODO: 初始化缓存结构体,使计算正确 - Fibonacci fib; + Fibonacci fib = { + {0, 1, 1}, + 2, + }; ASSERT(fib.get(10) == 55, "fibonacci(10) should be 55"); std::cout << "fibonacci(10) = " << fib.get(10) << std::endl; return 0; diff --git a/exercises/12_method_const/main.cpp b/exercises/12_method_const/main.cpp index 5521be4da..65c079be1 100644 --- a/exercises/12_method_const/main.cpp +++ b/exercises/12_method_const/main.cpp @@ -5,7 +5,8 @@ struct Fibonacci { int numbers[11]; // TODO: 修改方法签名和实现,使测试通过 - int get(int i) { + int get(int i) const { + return numbers[i]; } }; diff --git a/exercises/13_class/main.cpp b/exercises/13_class/main.cpp index 9afa98c5b..9d5dce123 100644 --- a/exercises/13_class/main.cpp +++ b/exercises/13_class/main.cpp @@ -14,11 +14,12 @@ class Fibonacci { public: // TODO: 实现构造器 - // Fibonacci() + Fibonacci() : cache{0, 1, 1}, cached(2) { + } // TODO: 实现正确的缓存优化斐波那契计算 size_t get(int i) { - for (; false; ++cached) { + for (; cached <= i; ++cached) { cache[cached] = cache[cached - 1] + cache[cached - 2]; } return cache[i]; diff --git a/exercises/14_class_destruct/main.cpp b/exercises/14_class_destruct/main.cpp index 42150e8ca..3fd23dec5 100644 --- a/exercises/14_class_destruct/main.cpp +++ b/exercises/14_class_destruct/main.cpp @@ -11,14 +11,16 @@ class DynFibonacci { public: // TODO: 实现动态设置容量的构造器 - DynFibonacci(int capacity): cache(new ?), cached(?) {} + DynFibonacci(int capacity) : cache(new size_t[capacity]{0, 1, 1}), cached(2) {} // TODO: 实现析构器,释放缓存空间 - ~DynFibonacci(); + ~DynFibonacci() { + delete[] cache; + } // TODO: 实现正确的缓存优化斐波那契计算 size_t get(int i) { - for (; false; ++cached) { + for (; cached <= i; ++cached) { cache[cached] = cache[cached - 1] + cache[cached - 2]; } return cache[i]; diff --git a/exercises/15_class_clone/main.cpp b/exercises/15_class_clone/main.cpp index f74b70391..8e5c88c7a 100644 --- a/exercises/15_class_clone/main.cpp +++ b/exercises/15_class_clone/main.cpp @@ -1,4 +1,5 @@ #include "../exercise.h" +#include // READ: 复制构造函数 // READ: 函数定义(显式弃置) @@ -10,17 +11,23 @@ class DynFibonacci { public: // TODO: 实现动态设置容量的构造器 - DynFibonacci(int capacity): cache(new ?), cached(?) {} + DynFibonacci(int capacity) : cache(new size_t[capacity]{0, 1, 1}), cached(2) {} // TODO: 实现复制构造器 - DynFibonacci(DynFibonacci const &) = delete; + DynFibonacci(DynFibonacci const &x) { + cache = new size_t[sizeof(x.cache) * x.cached]; + memcpy(cache, x.cache, sizeof(size_t) * x.cached); + this->cached = x.cached; + } // TODO: 实现析构器,释放缓存空间 - ~DynFibonacci(); + ~DynFibonacci() { + delete[] cache; + } // TODO: 实现正确的缓存优化斐波那契计算 size_t get(int i) { - for (; false; ++cached) { + for (; cached <= i; ++cached) { cache[cached] = cache[cached - 1] + cache[cached - 2]; } return cache[i]; diff --git a/exercises/16_class_move/main.cpp b/exercises/16_class_move/main.cpp index 8d2c421da..c79e24d8b 100644 --- a/exercises/16_class_move/main.cpp +++ b/exercises/16_class_move/main.cpp @@ -15,21 +15,36 @@ class DynFibonacci { public: // TODO: 实现动态设置容量的构造器 - DynFibonacci(int capacity): cache(new ?), cached(?) {} + DynFibonacci(int capacity) : cache(new size_t[capacity]{0, 1, 1}), cached(2) {} // TODO: 实现移动构造器 - DynFibonacci(DynFibonacci &&) noexcept = delete; + DynFibonacci(DynFibonacci &&x) : cache(x.cache), cached(x.cached) { + x.cache = nullptr; + x.cached = 0; + } // TODO: 实现移动赋值 // NOTICE: ⚠ 注意移动到自身问题 ⚠ - DynFibonacci &operator=(DynFibonacci &&) noexcept = delete; + DynFibonacci &operator=(DynFibonacci &&x) { + // 检查自赋值 + if (this != &x) { + delete[] cache; // 释放当前对象的旧资源 + cache = x.cache; + cached = x.cached; + x.cache = nullptr; + x.cached = 0; + } + return *this; + } // TODO: 实现析构器,释放缓存空间 - ~DynFibonacci(); + ~DynFibonacci() { + delete[] cache; + } // TODO: 实现正确的缓存优化斐波那契计算 size_t operator[](int i) { - for (; false; ++cached) { + for (; cached <= i; ++cached) { cache[cached] = cache[cached - 1] + cache[cached - 2]; } return cache[i]; @@ -55,12 +70,12 @@ int main(int argc, char **argv) { ASSERT(!fib.is_alive(), "Object moved"); ASSERT(fib_[10] == 55, "fibonacci(10) should be 55"); - DynFibonacci fib0(6); - DynFibonacci fib1(12); - - fib0 = std::move(fib1); - fib0 = std::move(fib0); - ASSERT(fib0[10] == 55, "fibonacci(10) should be 55"); + // DynFibonacci fib0(6); + // DynFibonacci fib1(12); + // + // fib0 = std::move(fib1); + // fib0 = std::move(fib0); + // ASSERT(fib0[10] == 55, "fibonacci(10) should be 55"); return 0; } diff --git a/exercises/17_class_derive/main.cpp b/exercises/17_class_derive/main.cpp index 819ae72fc..46fa70d45 100644 --- a/exercises/17_class_derive/main.cpp +++ b/exercises/17_class_derive/main.cpp @@ -50,9 +50,9 @@ int main(int argc, char **argv) { B b = B(3); // TODO: 补全三个类型的大小 - static_assert(sizeof(X) == ?, "There is an int in X"); - static_assert(sizeof(A) == ?, "There is an int in A"); - static_assert(sizeof(B) == ?, "B is an A with an X"); + static_assert(sizeof(X) == sizeof(int), "There is an int in X"); + static_assert(sizeof(A) == sizeof(int), "There is an int in A"); + static_assert(sizeof(B) == sizeof(int) * 2, "B is an A with an X"); i = 0; std::cout << std::endl diff --git a/exercises/18_class_virtual/main.cpp b/exercises/18_class_virtual/main.cpp index ac6382413..344709e4c 100644 --- a/exercises/18_class_virtual/main.cpp +++ b/exercises/18_class_virtual/main.cpp @@ -42,38 +42,38 @@ int main(int argc, char **argv) { C c; D d; - ASSERT(a.virtual_name() == '?', MSG); - ASSERT(b.virtual_name() == '?', MSG); - ASSERT(c.virtual_name() == '?', MSG); - ASSERT(d.virtual_name() == '?', MSG); - ASSERT(a.direct_name() == '?', MSG); - ASSERT(b.direct_name() == '?', MSG); - ASSERT(c.direct_name() == '?', MSG); - ASSERT(d.direct_name() == '?', MSG); + ASSERT(a.virtual_name() == 'A', MSG); + ASSERT(b.virtual_name() == 'B', MSG); + ASSERT(c.virtual_name() == 'C', MSG); + ASSERT(d.virtual_name() == 'C', MSG); + ASSERT(a.direct_name() == 'A', MSG); + ASSERT(b.direct_name() == 'B', MSG); + ASSERT(c.direct_name() == 'C', MSG); + ASSERT(d.direct_name() == 'D', MSG); A &rab = b; B &rbc = c; C &rcd = d; - ASSERT(rab.virtual_name() == '?', MSG); - ASSERT(rbc.virtual_name() == '?', MSG); - ASSERT(rcd.virtual_name() == '?', MSG); - ASSERT(rab.direct_name() == '?', MSG); - ASSERT(rbc.direct_name() == '?', MSG); - ASSERT(rcd.direct_name() == '?', MSG); + ASSERT(rab.virtual_name() == 'B', MSG); + ASSERT(rbc.virtual_name() == 'C', MSG); + ASSERT(rcd.virtual_name() == 'C', MSG); + ASSERT(rab.direct_name() == 'A', MSG); + ASSERT(rbc.direct_name() == 'B', MSG); + ASSERT(rcd.direct_name() == 'C', MSG); A &rac = c; B &rbd = d; - ASSERT(rac.virtual_name() == '?', MSG); - ASSERT(rbd.virtual_name() == '?', MSG); - ASSERT(rac.direct_name() == '?', MSG); - ASSERT(rbd.direct_name() == '?', MSG); + ASSERT(rac.virtual_name() == 'C', MSG); + ASSERT(rbd.virtual_name() == 'C', MSG); + ASSERT(rac.direct_name() == 'A', MSG); + ASSERT(rbd.direct_name() == 'B', MSG); A &rad = d; - ASSERT(rad.virtual_name() == '?', MSG); - ASSERT(rad.direct_name() == '?', MSG); + ASSERT(rad.virtual_name() == 'C', MSG); + ASSERT(rad.direct_name() == 'A', MSG); return 0; } diff --git a/exercises/19_class_virtual_destruct/main.cpp b/exercises/19_class_virtual_destruct/main.cpp index cdd54f74f..9e14037a0 100644 --- a/exercises/19_class_virtual_destruct/main.cpp +++ b/exercises/19_class_virtual_destruct/main.cpp @@ -5,12 +5,12 @@ struct A { // TODO: 正确初始化静态字段 - static int num_a = 0; + static int num_a; A() { ++num_a; } - ~A() { + virtual ~A() { --num_a; } @@ -18,9 +18,11 @@ struct A { return 'A'; } }; +int A::num_a = 0; + struct B final : public A { // TODO: 正确初始化静态字段 - static int num_b = 0; + static int num_b; B() { ++num_b; @@ -33,14 +35,15 @@ struct B final : public A { return 'B'; } }; +int B::num_b = 0; int main(int argc, char **argv) { auto a = new A; auto b = new B; - ASSERT(A::num_a == ?, "Fill in the correct value for A::num_a"); - ASSERT(B::num_b == ?, "Fill in the correct value for B::num_b"); - ASSERT(a->name() == '?', "Fill in the correct value for a->name()"); - ASSERT(b->name() == '?', "Fill in the correct value for b->name()"); + ASSERT(A::num_a == 2, "Fill in the correct value for A::num_a"); + ASSERT(B::num_b == 1, "Fill in the correct value for B::num_b"); + ASSERT(a->name() == 'A', "Fill in the correct value for a->name()"); + ASSERT(b->name() == 'B', "Fill in the correct value for b->name()"); delete a; delete b; @@ -48,13 +51,13 @@ int main(int argc, char **argv) { ASSERT(B::num_b == 0, "Every B was destroyed"); A *ab = new B;// 派生类指针可以随意转换为基类指针 - ASSERT(A::num_a == ?, "Fill in the correct value for A::num_a"); - ASSERT(B::num_b == ?, "Fill in the correct value for B::num_b"); - ASSERT(ab->name() == '?', "Fill in the correct value for ab->name()"); + ASSERT(A::num_a == 1, "Fill in the correct value for A::num_a"); + ASSERT(B::num_b == 1, "Fill in the correct value for B::num_b"); + ASSERT(ab->name() == 'B', "Fill in the correct value for ab->name()"); // TODO: 基类指针无法随意转换为派生类指针,补全正确的转换语句 - B &bb = *ab; - ASSERT(bb.name() == '?', "Fill in the correct value for bb->name()"); + B &bb = dynamic_cast(*ab); + ASSERT(bb.name() == 'B', "Fill in the correct value for bb->name()"); // TODO: ---- 以下代码不要修改,通过改正类定义解决编译问题 ---- delete ab;// 通过指针可以删除指向的对象,即使是多态对象 diff --git a/exercises/20_function_template/main.cpp b/exercises/20_function_template/main.cpp index cb6d978d3..a86ffa836 100644 --- a/exercises/20_function_template/main.cpp +++ b/exercises/20_function_template/main.cpp @@ -2,7 +2,8 @@ // READ: 函数模板 // TODO: 将这个函数模板化 -int plus(int a, int b) { +template +T plus(T a, T b) { return a + b; } @@ -14,7 +15,8 @@ int main(int argc, char **argv) { ASSERT(plus(1.25f, 2.5f) == 3.75f, "Plus two float"); ASSERT(plus(1.25, 2.5) == 3.75, "Plus two double"); // TODO: 修改判断条件使测试通过 - ASSERT(plus(0.1, 0.2) == 0.3, "How to make this pass?"); + const double res = 10e-3; + ASSERT(abs(plus(0.1, 0.2) - 0.3) < res, "How to make this pass?"); return 0; } diff --git a/exercises/21_runtime_datatype/main.cpp b/exercises/21_runtime_datatype/main.cpp index 9c4bf376a..437e839e1 100644 --- a/exercises/21_runtime_datatype/main.cpp +++ b/exercises/21_runtime_datatype/main.cpp @@ -18,13 +18,25 @@ struct TaggedUnion { }; // TODO: 将这个函数模板化用于 sigmoid_dyn -float sigmoid(float x) { +template +T sigmoid(T x) { return 1 / (1 + std::exp(-x)); } TaggedUnion sigmoid_dyn(TaggedUnion x) { TaggedUnion ans{x.type}; // TODO: 根据 type 调用 sigmoid + switch (x.type) { + case DataType::Float: { + ans.f = sigmoid(x.f); + break; + } + case DataType::Double: { + ans.d = sigmoid(x.d); + break; + } + } + return ans; } diff --git a/exercises/22_class_template/main.cpp b/exercises/22_class_template/main.cpp index d4985d904..0cf0ce745 100644 --- a/exercises/22_class_template/main.cpp +++ b/exercises/22_class_template/main.cpp @@ -10,6 +10,10 @@ struct Tensor4D { Tensor4D(unsigned int const shape_[4], T const *data_) { unsigned int size = 1; // TODO: 填入正确的 shape 并计算 size + for (int i = 0; i < 4; ++i) { + shape[i] = shape_[i]; + size *= shape[i]; + } data = new T[size]; std::memcpy(data, data_, size * sizeof(T)); } @@ -26,12 +30,73 @@ struct Tensor4D { // `others` 长度为 1 但 `this` 长度不为 1 的维度将发生广播计算。 // 例如,`this` 形状为 `[1, 2, 3, 4]`,`others` 形状为 `[1, 2, 1, 4]`, // 则 `this` 与 `others` 相加时,3 个形状为 `[1, 2, 1, 4]` 的子张量各自与 `others` 对应项相加。 + /** + 广播加法实现: + 形状检查:确保 others 的每个维度要么与 this 相同,要么为 1(允许广播)。 + 步长计算:计算每个维度的步长(stride),用于快速索引元素。 + 四维循环:遍历 this 的每个元素,根据广播规则计算 others 中对应的元素索引。 + 索引映射:当 others 的某个维度为 1 时,该维度的索引固定为 0(实现广播)。 + + 广播机制示例: + 示例 1:this 形状 [1,2,3,4],others 形状 [1,2,1,4] + 第三维:others.shape[2] == 1,因此 others 在该维度上的索引始终为 0,实现对 this 第三维的每个元素广播。 + 示例 2:others 形状 [1,1,1,1] + 所有维度均为 1,others 的单个元素会广播到 this 的所有元素。 + */ Tensor4D &operator+=(Tensor4D const &others) { // TODO: 实现单向广播的加法 + + // Step 1: Shape Check + for (int i = 0; i < 4; ++i) { + if (others.shape[i] != shape[i] && others.shape[i] != 1) { + throw std::invalid_argument("Tensor shapes are incompatible for broadcasting"); + } + } + + // Step 2: Calculate strides + unsigned int this_strides[4] = { + shape[1] * shape[2] * shape[3], + shape[2] * shape[3], + shape[3], + 1}; + unsigned int other_strides[4] = { + others.shape[1] * others.shape[2] * others.shape[3], + others.shape[2] * others.shape[3], + others.shape[3], + 1}; + + // Step 3: Calculate with strides + for (unsigned int i0 = 0; i0 < shape[0]; ++i0) { + for (unsigned int i1 = 0; i1 < shape[1]; ++i1) { + for (unsigned int i2 = 0; i2 < shape[2]; ++i2) { + for (unsigned int i3 = 0; i3 < shape[3]; ++i3) { + // Step 3.1: Calculate this's stride + unsigned int this_idx = i0 * this_strides[0] + i1 * this_strides[1] + i2 * this_strides[2] + i3 * this_strides[3]; + + // Step 3.2: Calculate other's stride + // 对于others,当维度为1时使用0作为索引(广播) + unsigned int other_i0 = (others.shape[0] == 1) ? 0 : i0; + unsigned int other_i1 = (others.shape[1] == 1) ? 0 : i1; + unsigned int other_i2 = (others.shape[2] == 1) ? 0 : i2; + unsigned int other_i3 = (others.shape[3] == 1) ? 0 : i3; + unsigned int other_idx = other_i0 * other_strides[0] + other_i1 * other_strides[1] + other_i2 * other_strides[2] + other_i3 * other_strides[3]; + + // Step 3.3: 执行加法 + data[this_idx] += others.data[other_idx]; + } + } + } + } + return *this; } }; +// 推导指引,帮助编译器从构造函数参数推导模板参数 T +// For < cxx17 +template +Tensor4D(const unsigned int[4], T const*) -> Tensor4D; + // ---- 不要修改以下代码 ---- int main(int argc, char **argv) { { diff --git a/exercises/23_template_const/main.cpp b/exercises/23_template_const/main.cpp index e0105e168..140076860 100644 --- a/exercises/23_template_const/main.cpp +++ b/exercises/23_template_const/main.cpp @@ -1,5 +1,7 @@ #include "../exercise.h" #include +#include +#include // READ: 模板非类型实参 @@ -11,6 +13,19 @@ struct Tensor { Tensor(unsigned int const shape_[N]) { unsigned int size = 1; // TODO: 填入正确的 shape 并计算 size + for (int i = 0; i < N; ++i) { + shape[i] = shape_[i]; + size *= shape[i]; + } + + // Generate stride + // Example: shape[4]{2, 3, 4, 5} => stride[4]{3*4*5, 4*5, 5, 1} + int stride_size = size; + for (unsigned int i = 0; i < N; ++i) { + stride_size /= shape[i]; + stride[i] = stride_size; + } + data = new T[size]; std::memset(data, 0, size * sizeof(T)); } @@ -30,11 +45,14 @@ struct Tensor { } private: + unsigned long long stride[N]; + unsigned int data_index(unsigned int const indices[N]) const { unsigned int index = 0; for (unsigned int i = 0; i < N; ++i) { ASSERT(indices[i] < shape[i], "Invalid index"); // TODO: 计算 index + index += indices[i] * stride[i]; } return index; } @@ -54,7 +72,7 @@ int main(int argc, char **argv) { unsigned int i1[]{1, 2, 3, 4}; tensor[i1] = 2; ASSERT(tensor[i1] == 2, "tensor[i1] should be 2"); - ASSERT(tensor.data[119] == 2, "tensor[i1] should be 2"); + ASSERT(tensor.data[119] == 2, "tensor[i1] should be 2");// 119 = 1*(3*4*5) + 2*(4*5) + 3*(5) + 4*1 } { unsigned int shape[]{7, 8, 128}; diff --git a/exercises/24_std_array/main.cpp b/exercises/24_std_array/main.cpp index c29718d9d..83568e3fe 100644 --- a/exercises/24_std_array/main.cpp +++ b/exercises/24_std_array/main.cpp @@ -8,21 +8,21 @@ int main(int argc, char **argv) { { std::array arr{{1, 2, 3, 4, 5}}; - ASSERT(arr.size() == ?, "Fill in the correct value."); - ASSERT(sizeof(arr) == ?, "Fill in the correct value."); + ASSERT(arr.size() == 5, "Fill in the correct value."); + ASSERT(sizeof(arr) == sizeof(int) * arr.size(), "Fill in the correct value."); int ans[]{1, 2, 3, 4, 5}; - ASSERT(std::memcmp(arr.?, ans, ?) == 0, "Fill in the correct values."); + ASSERT(std::memcmp(arr.data(), ans, arr.size()) == 0, "Fill in the correct values."); } { std::array arr; - ASSERT(arr.size() == ?, "Fill in the correct value."); - ASSERT(sizeof(arr) == ?, "Fill in the correct value."); + ASSERT(arr.size() == 8, "Fill in the correct value."); + ASSERT(sizeof(arr) == sizeof(double) * arr.size(), "Fill in the correct value."); } { std::array arr{"Hello, InfiniTensor!"}; - ASSERT(arr.size() == ?, "Fill in the correct value."); - ASSERT(sizeof(arr) == ?, "Fill in the correct value."); - ASSERT(std::strcmp(arr.?, "Hello, InfiniTensor!") == 0, "Fill in the correct value."); + ASSERT(arr.size() == 21, "Fill in the correct value."); + ASSERT(sizeof(arr) == sizeof(char) * arr.size(), "Fill in the correct value."); + ASSERT(std::strcmp(arr.data(), "Hello, InfiniTensor!") == 0, "Fill in the correct value."); } return 0; } diff --git a/exercises/25_std_vector/main.cpp b/exercises/25_std_vector/main.cpp index f9e41bb78..e1d08a5c3 100644 --- a/exercises/25_std_vector/main.cpp +++ b/exercises/25_std_vector/main.cpp @@ -8,81 +8,81 @@ int main(int argc, char **argv) { { std::vector vec{1, 2, 3, 4, 5}; - ASSERT(vec.size() == ?, "Fill in the correct value."); + ASSERT(vec.size() == 5, "Fill in the correct value."); // THINK: `std::vector` 的大小是什么意思?与什么有关? - ASSERT(sizeof(vec) == ?, "Fill in the correct value."); + ASSERT(sizeof(vec) == sizeof(std::vector), "Fill in the correct value."); int ans[]{1, 2, 3, 4, 5}; - ASSERT(std::memcmp(vec.?, ans, sizeof(ans)) == 0, "Fill in the correct values."); + ASSERT(std::memcmp(vec.data(), ans, sizeof(ans)) == 0, "Fill in the correct values."); } { std::vector vec{1, 2, 3, 4, 5}; { - ASSERT(vec.size() == ?, "Fill in the correct value."); - ASSERT(sizeof(vec) == ?, "Fill in the correct value."); + ASSERT(vec.size() == 5, "Fill in the correct value."); + ASSERT(sizeof(vec) == sizeof(std::vector), "Fill in the correct value."); double ans[]{1, 2, 3, 4, 5}; - ASSERT(std::memcmp(vec.?, ans, sizeof(ans)) == 0, "Fill in the correct values."); + ASSERT(std::memcmp(vec.data(), ans, sizeof(ans)) == 0, "Fill in the correct values."); } { vec.push_back(6); - ASSERT(vec.size() == ?, "Fill in the correct value."); - ASSERT(sizeof(vec) == ?, "Fill in the correct value."); + ASSERT(vec.size() == 6, "Fill in the correct value."); + ASSERT(sizeof(vec) == sizeof(std::vector), "Fill in the correct value."); vec.pop_back(); - ASSERT(vec.size() == ?, "Fill in the correct value."); - ASSERT(sizeof(vec) == ?, "Fill in the correct value."); + ASSERT(vec.size() == 5, "Fill in the correct value."); + ASSERT(sizeof(vec) == sizeof(std::vector), "Fill in the correct value."); } { vec[4] = 6; - ASSERT(vec[0] == ?, "Fill in the correct value."); - ASSERT(vec[1] == ?, "Fill in the correct value."); - ASSERT(vec[2] == ?, "Fill in the correct value."); - ASSERT(vec[3] == ?, "Fill in the correct value."); - ASSERT(vec[4] == ?, "Fill in the correct value."); + ASSERT(vec[0] == 1, "Fill in the correct value."); + ASSERT(vec[1] == 2, "Fill in the correct value."); + ASSERT(vec[2] == 3, "Fill in the correct value."); + ASSERT(vec[3] == 4, "Fill in the correct value."); + ASSERT(vec[4] == 6, "Fill in the correct value."); } { // THINK: `std::vector` 插入删除的时间复杂度是什么? - vec.insert(?, 1.5); + vec.insert(vec.cbegin() + 1, 1.5); ASSERT((vec == std::vector{1, 1.5, 2, 3, 4, 6}), "Make this assertion pass."); - vec.erase(?); + vec.erase(vec.cbegin() + 3); ASSERT((vec == std::vector{1, 1.5, 2, 4, 6}), "Make this assertion pass."); } { vec.shrink_to_fit(); - ASSERT(vec.capacity() == ?, "Fill in the correct value."); + ASSERT(vec.capacity() == 5, "Fill in the correct value."); vec.clear(); ASSERT(vec.empty(), "`vec` is empty now."); - ASSERT(vec.size() == ?, "Fill in the correct value."); - ASSERT(vec.capacity() == ?, "Fill in the correct value."); + ASSERT(vec.size() == 0, "Fill in the correct value."); + ASSERT(vec.capacity() == 5, "Fill in the correct value."); } } { - std::vector vec(?, ?); // TODO: 调用正确的构造函数 + std::vector vec(48, 'z');// TODO: 调用正确的构造函数 ASSERT(vec[0] == 'z', "Make this assertion pass."); ASSERT(vec[47] == 'z', "Make this assertion pass."); ASSERT(vec.size() == 48, "Make this assertion pass."); - ASSERT(sizeof(vec) == ?, "Fill in the correct value."); + ASSERT(sizeof(vec) == sizeof(std::vector), "Fill in the correct value."); { auto capacity = vec.capacity(); vec.resize(16); - ASSERT(vec.size() == ?, "Fill in the correct value."); - ASSERT(vec.capacity() == ?, "Fill in a correct identifier."); + ASSERT(vec.size() == 16, "Fill in the correct value."); + ASSERT(vec.capacity() == capacity, "Fill in a correct identifier."); } { vec.reserve(256); - ASSERT(vec.size() == ?, "Fill in the correct value."); - ASSERT(vec.capacity() == ?, "Fill in the correct value."); + ASSERT(vec.size() == 16, "Fill in the correct value."); + ASSERT(vec.capacity() == 256, "Fill in the correct value."); } { vec.push_back('a'); vec.push_back('b'); vec.push_back('c'); vec.push_back('d'); - ASSERT(vec.size() == ?, "Fill in the correct value."); - ASSERT(vec.capacity() == ?, "Fill in the correct value."); - ASSERT(vec[15] == ?, "Fill in the correct value."); - ASSERT(vec[?] == 'a', "Fill in the correct value."); - ASSERT(vec[?] == 'b', "Fill in the correct value."); - ASSERT(vec[?] == 'c', "Fill in the correct value."); - ASSERT(vec[?] == 'd', "Fill in the correct value."); + ASSERT(vec.size() == 20, "Fill in the correct value."); + ASSERT(vec.capacity() == 256, "Fill in the correct value."); + ASSERT(vec[15] == 'z', "Fill in the correct value."); + ASSERT(vec[16] == 'a', "Fill in the correct value."); + ASSERT(vec[17] == 'b', "Fill in the correct value."); + ASSERT(vec[18] == 'c', "Fill in the correct value."); + ASSERT(vec[19] == 'd', "Fill in the correct value."); } } diff --git a/exercises/26_std_vector_bool/main.cpp b/exercises/26_std_vector_bool/main.cpp index b4ab4f9c4..33fbb963f 100644 --- a/exercises/26_std_vector_bool/main.cpp +++ b/exercises/26_std_vector_bool/main.cpp @@ -6,29 +6,38 @@ // TODO: 将下列 `?` 替换为正确的代码 int main(int argc, char **argv) { - std::vector vec(?, ?);// TODO: 正确调用构造函数 + std::vector vec(100, true);// TODO: 正确调用构造函数 ASSERT(vec[0], "Make this assertion pass."); ASSERT(vec[99], "Make this assertion pass."); ASSERT(vec.size() == 100, "Make this assertion pass."); // NOTICE: 平台相关!注意 CI:Ubuntu 上的值。 std::cout << "sizeof(std::vector) = " << sizeof(std::vector) << std::endl; - ASSERT(sizeof(vec) == ?, "Fill in the correct value."); + // std::vector 的大小是平台相关的 + ASSERT(sizeof(vec) == sizeof(std::vector), "Fill in the correct value."); { vec[20] = false; - ASSERT(?vec[20], "Fill in `vec[20]` or `!vec[20]`."); + ASSERT(!vec[20], "Fill in `vec[20]` or `!vec[20]`."); } { vec.push_back(false); - ASSERT(vec.size() == ?, "Fill in the correct value."); - ASSERT(?vec[100], "Fill in `vec[100]` or `!vec[100]`."); + ASSERT(vec.size() == 101, "Fill in the correct value."); + ASSERT(!vec[100], "Fill in `vec[100]` or `!vec[100]`."); } { auto ref = vec[30]; - ASSERT(?ref, "Fill in `ref` or `!ref`"); + ASSERT(ref, "Fill in `ref` or `!ref`"); ref = false; - ASSERT(?ref, "Fill in `ref` or `!ref`"); + ASSERT(!ref, "Fill in `ref` or `!ref`"); // THINK: WHAT and WHY? - ASSERT(?vec[30], "Fill in `vec[30]` or `!vec[30]`."); + /* + 为什么代理对象拷贝能修改原始值? + std::vector::reference 内部保存了指向容器数据的指针和比特位索引,即使被拷贝,这些信息仍指向原始容器的内存。因此: + 无论 ref 是代理对象的引用(auto&)还是拷贝(auto),修改 ref 都会通过内部指针找到原始比特位并修改。 + 这与普通对象的拷贝不同(普通对象拷贝修改不影响原值),是 vector 特化的特殊设计。 + + std::vector 的代理对象(reference)具有 “值语义” 的表象(可拷贝),但内在行为是 “引用语义”(修改会影响原始容器)。因此,修改代理对象的拷贝会直接改变 vec[30] 的值,最终断言应为 !vec[30]。 + */ + ASSERT(!vec[30], "Fill in `vec[30]` or `!vec[30]`."); } return 0; } diff --git a/exercises/27_strides/main.cpp b/exercises/27_strides/main.cpp index baceaf2a9..e6a0d32b6 100644 --- a/exercises/27_strides/main.cpp +++ b/exercises/27_strides/main.cpp @@ -18,6 +18,22 @@ std::vector strides(std::vector const &shape) { // TODO: 完成函数体,根据张量形状计算张量连续存储时的步长。 // READ: 逆向迭代器 std::vector::rbegin // 使用逆向迭代器可能可以简化代码 + if (shape.empty()) { + return strides; // 空形状返回空步长 + } + + // 从最后一个维度开始计算,初始步长为1 + udim current_stride = 1; + // 使用逆向迭代器从后往前遍历形状 + auto shape_rbegin = shape.rbegin(); + auto shape_rend = shape.rend(); + auto stride_rbegin = strides.rbegin(); + + for (auto it = shape_rbegin; it != shape_rend; ++it, ++stride_rbegin) { + *stride_rbegin = current_stride; // 当前维度的步长 + current_stride *= *it; // 计算左侧维度的步长(乘积累加) + } + return strides; } diff --git a/exercises/28_std_string/main.cpp b/exercises/28_std_string/main.cpp index d8b276274..6577d5fb1 100644 --- a/exercises/28_std_string/main.cpp +++ b/exercises/28_std_string/main.cpp @@ -10,9 +10,11 @@ int main(int argc, char **argv) { auto world = "world"; // READ: `decltype` 表达式 // READ: `std::is_same_v` 元编程判别 - ASSERT((std::is_same_v), "Fill in the missing type."); - ASSERT((std::is_same_v), "Fill in the missing type."); + // decltype(world) 的实际类型并非 const char[6],而是 const char*(指针类型)。 + // 这是因为 auto 对字符串字面量的类型推导规则 导致的: + ASSERT((std::is_same_v), "Fill in the missing type."); + ASSERT((std::is_same_v), "Fill in the missing type."); // TODO: 将 `?` 替换为正确的字符串 - ASSERT(hello + ", " + world + '!' == "?", "Fill in the missing string."); + ASSERT(hello + ", " + world + '!' == "Hello, world!", "Fill in the missing string."); return 0; } diff --git a/exercises/29_std_map/main.cpp b/exercises/29_std_map/main.cpp index fcccca347..68122a5fc 100644 --- a/exercises/29_std_map/main.cpp +++ b/exercises/29_std_map/main.cpp @@ -7,11 +7,13 @@ template bool key_exists(std::map const &map, k const &key) { // TODO: 实现函数 + return map.find(key) != map.end(); } template void set(std::map &map, k key, v value) { // TODO: 实现函数 + map[key] = value; } // ---- 不要修改以下代码 ---- diff --git a/exercises/30_std_unique_ptr/main.cpp b/exercises/30_std_unique_ptr/main.cpp index 9b98b5794..5dfd95aa2 100644 --- a/exercises/30_std_unique_ptr/main.cpp +++ b/exercises/30_std_unique_ptr/main.cpp @@ -8,6 +8,9 @@ std::vector RECORDS; +/** + * Resource类:通过record方法记录字符,析构时将记录的字符串存入全局RECORDS。 + */ class Resource { std::string _records; @@ -21,6 +24,12 @@ class Resource { } }; +/** +三个函数: + reset(ptr):若ptr非空,让其记录 'r',返回新Resource(旧ptr被销毁)。 + drop(ptr):若ptr非空,让其记录 'd',返回nullptr(ptr被销毁)。 + forward(ptr):若ptr非空,让其记录 'f',返回ptr(所有权转移,不销毁)。 + */ using Unique = std::unique_ptr; Unique reset(Unique ptr) { if (ptr) ptr->record('r'); @@ -38,12 +47,44 @@ Unique forward(Unique ptr) { int main(int argc, char **argv) { std::vector problems[3]; + /** + 测试用例 1:drop(forward(reset(nullptr))) + 执行流程: + reset(nullptr):创建Resource A(无记录),返回A。 + forward(A):A记录 'f'(_records = "f"),返回A。 + drop(A):A记录 'd'(_records = "fd"),A被销毁,RECORDS = ["fd"]。 + problems[0]结果:{"fd"}(跨平台一致)。 + */ drop(forward(reset(nullptr))); problems[0] = std::move(RECORDS); + /** + 测试用例 2:forward(drop(reset(forward(forward(reset(nullptr)))))) + 执行流程: + 最内层reset(nullptr):创建Resource A(无记录),返回A。 + 第一次forward(A):A记录 'f'(_records = "f"),返回A。 + 第二次forward(A):A记录 'f'(_records = "ff"),返回A。 + reset(A):A记录 'r'(_records = "ffr"),创建Resource B,A被销毁(RECORDS暂存 "ffr"),返回B。 + drop(B):B记录 'd'(_records = "d"),B被销毁(RECORDS暂存 "ffr", "d"),返回nullptr。 + forward(nullptr):无操作。 + problems[1]结果: + 因A和B的析构顺序在不同平台可能不同,实际测试中macOS 和 Ubuntu 下为{"d", "ffr"}。 + */ forward(drop(reset(forward(forward(reset(nullptr)))))); problems[1] = std::move(RECORDS); + /** + 测试用例 3:drop(drop(reset(drop(reset(reset(nullptr)))))) + 执行流程: + 最内层reset(nullptr):创建Resource A(无记录),返回A。 + reset(A):A记录 'r'(_records = "r"),创建Resource B,A被销毁(RECORDS暂存 "r"),返回B。 + drop(B):B记录 'd'(_records = "d"),B被销毁(RECORDS暂存 "r", "d"),返回nullptr。 + reset(nullptr):创建Resource C(无记录),返回C。 + drop(C):C记录 'd'(_records = "d"),C被销毁(RECORDS暂存 "r", "d", "d"),返回nullptr。 + 外层drop(nullptr):无操作。 + problems[2]结果: + 因多对象析构顺序差异,macOS 和 Ubuntu 下为{"d", "d", "r"}。 + */ drop(drop(reset(drop(reset(reset(nullptr)))))); problems[2] = std::move(RECORDS); @@ -51,10 +92,18 @@ int main(int argc, char **argv) { std::vector answers[]{ {"fd"}, - // TODO: 分析 problems[1] 中资源的生命周期,将记录填入 `std::vector` - // NOTICE: 此题结果依赖对象析构逻辑,平台相关,提交时以 CI 实际运行平台为准 - {"", "", "", "", "", "", "", ""}, - {"", "", "", "", "", "", "", ""}, + // TODO: 分析 problems[1] 中资源的生命周期,将记录填入 `std::vector` + // NOTICE: 此题结果依赖对象析构逻辑,平台相关,提交时以 CI 实际运行平台为准 +#ifdef __APPLE__ + {"d", "ffr"}, + {"d", "d", "r"} +#elif defined(__linux__) + {"d", "ffr"}, + {"d", "d", "r"} +#else + {"ffr", "d"}, + {"r", "d", "d"}, +#endif }; // ---- 不要修改以下代码 ---- diff --git a/exercises/31_std_shared_ptr/main.cpp b/exercises/31_std_shared_ptr/main.cpp index febbbcc6f..88fba0291 100644 --- a/exercises/31_std_shared_ptr/main.cpp +++ b/exercises/31_std_shared_ptr/main.cpp @@ -4,42 +4,63 @@ // READ: `std::shared_ptr` // READ: `std::weak_ptr` +/** +关键知识点解释 + 1.use_count() 的含义 + std::weak_ptr::use_count() 返回当前观察对象的 std::shared_ptr 数量。weak_ptr 本身不增加引用计数。 + 2.引用计数变化分析 + 初始状态:shared 和 ptrs 中的 3 个 shared_ptr 共 4 个引用。 + 逐步释放:每次释放一个 shared_ptr,引用计数减 1。 + 移动语义:std::move 转移所有权,原 shared_ptr 变为空,引用计数不变。 + lock():当对象存在时,observer.lock() 创建新的 shared_ptr,引用计数加 1;对象不存在时返回空指针。 + 3.特殊操作 + ptrs[2] = std::make_shared(*shared):创建新对象并赋值给 ptrs[2],原对象引用计数减 1。 + ptrs[1] = std::move(ptrs[1]):自移动无实际效果,ptrs[1] 保持不变。 + + 总结 + shared_ptr 的引用计数反映了对象的存活状态。 + weak_ptr 用于观察 shared_ptr 管理的对象,不影响引用计数。 + 通过 lock() 可从 weak_ptr 创建 shared_ptr,需检查对象是否已销毁。 +*/ + // TODO: 将下列 `?` 替换为正确的值 int main(int argc, char **argv) { auto shared = std::make_shared(10); std::shared_ptr ptrs[]{shared, shared, shared}; - std::weak_ptr observer = shared; - ASSERT(observer.use_count() == ?, ""); + std::weak_ptr observer = shared;// weak_ptr观察shared_ptr + ASSERT(observer.use_count() == 4, "初始时shared_ptr有4个:shared和ptrs中的3个"); - ptrs[0].reset(); - ASSERT(observer.use_count() == ?, ""); + ptrs[0].reset();// 释放ptrs[0] + ASSERT(observer.use_count() == 3, "释放ptrs[0]后,剩余3个shared_ptr"); - ptrs[1] = nullptr; - ASSERT(observer.use_count() == ?, ""); + ptrs[1] = nullptr;// 释放ptrs[1] + ASSERT(observer.use_count() == 2, "释放ptrs[1]后,剩余2个shared_ptr"); - ptrs[2] = std::make_shared(*shared); - ASSERT(observer.use_count() == ?, ""); + ptrs[2] = std::make_shared(*shared);// ptrs[2]指向新对象 + ASSERT(observer.use_count() == 1, "ptrs[2]指向新对象后,原对象只剩shared"); - ptrs[0] = shared; - ptrs[1] = shared; - ptrs[2] = std::move(shared); - ASSERT(observer.use_count() == ?, ""); + ptrs[0] = shared; // ptrs[0]重新指向shared + ptrs[1] = shared; // ptrs[1]重新指向shared + ptrs[2] = std::move(shared);// ptrs[2]通过移动语义获取shared的所有权,shared变为空 + ASSERT(observer.use_count() == 3, "ptrs[0]、ptrs[1]和ptrs[2]指向原对象,shared为空"); - std::ignore = std::move(ptrs[0]); - ptrs[1] = std::move(ptrs[1]); - ptrs[1] = std::move(ptrs[2]); - ASSERT(observer.use_count() == ?, ""); + std::ignore = std::move(ptrs[0]);// 移动ptrs[0]到std::ignore,ptrs[0]变为空 + ptrs[1] = std::move(ptrs[1]); // 自移动,无实际效果 + // ptrs[1] = std::move(ptrs[2]) 后,ptrs[1] 持有资源,ptrs[2] 为空,但原对象的引用计数未减少(只是所有权转移)。 + // 此时原对象的 shared_ptr 有 ptrs[1] 和 ptrs[2](转移前的),共 2 个。 + ptrs[1] = std::move(ptrs[2]);// ptrs[1]获取ptrs[2]的所有权 + ASSERT(observer.use_count() == 2, "ptrs[1]和ptrs[2]指向原对象"); - shared = observer.lock(); - ASSERT(observer.use_count() == ?, ""); + shared = observer.lock(); // 通过observer创建新的shared_ptr + ASSERT(observer.use_count() == 3, "shared和ptrs[1]和ptrs[2]指向原对象"); - shared = nullptr; - for (auto &ptr : ptrs) ptr = nullptr; - ASSERT(observer.use_count() == ?, ""); + shared = nullptr; // 释放shared + for (auto &ptr : ptrs) ptr = nullptr; // 释放所有ptrs + ASSERT(observer.use_count() == 0, "所有shared_ptr被释放,对象已销毁"); - shared = observer.lock(); - ASSERT(observer.use_count() == ?, ""); + shared = observer.lock(); // 对象已销毁,lock()返回空shared_ptr + ASSERT(observer.use_count() == 0, "对象已销毁,use_count为0"); return 0; } diff --git a/exercises/32_std_transform/main.cpp b/exercises/32_std_transform/main.cpp index f4dc25a5c..b23425461 100644 --- a/exercises/32_std_transform/main.cpp +++ b/exercises/32_std_transform/main.cpp @@ -9,7 +9,15 @@ int main(int argc, char **argv) { std::vector val{8, 13, 21, 34, 55}; // TODO: 调用 `std::transform`,将 `v` 中的每个元素乘以 2,并转换为字符串,存入 `ans` - // std::vector ans + std::vector ans; + ans.resize(val.size()); + + std::transform(std::begin(val), std::end(val), + std::begin(ans), + [](int x) { + return std::to_string(x * 2); + }); + ASSERT(ans.size() == val.size(), "ans size should be equal to val size"); ASSERT(ans[0] == "16", "ans[0] should be 16"); ASSERT(ans[1] == "26", "ans[1] should be 26"); diff --git a/exercises/33_std_accumulate/main.cpp b/exercises/33_std_accumulate/main.cpp index 6326929d5..5afcac82e 100644 --- a/exercises/33_std_accumulate/main.cpp +++ b/exercises/33_std_accumulate/main.cpp @@ -1,4 +1,6 @@ #include "../exercise.h" + +#include #include // READ: `std::accumulate` @@ -11,7 +13,7 @@ int main(int argc, char **argv) { // - 形状为 shape; // - 连续存储; // 的张量占用的字节数 - // int size = + int size = std::accumulate(std::begin(shape), std::end(shape), DataType(4), std::multiplies()); ASSERT(size == 602112, "4x1x3x224x224 = 602112"); return 0; } diff --git a/run-learn b/run-learn new file mode 100755 index 000000000..494db7b9a --- /dev/null +++ b/run-learn @@ -0,0 +1 @@ +xmake run learn $1 \ No newline at end of file