@@ -3,8 +3,8 @@ use std::{collections::HashSet, mem::discriminant};
3
3
use plc_ast:: control_statements:: ForLoopStatement ;
4
4
use plc_ast:: {
5
5
ast:: {
6
- flatten_expression_list, AstNode , AstStatement , DirectAccess , DirectAccessType , JumpStatement ,
7
- Operator , ReferenceAccess ,
6
+ flatten_expression_list, AstNode , AstStatement , BinaryExpression , CallStatement , DirectAccess ,
7
+ DirectAccessType , JumpStatement , Operator , ReferenceAccess , UnaryExpression ,
8
8
} ,
9
9
control_statements:: { AstControlStatement , ConditionalBlock } ,
10
10
literals:: { Array , AstLiteral , StringValue } ,
@@ -833,9 +833,8 @@ fn validate_assignment<T: AnnotationMap>(
833
833
location. clone ( ) ,
834
834
) ) ;
835
835
}
836
- } else if right. is_literal ( ) {
837
- // TODO: See https://github.com/PLC-lang/rusty/issues/857
838
- // validate_assignment_type_sizes(validator, left_type, right_type, location, context)
836
+ } else {
837
+ validate_assignment_type_sizes ( validator, left_type, right, context)
839
838
}
840
839
}
841
840
}
@@ -1343,26 +1342,106 @@ fn validate_type_nature<T: AnnotationMap>(
1343
1342
}
1344
1343
}
1345
1344
1346
- fn _validate_assignment_type_sizes < T : AnnotationMap > (
1345
+ fn validate_assignment_type_sizes < T : AnnotationMap > (
1347
1346
validator : & mut Validator ,
1348
1347
left : & DataType ,
1349
- right : & DataType ,
1350
- location : & SourceLocation ,
1348
+ right : & AstNode ,
1351
1349
context : & ValidationContext < T > ,
1352
1350
) {
1353
- if left. get_type_information ( ) . get_size ( context. index )
1354
- < right. get_type_information ( ) . get_size ( context. index )
1355
- {
1356
- validator. push_diagnostic (
1357
- Diagnostic :: new ( format ! (
1358
- "Potential loss of information due to assigning '{}' to variable of type '{}'." ,
1359
- left. get_name( ) ,
1360
- right. get_name( )
1361
- ) )
1362
- . with_error_code ( "E067" )
1363
- . with_location ( location. clone ( ) ) ,
1364
- )
1351
+ use std:: collections:: HashMap ;
1352
+ fn get_expression_types_and_locations < ' b , T : AnnotationMap > (
1353
+ expression : & AstNode ,
1354
+ context : & ' b ValidationContext < T > ,
1355
+ lhs_is_signed_int : bool ,
1356
+ is_builtin_call : bool ,
1357
+ ) -> HashMap < & ' b DataType , Vec < SourceLocation > > {
1358
+ let mut map: HashMap < & DataType , Vec < SourceLocation > > = HashMap :: new ( ) ;
1359
+ match expression. get_stmt_peeled ( ) {
1360
+ AstStatement :: BinaryExpression ( BinaryExpression { operator, left, right, .. } )
1361
+ if !operator. is_comparison_operator ( ) =>
1362
+ {
1363
+ get_expression_types_and_locations ( left, context, lhs_is_signed_int, false )
1364
+ . into_iter ( )
1365
+ . for_each ( |( k, v) | map. entry ( k) . or_default ( ) . extend ( v) ) ;
1366
+ // the RHS type in a MOD expression has no impact on the resulting value type
1367
+ if matches ! ( operator, Operator :: Modulo ) {
1368
+ return map
1369
+ } ;
1370
+ get_expression_types_and_locations ( right, context, lhs_is_signed_int, false )
1371
+ . into_iter ( )
1372
+ . for_each ( |( k, v) | map. entry ( k) . or_default ( ) . extend ( v) ) ;
1373
+ }
1374
+ AstStatement :: UnaryExpression ( UnaryExpression { operator, value } )
1375
+ if !operator. is_comparison_operator ( ) =>
1376
+ {
1377
+ get_expression_types_and_locations ( value, context, lhs_is_signed_int, false )
1378
+ . into_iter ( )
1379
+ . for_each ( |( k, v) | map. entry ( k) . or_default ( ) . extend ( v) ) ;
1380
+ }
1381
+ // `get_literal_actual_signed_type_name` will always return `LREAL` for FP literals, so they will be handled by the fall-through case according to their annotated type
1382
+ AstStatement :: Literal ( lit) if !matches ! ( lit, & AstLiteral :: Real ( _) ) => {
1383
+ if !lit. is_numerical ( ) {
1384
+ return map
1385
+ }
1386
+ if let Some ( dt) = get_literal_actual_signed_type_name ( lit, lhs_is_signed_int)
1387
+ . map ( |name| context. index . get_type ( name) . unwrap_or ( context. index . get_void_type ( ) ) )
1388
+ {
1389
+ map. entry ( dt) . or_default ( ) . push ( expression. get_location ( ) ) ;
1390
+ }
1391
+ }
1392
+ AstStatement :: CallStatement ( CallStatement { operator, parameters } )
1393
+ // special handling for builtin selector functions MUX and SEL
1394
+ if matches ! ( operator. get_flat_reference_name( ) . unwrap_or_default( ) , "MUX" | "SEL" ) =>
1395
+ {
1396
+ let Some ( args) = parameters else {
1397
+ return map
1398
+ } ;
1399
+ if let AstStatement :: ExpressionList ( list) = args. get_stmt_peeled ( ) {
1400
+ // skip the selector argument since it will never be assigned to the target type
1401
+ list. iter ( ) . skip ( 1 ) . flat_map ( |arg| {
1402
+ get_expression_types_and_locations ( arg, context, lhs_is_signed_int, true )
1403
+ } )
1404
+ . for_each ( |( k, v) | map. entry ( k) . or_default ( ) . extend ( v) ) ;
1405
+ } ;
1406
+ }
1407
+ _ => {
1408
+ if !( context. annotations . get_generic_nature ( expression) . is_none ( ) || is_builtin_call) {
1409
+ return map
1410
+ } ;
1411
+ if let Some ( dt) = context. annotations . get_type ( expression, context. index ) {
1412
+ map. entry ( dt) . or_default ( ) . push ( expression. get_location ( ) ) ;
1413
+ }
1414
+ }
1415
+ } ;
1416
+ map
1365
1417
}
1418
+
1419
+ let lhs = left. get_type_information ( ) ;
1420
+ let lhs_size = lhs. get_size ( context. index ) ;
1421
+ let results_in_truncation = |rhs : & DataType | {
1422
+ let rhs = rhs. get_type_information ( ) ;
1423
+ let rhs_size = rhs. get_size ( context. index ) ;
1424
+ lhs_size < rhs_size
1425
+ || ( lhs_size == rhs_size
1426
+ && ( ( lhs. is_signed_int ( ) && rhs. is_unsigned_int ( ) ) || ( lhs. is_int ( ) && rhs. is_float ( ) ) ) )
1427
+ } ;
1428
+
1429
+ get_expression_types_and_locations ( right, context, lhs. is_signed_int ( ) , false )
1430
+ . into_iter ( )
1431
+ . filter ( |( dt, _) | !dt. is_aggregate_type ( ) && results_in_truncation ( dt) )
1432
+ . for_each ( |( dt, location) | {
1433
+ location. into_iter ( ) . for_each ( |loc| {
1434
+ validator. push_diagnostic (
1435
+ Diagnostic :: new ( format ! (
1436
+ "Implicit downcast from '{}' to '{}'." ,
1437
+ get_datatype_name_or_slice( validator. context, dt) ,
1438
+ get_datatype_name_or_slice( validator. context, left)
1439
+ ) )
1440
+ . with_error_code ( "E067" )
1441
+ . with_location ( loc) ,
1442
+ ) ;
1443
+ } )
1444
+ } ) ;
1366
1445
}
1367
1446
1368
1447
mod helper {
0 commit comments