From 9c7cb378fd93374d0aa3df4c4322d4f17a27c94e Mon Sep 17 00:00:00 2001
From: Ting-Wei Lan <lantw44@gmail.com>
Date: Sat, 2 Jan 2016 20:58:17 +0800
Subject: Generate code for return statements

---
 src/code-generation.c | 54 ++++++++++++++++++++++++++++++++++++---------------
 1 file changed, 38 insertions(+), 16 deletions(-)

(limited to 'src/code-generation.c')

diff --git a/src/code-generation.c b/src/code-generation.c
index b90d00b..a2e90ef 100644
--- a/src/code-generation.c
+++ b/src/code-generation.c
@@ -440,43 +440,48 @@ static void generate_expression(CcmmcAst *expr, CcmmcState *state,
     assert(false);
 }
 
-static void calc_and_save_expression_result(CcmmcAst *lvar, CcmmcAst *expr,
-    CcmmcState *state, uint64_t current_offset)
+static void calc_expression_result(CcmmcAst *expr, CcmmcAstValueType type,
+    CcmmcState *state, uint64_t current_offset, const char *result)
 {
 #define FPREG_TMP  "s16"
     CcmmcTmp *tmp1 = ccmmc_register_alloc(state->reg_pool, &current_offset);
     CcmmcTmp *tmp2 = ccmmc_register_alloc(state->reg_pool, &current_offset);
-    CcmmcTmp *tmp3 = ccmmc_register_alloc(state->reg_pool, &current_offset);
-    const char *result = ccmmc_register_lock(state->reg_pool, tmp1);
-    const char *op1 = ccmmc_register_lock(state->reg_pool, tmp2);
-    const char *op2 = ccmmc_register_lock(state->reg_pool, tmp3);
+    const char *op1 = ccmmc_register_lock(state->reg_pool, tmp1);
+    const char *op2 = ccmmc_register_lock(state->reg_pool, tmp2);
     generate_expression(expr, state, result, op1, op2);
-    if (lvar->type_value == CCMMC_AST_VALUE_INT &&
-        expr->type_value == CCMMC_AST_VALUE_FLOAT) {
+    if (expr->type_value == CCMMC_AST_VALUE_FLOAT &&
+        type == CCMMC_AST_VALUE_INT) {
         fprintf(state->asm_output,
             "\tfmov\t%s, %s\n"
             "\tfcvtas\t%s, %s\n",
             FPREG_TMP, result,
             result, FPREG_TMP);
-    } else if (
-        lvar->type_value == CCMMC_AST_VALUE_FLOAT &&
-        expr->type_value == CCMMC_AST_VALUE_INT) {
+    } else if (expr->type_value == CCMMC_AST_VALUE_INT &&
+        type == CCMMC_AST_VALUE_FLOAT) {
         fprintf(state->asm_output,
             "\tscvtf\t%s, %s\n"
             "\tfmov\t%s, %s\n",
             FPREG_TMP, result,
             result, FPREG_TMP);
     }
-    store_variable(lvar, state, result);
     ccmmc_register_unlock(state->reg_pool, tmp1);
     ccmmc_register_unlock(state->reg_pool, tmp2);
-    ccmmc_register_unlock(state->reg_pool, tmp3);
     ccmmc_register_free(state->reg_pool, tmp1, &current_offset);
     ccmmc_register_free(state->reg_pool, tmp2, &current_offset);
-    ccmmc_register_free(state->reg_pool, tmp3, &current_offset);
 #undef FPREG_TMP
 }
 
+static void calc_and_save_expression_result(CcmmcAst *lvar, CcmmcAst *expr,
+    CcmmcState *state, uint64_t current_offset)
+{
+    CcmmcTmp *tmp = ccmmc_register_alloc(state->reg_pool, &current_offset);
+    const char *result = ccmmc_register_lock(state->reg_pool, tmp);
+    calc_expression_result(expr, lvar->type_value, state, current_offset, result);
+    store_variable(lvar, state, result);
+    ccmmc_register_unlock(state->reg_pool, tmp);
+    ccmmc_register_free(state->reg_pool, tmp, &current_offset);
+}
+
 static void generate_block(
     CcmmcAst *block, CcmmcState *state, uint64_t current_offset);
 static void generate_statement(
@@ -529,10 +534,10 @@ static void generate_statement(
             } break;
         case CCMMC_KIND_STMT_FOR:
             break;
-        case CCMMC_KIND_STMT_ASSIGN: {
+        case CCMMC_KIND_STMT_ASSIGN:
             calc_and_save_expression_result(stmt->child,
                 stmt->child->right_sibling, state, current_offset);
-            } break;
+            break;
         case CCMMC_KIND_STMT_IF:
             generate_statement(stmt->child->right_sibling,
                 state, current_offset);
@@ -547,6 +552,21 @@ static void generate_statement(
             ccmmc_register_caller_load(state->reg_pool);
             break;
         case CCMMC_KIND_STMT_RETURN:
+            for (CcmmcAst *func = stmt->parent; ; func = func->parent) {
+                if (func->type_node == CCMMC_AST_NODE_DECL &&
+                    func->value_decl.kind == CCMMC_KIND_DECL_FUNCTION) {
+                    const char *func_name = func->child->right_sibling->value_id.name;
+                    CcmmcSymbol *func_sym =
+                        ccmmc_symbol_table_retrieve(state->table, func_name);
+                    CcmmcAstValueType func_type = func_sym->type.type_base;
+                    calc_expression_result(stmt->child, func_sym->type.type_base,
+                        state, current_offset, "w0");
+                    if (func_type == CCMMC_AST_VALUE_FLOAT)
+                        fputs("\tfmov\ts0, w0\n", state->asm_output);
+                    fprintf(state->asm_output, "\tb\t.LR_%s\n", func_name);
+                    break;
+                }
+            }
             break;
         default:
             assert(false);
@@ -655,10 +675,12 @@ static void generate_function(CcmmcAst *function, CcmmcState *state)
     CcmmcAst *block_node = param_node->right_sibling;
     generate_block(block_node, state, 0);
     fprintf(state->asm_output,
+        ".LR_%s:\n"
         "\tldp\tlr, fp, [sp], 16\n"
         "\tret\tlr\n"
         "\t.size\t%s, .-%s\n",
         function->child->right_sibling->value_id.name,
+        function->child->right_sibling->value_id.name,
         function->child->right_sibling->value_id.name);
 }
 
-- 
cgit